rlx_optim/kron_psgd.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Kron-PSGD — Preconditioned SGD with a Kronecker-factored
17//! preconditioner (Li, 2018; "Preconditioned Stochastic Gradient
18//! Descent").
19//!
20//! # Idea
21//!
22//! Approximate the inverse Hessian as a Kronecker product
23//! `P ≈ P_L ⊗ P_R` where `P_L = Q_LᵀQ_L` and `P_R = Q_RᵀQ_R` for two
24//! upper-triangular factors. The factors are updated by a *Lie-group
25//! descent* on a whitening criterion — no eigendecomposition needed,
26//! and updates are stable by construction (the upper-triangular
27//! manifold).
28//!
29//! # Update rule
30//!
31//! For a 2-D parameter `W ∈ ℝ^{m×n}`:
32//!
33//! ```text
34//! A = Q_L · G · Q_Rᵀ // m×n
35//! B = Q_L⁻ᵀ · G · Q_R⁻¹ // m×n (triangular solves)
36//! dQ_L ∝ tril(A·Aᵀ − B·Bᵀ); Q_L ← Q_L − η_p · Q_L · dQ_L
37//! dQ_R ∝ tril(Aᵀ·A − Bᵀ·B); Q_R ← Q_R − η_p · Q_R · dQ_R
38//! P_L = Q_LᵀQ_L; P_R = Q_RᵀQ_R
39//! p_g = P_L · G · P_R // preconditioned grad
40//! [spectral-clip to ‖·‖_∞ ≤ clip, then SGD+momentum on p_g]
41//! ```
42//!
43//! Li (2018) Algorithm 1 uses an HVP probe `v` and its perturbed
44//! gradient to update `Q_L, Q_R`. This crate has no HVP oracle, so we
45//! use the gradient itself as the probe — the "PSGD-Affine"
46//! approximation — which is cheap and still gives strong empirical
47//! preconditioning on convex and mildly non-convex problems.
48//! Non-2-D parameters fall back to plain SGD-with-momentum.
49//!
50//! # When to use
51//!
52//! Ill-conditioned problems where Adam's coordinate-wise
53//! preconditioner is too weak (RNNs, deep MLPs, certain inverse
54//! problems). State cost per matrix: `m² + n²` plus a velocity buffer.
55
56use std::collections::HashMap;
57
58use crate::Optimizer;
59use crate::common::{matmul, zeros_entry};
60
61#[derive(Debug, Clone)]
62struct KronState {
63 ql: Vec<f32>, // m × m upper-triangular
64 qr: Vec<f32>, // n × n upper-triangular
65}
66
67/// Kron-PSGD — Kronecker-factored preconditioned SGD.
68#[derive(Debug, Clone)]
69pub struct KronPsgd {
70 /// Learning rate.
71 pub lr: f32,
72 /// Learning rate for the **preconditioner** update (Lie-group
73 /// descent on Q_L / Q_R). Default `0.1`. Too high ⇒ Q drifts;
74 /// too low ⇒ preconditioner lags.
75 pub precond_lr: f32,
76 /// Polyak momentum for the preconditioned-gradient SGD step.
77 /// Default `0.9`.
78 pub momentum: f32,
79 /// L2 weight-decay coefficient (folded into the gradient).
80 /// Default `0.0`.
81 pub weight_decay: f32,
82 /// Numerical floor on the preconditioner-update normalizer.
83 /// Default `1e-8`.
84 pub eps: f32,
85 /// Cap the per-coordinate magnitude of the preconditioned update
86 /// (defensive — early Q estimates can be ill-conditioned). Default `1.0`.
87 pub clip: f32,
88 state: HashMap<String, KronState>,
89 mom: HashMap<String, Vec<f32>>,
90}
91
92impl KronPsgd {
93 /// Construct with `(precond_lr, μ, λ, ε, clip) = (0.1, 0.9, 0.0, 1e-8, 1.0)`.
94 pub fn new(lr: f32) -> Self {
95 Self {
96 lr,
97 precond_lr: 0.1,
98 momentum: 0.9,
99 weight_decay: 0.0,
100 eps: 1e-8,
101 clip: 1.0,
102 state: HashMap::new(),
103 mom: HashMap::new(),
104 }
105 }
106
107 /// Override the Polyak momentum.
108 pub fn with_momentum(mut self, mu: f32) -> Self {
109 self.momentum = mu;
110 self
111 }
112
113 /// Override the weight-decay coefficient.
114 pub fn with_weight_decay(mut self, wd: f32) -> Self {
115 self.weight_decay = wd;
116 self
117 }
118}
119
120impl Optimizer for KronPsgd {
121 fn step(&mut self, name: &str, shape: &[usize], param: &mut [f32], grad: &[f32]) {
122 debug_assert_eq!(param.len(), grad.len());
123 let lr = self.lr;
124 let wd = self.weight_decay;
125
126 if shape.len() != 2 {
127 // Non-matrix: SGD + momentum fallback.
128 let v = zeros_entry(&mut self.mom, name, param.len());
129 let mu = self.momentum;
130 for i in 0..param.len() {
131 v[i] = mu * v[i] + grad[i] + wd * param[i];
132 param[i] -= lr * v[i];
133 }
134 return;
135 }
136 let (m, n) = (shape[0], shape[1]);
137 debug_assert_eq!(m * n, param.len());
138 let st = self
139 .state
140 .entry(name.to_owned())
141 .or_insert_with(|| KronState {
142 ql: identity_triangular(m),
143 qr: identity_triangular(n),
144 });
145
146 // ── 1. Update Q_L, Q_R via Li (2018) Lie-group rule. ──────
147 // Use g itself as the probe; the affine variant requires:
148 // A = Q_L · g · Q_Rᵀ (m × n)
149 // B = Q_L⁻ᵀ · g · Q_R⁻¹ (m × n; cheap because Q is triangular)
150 // dQ_L ∝ tril(A·Aᵀ − B·Bᵀ); dQ_R ∝ tril(Aᵀ·A − Bᵀ·B).
151 let a = matmul_3(&st.ql, grad, &st.qr, m, n, /*trans_q_r=*/ true);
152 let b = matmul_3_inv(&st.ql, grad, &st.qr, m, n);
153 update_factor(&mut st.ql, &a, &b, m, n, true, self.precond_lr, self.eps);
154 update_factor(&mut st.qr, &a, &b, m, n, false, self.precond_lr, self.eps);
155
156 // ── 2. Preconditioned gradient: p_g = Q_Lᵀ · Q_L · g · Q_R · Q_Rᵀ ──
157 // Build Q_Lᵀ Q_L (m×m, symmetric)
158 let mut ql_t_ql = vec![0.0f32; m * m];
159 for i in 0..m {
160 for j in 0..m {
161 let mut s = 0.0f32;
162 for p in 0..m {
163 s += st.ql[p * m + i] * st.ql[p * m + j];
164 }
165 ql_t_ql[i * m + j] = s;
166 }
167 }
168 let mut qr_qr_t = vec![0.0f32; n * n];
169 for i in 0..n {
170 for j in 0..n {
171 let mut s = 0.0f32;
172 for p in 0..n {
173 s += st.qr[i * n + p] * st.qr[j * n + p];
174 }
175 qr_qr_t[i * n + j] = s;
176 }
177 }
178 // p_g = (Q_Lᵀ Q_L) · g · (Q_R Q_Rᵀ)
179 let mut tmp = vec![0.0f32; m * n];
180 matmul(&ql_t_ql, grad, m, m, n, &mut tmp);
181 let mut p_g = vec![0.0f32; m * n];
182 matmul(&tmp, &qr_qr_t, m, n, n, &mut p_g);
183
184 // ── 3. Spectral clip + momentum + apply. ─────────────────
185 let mut max_abs = 0.0f32;
186 for &x in &p_g {
187 if x.abs() > max_abs {
188 max_abs = x.abs();
189 }
190 }
191 let scale = if max_abs > self.clip {
192 self.clip / max_abs
193 } else {
194 1.0
195 };
196 let v = zeros_entry(&mut self.mom, name, param.len());
197 let mu = self.momentum;
198 for i in 0..param.len() {
199 let g = scale * p_g[i] + wd * param[i];
200 v[i] = mu * v[i] + g;
201 param[i] -= lr * v[i];
202 }
203 }
204}
205
206fn identity_triangular(n: usize) -> Vec<f32> {
207 let mut out = vec![0.0; n * n];
208 for i in 0..n {
209 out[i * n + i] = 1.0;
210 }
211 out
212}
213
214/// Compute `Q_L · G · Q_Rᵀ` (or `Q_L · G · Q_R` if `trans_q_r=false`).
215fn matmul_3(ql: &[f32], g: &[f32], qr: &[f32], m: usize, n: usize, trans_q_r: bool) -> Vec<f32> {
216 let mut t1 = vec![0.0f32; m * n];
217 matmul(ql, g, m, m, n, &mut t1);
218 let mut out = vec![0.0f32; m * n];
219 if trans_q_r {
220 // out = t1 · Q_Rᵀ ⇒ out[i,j] = sum_p t1[i,p] · Q_R[j,p]
221 for i in 0..m {
222 for j in 0..n {
223 let mut s = 0.0f32;
224 for p in 0..n {
225 s += t1[i * n + p] * qr[j * n + p];
226 }
227 out[i * n + j] = s;
228 }
229 }
230 } else {
231 matmul(&t1, qr, m, n, n, &mut out);
232 }
233 out
234}
235
236/// Compute `Q_L⁻ᵀ · G · Q_R⁻¹` for upper-triangular Q's via two
237/// triangular solves on `G`.
238fn matmul_3_inv(ql: &[f32], g: &[f32], qr: &[f32], m: usize, n: usize) -> Vec<f32> {
239 // First solve Q_Lᵀ · X = G column-by-column. Q_Lᵀ is lower-triangular.
240 let mut x = g.to_vec();
241 for j in 0..n {
242 // Forward-substitute one column.
243 for i in 0..m {
244 let mut s = x[i * n + j];
245 for p in 0..i {
246 s -= ql[p * m + i] * x[p * n + j];
247 }
248 let d = ql[i * m + i];
249 x[i * n + j] = if d.abs() > 1e-12 { s / d } else { 0.0 };
250 }
251 }
252 // Then solve Y · Q_R = X for Y row-by-row (Q_R upper-triangular).
253 // Equivalently: for each row i, back-substitute Y[i,:] · Q_R = X[i,:].
254 let mut y = x;
255 for i in 0..m {
256 for j in 0..n {
257 let mut s = y[i * n + j];
258 for p in 0..j {
259 s -= y[i * n + p] * qr[p * n + j];
260 }
261 let d = qr[j * n + j];
262 y[i * n + j] = if d.abs() > 1e-12 { s / d } else { 0.0 };
263 }
264 }
265 y
266}
267
268/// Lie-group update of a triangular factor. `which = true` updates Q_L
269/// using `A·Aᵀ − B·Bᵀ` (m×m), `which = false` updates Q_R using
270/// `Aᵀ·A − Bᵀ·B` (n×n). The descent direction is then projected onto
271/// the upper-triangular tangent space.
272fn update_factor(
273 q: &mut [f32],
274 a: &[f32],
275 b: &[f32],
276 m: usize,
277 n: usize,
278 which: bool,
279 plr: f32,
280 eps: f32,
281) {
282 let dim = if which { m } else { n };
283 let mut grad_q = vec![0.0f32; dim * dim];
284 // Build A·Aᵀ − B·Bᵀ or Aᵀ·A − Bᵀ·B.
285 let mut norm = 0.0f64;
286 for i in 0..dim {
287 for j in 0..dim {
288 let mut a_term = 0.0f32;
289 let mut b_term = 0.0f32;
290 if which {
291 for p in 0..n {
292 a_term += a[i * n + p] * a[j * n + p];
293 b_term += b[i * n + p] * b[j * n + p];
294 }
295 } else {
296 for p in 0..m {
297 a_term += a[p * n + i] * a[p * n + j];
298 b_term += b[p * n + i] * b[p * n + j];
299 }
300 }
301 let d = a_term - b_term;
302 grad_q[i * dim + j] = d;
303 norm += d as f64 * d as f64;
304 }
305 }
306 let scale = plr / ((norm.sqrt() as f32) + eps);
307 // Project onto upper-triangular: Q ← Q · (I − 0.5·scale·tril(grad_q + grad_qᵀ))
308 // (Simplified Lie-group projection; full version solves a tiny matrix
309 // exponential, but a single linearized step is the standard choice.)
310 for i in 0..dim {
311 for j in 0..dim {
312 if j < i {
313 grad_q[i * dim + j] = 0.0; // upper-triangular projection
314 }
315 }
316 }
317 let mut q_new = vec![0.0f32; dim * dim];
318 matmul(q, &grad_q, dim, dim, dim, &mut q_new);
319 for k in 0..dim * dim {
320 q[k] -= scale * q_new[k];
321 }
322}