Skip to main content

rlx_optim/
soap.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SOAP — ShampoO with Adam in the Preconditioner's eigenbasis
17//! (Vyas, Morwani, Anil, et al., 2024).
18//!
19//! # Idea
20//!
21//! Shampoo (Gupta et al. 2018) preconditions a 2-D parameter's
22//! gradient by `L⁻¹ᐟ⁴ · G · R⁻¹ᐟ⁴`, where `L = E[G·Gᵀ]` and
23//! `R = E[Gᵀ·G]` are Kronecker-factor covariances. SOAP observes that
24//! the same preconditioner is equivalent to **running Adam in the
25//! eigenbasis** of `L` and `R` — and that you only need to recompute
26//! the eigenbasis every K steps. This delivers Shampoo's quality with
27//! Adam's per-step cost (between recompiles).
28//!
29//! # Update rule (for a 2-D parameter `W ∈ ℝ^{m×n}`)
30//!
31//! ```text
32//! L_t = sb·L_{t-1} + (1−sb)·G·Gᵀ           // m×m
33//! R_t = sb·R_{t-1} + (1−sb)·Gᵀ·G           // n×n
34//! every K steps:                            // K = precond_freq
35//!   Q_L, _ = eigh(L_t)                     // m×m eigenbasis
36//!   Q_R, _ = eigh(R_t)                     // n×n eigenbasis
37//! G' = Q_Lᵀ · G · Q_R                       // rotated gradient
38//! [per-element Adam on G' → U']
39//! U  = Q_L · U' · Q_Rᵀ                      // rotate back
40//! θ_t = θ_{t-1} − lr · ( U + λ·θ_{t-1} )
41//! ```
42//!
43//! For non-2-D parameters we fall back to plain AdamW.
44//!
45//! # When to use
46//!
47//! When you want Shampoo's quality on transformers / large dense
48//! models and can afford the eigendecomposition cost amortized over
49//! `precond_freq` steps. State cost per matrix:
50//! `L (m²) + R (n²) + Q_L (m²) + Q_R (n²) + m_rot (m·n) + v_rot (m·n)`.
51
52use std::collections::HashMap;
53
54use crate::Optimizer;
55use crate::common::{jacobi_eigh_sym, matmul, zeros_entry};
56
57#[derive(Debug, Clone)]
58struct SoapState {
59    l: Vec<f32>,     // m × m left covariance
60    r: Vec<f32>,     // n × n right covariance
61    ql: Vec<f32>,    // m × m eigenbasis (row-major: row i = eigvec i)
62    qr: Vec<f32>,    // n × n eigenbasis
63    m_rot: Vec<f32>, // first moment in rotated basis (m·n)
64    v_rot: Vec<f32>, // second moment in rotated basis (m·n)
65    initialized_basis: bool,
66}
67
68/// SOAP — Shampoo-in-Adam-basis optimizer.
69#[derive(Debug, Clone)]
70pub struct Soap {
71    /// Learning rate.
72    pub lr: f32,
73    /// First-moment EMA decay (in the *rotated* basis). Default `0.95`.
74    pub beta1: f32,
75    /// Second-moment EMA decay (in the *rotated* basis). Default `0.95`.
76    pub beta2: f32,
77    /// Decay for the L/R covariance EMAs. Often equal to β₂.
78    pub shampoo_beta: f32,
79    /// Denominator stability constant. Default `1e-8`.
80    pub eps: f32,
81    /// Decoupled weight-decay coefficient λ. Default `0.01`.
82    pub weight_decay: f32,
83    /// Recompute the eigenbasis every `precond_freq` steps. Larger
84    /// values amortize the Jacobi cost but lag the preconditioner.
85    /// Default `10`.
86    pub precond_freq: u64,
87    /// Max Jacobi sweeps per rediagonalization. Default `30`.
88    pub jacobi_sweeps: u32,
89    step: u64,
90    state: HashMap<String, SoapState>,
91    // Fallback Adam state for non-2D parameters.
92    fb_m: HashMap<String, Vec<f32>>,
93    fb_v: HashMap<String, Vec<f32>>,
94}
95
96impl Soap {
97    /// Construct with `(β₁, β₂, sb, ε, λ, freq, sweeps) =
98    /// (0.95, 0.95, 0.95, 1e-8, 0.01, 10, 30)`.
99    pub fn new(lr: f32) -> Self {
100        Self {
101            lr,
102            beta1: 0.95,
103            beta2: 0.95,
104            shampoo_beta: 0.95,
105            eps: 1e-8,
106            weight_decay: 0.01,
107            precond_freq: 10,
108            jacobi_sweeps: 30,
109            step: 0,
110            state: HashMap::new(),
111            fb_m: HashMap::new(),
112            fb_v: HashMap::new(),
113        }
114    }
115
116    /// Override the decoupled-decay coefficient.
117    pub fn with_weight_decay(mut self, wd: f32) -> Self {
118        self.weight_decay = wd;
119        self
120    }
121}
122
123impl Optimizer for Soap {
124    fn step(&mut self, name: &str, shape: &[usize], param: &mut [f32], grad: &[f32]) {
125        debug_assert_eq!(param.len(), grad.len());
126        if shape.len() != 2 {
127            // Fallback: plain AdamW for non-matrix parameters.
128            adamw_fallback(self, name, param, grad);
129            return;
130        }
131        let (m, n) = (shape[0], shape[1]);
132        debug_assert_eq!(m * n, param.len());
133        let t = (self.step + 1) as f64;
134        let b1 = self.beta1 as f64;
135        let b2 = self.beta2 as f64;
136        let bc1 = 1.0 - b1.powf(t);
137        let bc2 = 1.0 - b2.powf(t);
138        let sb = self.shampoo_beta as f64;
139        let eps = self.eps;
140        let lr = self.lr;
141        let wd = self.weight_decay;
142
143        let st = self
144            .state
145            .entry(name.to_owned())
146            .or_insert_with(|| SoapState {
147                l: vec![0.0; m * m],
148                r: vec![0.0; n * n],
149                ql: identity(m),
150                qr: identity(n),
151                m_rot: vec![0.0; m * n],
152                v_rot: vec![0.0; m * n],
153                initialized_basis: false,
154            });
155
156        // ── 1. Update L, R covariances ─────────────────────────────
157        // L += (1-sb)·G·Gᵀ; R += (1-sb)·Gᵀ·G  (with β2 decay).
158        for i in 0..m {
159            for j in 0..m {
160                let mut s = 0.0f64;
161                for p in 0..n {
162                    s += grad[i * n + p] as f64 * grad[j * n + p] as f64;
163                }
164                let lij = sb * st.l[i * m + j] as f64 + (1.0 - sb) * s;
165                st.l[i * m + j] = lij as f32;
166            }
167        }
168        for i in 0..n {
169            for j in 0..n {
170                let mut s = 0.0f64;
171                for p in 0..m {
172                    s += grad[p * n + i] as f64 * grad[p * n + j] as f64;
173                }
174                let rij = sb * st.r[i * n + j] as f64 + (1.0 - sb) * s;
175                st.r[i * n + j] = rij as f32;
176            }
177        }
178
179        // ── 2. Rediagonalize periodically (and once on the first step). ──
180        let need_rediag = !st.initialized_basis || self.step.is_multiple_of(self.precond_freq);
181        if need_rediag {
182            let mut l_copy = st.l.clone();
183            let mut r_copy = st.r.clone();
184            jacobi_eigh_sym(&mut l_copy, m, &mut st.ql, self.jacobi_sweeps, 1e-6);
185            jacobi_eigh_sym(&mut r_copy, n, &mut st.qr, self.jacobi_sweeps, 1e-6);
186            st.initialized_basis = true;
187        }
188
189        // ── 3. Rotate gradient: G' = Qₗᵀ · G · Q_r ────────────────
190        let mut tmp = vec![0.0f32; m * n];
191        // tmp = Qₗᵀ · G  ⇒ tmp[i,j] = sum_p Qₗ[p,i] · G[p,j]
192        for i in 0..m {
193            for j in 0..n {
194                let mut s = 0.0f32;
195                for p in 0..m {
196                    s += st.ql[p * m + i] * grad[p * n + j];
197                }
198                tmp[i * n + j] = s;
199            }
200        }
201        let mut g_rot = vec![0.0f32; m * n];
202        matmul(&tmp, &st.qr, m, n, n, &mut g_rot);
203
204        // ── 4. Per-element Adam on rotated grad ──────────────────
205        let mut u_rot = vec![0.0f32; m * n];
206        for k in 0..m * n {
207            let g = g_rot[k] as f64;
208            let mi = b1 * st.m_rot[k] as f64 + (1.0 - b1) * g;
209            let vi = b2 * st.v_rot[k] as f64 + (1.0 - b2) * g * g;
210            st.m_rot[k] = mi as f32;
211            st.v_rot[k] = vi as f32;
212            let m_hat = mi / bc1;
213            let v_hat = vi / bc2;
214            u_rot[k] = (m_hat / (v_hat.sqrt() + eps as f64)) as f32;
215        }
216
217        // ── 5. Rotate back: U = Qₗ · U' · Q_rᵀ ───────────────────
218        // tmp = Qₗ · U'
219        matmul(&st.ql, &u_rot, m, m, n, &mut tmp);
220        // U = tmp · Q_rᵀ ⇒ U[i,j] = sum_p tmp[i,p] · Q_r[j,p]
221        let mut u = vec![0.0f32; m * n];
222        for i in 0..m {
223            for j in 0..n {
224                let mut s = 0.0f32;
225                for p in 0..n {
226                    s += tmp[i * n + p] * st.qr[j * n + p];
227                }
228                u[i * n + j] = s;
229            }
230        }
231
232        // ── 6. Decoupled weight decay + parameter update ─────────
233        for i in 0..m * n {
234            param[i] -= lr * (u[i] + wd * param[i]);
235        }
236    }
237
238    fn end_iteration(&mut self) {
239        self.step += 1;
240    }
241}
242
243fn identity(n: usize) -> Vec<f32> {
244    let mut out = vec![0.0; n * n];
245    for i in 0..n {
246        out[i * n + i] = 1.0;
247    }
248    out
249}
250
251// Plain AdamW for the non-matrix fallback path.
252fn adamw_fallback(opt: &mut Soap, name: &str, param: &mut [f32], grad: &[f32]) {
253    let t = (opt.step + 1) as f64;
254    let b1 = opt.beta1 as f64;
255    let b2 = opt.beta2 as f64;
256    let bc1 = 1.0 - b1.powf(t);
257    let bc2 = 1.0 - b2.powf(t);
258    let m = zeros_entry(&mut opt.fb_m, name, param.len());
259    let v = zeros_entry(&mut opt.fb_v, name, param.len());
260    let eps = opt.eps as f64;
261    let lr = opt.lr as f64;
262    let wd = opt.weight_decay as f64;
263    for i in 0..param.len() {
264        let g = grad[i] as f64;
265        let mi = b1 * m[i] as f64 + (1.0 - b1) * g;
266        let vi = b2 * v[i] as f64 + (1.0 - b2) * g * g;
267        m[i] = mi as f32;
268        v[i] = vi as f32;
269        let p = param[i] as f64;
270        param[i] = (p - lr * (mi / bc1 / ((vi / bc2).sqrt() + eps) + wd * p)) as f32;
271    }
272}