rlx_optim/mars.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! MARS — Make vAriance Reduction Shine (Yuan, Liu, Wu, Su, Gu, 2024).
17//!
18//! # Idea
19//!
20//! Variance reduction (SVRG, SARAH) lowers gradient noise by mixing in
21//! a *previous* gradient — at the cost of an extra forward/backward
22//! pass per snapshot. MARS shows that you don't need a snapshot:
23//! using just `g_{t−1}` (the previous mini-batch's gradient) as the
24//! "control variate" gives most of the benefit of full variance
25//! reduction, for free.
26//!
27//! # Update rule
28//!
29//! ```text
30//! c_t = g_t + γ · β₁/(1−β₁) · (g_t − g_{t−1}) // VR-corrected grad
31//! [optional: clip c_t to unit norm per tensor]
32//! m_t = β₁·m_{t-1} + (1 − β₁)·c_t
33//! v_t = β₂·v_{t-1} + (1 − β₂)·c_t²
34//! θ_t = θ_{t-1} − lr · ( m̂_t/(√v̂_t + ε) + λ·θ_{t-1} ) // AdamW-style
35//! ```
36//!
37//! γ = 0 collapses MARS to AdamW. γ = 1 is the "full" Yuan et al.
38//! prescription; the recommended sweet spot is `γ ≈ 0.025`.
39//!
40//! # When to use
41//!
42//! Anywhere variance-reduced SGD/Adam variants would help — noisy
43//! gradients, small batches, RL-style on-policy training. State cost
44//! per parameter: three buffers (`m`, `v`, previous-gradient cache).
45
46use std::collections::HashMap;
47
48use crate::Optimizer;
49use crate::common::zeros_entry;
50
51/// MARS — variance-reduced AdamW. Per-tensor state: three `f32`
52/// buffers (`m`, `v`, previous-gradient cache).
53#[derive(Debug, Clone)]
54pub struct Mars {
55 /// Learning rate.
56 pub lr: f32,
57 /// First-moment EMA decay β₁. Default `0.95`.
58 pub beta1: f32,
59 /// Second-moment EMA decay β₂. Default `0.99`.
60 pub beta2: f32,
61 /// Denominator stability constant. Default `1e-8`.
62 pub eps: f32,
63 /// Decoupled weight-decay coefficient λ. Default `0.0`.
64 pub weight_decay: f32,
65 /// Variance-reduction strength γ (Yuan et al. eq. 7). `0.0`
66 /// collapses MARS to plain AdamW; `1.0` is the full prescription.
67 /// Default `0.025`.
68 pub gamma: f32,
69 /// If `true`, clip the variance-reduced surrogate `c_t` to unit
70 /// norm per tensor (matches the "MARS-AdamW" recipe and keeps
71 /// the VR kick from exploding when `g_{t-1}` is unrelated noise
72 /// on early steps).
73 pub clip_c: bool,
74 step: u64,
75 m: HashMap<String, Vec<f32>>,
76 v: HashMap<String, Vec<f32>>,
77 prev_g: HashMap<String, Vec<f32>>,
78 /// Reusable scratch for the variance-reduced surrogate `c_t`.
79 scratch: HashMap<String, Vec<f32>>,
80}
81
82impl Mars {
83 /// Construct with `(β₁, β₂, ε, λ, γ, clip_c) = (0.95, 0.99, 1e-8, 0.0, 0.025, true)`.
84 pub fn new(lr: f32) -> Self {
85 Self {
86 lr,
87 beta1: 0.95,
88 beta2: 0.99,
89 eps: 1e-8,
90 weight_decay: 0.0,
91 gamma: 0.025,
92 clip_c: true,
93 step: 0,
94 m: HashMap::new(),
95 v: HashMap::new(),
96 prev_g: HashMap::new(),
97 scratch: HashMap::new(),
98 }
99 }
100
101 /// Override (β₁, β₂).
102 pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
103 self.beta1 = b1;
104 self.beta2 = b2;
105 self
106 }
107
108 /// Override the decoupled-decay coefficient.
109 pub fn with_weight_decay(mut self, wd: f32) -> Self {
110 self.weight_decay = wd;
111 self
112 }
113}
114
115impl Optimizer for Mars {
116 fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
117 debug_assert_eq!(param.len(), grad.len());
118 let t = (self.step + 1) as f64;
119 let b1 = self.beta1 as f64;
120 let b2 = self.beta2 as f64;
121 let bc1 = 1.0 - b1.powf(t);
122 let bc2 = 1.0 - b2.powf(t);
123 let scale = self.gamma as f64 * b1 / (1.0 - b1);
124 let eps = self.eps as f64;
125 let lr = self.lr as f64;
126 let wd = self.weight_decay as f64;
127 // Four distinct fields ⇒ borrows can coexist.
128 let prev = zeros_entry(&mut self.prev_g, name, param.len());
129 let c = zeros_entry(&mut self.scratch, name, param.len());
130 let m = zeros_entry(&mut self.m, name, param.len());
131 let v = zeros_entry(&mut self.v, name, param.len());
132 // c_t = g_t + scale * (g_t - g_{t-1})
133 let mut c_sq_norm = 0.0f64;
134 for i in 0..param.len() {
135 let g = grad[i] as f64;
136 let pg = prev[i] as f64;
137 let ci = g + scale * (g - pg);
138 c[i] = ci as f32;
139 c_sq_norm += ci * ci;
140 prev[i] = grad[i];
141 }
142 // Optional per-tensor norm clip on c (keeps the VR kick from
143 // exploding on early steps when g_{t-1} is unrelated noise).
144 if self.clip_c && c_sq_norm > 1.0 {
145 let s = (1.0 / c_sq_norm.sqrt()) as f32;
146 for ci in c.iter_mut() {
147 *ci *= s;
148 }
149 }
150 for i in 0..param.len() {
151 let ci = c[i] as f64;
152 let mi = b1 * m[i] as f64 + (1.0 - b1) * ci;
153 let vi = b2 * v[i] as f64 + (1.0 - b2) * ci * ci;
154 m[i] = mi as f32;
155 v[i] = vi as f32;
156 let m_hat = mi / bc1;
157 let v_hat = vi / bc2;
158 let p = param[i] as f64;
159 param[i] = (p - lr * (m_hat / (v_hat.sqrt() + eps) + wd * p)) as f32;
160 }
161 }
162
163 fn end_iteration(&mut self) {
164 self.step += 1;
165 }
166}