Skip to main content

rlx_optim/
kron_psgd.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Kron-PSGD — Preconditioned SGD with a Kronecker-factored
17//! preconditioner (Li, 2018; "Preconditioned Stochastic Gradient
18//! Descent").
19//!
20//! # Idea
21//!
22//! Approximate the inverse Hessian as a Kronecker product
23//! `P ≈ P_L ⊗ P_R` where `P_L = Q_LᵀQ_L` and `P_R = Q_RᵀQ_R` for two
24//! upper-triangular factors. The factors are updated by a *Lie-group
25//! descent* on a whitening criterion — no eigendecomposition needed,
26//! and updates are stable by construction (the upper-triangular
27//! manifold).
28//!
29//! # Update rule
30//!
31//! For a 2-D parameter `W ∈ ℝ^{m×n}`:
32//!
33//! ```text
34//! A = Q_L · G · Q_Rᵀ                          // m×n
35//! B = Q_L⁻ᵀ · G · Q_R⁻¹                       // m×n (triangular solves)
36//! dQ_L ∝ tril(A·Aᵀ − B·Bᵀ);  Q_L ← Q_L − η_p · Q_L · dQ_L
37//! dQ_R ∝ tril(Aᵀ·A − Bᵀ·B);  Q_R ← Q_R − η_p · Q_R · dQ_R
38//! P_L = Q_LᵀQ_L;   P_R = Q_RᵀQ_R
39//! p_g = P_L · G · P_R                         // preconditioned grad
40//! [spectral-clip to ‖·‖_∞ ≤ clip, then SGD+momentum on p_g]
41//! ```
42//!
43//! Li (2018) Algorithm 1 uses an HVP probe `v` and its perturbed
44//! gradient to update `Q_L, Q_R`. This crate has no HVP oracle, so we
45//! use the gradient itself as the probe — the "PSGD-Affine"
46//! approximation — which is cheap and still gives strong empirical
47//! preconditioning on convex and mildly non-convex problems.
48//! Non-2-D parameters fall back to plain SGD-with-momentum.
49//!
50//! # When to use
51//!
52//! Ill-conditioned problems where Adam's coordinate-wise
53//! preconditioner is too weak (RNNs, deep MLPs, certain inverse
54//! problems). State cost per matrix: `m² + n²` plus a velocity buffer.
55
56use std::collections::HashMap;
57
58use crate::Optimizer;
59use crate::common::{matmul, zeros_entry};
60
61#[derive(Debug, Clone)]
62struct KronState {
63    ql: Vec<f32>, // m × m upper-triangular
64    qr: Vec<f32>, // n × n upper-triangular
65}
66
67/// Kron-PSGD — Kronecker-factored preconditioned SGD.
68#[derive(Debug, Clone)]
69pub struct KronPsgd {
70    /// Learning rate.
71    pub lr: f32,
72    /// Learning rate for the **preconditioner** update (Lie-group
73    /// descent on Q_L / Q_R). Default `0.1`. Too high ⇒ Q drifts;
74    /// too low ⇒ preconditioner lags.
75    pub precond_lr: f32,
76    /// Polyak momentum for the preconditioned-gradient SGD step.
77    /// Default `0.9`.
78    pub momentum: f32,
79    /// L2 weight-decay coefficient (folded into the gradient).
80    /// Default `0.0`.
81    pub weight_decay: f32,
82    /// Numerical floor on the preconditioner-update normalizer.
83    /// Default `1e-8`.
84    pub eps: f32,
85    /// Cap the per-coordinate magnitude of the preconditioned update
86    /// (defensive — early Q estimates can be ill-conditioned). Default `1.0`.
87    pub clip: f32,
88    state: HashMap<String, KronState>,
89    mom: HashMap<String, Vec<f32>>,
90}
91
92impl KronPsgd {
93    /// Construct with `(precond_lr, μ, λ, ε, clip) = (0.1, 0.9, 0.0, 1e-8, 1.0)`.
94    pub fn new(lr: f32) -> Self {
95        Self {
96            lr,
97            precond_lr: 0.1,
98            momentum: 0.9,
99            weight_decay: 0.0,
100            eps: 1e-8,
101            clip: 1.0,
102            state: HashMap::new(),
103            mom: HashMap::new(),
104        }
105    }
106
107    /// Override the Polyak momentum.
108    pub fn with_momentum(mut self, mu: f32) -> Self {
109        self.momentum = mu;
110        self
111    }
112
113    /// Override the weight-decay coefficient.
114    pub fn with_weight_decay(mut self, wd: f32) -> Self {
115        self.weight_decay = wd;
116        self
117    }
118}
119
120impl Optimizer for KronPsgd {
121    fn step(&mut self, name: &str, shape: &[usize], param: &mut [f32], grad: &[f32]) {
122        debug_assert_eq!(param.len(), grad.len());
123        let lr = self.lr;
124        let wd = self.weight_decay;
125
126        if shape.len() != 2 {
127            // Non-matrix: SGD + momentum fallback.
128            let v = zeros_entry(&mut self.mom, name, param.len());
129            let mu = self.momentum;
130            for i in 0..param.len() {
131                v[i] = mu * v[i] + grad[i] + wd * param[i];
132                param[i] -= lr * v[i];
133            }
134            return;
135        }
136        let (m, n) = (shape[0], shape[1]);
137        debug_assert_eq!(m * n, param.len());
138        let st = self
139            .state
140            .entry(name.to_owned())
141            .or_insert_with(|| KronState {
142                ql: identity_triangular(m),
143                qr: identity_triangular(n),
144            });
145
146        // ── 1. Update Q_L, Q_R via Li (2018) Lie-group rule. ──────
147        // Use g itself as the probe; the affine variant requires:
148        //   A = Q_L · g · Q_Rᵀ        (m × n)
149        //   B = Q_L⁻ᵀ · g · Q_R⁻¹     (m × n; cheap because Q is triangular)
150        // dQ_L ∝ tril(A·Aᵀ − B·Bᵀ); dQ_R ∝ tril(Aᵀ·A − Bᵀ·B).
151        let a = matmul_3(&st.ql, grad, &st.qr, m, n, /*trans_q_r=*/ true);
152        let b = matmul_3_inv(&st.ql, grad, &st.qr, m, n);
153        update_factor(&mut st.ql, &a, &b, m, n, true, self.precond_lr, self.eps);
154        update_factor(&mut st.qr, &a, &b, m, n, false, self.precond_lr, self.eps);
155
156        // ── 2. Preconditioned gradient: p_g = Q_Lᵀ · Q_L · g · Q_R · Q_Rᵀ ──
157        // Build Q_Lᵀ Q_L (m×m, symmetric)
158        let mut ql_t_ql = vec![0.0f32; m * m];
159        for i in 0..m {
160            for j in 0..m {
161                let mut s = 0.0f32;
162                for p in 0..m {
163                    s += st.ql[p * m + i] * st.ql[p * m + j];
164                }
165                ql_t_ql[i * m + j] = s;
166            }
167        }
168        let mut qr_qr_t = vec![0.0f32; n * n];
169        for i in 0..n {
170            for j in 0..n {
171                let mut s = 0.0f32;
172                for p in 0..n {
173                    s += st.qr[i * n + p] * st.qr[j * n + p];
174                }
175                qr_qr_t[i * n + j] = s;
176            }
177        }
178        // p_g = (Q_Lᵀ Q_L) · g · (Q_R Q_Rᵀ)
179        let mut tmp = vec![0.0f32; m * n];
180        matmul(&ql_t_ql, grad, m, m, n, &mut tmp);
181        let mut p_g = vec![0.0f32; m * n];
182        matmul(&tmp, &qr_qr_t, m, n, n, &mut p_g);
183
184        // ── 3. Spectral clip + momentum + apply. ─────────────────
185        let mut max_abs = 0.0f32;
186        for &x in &p_g {
187            if x.abs() > max_abs {
188                max_abs = x.abs();
189            }
190        }
191        let scale = if max_abs > self.clip {
192            self.clip / max_abs
193        } else {
194            1.0
195        };
196        let v = zeros_entry(&mut self.mom, name, param.len());
197        let mu = self.momentum;
198        for i in 0..param.len() {
199            let g = scale * p_g[i] + wd * param[i];
200            v[i] = mu * v[i] + g;
201            param[i] -= lr * v[i];
202        }
203    }
204}
205
206fn identity_triangular(n: usize) -> Vec<f32> {
207    let mut out = vec![0.0; n * n];
208    for i in 0..n {
209        out[i * n + i] = 1.0;
210    }
211    out
212}
213
214/// Compute `Q_L · G · Q_Rᵀ` (or `Q_L · G · Q_R` if `trans_q_r=false`).
215fn matmul_3(ql: &[f32], g: &[f32], qr: &[f32], m: usize, n: usize, trans_q_r: bool) -> Vec<f32> {
216    let mut t1 = vec![0.0f32; m * n];
217    matmul(ql, g, m, m, n, &mut t1);
218    let mut out = vec![0.0f32; m * n];
219    if trans_q_r {
220        // out = t1 · Q_Rᵀ  ⇒  out[i,j] = sum_p t1[i,p] · Q_R[j,p]
221        for i in 0..m {
222            for j in 0..n {
223                let mut s = 0.0f32;
224                for p in 0..n {
225                    s += t1[i * n + p] * qr[j * n + p];
226                }
227                out[i * n + j] = s;
228            }
229        }
230    } else {
231        matmul(&t1, qr, m, n, n, &mut out);
232    }
233    out
234}
235
236/// Compute `Q_L⁻ᵀ · G · Q_R⁻¹` for upper-triangular Q's via two
237/// triangular solves on `G`.
238fn matmul_3_inv(ql: &[f32], g: &[f32], qr: &[f32], m: usize, n: usize) -> Vec<f32> {
239    // First solve Q_Lᵀ · X = G column-by-column. Q_Lᵀ is lower-triangular.
240    let mut x = g.to_vec();
241    for j in 0..n {
242        // Forward-substitute one column.
243        for i in 0..m {
244            let mut s = x[i * n + j];
245            for p in 0..i {
246                s -= ql[p * m + i] * x[p * n + j];
247            }
248            let d = ql[i * m + i];
249            x[i * n + j] = if d.abs() > 1e-12 { s / d } else { 0.0 };
250        }
251    }
252    // Then solve Y · Q_R = X for Y row-by-row (Q_R upper-triangular).
253    // Equivalently: for each row i, back-substitute Y[i,:] · Q_R = X[i,:].
254    let mut y = x;
255    for i in 0..m {
256        for j in 0..n {
257            let mut s = y[i * n + j];
258            for p in 0..j {
259                s -= y[i * n + p] * qr[p * n + j];
260            }
261            let d = qr[j * n + j];
262            y[i * n + j] = if d.abs() > 1e-12 { s / d } else { 0.0 };
263        }
264    }
265    y
266}
267
268/// Lie-group update of a triangular factor. `which = true` updates Q_L
269/// using `A·Aᵀ − B·Bᵀ` (m×m), `which = false` updates Q_R using
270/// `Aᵀ·A − Bᵀ·B` (n×n). The descent direction is then projected onto
271/// the upper-triangular tangent space.
272fn update_factor(
273    q: &mut [f32],
274    a: &[f32],
275    b: &[f32],
276    m: usize,
277    n: usize,
278    which: bool,
279    plr: f32,
280    eps: f32,
281) {
282    let dim = if which { m } else { n };
283    let mut grad_q = vec![0.0f32; dim * dim];
284    // Build A·Aᵀ − B·Bᵀ  or  Aᵀ·A − Bᵀ·B.
285    let mut norm = 0.0f64;
286    for i in 0..dim {
287        for j in 0..dim {
288            let mut a_term = 0.0f32;
289            let mut b_term = 0.0f32;
290            if which {
291                for p in 0..n {
292                    a_term += a[i * n + p] * a[j * n + p];
293                    b_term += b[i * n + p] * b[j * n + p];
294                }
295            } else {
296                for p in 0..m {
297                    a_term += a[p * n + i] * a[p * n + j];
298                    b_term += b[p * n + i] * b[p * n + j];
299                }
300            }
301            let d = a_term - b_term;
302            grad_q[i * dim + j] = d;
303            norm += d as f64 * d as f64;
304        }
305    }
306    let scale = plr / ((norm.sqrt() as f32) + eps);
307    // Project onto upper-triangular: Q ← Q · (I − 0.5·scale·tril(grad_q + grad_qᵀ))
308    // (Simplified Lie-group projection; full version solves a tiny matrix
309    // exponential, but a single linearized step is the standard choice.)
310    for i in 0..dim {
311        for j in 0..dim {
312            if j < i {
313                grad_q[i * dim + j] = 0.0; // upper-triangular projection
314            }
315        }
316    }
317    let mut q_new = vec![0.0f32; dim * dim];
318    matmul(q, &grad_q, dim, dim, dim, &mut q_new);
319    for k in 0..dim * dim {
320        q[k] -= scale * q_new[k];
321    }
322}