use std::collections::HashMap;
use smol_str::format_smolstr;
use crate::{
Array, Result,
error::{Error, NonFiniteScalarPayload, OutOfRangePayload, RankMismatchPayload},
lm::{
load::Weights,
tuner::optimizers::base::{LearningRate, Optimizer, zeros_like, zeros_like_map},
},
ops::{arithmetic, linalg_basic::addmm, linalg_full::norm_l2, shape::reshape},
};
fn validate_momentum_finite(momentum: f32) -> Result<()> {
if !momentum.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"Muon: momentum",
momentum as f64,
)));
}
Ok(())
}
fn validate_weight_decay(weight_decay: f32) -> Result<()> {
if !weight_decay.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"Muon: weight_decay",
weight_decay as f64,
)));
}
if weight_decay < 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Muon: weight_decay",
"must be >= 0.0",
format_smolstr!("{weight_decay}"),
)));
}
Ok(())
}
fn scalar(v: f32) -> Result<Array> {
Array::full::<f32>(&[0i32; 0], v)
}
pub struct Muon {
learning_rate: LearningRate,
momentum: f32,
weight_decay: f32,
nesterov: bool,
ns_steps: usize,
step_count: usize,
current_lr: f32,
lr_resolved_for_step: Option<usize>,
state: HashMap<String, Array>,
}
impl Muon {
pub fn new(
learning_rate: impl Into<LearningRate>,
momentum: f32,
weight_decay: f32,
nesterov: bool,
ns_steps: usize,
) -> Result<Self> {
validate_momentum_finite(momentum)?;
validate_weight_decay(weight_decay)?;
let lr = learning_rate.into();
let current_lr = lr.try_current(0)?;
Ok(Self {
learning_rate: lr,
momentum,
weight_decay,
nesterov,
ns_steps,
step_count: 0,
current_lr,
lr_resolved_for_step: Some(0),
state: HashMap::new(),
})
}
pub fn default_with_lr(learning_rate: impl Into<LearningRate>) -> Result<Self> {
Self::new(learning_rate, 0.95, 0.01, true, 5)
}
#[inline(always)]
pub fn learning_rate_ref(&self) -> &LearningRate {
&self.learning_rate
}
#[inline(always)]
pub fn momentum(&self) -> f32 {
self.momentum
}
#[inline(always)]
pub fn weight_decay(&self) -> f32 {
self.weight_decay
}
#[inline(always)]
pub fn nesterov(&self) -> bool {
self.nesterov
}
#[inline(always)]
pub fn ns_steps(&self) -> usize {
self.ns_steps
}
pub fn with_learning_rate(mut self, learning_rate: impl Into<LearningRate>) -> Result<Self> {
let lr = learning_rate.into();
let current_lr = lr.try_current(self.step_count)?;
self.learning_rate = lr;
self.current_lr = current_lr;
self.lr_resolved_for_step = Some(self.step_count);
Ok(self)
}
pub fn with_momentum(mut self, momentum: f32) -> Result<Self> {
validate_momentum_finite(momentum)?;
self.momentum = momentum;
Ok(self)
}
pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self> {
validate_weight_decay(weight_decay)?;
self.weight_decay = weight_decay;
Ok(self)
}
#[must_use]
pub fn with_nesterov(mut self, nesterov: bool) -> Self {
self.nesterov = nesterov;
self
}
#[must_use]
pub fn with_ns_steps(mut self, ns_steps: usize) -> Self {
self.ns_steps = ns_steps;
self
}
fn newton_schulz5(&self, x: &Array, steps: usize) -> Result<Array> {
let shape = x.shape();
if shape.len() != 2 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"Muon.newton_schulz5: expected 2D input",
shape.len() as u32,
shape.to_vec(),
)));
}
let (a, b, c) = (3.4445_f32, -4.7750_f32, 2.0315_f32);
let transpose_needed = shape[shape.len() - 2] > shape[shape.len() - 1];
let mut x = if transpose_needed {
x.transpose()?
} else {
x.try_clone()?
};
let n = norm_l2(&x, &[], true)?;
let denom = arithmetic::add(&n, &scalar(1e-7)?)?;
x = arithmetic::divide(&x, &denom)?;
for _ in 0..steps {
let xt = x.transpose()?;
let a_mat = crate::ops::linalg_basic::matmul(&x, &xt)?;
let b_a = arithmetic::multiply(&scalar(b)?, &a_mat)?;
let big_b = addmm(&b_a, &a_mat, &a_mat, c, 1.0)?;
let a_x = arithmetic::multiply(&scalar(a)?, &x)?;
x = addmm(&a_x, &big_b, &x, 1.0, 1.0)?;
}
if transpose_needed {
x = x.transpose()?;
}
Ok(x)
}
}
impl Optimizer for Muon {
fn init(&mut self, params: &Weights) -> Result<()> {
self.state = zeros_like_map(params)?;
Ok(())
}
fn preflight(&mut self) -> Result<()> {
if self.lr_resolved_for_step == Some(self.step_count) {
return Ok(()); }
self.current_lr = self.learning_rate.try_current(self.step_count)?;
self.lr_resolved_for_step = Some(self.step_count);
Ok(())
}
fn apply_gradients(&mut self, gradients: &Weights, params: &mut Weights) -> Result<()> {
if self.state.is_empty() {
self.init(gradients)?;
}
self.preflight()?;
self.step_count += 1;
let mu_s = scalar(self.momentum)?;
let one_minus_mu = scalar(1.0 - self.momentum)?;
for (key, grad) in gradients {
let Some(param) = params.get(key) else {
continue;
};
let g_eff = if self.weight_decay != 0.0 {
let wd_s = scalar(self.weight_decay)?;
let wd_term = arithmetic::multiply(&wd_s, param)?;
arithmetic::add(grad, &wd_term)?
} else {
grad.try_clone()?
};
let prev_v = match self.state.get(key) {
Some(v) => v.try_clone()?,
None => zeros_like(param)?,
};
let v_scaled = arithmetic::multiply(&mu_s, &prev_v)?;
let g_scaled = arithmetic::multiply(&one_minus_mu, &g_eff)?;
let v_new = arithmetic::add(&v_scaled, &g_scaled)?;
let mut update = if self.nesterov {
let g_term = arithmetic::multiply(&g_eff, &one_minus_mu)?;
let v_term = arithmetic::multiply(&v_new, &mu_s)?;
arithmetic::add(&g_term, &v_term)?
} else {
v_new.try_clone()?
};
let mut lr_eff = self.current_lr;
let original_shape = update.shape();
if update.ndim() >= 2 {
let reshape_needed = update.ndim() > 2;
if reshape_needed {
let m_dim = original_shape[0];
let n_dim: usize = original_shape[1..].iter().product();
update = reshape(&update, &(m_dim, n_dim))?;
}
update = self.newton_schulz5(&update, self.ns_steps)?;
if reshape_needed {
update = reshape(&update, &original_shape.as_slice())?;
}
let updated_shape = update.shape();
let m_d = updated_shape[updated_shape.len() - 2] as f32;
let n_d = updated_shape[updated_shape.len() - 1] as f32;
lr_eff *= (1.0_f32.max(m_d / n_d)).sqrt();
}
let lr_s = scalar(lr_eff)?;
let step_term = arithmetic::multiply(&lr_s, &update)?;
let new_w = arithmetic::subtract(param, &step_term)?;
params.insert(key.clone(), new_w);
self.state.insert(key.clone(), v_new);
}
Ok(())
}
fn step(&self) -> usize {
self.step_count
}
fn learning_rate(&self) -> f32 {
self.current_lr
}
}
#[cfg(test)]
mod tests;