use std::collections::HashMap;
use crate::Optimizer;
use crate::common::zeros_entry;
#[derive(Debug, Clone)]
pub struct Muon {
pub lr: f32,
pub momentum: f32,
pub nesterov: bool,
pub weight_decay: f32,
pub ns_steps: u32,
pub ns_coeffs: (f32, f32, f32),
m: HashMap<String, Vec<f32>>,
}
impl Muon {
pub fn new(lr: f32) -> Self {
Self {
lr,
momentum: 0.95,
nesterov: true,
weight_decay: 0.0,
ns_steps: 5,
ns_coeffs: (3.4445, -4.7750, 2.0315),
m: HashMap::new(),
}
}
pub fn with_momentum(mut self, mu: f32) -> Self {
self.momentum = mu;
self
}
pub fn with_weight_decay(mut self, wd: f32) -> Self {
self.weight_decay = wd;
self
}
pub fn with_ns_steps(mut self, n: u32) -> Self {
self.ns_steps = n;
self
}
}
impl Optimizer for Muon {
fn step(&mut self, name: &str, shape: &[usize], param: &mut [f32], grad: &[f32]) {
debug_assert_eq!(param.len(), grad.len());
let mu = self.momentum;
let wd = self.weight_decay;
let lr = self.lr;
let m = zeros_entry(&mut self.m, name, param.len());
for i in 0..param.len() {
m[i] = mu * m[i] + grad[i];
}
if shape.len() != 2 {
for i in 0..param.len() {
let g = if self.nesterov {
grad[i] + mu * m[i]
} else {
m[i]
};
param[i] -= lr * (g + wd * param[i]);
}
return;
}
let (rows, cols) = (shape[0], shape[1]);
debug_assert_eq!(rows * cols, param.len());
let mut g_mat = vec![0.0f32; rows * cols];
if self.nesterov {
for i in 0..rows * cols {
g_mat[i] = grad[i] + mu * m[i];
}
} else {
g_mat.copy_from_slice(m);
}
let ortho = newton_schulz_orth(&g_mat, rows, cols, self.ns_steps, self.ns_coeffs);
let s = (rows.max(cols) as f32).sqrt();
for i in 0..param.len() {
param[i] -= lr * (s * ortho[i] + wd * param[i]);
}
}
}
fn newton_schulz_orth(
g: &[f32],
rows: usize,
cols: usize,
steps: u32,
c: (f32, f32, f32),
) -> Vec<f32> {
let mut x = g.to_vec();
let mut fro = 0.0f64;
for &xi in &x {
fro += xi as f64 * xi as f64;
}
let fro = (fro.sqrt() as f32).max(1e-12);
for xi in &mut x {
*xi /= fro;
}
let (mut x_mat, r, k, transposed) = if rows < cols {
let mut t = vec![0.0f32; rows * cols];
for i in 0..rows {
for j in 0..cols {
t[j * rows + i] = x[i * cols + j];
}
}
(t, cols, rows, true)
} else {
(x, rows, cols, false)
};
let (a, b, cc) = c;
let mut tmp = vec![0.0f32; r * k]; let mut a_mat = vec![0.0f32; r * r];
let mut a2 = vec![0.0f32; r * r];
for _ in 0..steps {
for i in 0..r {
for j in 0..r {
let mut s = 0.0f32;
for p in 0..k {
s += x_mat[i * k + p] * x_mat[j * k + p];
}
a_mat[i * r + j] = s;
}
}
for i in 0..r {
for j in 0..r {
let mut s = 0.0f32;
for p in 0..r {
s += a_mat[i * r + p] * a_mat[p * r + j];
}
a2[i * r + j] = s;
}
}
for i in 0..r {
for j in 0..k {
let mut s = a * x_mat[i * k + j];
for p in 0..r {
s += b * a_mat[i * r + p] * x_mat[p * k + j];
s += cc * a2[i * r + p] * x_mat[p * k + j];
}
tmp[i * k + j] = s;
}
}
std::mem::swap(&mut x_mat, &mut tmp);
}
if transposed {
let mut out = vec![0.0f32; rows * cols];
for i in 0..r {
for j in 0..k {
out[j * r + i] = x_mat[i * k + j];
}
}
out
} else {
x_mat
}
}