rlx_optim/muon.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Muon — MomentUm Orthogonalized by Newton–Schulz (Jordan, Bernstein,
17//! Vyas, Hubara, et al., 2024).
18//!
19//! # Idea
20//!
21//! For a 2-D parameter, replace the momentum buffer with its **closest
22//! semi-orthogonal matrix** before applying it as an update. The SVD
23//! `M = U·Σ·Vᵀ` has closest semi-orthogonal matrix `U·Vᵀ` — but the
24//! SVD is expensive. A *Newton–Schulz cubic iteration* approximates
25//! `U·Vᵀ` in only 5 small matrix products per step. Empirically this
26//! gives a step-size-invariant update that punches above its weight on
27//! transformer training.
28//!
29//! # Update rule (2-D parameter `W ∈ ℝ^{m×n}`)
30//!
31//! ```text
32//! m_t = μ·m_{t-1} + g_t // Polyak momentum
33//! M = m_t if !nesterov
34//! = g_t + μ·m_t if nesterov
35//! M̂ = M / ‖M‖_F // normalize for NS
36//! repeat ns_steps times: // ns_steps = 5
37//! A = M̂ · M̂ᵀ
38//! M̂ ← a·M̂ + b·A·M̂ + c·A²·M̂ // cubic NS iter
39//! U = √max(m, n) · M̂ // RMS-of-cols scaling
40//! θ_t = θ_{t-1} − lr · ( U + λ·θ_{t-1} )
41//! ```
42//!
43//! The (a, b, c) coefficients are chosen so the cubic polynomial maps
44//! singular values in (0, √3] toward 1; defaults
45//! `(3.4445, −4.7750, 2.0315)` are from the original release.
46//!
47//! Non-2-D parameters fall back to SGD-with-momentum (the original
48//! recipe routes them to AdamW; this crate stays dependency-free).
49//!
50//! # When to use
51//!
52//! Pre-training transformer matrix-shaped weights (Q/K/V/FFN
53//! projections). Often paired with AdamW for embeddings and biases.
54//! State cost: one momentum buffer per matrix.
55
56use std::collections::HashMap;
57
58use crate::Optimizer;
59use crate::common::zeros_entry;
60
61/// Muon — Momentum-Orthogonalized-by-Newton-Schulz.
62///
63/// Per-tensor state: **one** momentum buffer per matrix (half of
64/// Adam's footprint, like Lion).
65#[derive(Debug, Clone)]
66pub struct Muon {
67 /// Learning rate. The Newton–Schulz update has roughly unit
68 /// Frobenius norm per column, so this is on the same scale as
69 /// SGD's lr — typically `2e-2` to `5e-2`.
70 pub lr: f32,
71 /// Polyak momentum coefficient. Default `0.95`.
72 pub momentum: f32,
73 /// Use Nesterov lookahead inside the matrix being orthogonalized.
74 /// Default `true`.
75 pub nesterov: bool,
76 /// Decoupled weight-decay coefficient λ. Default `0.0`.
77 pub weight_decay: f32,
78 /// Newton–Schulz iteration count. `5` is the published default;
79 /// `3` is enough for most well-conditioned matrices.
80 pub ns_steps: u32,
81 /// `(a, b, c)` coefficients of the cubic Newton–Schulz iteration
82 /// `X ← a·X + b·(XXᵀ)X + c·(XXᵀ)²X`. Defaults match Jordan et al.
83 pub ns_coeffs: (f32, f32, f32),
84 m: HashMap<String, Vec<f32>>,
85}
86
87impl Muon {
88 /// Construct with `(μ, nesterov, λ, ns_steps) = (0.95, true, 0.0, 5)`
89 /// and the published NS coefficients.
90 pub fn new(lr: f32) -> Self {
91 Self {
92 lr,
93 momentum: 0.95,
94 nesterov: true,
95 weight_decay: 0.0,
96 ns_steps: 5,
97 ns_coeffs: (3.4445, -4.7750, 2.0315),
98 m: HashMap::new(),
99 }
100 }
101
102 /// Override the Polyak momentum coefficient.
103 pub fn with_momentum(mut self, mu: f32) -> Self {
104 self.momentum = mu;
105 self
106 }
107
108 /// Override the decoupled-decay coefficient.
109 pub fn with_weight_decay(mut self, wd: f32) -> Self {
110 self.weight_decay = wd;
111 self
112 }
113
114 /// Override the Newton–Schulz iteration count.
115 pub fn with_ns_steps(mut self, n: u32) -> Self {
116 self.ns_steps = n;
117 self
118 }
119}
120
121impl Optimizer for Muon {
122 fn step(&mut self, name: &str, shape: &[usize], param: &mut [f32], grad: &[f32]) {
123 debug_assert_eq!(param.len(), grad.len());
124 let mu = self.momentum;
125 let wd = self.weight_decay;
126 let lr = self.lr;
127 let m = zeros_entry(&mut self.m, name, param.len());
128 // EMA buffer (classical Polyak momentum: `m ← μ·m + g`).
129 for i in 0..param.len() {
130 m[i] = mu * m[i] + grad[i];
131 }
132 if shape.len() != 2 {
133 // Non-matrix: SGD-with-momentum update.
134 for i in 0..param.len() {
135 let g = if self.nesterov {
136 grad[i] + mu * m[i]
137 } else {
138 m[i]
139 };
140 param[i] -= lr * (g + wd * param[i]);
141 }
142 return;
143 }
144 let (rows, cols) = (shape[0], shape[1]);
145 debug_assert_eq!(rows * cols, param.len());
146 // Build the matrix to orthogonalize. With Nesterov:
147 // G = grad + μ·m (m has already been updated above)
148 let mut g_mat = vec![0.0f32; rows * cols];
149 if self.nesterov {
150 for i in 0..rows * cols {
151 g_mat[i] = grad[i] + mu * m[i];
152 }
153 } else {
154 g_mat.copy_from_slice(m);
155 }
156 let ortho = newton_schulz_orth(&g_mat, rows, cols, self.ns_steps, self.ns_coeffs);
157 // The Muon paper scales the update by sqrt(max(rows, cols)) so
158 // its effective magnitude matches a unit-norm column.
159 let s = (rows.max(cols) as f32).sqrt();
160 for i in 0..param.len() {
161 param[i] -= lr * (s * ortho[i] + wd * param[i]);
162 }
163 }
164}
165
166/// Newton–Schulz semi-orthogonalization. Operates on a row-major
167/// `rows × cols` matrix and returns its closest semi-orthogonal matrix
168/// (up to the polynomial truncation). The input is first scaled by its
169/// Frobenius norm to stay inside the polynomial's region of convergence.
170fn newton_schulz_orth(
171 g: &[f32],
172 rows: usize,
173 cols: usize,
174 steps: u32,
175 c: (f32, f32, f32),
176) -> Vec<f32> {
177 let mut x = g.to_vec();
178 // Frobenius normalization.
179 let mut fro = 0.0f64;
180 for &xi in &x {
181 fro += xi as f64 * xi as f64;
182 }
183 let fro = (fro.sqrt() as f32).max(1e-12);
184 for xi in &mut x {
185 *xi /= fro;
186 }
187 // The cubic iteration is more efficient on the "thin" side; we
188 // transpose internally if rows < cols so that the inner products
189 // are over the longer axis.
190 let (mut x_mat, r, k, transposed) = if rows < cols {
191 // transpose
192 let mut t = vec![0.0f32; rows * cols];
193 for i in 0..rows {
194 for j in 0..cols {
195 t[j * rows + i] = x[i * cols + j];
196 }
197 }
198 (t, cols, rows, true)
199 } else {
200 (x, rows, cols, false)
201 };
202 let (a, b, cc) = c;
203 let mut tmp = vec![0.0f32; r * k]; // XXᵀ X has shape r × k
204 let mut a_mat = vec![0.0f32; r * r];
205 let mut a2 = vec![0.0f32; r * r];
206 for _ in 0..steps {
207 // A = X · Xᵀ (r × r)
208 for i in 0..r {
209 for j in 0..r {
210 let mut s = 0.0f32;
211 for p in 0..k {
212 s += x_mat[i * k + p] * x_mat[j * k + p];
213 }
214 a_mat[i * r + j] = s;
215 }
216 }
217 // A² = A · A
218 for i in 0..r {
219 for j in 0..r {
220 let mut s = 0.0f32;
221 for p in 0..r {
222 s += a_mat[i * r + p] * a_mat[p * r + j];
223 }
224 a2[i * r + j] = s;
225 }
226 }
227 // X ← a·X + b·A·X + cc·A²·X
228 for i in 0..r {
229 for j in 0..k {
230 let mut s = a * x_mat[i * k + j];
231 for p in 0..r {
232 s += b * a_mat[i * r + p] * x_mat[p * k + j];
233 s += cc * a2[i * r + p] * x_mat[p * k + j];
234 }
235 tmp[i * k + j] = s;
236 }
237 }
238 std::mem::swap(&mut x_mat, &mut tmp);
239 }
240 if transposed {
241 // Transpose back to rows × cols.
242 let mut out = vec![0.0f32; rows * cols];
243 for i in 0..r {
244 for j in 0..k {
245 out[j * r + i] = x_mat[i * k + j];
246 }
247 }
248 out
249 } else {
250 x_mat
251 }
252}