use burn_core as burn;
use burn::{module::AutodiffModule, record::Record};
use burn::config::Config;
use burn::tensor::{Tensor, backend::AutodiffBackend};
use burn::tensor::{backend::Backend, ops::Device};
use serde::{Deserialize, Serialize};
use super::{
SimpleOptimizer,
adaptor::OptimizerAdaptor,
decay::WeightDecayConfig,
momentum::{Momentum, MomentumConfig, MomentumState},
};
use crate::LearningRate;
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float as _;
#[derive(Clone, Default, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AdjustLrFn {
#[default]
Original,
MatchRmsAdamW,
}
impl AdjustLrFn {
fn adjustment_ratio(&self, shape: &[usize]) -> f64 {
if shape.len() < 2 {
return 1.0;
}
let a = shape[0] as f64;
let b = shape[1] as f64;
match self {
Self::Original => {
let ratio = a / b;
ratio.max(1.0).sqrt()
}
Self::MatchRmsAdamW => {
0.2 * a.max(b).sqrt()
}
}
}
}
#[derive(Config, Debug)]
pub struct MuonConfig {
weight_decay: Option<WeightDecayConfig>,
#[config(default = "MomentumConfig { momentum: 0.95, dampening: 0.0, nesterov: true }")]
momentum: MomentumConfig,
#[config(default = "(3.4445, -4.775, 2.0315)")]
ns_coefficients: (f32, f32, f32),
#[config(default = 1e-7)]
epsilon: f32,
#[config(default = 5)]
ns_steps: usize,
#[config(default = "AdjustLrFn::Original")]
adjust_lr_fn: AdjustLrFn,
}
impl MuonConfig {
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
&self,
) -> OptimizerAdaptor<Muon<B::InnerBackend>, M, B> {
let momentum = Momentum::new(&self.momentum);
let weight_decay_penalty = self.weight_decay.as_ref().map(|wd| wd.penalty);
let optim = Muon {
momentum,
ns_params: NewtonSchulzParams::new(self.ns_coefficients, self.ns_steps),
weight_decay_penalty,
epsilon: self.epsilon,
adjust_lr_fn: self.adjust_lr_fn,
};
OptimizerAdaptor::from(optim)
}
}
#[derive(Clone, Copy)]
struct NewtonSchulzParams {
a: f32,
b: f32,
c: f32,
steps: usize,
}
impl NewtonSchulzParams {
fn new(coefficients: (f32, f32, f32), steps: usize) -> Self {
Self {
a: coefficients.0,
b: coefficients.1,
c: coefficients.2,
steps,
}
}
}
#[derive(Clone)]
pub struct Muon<B: Backend> {
momentum: Momentum<B>,
ns_params: NewtonSchulzParams,
weight_decay_penalty: Option<f32>,
epsilon: f32,
adjust_lr_fn: AdjustLrFn,
}
impl<B: Backend> Muon<B> {
fn adjust_lr(&self, lr: LearningRate, shape: &[usize]) -> LearningRate {
lr * self.adjust_lr_fn.adjustment_ratio(shape)
}
fn zeropower_via_newtonschulz<const D: usize>(&self, g: Tensor<B, D>) -> Tensor<B, D> {
let shape = g.shape();
let dim_m2 = shape[D - 2];
let dim_m1 = shape[D - 1];
let (mut x, needs_transpose) = if dim_m2 > dim_m1 {
(g.swap_dims(D - 2, D - 1), true)
} else {
(g, false)
};
let norm = x
.clone()
.powf_scalar(2.0)
.sum()
.sqrt()
.clamp_min(self.epsilon)
.unsqueeze();
x = x.div(norm);
let NewtonSchulzParams { a, b, c, steps } = self.ns_params;
for _ in 0..steps {
let x_t = x.clone().swap_dims(D - 2, D - 1);
let a_matrix = x.clone().matmul(x_t);
let a_squared = a_matrix.clone().matmul(a_matrix.clone());
let b_matrix = a_matrix.mul_scalar(b).add(a_squared.mul_scalar(c));
x = x.clone().mul_scalar(a).add(b_matrix.matmul(x.clone()));
}
if needs_transpose {
x = x.swap_dims(D - 2, D - 1);
}
x
}
}
#[derive(Record, Clone, new)]
pub struct MuonState<B: Backend, const D: usize> {
pub momentum: MomentumState<B, D>,
}
impl<B: Backend> SimpleOptimizer<B> for Muon<B> {
type State<const D: usize> = MuonState<B, D>;
fn step<const D: usize>(
&self,
lr: LearningRate,
tensor: Tensor<B, D>,
grad: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>) {
assert!(
D == 2,
"Newton-Schulz iteration requires 2D tensors, got {}D",
D
);
let state_momentum = state.map(|s| s.momentum);
let (grad, new_momentum_state) = self.momentum.transform(grad, state_momentum);
let update = self.zeropower_via_newtonschulz(grad);
let adjusted_lr = self.adjust_lr(lr, &tensor.shape());
let tensor = if let Some(penalty) = self.weight_decay_penalty {
let decay_factor = 1.0 - lr * penalty as f64;
tensor.mul_scalar(decay_factor)
} else {
tensor
};
let delta = update.mul_scalar(adjusted_lr);
let new_state = MuonState::new(new_momentum_state);
(tensor - delta, Some(new_state))
}
fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
state.momentum = state.momentum.to_device(device);
state
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestAutodiffBackend;
use crate::{GradientsParams, Optimizer};
use burn::module::{Module, Param};
use burn::tensor::{Distribution, Tensor, TensorData};
use burn_nn::{Linear, LinearConfig, LinearRecord};
type TestBackend = burn_ndarray::NdArray<f32>;
const TOLERANCE: f64 = 1e-8;
fn given_linear_layer_no_bias(weight: TensorData) -> Linear<TestAutodiffBackend> {
let device = Default::default();
let record = LinearRecord {
weight: Param::from_data(weight, &device),
bias: None, };
LinearConfig::new(4, 4)
.with_bias(false)
.init(&device)
.load_record(record)
}
#[test]
fn test_adjust_lr_fn_original() {
let method = AdjustLrFn::Original;
let ratio = method.adjustment_ratio(&[512, 512]);
assert!((ratio - 1.0).abs() < TOLERANCE);
let ratio = method.adjustment_ratio(&[1024, 512]);
let expected = (2.0f64).sqrt();
assert!((ratio - expected).abs() < TOLERANCE);
let ratio = method.adjustment_ratio(&[512, 1024]);
assert!((ratio - 1.0).abs() < TOLERANCE);
}
#[test]
fn test_adjust_lr_fn_match_rms_adamw() {
let method = AdjustLrFn::MatchRmsAdamW;
let ratio = method.adjustment_ratio(&[1024, 512]);
let expected = 0.2 * 1024.0f64.sqrt();
assert!((ratio - expected).abs() < TOLERANCE);
let ratio = method.adjustment_ratio(&[512, 512]);
let expected = 0.2 * 512.0f64.sqrt();
assert!((ratio - expected).abs() < TOLERANCE);
}
#[test]
#[should_panic(expected = "Newton-Schulz iteration requires 2D tensors, got 1D")]
fn test_1d_tensor_panics() {
let device = Default::default();
let config = MuonConfig::new();
let optim: Muon<TestBackend> = Muon {
momentum: Momentum::new(&config.momentum),
ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
weight_decay_penalty: None,
epsilon: config.epsilon,
adjust_lr_fn: config.adjust_lr_fn,
};
let tensor_1d = Tensor::<TestBackend, 1>::zeros([512], &device);
let grad_1d = Tensor::<TestBackend, 1>::ones([512], &device);
let _ = optim.step(0.01, tensor_1d, grad_1d, None);
}
#[test]
fn test_muon_optimizer_save_load_state() {
let device = Default::default();
let linear = LinearConfig::new(6, 6)
.with_bias(false) .init::<TestAutodiffBackend>(&device);
let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
let mut optimizer =
MuonConfig::new().init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
let grads = linear.forward(x).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let _linear = optimizer.step(0.01, linear, grads);
let state_before = optimizer.to_record();
let state_before_copy = optimizer.to_record();
let optimizer_new =
MuonConfig::new().init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
let optimizer_loaded = optimizer_new.load_record(state_before_copy);
let state_after = optimizer_loaded.to_record();
assert_eq!(state_before.len(), state_after.len());
}
#[test]
fn test_muon_with_weight_decay() {
let device = Default::default();
let linear = given_linear_layer_no_bias(TensorData::from([
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
]));
let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
[[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]],
&device,
)
.require_grad();
let mut optimizer = MuonConfig::new()
.with_weight_decay(Some(WeightDecayConfig::new(0.01)))
.init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
let grads = linear.forward(x).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(0.01, linear, grads);
let state = linear.into_record();
let weight = state.weight.to_data();
for val in weight.as_slice::<f32>().unwrap() {
assert!(
*val < 1.0,
"Weight should be reduced by weight decay, got {}",
val
);
}
}
#[test]
fn test_newton_schulz_orthogonalization() {
let device = Default::default();
let matrix = Tensor::<TestBackend, 2>::from_floats([[1.0, 0.5], [0.5, 1.0]], &device);
let config = MuonConfig::new();
let muon: Muon<TestBackend> = Muon {
momentum: Momentum::new(&config.momentum),
ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
weight_decay_penalty: None,
epsilon: config.epsilon,
adjust_lr_fn: config.adjust_lr_fn,
};
let orthogonalized = muon.zeropower_via_newtonschulz(matrix);
let o_t = orthogonalized.clone().transpose();
let product = orthogonalized.matmul(o_t);
let data = product.into_data();
let values = data.as_slice::<f32>().unwrap();
assert!(
(values[0] - 1.0).abs() < 0.1,
"Product[0,0] should be ~1.0, got {}",
values[0]
);
assert!(
(values[3] - 1.0).abs() < 0.1,
"Product[1,1] should be ~1.0, got {}",
values[3]
);
}
#[test]
fn test_tall_matrix_transpose() {
let device = Default::default();
let tall_matrix = Tensor::<TestBackend, 2>::from_floats(
[
[1.0, 0.5, 0.3, 0.2],
[0.5, 1.0, 0.4, 0.1],
[0.3, 0.4, 1.0, 0.5],
[0.2, 0.1, 0.5, 1.0],
[0.1, 0.2, 0.3, 0.4],
[0.4, 0.3, 0.2, 0.1],
[0.2, 0.4, 0.1, 0.3],
[0.3, 0.1, 0.4, 0.2],
],
&device,
);
let config = MuonConfig::new();
let muon: Muon<TestBackend> = Muon {
momentum: Momentum::new(&config.momentum),
ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
weight_decay_penalty: None,
epsilon: config.epsilon,
adjust_lr_fn: config.adjust_lr_fn,
};
let orthogonalized = muon.zeropower_via_newtonschulz(tall_matrix.clone());
let original_shape = tall_matrix.shape();
let result_shape = orthogonalized.shape();
assert_eq!(
original_shape.dims::<2>(),
result_shape.dims::<2>(),
"Shape should be preserved: [8, 4]"
);
let original_data = tall_matrix.into_data();
let result_data = orthogonalized.into_data();
assert_ne!(
original_data.as_slice::<f32>().unwrap(),
result_data.as_slice::<f32>().unwrap(),
"Orthogonalized matrix should differ from input"
);
let wide_matrix = Tensor::<TestBackend, 2>::from_floats(
[
[1.0, 0.5, 0.3, 0.2, 0.1, 0.4, 0.2, 0.3],
[0.5, 1.0, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1],
[0.3, 0.4, 1.0, 0.5, 0.3, 0.2, 0.1, 0.4],
[0.2, 0.1, 0.5, 1.0, 0.4, 0.1, 0.3, 0.2],
],
&device,
);
let orthogonalized_wide = muon.zeropower_via_newtonschulz(wide_matrix.clone());
let wide_original_shape = wide_matrix.shape();
let wide_result_shape = orthogonalized_wide.shape();
assert_eq!(
wide_original_shape.dims::<2>(),
wide_result_shape.dims::<2>(),
"Wide matrix shape should be preserved: [4, 8]"
);
}
#[test]
fn test_zero_gradient() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 2>::from_floats(
[
[1.0, 0.5, 0.3, 0.2],
[0.5, 1.0, 0.4, 0.1],
[0.3, 0.4, 1.0, 0.5],
[0.2, 0.1, 0.5, 1.0],
],
&device,
);
let zero_grad = Tensor::<TestBackend, 2>::zeros([4, 4], &device);
let config = MuonConfig::new();
let muon: Muon<TestBackend> = Muon {
momentum: Momentum::new(&config.momentum),
ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
weight_decay_penalty: None,
epsilon: config.epsilon,
adjust_lr_fn: config.adjust_lr_fn,
};
let (updated_tensor, state) = muon.step(0.01, tensor.clone(), zero_grad, None);
assert!(state.is_some());
let original_data = tensor.into_data();
let updated_data = updated_tensor.clone().into_data();
let original_vals = original_data.as_slice::<f32>().unwrap();
let updated_vals = updated_data.as_slice::<f32>().unwrap();
for (orig, upd) in original_vals.iter().zip(updated_vals.iter()) {
assert!(
(orig - upd).abs() < 1e-6,
"With zero gradient, tensor should remain unchanged (or very close)"
);
}
for val in updated_vals {
assert!(
!val.is_nan(),
"Result should not contain NaN values with zero gradient"
);
}
let muon_with_decay: Muon<TestBackend> = Muon {
momentum: Momentum::new(&config.momentum),
ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
weight_decay_penalty: Some(0.01),
epsilon: config.epsilon,
adjust_lr_fn: config.adjust_lr_fn,
};
let tensor2 = Tensor::<TestBackend, 2>::from_floats(
[
[1.0, 0.5, 0.3, 0.2],
[0.5, 1.0, 0.4, 0.1],
[0.3, 0.4, 1.0, 0.5],
[0.2, 0.1, 0.5, 1.0],
],
&device,
);
let zero_grad2 = Tensor::<TestBackend, 2>::zeros([4, 4], &device);
let (updated_tensor_decay, _) =
muon_with_decay.step(0.01, tensor2.clone(), zero_grad2, None);
let updated_decay_data = updated_tensor_decay.into_data();
let updated_decay_vals = updated_decay_data.as_slice::<f32>().unwrap();
for val in updated_decay_vals {
assert!(
!val.is_nan(),
"Result should not contain NaN with zero gradient and weight decay"
);
}
let original_vals2 = tensor2.into_data().as_slice::<f32>().unwrap().to_vec();
for (orig, upd) in original_vals2.iter().zip(updated_decay_vals.iter()) {
if orig.abs() > 1e-6 {
assert!(
upd.abs() < orig.abs(),
"Weight decay should reduce magnitude: original={}, updated={}",
orig,
upd
);
}
}
}
}