rlx_optim/soap.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SOAP — ShampoO with Adam in the Preconditioner's eigenbasis
17//! (Vyas, Morwani, Anil, et al., 2024).
18//!
19//! # Idea
20//!
21//! Shampoo (Gupta et al. 2018) preconditions a 2-D parameter's
22//! gradient by `L⁻¹ᐟ⁴ · G · R⁻¹ᐟ⁴`, where `L = E[G·Gᵀ]` and
23//! `R = E[Gᵀ·G]` are Kronecker-factor covariances. SOAP observes that
24//! the same preconditioner is equivalent to **running Adam in the
25//! eigenbasis** of `L` and `R` — and that you only need to recompute
26//! the eigenbasis every K steps. This delivers Shampoo's quality with
27//! Adam's per-step cost (between recompiles).
28//!
29//! # Update rule (for a 2-D parameter `W ∈ ℝ^{m×n}`)
30//!
31//! ```text
32//! L_t = sb·L_{t-1} + (1−sb)·G·Gᵀ // m×m
33//! R_t = sb·R_{t-1} + (1−sb)·Gᵀ·G // n×n
34//! every K steps: // K = precond_freq
35//! Q_L, _ = eigh(L_t) // m×m eigenbasis
36//! Q_R, _ = eigh(R_t) // n×n eigenbasis
37//! G' = Q_Lᵀ · G · Q_R // rotated gradient
38//! [per-element Adam on G' → U']
39//! U = Q_L · U' · Q_Rᵀ // rotate back
40//! θ_t = θ_{t-1} − lr · ( U + λ·θ_{t-1} )
41//! ```
42//!
43//! For non-2-D parameters we fall back to plain AdamW.
44//!
45//! # When to use
46//!
47//! When you want Shampoo's quality on transformers / large dense
48//! models and can afford the eigendecomposition cost amortized over
49//! `precond_freq` steps. State cost per matrix:
50//! `L (m²) + R (n²) + Q_L (m²) + Q_R (n²) + m_rot (m·n) + v_rot (m·n)`.
51
52use std::collections::HashMap;
53
54use crate::Optimizer;
55use crate::common::{jacobi_eigh_sym, matmul, zeros_entry};
56
57#[derive(Debug, Clone)]
58struct SoapState {
59 l: Vec<f32>, // m × m left covariance
60 r: Vec<f32>, // n × n right covariance
61 ql: Vec<f32>, // m × m eigenbasis (row-major: row i = eigvec i)
62 qr: Vec<f32>, // n × n eigenbasis
63 m_rot: Vec<f32>, // first moment in rotated basis (m·n)
64 v_rot: Vec<f32>, // second moment in rotated basis (m·n)
65 initialized_basis: bool,
66}
67
68/// SOAP — Shampoo-in-Adam-basis optimizer.
69#[derive(Debug, Clone)]
70pub struct Soap {
71 /// Learning rate.
72 pub lr: f32,
73 /// First-moment EMA decay (in the *rotated* basis). Default `0.95`.
74 pub beta1: f32,
75 /// Second-moment EMA decay (in the *rotated* basis). Default `0.95`.
76 pub beta2: f32,
77 /// Decay for the L/R covariance EMAs. Often equal to β₂.
78 pub shampoo_beta: f32,
79 /// Denominator stability constant. Default `1e-8`.
80 pub eps: f32,
81 /// Decoupled weight-decay coefficient λ. Default `0.01`.
82 pub weight_decay: f32,
83 /// Recompute the eigenbasis every `precond_freq` steps. Larger
84 /// values amortize the Jacobi cost but lag the preconditioner.
85 /// Default `10`.
86 pub precond_freq: u64,
87 /// Max Jacobi sweeps per rediagonalization. Default `30`.
88 pub jacobi_sweeps: u32,
89 step: u64,
90 state: HashMap<String, SoapState>,
91 // Fallback Adam state for non-2D parameters.
92 fb_m: HashMap<String, Vec<f32>>,
93 fb_v: HashMap<String, Vec<f32>>,
94}
95
96impl Soap {
97 /// Construct with `(β₁, β₂, sb, ε, λ, freq, sweeps) =
98 /// (0.95, 0.95, 0.95, 1e-8, 0.01, 10, 30)`.
99 pub fn new(lr: f32) -> Self {
100 Self {
101 lr,
102 beta1: 0.95,
103 beta2: 0.95,
104 shampoo_beta: 0.95,
105 eps: 1e-8,
106 weight_decay: 0.01,
107 precond_freq: 10,
108 jacobi_sweeps: 30,
109 step: 0,
110 state: HashMap::new(),
111 fb_m: HashMap::new(),
112 fb_v: HashMap::new(),
113 }
114 }
115
116 /// Override the decoupled-decay coefficient.
117 pub fn with_weight_decay(mut self, wd: f32) -> Self {
118 self.weight_decay = wd;
119 self
120 }
121}
122
123impl Optimizer for Soap {
124 fn step(&mut self, name: &str, shape: &[usize], param: &mut [f32], grad: &[f32]) {
125 debug_assert_eq!(param.len(), grad.len());
126 if shape.len() != 2 {
127 // Fallback: plain AdamW for non-matrix parameters.
128 adamw_fallback(self, name, param, grad);
129 return;
130 }
131 let (m, n) = (shape[0], shape[1]);
132 debug_assert_eq!(m * n, param.len());
133 let t = (self.step + 1) as f64;
134 let b1 = self.beta1 as f64;
135 let b2 = self.beta2 as f64;
136 let bc1 = 1.0 - b1.powf(t);
137 let bc2 = 1.0 - b2.powf(t);
138 let sb = self.shampoo_beta as f64;
139 let eps = self.eps;
140 let lr = self.lr;
141 let wd = self.weight_decay;
142
143 let st = self
144 .state
145 .entry(name.to_owned())
146 .or_insert_with(|| SoapState {
147 l: vec![0.0; m * m],
148 r: vec![0.0; n * n],
149 ql: identity(m),
150 qr: identity(n),
151 m_rot: vec![0.0; m * n],
152 v_rot: vec![0.0; m * n],
153 initialized_basis: false,
154 });
155
156 // ── 1. Update L, R covariances ─────────────────────────────
157 // L += (1-sb)·G·Gᵀ; R += (1-sb)·Gᵀ·G (with β2 decay).
158 for i in 0..m {
159 for j in 0..m {
160 let mut s = 0.0f64;
161 for p in 0..n {
162 s += grad[i * n + p] as f64 * grad[j * n + p] as f64;
163 }
164 let lij = sb * st.l[i * m + j] as f64 + (1.0 - sb) * s;
165 st.l[i * m + j] = lij as f32;
166 }
167 }
168 for i in 0..n {
169 for j in 0..n {
170 let mut s = 0.0f64;
171 for p in 0..m {
172 s += grad[p * n + i] as f64 * grad[p * n + j] as f64;
173 }
174 let rij = sb * st.r[i * n + j] as f64 + (1.0 - sb) * s;
175 st.r[i * n + j] = rij as f32;
176 }
177 }
178
179 // ── 2. Rediagonalize periodically (and once on the first step). ──
180 let need_rediag = !st.initialized_basis || self.step.is_multiple_of(self.precond_freq);
181 if need_rediag {
182 let mut l_copy = st.l.clone();
183 let mut r_copy = st.r.clone();
184 jacobi_eigh_sym(&mut l_copy, m, &mut st.ql, self.jacobi_sweeps, 1e-6);
185 jacobi_eigh_sym(&mut r_copy, n, &mut st.qr, self.jacobi_sweeps, 1e-6);
186 st.initialized_basis = true;
187 }
188
189 // ── 3. Rotate gradient: G' = Qₗᵀ · G · Q_r ────────────────
190 let mut tmp = vec![0.0f32; m * n];
191 // tmp = Qₗᵀ · G ⇒ tmp[i,j] = sum_p Qₗ[p,i] · G[p,j]
192 for i in 0..m {
193 for j in 0..n {
194 let mut s = 0.0f32;
195 for p in 0..m {
196 s += st.ql[p * m + i] * grad[p * n + j];
197 }
198 tmp[i * n + j] = s;
199 }
200 }
201 let mut g_rot = vec![0.0f32; m * n];
202 matmul(&tmp, &st.qr, m, n, n, &mut g_rot);
203
204 // ── 4. Per-element Adam on rotated grad ──────────────────
205 let mut u_rot = vec![0.0f32; m * n];
206 for k in 0..m * n {
207 let g = g_rot[k] as f64;
208 let mi = b1 * st.m_rot[k] as f64 + (1.0 - b1) * g;
209 let vi = b2 * st.v_rot[k] as f64 + (1.0 - b2) * g * g;
210 st.m_rot[k] = mi as f32;
211 st.v_rot[k] = vi as f32;
212 let m_hat = mi / bc1;
213 let v_hat = vi / bc2;
214 u_rot[k] = (m_hat / (v_hat.sqrt() + eps as f64)) as f32;
215 }
216
217 // ── 5. Rotate back: U = Qₗ · U' · Q_rᵀ ───────────────────
218 // tmp = Qₗ · U'
219 matmul(&st.ql, &u_rot, m, m, n, &mut tmp);
220 // U = tmp · Q_rᵀ ⇒ U[i,j] = sum_p tmp[i,p] · Q_r[j,p]
221 let mut u = vec![0.0f32; m * n];
222 for i in 0..m {
223 for j in 0..n {
224 let mut s = 0.0f32;
225 for p in 0..n {
226 s += tmp[i * n + p] * st.qr[j * n + p];
227 }
228 u[i * n + j] = s;
229 }
230 }
231
232 // ── 6. Decoupled weight decay + parameter update ─────────
233 for i in 0..m * n {
234 param[i] -= lr * (u[i] + wd * param[i]);
235 }
236 }
237
238 fn end_iteration(&mut self) {
239 self.step += 1;
240 }
241}
242
243fn identity(n: usize) -> Vec<f32> {
244 let mut out = vec![0.0; n * n];
245 for i in 0..n {
246 out[i * n + i] = 1.0;
247 }
248 out
249}
250
251// Plain AdamW for the non-matrix fallback path.
252fn adamw_fallback(opt: &mut Soap, name: &str, param: &mut [f32], grad: &[f32]) {
253 let t = (opt.step + 1) as f64;
254 let b1 = opt.beta1 as f64;
255 let b2 = opt.beta2 as f64;
256 let bc1 = 1.0 - b1.powf(t);
257 let bc2 = 1.0 - b2.powf(t);
258 let m = zeros_entry(&mut opt.fb_m, name, param.len());
259 let v = zeros_entry(&mut opt.fb_v, name, param.len());
260 let eps = opt.eps as f64;
261 let lr = opt.lr as f64;
262 let wd = opt.weight_decay as f64;
263 for i in 0..param.len() {
264 let g = grad[i] as f64;
265 let mi = b1 * m[i] as f64 + (1.0 - b1) * g;
266 let vi = b2 * v[i] as f64 + (1.0 - b2) * g * g;
267 m[i] = mi as f32;
268 v[i] = vi as f32;
269 let p = param[i] as f64;
270 param[i] = (p - lr * (mi / bc1 / ((vi / bc2).sqrt() + eps) + wd * p)) as f32;
271 }
272}