Skip to main content

rlx_optim/
adafactor.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Adafactor (Shazeer & Stern, 2018, "Adafactor: Adaptive Learning
17//! Rates with Sublinear Memory Cost").
18//!
19//! # Idea
20//!
21//! Adam's `v_t` is the same shape as θ — for a 70B-parameter model
22//! that's 280 GB of optimizer state. Adafactor *factorizes* the
23//! second-moment matrix: for a 2-D parameter of shape `m × n`, instead
24//! of an `m·n` buffer it stores a row-statistic `R ∈ ℝᵐ` and a
25//! column-statistic `C ∈ ℝⁿ`, then reconstructs
26//! `V̂_{ij} ≈ R_i · C_j / Σ_k R_k`. State drops from `O(m·n)` to
27//! `O(m + n)`.
28//!
29//! # Update rule (this impl: factored 2nd-moment, no 1st-moment)
30//!
31//! Let `β₂_t = 1 − t^{decay_rate}` (default decay_rate = −0.8). For a
32//! 2-D parameter:
33//!
34//! ```text
35//! R_i = β₂_t·R_i + (1−β₂_t)·mean_j(g_ij² + ε₁)
36//! C_j = β₂_t·C_j + (1−β₂_t)·mean_i(g_ij² + ε₁)
37//! V̂_{ij} = R_i · C_j / Σ_k R_k
38//! u_{ij}  = g_{ij} / √V̂_{ij}
39//! u ← u / max(1, RMS(u) / clip_threshold)        // RMS-of-update clip
40//! lr_t    = manual_lr OR  min(1/√t, 1e-2) · max(ε₂, RMS(θ))   // relative step
41//! θ_t     = θ_{t-1} − lr_t · ( u + λ·θ_{t-1} )
42//! ```
43//!
44//! For non-2-D parameters (bias vectors, 4-D conv weights) we fall
45//! back to a full per-element EMA — the savings are negligible there
46//! anyway. The optional first-moment EMA is **not** implemented
47//! (matches the recommended T5 configuration).
48//!
49//! # When to use
50//!
51//! When you don't have memory for Adam-style optimizer state — large
52//! models, low-VRAM fine-tuning, sequence-length scaling experiments.
53//! State cost per matrix = `m + n` floats vs Adam's `2·m·n`.
54
55use std::collections::HashMap;
56
57use crate::Optimizer;
58use crate::common::{l2_norm, zeros_entry};
59
60/// Adafactor — factored-second-moment optimizer.
61///
62/// Per-tensor state: a `rows`-vector + a `cols`-vector for 2-D
63/// parameters (sublinear in `rows·cols`), or a full EMA for non-2-D.
64#[derive(Debug, Clone)]
65pub struct Adafactor {
66    /// Optional manual learning rate. `None` ⇒ use the "relative
67    /// step" rule `min(1/√t, 1e-2) · max(ε₂, RMS(θ))` from the paper.
68    /// Default `None`.
69    pub lr: Option<f32>,
70    /// β₂_t decay-rate exponent. `β₂_t = 1 − tˣ` with `x = -0.8`
71    /// (default) means slow decay early, full decay asymptotically.
72    pub beta2_decay: f32,
73    /// Squared-gradient stability constant added before each row /
74    /// column average. Default `1e-30`.
75    pub eps1: f32,
76    /// RMS-of-parameter floor for the relative-step rule. Default `1e-3`.
77    pub eps2: f32,
78    /// Update-RMS clipping threshold (Shazeer & Stern §6). Default `1.0`.
79    pub clip_threshold: f32,
80    /// Decoupled weight-decay coefficient λ. Default `0.0`.
81    pub weight_decay: f32,
82    step: u64,
83    // Per-parameter state.
84    r: HashMap<String, Vec<f32>>, // row factor (length rows) for 2D
85    c: HashMap<String, Vec<f32>>, // col factor (length cols) for 2D
86    v: HashMap<String, Vec<f32>>, // full EMA for non-2D
87}
88
89impl Adafactor {
90    /// Construct with paper defaults (no manual lr ⇒ relative step,
91    /// `decay_rate = -0.8`, `ε₁=1e-30, ε₂=1e-3, clip=1.0, λ=0.0`).
92    pub fn new() -> Self {
93        Self {
94            lr: None,
95            beta2_decay: -0.8,
96            eps1: 1e-30,
97            eps2: 1e-3,
98            clip_threshold: 1.0,
99            weight_decay: 0.0,
100            step: 0,
101            r: HashMap::new(),
102            c: HashMap::new(),
103            v: HashMap::new(),
104        }
105    }
106
107    /// Switch from the relative-step rule to a manual learning rate.
108    pub fn with_lr(mut self, lr: f32) -> Self {
109        self.lr = Some(lr);
110        self
111    }
112
113    /// Override the decoupled-decay coefficient.
114    pub fn with_weight_decay(mut self, wd: f32) -> Self {
115        self.weight_decay = wd;
116        self
117    }
118}
119
120impl Default for Adafactor {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl Optimizer for Adafactor {
127    fn step(&mut self, name: &str, shape: &[usize], param: &mut [f32], grad: &[f32]) {
128        debug_assert_eq!(param.len(), grad.len());
129        let t = (self.step + 1) as f64;
130        // β₂_t = 1 − t^{beta2_decay}, decay_rate ∈ (-1, 0).
131        let beta2_t = 1.0 - t.powf(self.beta2_decay as f64);
132        let eps1 = self.eps1 as f64;
133        let clip = self.clip_threshold as f64;
134        let n = param.len();
135
136        // ── Update second-moment estimate ──────────────────────────
137        let mut update = vec![0.0f32; n];
138        if shape.len() == 2 {
139            let (rows, cols) = (shape[0], shape[1]);
140            debug_assert_eq!(rows * cols, n);
141            let r = zeros_entry(&mut self.r, name, rows);
142            // Row factor: average of g² across columns, then EMA.
143            let mut row_buf = vec![0.0f64; rows];
144            for i in 0..rows {
145                let mut s = 0.0f64;
146                for j in 0..cols {
147                    let g = grad[i * cols + j] as f64;
148                    s += g * g + eps1;
149                }
150                row_buf[i] = s / cols as f64;
151            }
152            for i in 0..rows {
153                r[i] = (beta2_t * r[i] as f64 + (1.0 - beta2_t) * row_buf[i]) as f32;
154            }
155            let r_snapshot: Vec<f32> = r.clone();
156
157            // Column factor: average of g² across rows, then EMA.
158            let c = zeros_entry(&mut self.c, name, cols);
159            let mut col_buf = vec![0.0f64; cols];
160            for j in 0..cols {
161                let mut s = 0.0f64;
162                for i in 0..rows {
163                    let g = grad[i * cols + j] as f64;
164                    s += g * g + eps1;
165                }
166                col_buf[j] = s / rows as f64;
167            }
168            for j in 0..cols {
169                c[j] = (beta2_t * c[j] as f64 + (1.0 - beta2_t) * col_buf[j]) as f32;
170            }
171            let r_sum: f64 = r_snapshot.iter().map(|&x| x as f64).sum();
172            // v_ij = r_i * c_j / (sum_k r_k). Build update = g / sqrt(v).
173            for i in 0..rows {
174                for j in 0..cols {
175                    let v_ij = r_snapshot[i] as f64 * c[j] as f64 / r_sum.max(eps1);
176                    let g = grad[i * cols + j] as f64;
177                    update[i * cols + j] = (g / v_ij.sqrt().max(eps1.sqrt())) as f32;
178                }
179            }
180        } else {
181            // Non-2D: full per-element EMA.
182            let v = zeros_entry(&mut self.v, name, n);
183            for i in 0..n {
184                let g = grad[i] as f64;
185                v[i] = (beta2_t * v[i] as f64 + (1.0 - beta2_t) * (g * g + eps1)) as f32;
186                update[i] = (g / (v[i] as f64).sqrt().max(eps1.sqrt())) as f32;
187            }
188        }
189
190        // RMS-of-update clipping (Shazeer & Stern §6).
191        let u_rms = (l2_norm(&update) as f64 / (n as f64).sqrt()).max(1.0 / clip);
192        let scale = (1.0 / (u_rms * clip)).min(1.0);
193        for u in update.iter_mut() {
194            *u = (*u as f64 * scale) as f32;
195        }
196
197        // Learning rate (relative-step or manual).
198        let lr = match self.lr {
199            Some(x) => x as f64,
200            None => {
201                let p_rms = (l2_norm(param) as f64 / (n as f64).sqrt()).max(self.eps2 as f64);
202                (1.0 / t.sqrt()).min(1e-2) * p_rms
203            }
204        };
205        let wd = self.weight_decay as f64;
206        for i in 0..n {
207            let p = param[i] as f64;
208            param[i] = (p - lr * (update[i] as f64 + wd * p)) as f32;
209        }
210    }
211
212    fn end_iteration(&mut self) {
213        self.step += 1;
214    }
215}