use super::*;
fn read_scalar(a: &Array) -> Result<f32> {
let mut clone = a.try_clone()?;
clone.item::<f32>()
}
#[test]
fn muon_new_rejects_non_finite_momentum() {
assert!(Muon::new(0.01, f32::NAN, 0.01, true, 5).is_err());
assert!(Muon::new(0.01, f32::INFINITY, 0.01, true, 5).is_err());
}
#[test]
fn muon_new_rejects_negative_weight_decay() {
assert!(Muon::new(0.01, 0.95, -0.1, true, 5).is_err());
}
#[test]
fn muon_new_rejects_non_finite_weight_decay() {
assert!(Muon::new(0.01, 0.95, f32::NAN, true, 5).is_err());
}
#[test]
fn muon_with_momentum_rejects_non_finite() {
let res = Muon::default_with_lr(0.01).and_then(|m| m.with_momentum(f32::NAN));
assert!(res.is_err());
}
#[test]
fn muon_with_weight_decay_rejects_negative() {
let res = Muon::default_with_lr(0.01).and_then(|m| m.with_weight_decay(-0.1));
assert!(res.is_err());
}
#[test]
fn muon_with_weight_decay_rejects_non_finite() {
let res = Muon::default_with_lr(0.01).and_then(|m| m.with_weight_decay(f32::INFINITY));
assert!(res.is_err());
}
#[test]
fn muon_with_learning_rate_rejects_fixed_nan() {
let res =
Muon::default_with_lr(0.01).and_then(|m| m.with_learning_rate(LearningRate::Fixed(f32::NAN)));
assert!(res.is_err(), "with_learning_rate must reject Fixed(NaN)");
}
#[test]
fn muon_1d_param_runs_without_newton_schulz() -> Result<()> {
let mut muon = Muon::default_with_lr(0.01)?;
let mut params: Weights = HashMap::new();
params.insert(
"w".into(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3])?,
);
let mut grads: Weights = HashMap::new();
grads.insert(
"w".into(),
Array::from_slice::<f32>(&[0.1, 0.2, 0.3], &[3])?,
);
muon.apply_gradients(&grads, &mut params)?;
let mut got = params["w"].try_clone()?;
let v: Vec<f32> = got.to_vec()?;
assert!(
(v[0] - 0.999_892_8).abs() < 1e-5,
"expected ~0.9998928, got {}",
v[0]
);
Ok(())
}
#[test]
fn muon_2d_param_invokes_newton_schulz_branch() -> Result<()> {
let mut muon = Muon::new(0.01, 0.0, 0.0, false, 5)?;
let mut params: Weights = HashMap::new();
params.insert(
"w".into(),
Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2))?,
);
let mut grads: Weights = HashMap::new();
grads.insert(
"w".into(),
Array::from_slice::<f32>(&[0.5, 0.0, 0.0, 0.5], &(2, 2))?,
);
muon.apply_gradients(&grads, &mut params)?;
let mut got = params["w"].try_clone()?;
let v: Vec<f32> = got.to_vec()?;
assert!(v[0].is_finite() && v[3].is_finite());
assert!((v[0] - 1.0).abs() > 1e-6 || (v[3] - 1.0).abs() > 1e-6);
Ok(())
}
#[test]
fn muon_getters_echo_inputs() -> Result<()> {
let muon = Muon::new(LearningRate::Fixed(0.25), 0.8, 0.02, false, 7)?;
assert!(
muon.learning_rate_ref().is_fixed(),
"learning_rate_ref must echo the Fixed schedule"
);
assert_eq!(muon.momentum(), 0.8);
assert_eq!(muon.weight_decay(), 0.02);
assert!(
!muon.nesterov(),
"nesterov getter must echo the `false` arm"
);
assert_eq!(muon.ns_steps(), 7);
assert_eq!(muon.step(), 0);
assert_eq!(muon.learning_rate(), 0.25);
Ok(())
}
#[test]
fn muon_default_with_lr_getters() -> Result<()> {
let muon = Muon::default_with_lr(0.01)?;
assert_eq!(muon.momentum(), 0.95);
assert_eq!(muon.weight_decay(), 0.01);
assert!(muon.nesterov(), "default nesterov is true");
assert_eq!(muon.ns_steps(), 5);
assert!(muon.learning_rate_ref().is_fixed());
Ok(())
}
#[test]
fn muon_builder_success_paths_echo() -> Result<()> {
let muon = Muon::default_with_lr(0.01)?
.with_learning_rate(LearningRate::Fixed(0.05))?
.with_momentum(0.7)?
.with_weight_decay(0.2)?
.with_nesterov(false)
.with_ns_steps(3);
assert_eq!(muon.learning_rate(), 0.05);
assert!(muon.learning_rate_ref().is_fixed());
assert_eq!(muon.momentum(), 0.7);
assert_eq!(muon.weight_decay(), 0.2);
assert!(!muon.nesterov());
assert_eq!(muon.ns_steps(), 3);
Ok(())
}
#[test]
fn muon_with_nesterov_toggles_on() -> Result<()> {
let muon = Muon::new(0.01, 0.95, 0.01, false, 5)?.with_nesterov(true);
assert!(muon.nesterov());
Ok(())
}
#[test]
fn muon_newton_schulz5_rejects_non_2d() -> Result<()> {
let muon = Muon::default_with_lr(0.01)?;
let x1 = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3])?;
let err1 = muon.newton_schulz5(&x1, 5).unwrap_err();
assert!(
err1.is_rank_mismatch(),
"1D input must yield RankMismatch, got {err1:?}"
);
let x3 = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(1, 2, 2))?;
let err3 = muon.newton_schulz5(&x3, 5).unwrap_err();
assert!(
err3.is_rank_mismatch(),
"3D input must yield RankMismatch, got {err3:?}"
);
Ok(())
}
#[test]
fn muon_newton_schulz5_tall_matrix_transposes_and_preserves_shape() -> Result<()> {
let muon = Muon::default_with_lr(0.01)?;
let tall = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0, 1.0, 0.0], &(3, 2))?;
let out = muon.newton_schulz5(&tall, 5)?;
assert_eq!(
out.shape(),
vec![3, 2],
"tall-matrix output shape preserved"
);
Ok(())
}
#[test]
fn muon_newton_schulz5_wide_matrix_no_transpose_preserves_shape() -> Result<()> {
let muon = Muon::default_with_lr(0.01)?;
let wide = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &(2, 3))?;
let mut out = muon.newton_schulz5(&wide, 5)?;
assert_eq!(
out.shape(),
vec![2, 3],
"wide-matrix output shape preserved"
);
let v: Vec<f32> = out.to_vec()?;
assert!(v.iter().all(|x| x.is_finite()), "all entries finite: {v:?}");
Ok(())
}
#[test]
fn muon_two_steps_preflight_re_resolves() -> Result<()> {
let mut muon = Muon::default_with_lr(0.01)?;
let mut params: Weights = HashMap::new();
params.insert(
"w".into(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3])?,
);
muon.init(¶ms)?;
assert_eq!(muon.step(), 0);
let mut grads: Weights = HashMap::new();
grads.insert(
"w".into(),
Array::from_slice::<f32>(&[0.1, 0.2, 0.3], &[3])?,
);
muon.apply_gradients(&grads, &mut params)?;
let mut w1 = params["w"].try_clone()?;
let after_one: f32 = w1.to_vec::<f32>()?[0];
assert_eq!(muon.step(), 1);
muon.apply_gradients(&grads, &mut params)?;
let mut w2 = params["w"].try_clone()?;
let after_two: f32 = w2.to_vec::<f32>()?[0];
assert_eq!(muon.step(), 2);
assert_eq!(muon.learning_rate(), 0.01);
assert!(after_two < after_one, "weight should keep decreasing");
Ok(())
}
#[test]
fn muon_skips_grad_key_absent_from_params() -> Result<()> {
let mut muon = Muon::default_with_lr(0.01)?;
let mut params: Weights = HashMap::new();
params.insert("present".into(), Array::from_slice::<f32>(&[1.0], &[1])?);
let mut grads: Weights = HashMap::new();
grads.insert("present".into(), Array::from_slice::<f32>(&[0.5], &[1])?);
grads.insert("absent".into(), Array::from_slice::<f32>(&[0.5], &[1])?);
muon.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["present"])?;
assert!((got - 0.999_503).abs() < 1e-5, "present got {got}");
assert!(
!params.contains_key("absent"),
"absent grad must not be added to params"
);
Ok(())
}
#[test]
fn muon_step_none_state_arm_via_uninit_grad_key() -> Result<()> {
let mut muon = Muon::default_with_lr(0.01)?;
let mut init_params: Weights = HashMap::new();
init_params.insert("a".into(), Array::from_slice::<f32>(&[1.0], &[1])?);
muon.init(&init_params)?;
assert!(
!muon.state.is_empty(),
"explicit init populated state for 'a'"
);
let mut params: Weights = HashMap::new();
params.insert("a".into(), Array::from_slice::<f32>(&[1.0], &[1])?);
params.insert("b".into(), Array::from_slice::<f32>(&[1.0], &[1])?);
let mut grads: Weights = HashMap::new();
grads.insert("a".into(), Array::from_slice::<f32>(&[0.5], &[1])?);
grads.insert("b".into(), Array::from_slice::<f32>(&[0.5], &[1])?);
muon.apply_gradients(&grads, &mut params)?;
let got_b = read_scalar(¶ms["b"])?;
assert!((got_b - 0.999_503).abs() < 1e-5, "b got {got_b}");
let got_a = read_scalar(¶ms["a"])?;
assert!((got_a - 0.999_503).abs() < 1e-5, "a got {got_a}");
Ok(())
}
#[test]
fn muon_3d_param_invokes_reshape_branch() -> Result<()> {
let mut muon = Muon::new(0.01, 0.0, 0.0, false, 5)?;
let mut params: Weights = HashMap::new();
params.insert(
"w".into(),
Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2, 1))?,
);
let mut grads: Weights = HashMap::new();
grads.insert(
"w".into(),
Array::from_slice::<f32>(&[0.5, 0.0, 0.0, 0.5], &(2, 2, 1))?,
);
muon.apply_gradients(&grads, &mut params)?;
let mut out = params["w"].try_clone()?;
assert_eq!(out.shape(), vec![2, 2, 1], "3D output shape preserved");
let v: Vec<f32> = out.to_vec()?;
assert!(v.iter().all(|x| x.is_finite()), "all entries finite: {v:?}");
assert!(
(v[0] - 1.0).abs() > 1e-6 || (v[3] - 1.0).abs() > 1e-6,
"reshape+newton-schulz update must move the param: {v:?}"
);
Ok(())
}