Skip to main content

rlx_optim/
muon.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Muon — MomentUm Orthogonalized by Newton–Schulz (Jordan, Bernstein,
17//! Vyas, Hubara, et al., 2024).
18//!
19//! # Idea
20//!
21//! For a 2-D parameter, replace the momentum buffer with its **closest
22//! semi-orthogonal matrix** before applying it as an update. The SVD
23//! `M = U·Σ·Vᵀ` has closest semi-orthogonal matrix `U·Vᵀ` — but the
24//! SVD is expensive. A *Newton–Schulz cubic iteration* approximates
25//! `U·Vᵀ` in only 5 small matrix products per step. Empirically this
26//! gives a step-size-invariant update that punches above its weight on
27//! transformer training.
28//!
29//! # Update rule (2-D parameter `W ∈ ℝ^{m×n}`)
30//!
31//! ```text
32//! m_t = μ·m_{t-1} + g_t                              // Polyak momentum
33//! M   = m_t                  if !nesterov
34//!     = g_t + μ·m_t          if  nesterov
35//! M̂   = M / ‖M‖_F                                    // normalize for NS
36//! repeat ns_steps times:                              // ns_steps = 5
37//!     A = M̂ · M̂ᵀ
38//!     M̂ ← a·M̂ + b·A·M̂ + c·A²·M̂                       // cubic NS iter
39//! U   = √max(m, n) · M̂                                // RMS-of-cols scaling
40//! θ_t = θ_{t-1} − lr · ( U + λ·θ_{t-1} )
41//! ```
42//!
43//! The (a, b, c) coefficients are chosen so the cubic polynomial maps
44//! singular values in (0, √3] toward 1; defaults
45//! `(3.4445, −4.7750, 2.0315)` are from the original release.
46//!
47//! Non-2-D parameters fall back to SGD-with-momentum (the original
48//! recipe routes them to AdamW; this crate stays dependency-free).
49//!
50//! # When to use
51//!
52//! Pre-training transformer matrix-shaped weights (Q/K/V/FFN
53//! projections). Often paired with AdamW for embeddings and biases.
54//! State cost: one momentum buffer per matrix.
55
56use std::collections::HashMap;
57
58use crate::Optimizer;
59use crate::common::zeros_entry;
60
61/// Muon — Momentum-Orthogonalized-by-Newton-Schulz.
62///
63/// Per-tensor state: **one** momentum buffer per matrix (half of
64/// Adam's footprint, like Lion).
65#[derive(Debug, Clone)]
66pub struct Muon {
67    /// Learning rate. The Newton–Schulz update has roughly unit
68    /// Frobenius norm per column, so this is on the same scale as
69    /// SGD's lr — typically `2e-2` to `5e-2`.
70    pub lr: f32,
71    /// Polyak momentum coefficient. Default `0.95`.
72    pub momentum: f32,
73    /// Use Nesterov lookahead inside the matrix being orthogonalized.
74    /// Default `true`.
75    pub nesterov: bool,
76    /// Decoupled weight-decay coefficient λ. Default `0.0`.
77    pub weight_decay: f32,
78    /// Newton–Schulz iteration count. `5` is the published default;
79    /// `3` is enough for most well-conditioned matrices.
80    pub ns_steps: u32,
81    /// `(a, b, c)` coefficients of the cubic Newton–Schulz iteration
82    /// `X ← a·X + b·(XXᵀ)X + c·(XXᵀ)²X`. Defaults match Jordan et al.
83    pub ns_coeffs: (f32, f32, f32),
84    m: HashMap<String, Vec<f32>>,
85}
86
87impl Muon {
88    /// Construct with `(μ, nesterov, λ, ns_steps) = (0.95, true, 0.0, 5)`
89    /// and the published NS coefficients.
90    pub fn new(lr: f32) -> Self {
91        Self {
92            lr,
93            momentum: 0.95,
94            nesterov: true,
95            weight_decay: 0.0,
96            ns_steps: 5,
97            ns_coeffs: (3.4445, -4.7750, 2.0315),
98            m: HashMap::new(),
99        }
100    }
101
102    /// Override the Polyak momentum coefficient.
103    pub fn with_momentum(mut self, mu: f32) -> Self {
104        self.momentum = mu;
105        self
106    }
107
108    /// Override the decoupled-decay coefficient.
109    pub fn with_weight_decay(mut self, wd: f32) -> Self {
110        self.weight_decay = wd;
111        self
112    }
113
114    /// Override the Newton–Schulz iteration count.
115    pub fn with_ns_steps(mut self, n: u32) -> Self {
116        self.ns_steps = n;
117        self
118    }
119}
120
121impl Optimizer for Muon {
122    fn step(&mut self, name: &str, shape: &[usize], param: &mut [f32], grad: &[f32]) {
123        debug_assert_eq!(param.len(), grad.len());
124        let mu = self.momentum;
125        let wd = self.weight_decay;
126        let lr = self.lr;
127        let m = zeros_entry(&mut self.m, name, param.len());
128        // EMA buffer (classical Polyak momentum: `m ← μ·m + g`).
129        for i in 0..param.len() {
130            m[i] = mu * m[i] + grad[i];
131        }
132        if shape.len() != 2 {
133            // Non-matrix: SGD-with-momentum update.
134            for i in 0..param.len() {
135                let g = if self.nesterov {
136                    grad[i] + mu * m[i]
137                } else {
138                    m[i]
139                };
140                param[i] -= lr * (g + wd * param[i]);
141            }
142            return;
143        }
144        let (rows, cols) = (shape[0], shape[1]);
145        debug_assert_eq!(rows * cols, param.len());
146        // Build the matrix to orthogonalize. With Nesterov:
147        //   G = grad + μ·m   (m has already been updated above)
148        let mut g_mat = vec![0.0f32; rows * cols];
149        if self.nesterov {
150            for i in 0..rows * cols {
151                g_mat[i] = grad[i] + mu * m[i];
152            }
153        } else {
154            g_mat.copy_from_slice(m);
155        }
156        let ortho = newton_schulz_orth(&g_mat, rows, cols, self.ns_steps, self.ns_coeffs);
157        // The Muon paper scales the update by sqrt(max(rows, cols)) so
158        // its effective magnitude matches a unit-norm column.
159        let s = (rows.max(cols) as f32).sqrt();
160        for i in 0..param.len() {
161            param[i] -= lr * (s * ortho[i] + wd * param[i]);
162        }
163    }
164}
165
166/// Newton–Schulz semi-orthogonalization. Operates on a row-major
167/// `rows × cols` matrix and returns its closest semi-orthogonal matrix
168/// (up to the polynomial truncation). The input is first scaled by its
169/// Frobenius norm to stay inside the polynomial's region of convergence.
170fn newton_schulz_orth(
171    g: &[f32],
172    rows: usize,
173    cols: usize,
174    steps: u32,
175    c: (f32, f32, f32),
176) -> Vec<f32> {
177    let mut x = g.to_vec();
178    // Frobenius normalization.
179    let mut fro = 0.0f64;
180    for &xi in &x {
181        fro += xi as f64 * xi as f64;
182    }
183    let fro = (fro.sqrt() as f32).max(1e-12);
184    for xi in &mut x {
185        *xi /= fro;
186    }
187    // The cubic iteration is more efficient on the "thin" side; we
188    // transpose internally if rows < cols so that the inner products
189    // are over the longer axis.
190    let (mut x_mat, r, k, transposed) = if rows < cols {
191        // transpose
192        let mut t = vec![0.0f32; rows * cols];
193        for i in 0..rows {
194            for j in 0..cols {
195                t[j * rows + i] = x[i * cols + j];
196            }
197        }
198        (t, cols, rows, true)
199    } else {
200        (x, rows, cols, false)
201    };
202    let (a, b, cc) = c;
203    let mut tmp = vec![0.0f32; r * k]; // XXᵀ X has shape r × k
204    let mut a_mat = vec![0.0f32; r * r];
205    let mut a2 = vec![0.0f32; r * r];
206    for _ in 0..steps {
207        // A = X · Xᵀ  (r × r)
208        for i in 0..r {
209            for j in 0..r {
210                let mut s = 0.0f32;
211                for p in 0..k {
212                    s += x_mat[i * k + p] * x_mat[j * k + p];
213                }
214                a_mat[i * r + j] = s;
215            }
216        }
217        // A² = A · A
218        for i in 0..r {
219            for j in 0..r {
220                let mut s = 0.0f32;
221                for p in 0..r {
222                    s += a_mat[i * r + p] * a_mat[p * r + j];
223                }
224                a2[i * r + j] = s;
225            }
226        }
227        // X ← a·X + b·A·X + cc·A²·X
228        for i in 0..r {
229            for j in 0..k {
230                let mut s = a * x_mat[i * k + j];
231                for p in 0..r {
232                    s += b * a_mat[i * r + p] * x_mat[p * k + j];
233                    s += cc * a2[i * r + p] * x_mat[p * k + j];
234                }
235                tmp[i * k + j] = s;
236            }
237        }
238        std::mem::swap(&mut x_mat, &mut tmp);
239    }
240    if transposed {
241        // Transpose back to rows × cols.
242        let mut out = vec![0.0f32; rows * cols];
243        for i in 0..r {
244            for j in 0..k {
245                out[j * r + i] = x_mat[i * k + j];
246            }
247        }
248        out
249    } else {
250        x_mat
251    }
252}