Skip to main content

rlx_optim/
mars.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! MARS — Make vAriance Reduction Shine (Yuan, Liu, Wu, Su, Gu, 2024).
17//!
18//! # Idea
19//!
20//! Variance reduction (SVRG, SARAH) lowers gradient noise by mixing in
21//! a *previous* gradient — at the cost of an extra forward/backward
22//! pass per snapshot. MARS shows that you don't need a snapshot:
23//! using just `g_{t−1}` (the previous mini-batch's gradient) as the
24//! "control variate" gives most of the benefit of full variance
25//! reduction, for free.
26//!
27//! # Update rule
28//!
29//! ```text
30//! c_t = g_t + γ · β₁/(1−β₁) · (g_t − g_{t−1})      // VR-corrected grad
31//! [optional: clip c_t to unit norm per tensor]
32//! m_t = β₁·m_{t-1} + (1 − β₁)·c_t
33//! v_t = β₂·v_{t-1} + (1 − β₂)·c_t²
34//! θ_t = θ_{t-1} − lr · ( m̂_t/(√v̂_t + ε) + λ·θ_{t-1} )    // AdamW-style
35//! ```
36//!
37//! γ = 0 collapses MARS to AdamW. γ = 1 is the "full" Yuan et al.
38//! prescription; the recommended sweet spot is `γ ≈ 0.025`.
39//!
40//! # When to use
41//!
42//! Anywhere variance-reduced SGD/Adam variants would help — noisy
43//! gradients, small batches, RL-style on-policy training. State cost
44//! per parameter: three buffers (`m`, `v`, previous-gradient cache).
45
46use std::collections::HashMap;
47
48use crate::Optimizer;
49use crate::common::zeros_entry;
50
51/// MARS — variance-reduced AdamW. Per-tensor state: three `f32`
52/// buffers (`m`, `v`, previous-gradient cache).
53#[derive(Debug, Clone)]
54pub struct Mars {
55    /// Learning rate.
56    pub lr: f32,
57    /// First-moment EMA decay β₁. Default `0.95`.
58    pub beta1: f32,
59    /// Second-moment EMA decay β₂. Default `0.99`.
60    pub beta2: f32,
61    /// Denominator stability constant. Default `1e-8`.
62    pub eps: f32,
63    /// Decoupled weight-decay coefficient λ. Default `0.0`.
64    pub weight_decay: f32,
65    /// Variance-reduction strength γ (Yuan et al. eq. 7). `0.0`
66    /// collapses MARS to plain AdamW; `1.0` is the full prescription.
67    /// Default `0.025`.
68    pub gamma: f32,
69    /// If `true`, clip the variance-reduced surrogate `c_t` to unit
70    /// norm per tensor (matches the "MARS-AdamW" recipe and keeps
71    /// the VR kick from exploding when `g_{t-1}` is unrelated noise
72    /// on early steps).
73    pub clip_c: bool,
74    step: u64,
75    m: HashMap<String, Vec<f32>>,
76    v: HashMap<String, Vec<f32>>,
77    prev_g: HashMap<String, Vec<f32>>,
78    /// Reusable scratch for the variance-reduced surrogate `c_t`.
79    scratch: HashMap<String, Vec<f32>>,
80}
81
82impl Mars {
83    /// Construct with `(β₁, β₂, ε, λ, γ, clip_c) = (0.95, 0.99, 1e-8, 0.0, 0.025, true)`.
84    pub fn new(lr: f32) -> Self {
85        Self {
86            lr,
87            beta1: 0.95,
88            beta2: 0.99,
89            eps: 1e-8,
90            weight_decay: 0.0,
91            gamma: 0.025,
92            clip_c: true,
93            step: 0,
94            m: HashMap::new(),
95            v: HashMap::new(),
96            prev_g: HashMap::new(),
97            scratch: HashMap::new(),
98        }
99    }
100
101    /// Override (β₁, β₂).
102    pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
103        self.beta1 = b1;
104        self.beta2 = b2;
105        self
106    }
107
108    /// Override the decoupled-decay coefficient.
109    pub fn with_weight_decay(mut self, wd: f32) -> Self {
110        self.weight_decay = wd;
111        self
112    }
113}
114
115impl Optimizer for Mars {
116    fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
117        debug_assert_eq!(param.len(), grad.len());
118        let t = (self.step + 1) as f64;
119        let b1 = self.beta1 as f64;
120        let b2 = self.beta2 as f64;
121        let bc1 = 1.0 - b1.powf(t);
122        let bc2 = 1.0 - b2.powf(t);
123        let scale = self.gamma as f64 * b1 / (1.0 - b1);
124        let eps = self.eps as f64;
125        let lr = self.lr as f64;
126        let wd = self.weight_decay as f64;
127        // Four distinct fields ⇒ borrows can coexist.
128        let prev = zeros_entry(&mut self.prev_g, name, param.len());
129        let c = zeros_entry(&mut self.scratch, name, param.len());
130        let m = zeros_entry(&mut self.m, name, param.len());
131        let v = zeros_entry(&mut self.v, name, param.len());
132        // c_t = g_t + scale * (g_t - g_{t-1})
133        let mut c_sq_norm = 0.0f64;
134        for i in 0..param.len() {
135            let g = grad[i] as f64;
136            let pg = prev[i] as f64;
137            let ci = g + scale * (g - pg);
138            c[i] = ci as f32;
139            c_sq_norm += ci * ci;
140            prev[i] = grad[i];
141        }
142        // Optional per-tensor norm clip on c (keeps the VR kick from
143        // exploding on early steps when g_{t-1} is unrelated noise).
144        if self.clip_c && c_sq_norm > 1.0 {
145            let s = (1.0 / c_sq_norm.sqrt()) as f32;
146            for ci in c.iter_mut() {
147                *ci *= s;
148            }
149        }
150        for i in 0..param.len() {
151            let ci = c[i] as f64;
152            let mi = b1 * m[i] as f64 + (1.0 - b1) * ci;
153            let vi = b2 * v[i] as f64 + (1.0 - b2) * ci * ci;
154            m[i] = mi as f32;
155            v[i] = vi as f32;
156            let m_hat = mi / bc1;
157            let v_hat = vi / bc2;
158            let p = param[i] as f64;
159            param[i] = (p - lr * (m_hat / (v_hat.sqrt() + eps) + wd * p)) as f32;
160        }
161    }
162
163    fn end_iteration(&mut self) {
164        self.step += 1;
165    }
166}