rlx_optim/adafactor.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Adafactor (Shazeer & Stern, 2018, "Adafactor: Adaptive Learning
17//! Rates with Sublinear Memory Cost").
18//!
19//! # Idea
20//!
21//! Adam's `v_t` is the same shape as θ — for a 70B-parameter model
22//! that's 280 GB of optimizer state. Adafactor *factorizes* the
23//! second-moment matrix: for a 2-D parameter of shape `m × n`, instead
24//! of an `m·n` buffer it stores a row-statistic `R ∈ ℝᵐ` and a
25//! column-statistic `C ∈ ℝⁿ`, then reconstructs
26//! `V̂_{ij} ≈ R_i · C_j / Σ_k R_k`. State drops from `O(m·n)` to
27//! `O(m + n)`.
28//!
29//! # Update rule (this impl: factored 2nd-moment, no 1st-moment)
30//!
31//! Let `β₂_t = 1 − t^{decay_rate}` (default decay_rate = −0.8). For a
32//! 2-D parameter:
33//!
34//! ```text
35//! R_i = β₂_t·R_i + (1−β₂_t)·mean_j(g_ij² + ε₁)
36//! C_j = β₂_t·C_j + (1−β₂_t)·mean_i(g_ij² + ε₁)
37//! V̂_{ij} = R_i · C_j / Σ_k R_k
38//! u_{ij} = g_{ij} / √V̂_{ij}
39//! u ← u / max(1, RMS(u) / clip_threshold) // RMS-of-update clip
40//! lr_t = manual_lr OR min(1/√t, 1e-2) · max(ε₂, RMS(θ)) // relative step
41//! θ_t = θ_{t-1} − lr_t · ( u + λ·θ_{t-1} )
42//! ```
43//!
44//! For non-2-D parameters (bias vectors, 4-D conv weights) we fall
45//! back to a full per-element EMA — the savings are negligible there
46//! anyway. The optional first-moment EMA is **not** implemented
47//! (matches the recommended T5 configuration).
48//!
49//! # When to use
50//!
51//! When you don't have memory for Adam-style optimizer state — large
52//! models, low-VRAM fine-tuning, sequence-length scaling experiments.
53//! State cost per matrix = `m + n` floats vs Adam's `2·m·n`.
54
55use std::collections::HashMap;
56
57use crate::Optimizer;
58use crate::common::{l2_norm, zeros_entry};
59
60/// Adafactor — factored-second-moment optimizer.
61///
62/// Per-tensor state: a `rows`-vector + a `cols`-vector for 2-D
63/// parameters (sublinear in `rows·cols`), or a full EMA for non-2-D.
64#[derive(Debug, Clone)]
65pub struct Adafactor {
66 /// Optional manual learning rate. `None` ⇒ use the "relative
67 /// step" rule `min(1/√t, 1e-2) · max(ε₂, RMS(θ))` from the paper.
68 /// Default `None`.
69 pub lr: Option<f32>,
70 /// β₂_t decay-rate exponent. `β₂_t = 1 − tˣ` with `x = -0.8`
71 /// (default) means slow decay early, full decay asymptotically.
72 pub beta2_decay: f32,
73 /// Squared-gradient stability constant added before each row /
74 /// column average. Default `1e-30`.
75 pub eps1: f32,
76 /// RMS-of-parameter floor for the relative-step rule. Default `1e-3`.
77 pub eps2: f32,
78 /// Update-RMS clipping threshold (Shazeer & Stern §6). Default `1.0`.
79 pub clip_threshold: f32,
80 /// Decoupled weight-decay coefficient λ. Default `0.0`.
81 pub weight_decay: f32,
82 step: u64,
83 // Per-parameter state.
84 r: HashMap<String, Vec<f32>>, // row factor (length rows) for 2D
85 c: HashMap<String, Vec<f32>>, // col factor (length cols) for 2D
86 v: HashMap<String, Vec<f32>>, // full EMA for non-2D
87}
88
89impl Adafactor {
90 /// Construct with paper defaults (no manual lr ⇒ relative step,
91 /// `decay_rate = -0.8`, `ε₁=1e-30, ε₂=1e-3, clip=1.0, λ=0.0`).
92 pub fn new() -> Self {
93 Self {
94 lr: None,
95 beta2_decay: -0.8,
96 eps1: 1e-30,
97 eps2: 1e-3,
98 clip_threshold: 1.0,
99 weight_decay: 0.0,
100 step: 0,
101 r: HashMap::new(),
102 c: HashMap::new(),
103 v: HashMap::new(),
104 }
105 }
106
107 /// Switch from the relative-step rule to a manual learning rate.
108 pub fn with_lr(mut self, lr: f32) -> Self {
109 self.lr = Some(lr);
110 self
111 }
112
113 /// Override the decoupled-decay coefficient.
114 pub fn with_weight_decay(mut self, wd: f32) -> Self {
115 self.weight_decay = wd;
116 self
117 }
118}
119
120impl Default for Adafactor {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl Optimizer for Adafactor {
127 fn step(&mut self, name: &str, shape: &[usize], param: &mut [f32], grad: &[f32]) {
128 debug_assert_eq!(param.len(), grad.len());
129 let t = (self.step + 1) as f64;
130 // β₂_t = 1 − t^{beta2_decay}, decay_rate ∈ (-1, 0).
131 let beta2_t = 1.0 - t.powf(self.beta2_decay as f64);
132 let eps1 = self.eps1 as f64;
133 let clip = self.clip_threshold as f64;
134 let n = param.len();
135
136 // ── Update second-moment estimate ──────────────────────────
137 let mut update = vec![0.0f32; n];
138 if shape.len() == 2 {
139 let (rows, cols) = (shape[0], shape[1]);
140 debug_assert_eq!(rows * cols, n);
141 let r = zeros_entry(&mut self.r, name, rows);
142 // Row factor: average of g² across columns, then EMA.
143 let mut row_buf = vec![0.0f64; rows];
144 for i in 0..rows {
145 let mut s = 0.0f64;
146 for j in 0..cols {
147 let g = grad[i * cols + j] as f64;
148 s += g * g + eps1;
149 }
150 row_buf[i] = s / cols as f64;
151 }
152 for i in 0..rows {
153 r[i] = (beta2_t * r[i] as f64 + (1.0 - beta2_t) * row_buf[i]) as f32;
154 }
155 let r_snapshot: Vec<f32> = r.clone();
156
157 // Column factor: average of g² across rows, then EMA.
158 let c = zeros_entry(&mut self.c, name, cols);
159 let mut col_buf = vec![0.0f64; cols];
160 for j in 0..cols {
161 let mut s = 0.0f64;
162 for i in 0..rows {
163 let g = grad[i * cols + j] as f64;
164 s += g * g + eps1;
165 }
166 col_buf[j] = s / rows as f64;
167 }
168 for j in 0..cols {
169 c[j] = (beta2_t * c[j] as f64 + (1.0 - beta2_t) * col_buf[j]) as f32;
170 }
171 let r_sum: f64 = r_snapshot.iter().map(|&x| x as f64).sum();
172 // v_ij = r_i * c_j / (sum_k r_k). Build update = g / sqrt(v).
173 for i in 0..rows {
174 for j in 0..cols {
175 let v_ij = r_snapshot[i] as f64 * c[j] as f64 / r_sum.max(eps1);
176 let g = grad[i * cols + j] as f64;
177 update[i * cols + j] = (g / v_ij.sqrt().max(eps1.sqrt())) as f32;
178 }
179 }
180 } else {
181 // Non-2D: full per-element EMA.
182 let v = zeros_entry(&mut self.v, name, n);
183 for i in 0..n {
184 let g = grad[i] as f64;
185 v[i] = (beta2_t * v[i] as f64 + (1.0 - beta2_t) * (g * g + eps1)) as f32;
186 update[i] = (g / (v[i] as f64).sqrt().max(eps1.sqrt())) as f32;
187 }
188 }
189
190 // RMS-of-update clipping (Shazeer & Stern §6).
191 let u_rms = (l2_norm(&update) as f64 / (n as f64).sqrt()).max(1.0 / clip);
192 let scale = (1.0 / (u_rms * clip)).min(1.0);
193 for u in update.iter_mut() {
194 *u = (*u as f64 * scale) as f32;
195 }
196
197 // Learning rate (relative-step or manual).
198 let lr = match self.lr {
199 Some(x) => x as f64,
200 None => {
201 let p_rms = (l2_norm(param) as f64 / (n as f64).sqrt()).max(self.eps2 as f64);
202 (1.0 / t.sqrt()).min(1e-2) * p_rms
203 }
204 };
205 let wd = self.weight_decay as f64;
206 for i in 0..n {
207 let p = param[i] as f64;
208 param[i] = (p - lr * (update[i] as f64 + wd * p)) as f32;
209 }
210 }
211
212 fn end_iteration(&mut self) {
213 self.step += 1;
214 }
215}