use super::*;
fn scalar(v: f32) -> Result<Array> {
Array::full::<f32>(&[0i32; 0], v)
}
fn read_scalar(a: &Array) -> Result<f32> {
let mut clone = a.try_clone()?;
clone.item::<f32>()
}
#[test]
fn vanilla_sgd_single_step_matches_python_ref() -> Result<()> {
let mut sgd = SGD::vanilla(0.1)?;
let mut params: Weights = HashMap::new();
params.insert("w".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("w".into(), scalar(0.5)?);
sgd.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["w"])?;
assert!((got - 0.95).abs() < 1e-6, "expected 0.95, got {got}");
assert_eq!(sgd.step(), 1);
assert!((sgd.learning_rate() - 0.1).abs() < 1e-6);
Ok(())
}
#[test]
fn sgd_with_momentum_single_step_matches_python_ref() -> Result<()> {
let mut sgd = SGD::new(0.1, 0.9, 0.0, 0.0, false)?;
let mut params: Weights = HashMap::new();
params.insert("w".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("w".into(), scalar(0.5)?);
sgd.apply_gradients(&grads, &mut params)?;
assert!((read_scalar(¶ms["w"])? - 0.95).abs() < 1e-6);
let v = read_scalar(&sgd.state["w"])?;
assert!((v - 0.5).abs() < 1e-6, "expected v=0.5, got {v}");
Ok(())
}
#[test]
fn sgd_with_weight_decay_matches_python_ref() -> Result<()> {
let mut sgd = SGD::new(0.1, 0.0, 0.5, 0.0, false)?;
let mut params: Weights = HashMap::new();
params.insert("w".into(), scalar(2.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("w".into(), scalar(1.0)?);
sgd.apply_gradients(&grads, &mut params)?;
assert!((read_scalar(¶ms["w"])? - 1.8).abs() < 1e-6);
Ok(())
}
#[test]
fn sgd_nesterov_precondition_rejects_zero_momentum() {
match SGD::new(0.1, 0.0, 0.0, 0.0, true) {
Err(Error::InvariantViolation(payload)) => {
assert_eq!(payload.context(), "SGD: Nesterov momentum");
assert_eq!(
payload.requirement(),
"requires momentum > 0 (finite) and dampening == 0"
);
}
Err(other) => panic!("expected InvariantViolation, got: {other:?}"),
Ok(_) => panic!("nesterov with momentum=0 must be rejected"),
}
}
#[test]
fn sgd_builder_with_nesterov_rejects_zero_momentum() {
let res = SGD::vanilla(0.1).and_then(|s| s.with_nesterov(true));
assert!(res.is_err());
}
#[test]
fn sgd_builder_with_momentum_zero_rejects_nesterov() {
let res = SGD::new(0.1, 0.9, 0.0, 0.0, true).and_then(|s| s.with_momentum(0.0));
assert!(res.is_err());
}
#[test]
fn sgd_builder_with_dampening_rejects_nesterov() {
let res = SGD::new(0.1, 0.9, 0.0, 0.0, true).and_then(|s| s.with_dampening(0.1));
assert!(res.is_err());
}
#[test]
fn sgd_schedule_advances_lr_each_step() -> Result<()> {
let sched: Box<dyn Fn(usize) -> f32> = Box::new(|step| 0.1 / (step as f32).max(1.0));
let mut sgd = SGD::vanilla(LearningRate::Schedule(sched))?;
let mut params: Weights = HashMap::new();
params.insert("w".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("w".into(), scalar(1.0)?);
sgd.apply_gradients(&grads, &mut params)?;
assert!((sgd.learning_rate() - 0.1).abs() < 1e-6);
sgd.apply_gradients(&grads, &mut params)?;
assert!((sgd.learning_rate() - 0.1).abs() < 1e-6);
sgd.apply_gradients(&grads, &mut params)?;
assert!((sgd.learning_rate() - 0.05).abs() < 1e-6);
Ok(())
}
#[test]
fn optimizer_lr_schedule_resolves_at_pre_increment_step() -> Result<()> {
let sched: Box<dyn Fn(usize) -> f32> = Box::new(|step| step as f32);
let mut sgd = SGD::vanilla(LearningRate::Schedule(sched))?;
let mut params: Weights = HashMap::new();
params.insert("w".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("w".into(), scalar(0.0)?); sgd.apply_gradients(&grads, &mut params)?;
assert!(
(sgd.learning_rate() - 0.0).abs() < 1e-6,
"first call must see step 0, got {}",
sgd.learning_rate()
);
sgd.apply_gradients(&grads, &mut params)?;
assert!(
(sgd.learning_rate() - 1.0).abs() < 1e-6,
"second call must see step 1, got {}",
sgd.learning_rate()
);
sgd.apply_gradients(&grads, &mut params)?;
assert!(
(sgd.learning_rate() - 2.0).abs() < 1e-6,
"third call must see step 2, got {}",
sgd.learning_rate()
);
assert_eq!(sgd.step(), 3);
Ok(())
}
#[test]
fn sgd_new_rejects_nan_weight_decay() {
assert!(SGD::new(0.1, 0.0, f32::NAN, 0.0, false).is_err());
}
#[test]
fn sgd_new_rejects_inf_weight_decay() {
assert!(SGD::new(0.1, 0.0, f32::INFINITY, 0.0, false).is_err());
}
#[test]
fn sgd_new_rejects_negative_weight_decay() {
assert!(SGD::new(0.1, 0.0, -0.1, 0.0, false).is_err());
}
#[test]
fn sgd_with_weight_decay_rejects_nan() {
let res = SGD::vanilla(0.1).and_then(|s| s.with_weight_decay(f32::NAN));
assert!(res.is_err());
}
#[test]
fn sgd_with_weight_decay_rejects_inf() {
let res = SGD::vanilla(0.1).and_then(|s| s.with_weight_decay(f32::INFINITY));
assert!(res.is_err());
}
#[test]
fn sgd_with_weight_decay_rejects_negative() {
let res = SGD::vanilla(0.1).and_then(|s| s.with_weight_decay(-0.1));
assert!(res.is_err());
}
#[test]
fn sgd_new_rejects_nan_dampening() {
assert!(SGD::new(0.1, 0.0, 0.0, f32::NAN, false).is_err());
}
#[test]
fn sgd_new_rejects_inf_dampening() {
assert!(SGD::new(0.1, 0.0, 0.0, f32::INFINITY, false).is_err());
}
#[test]
fn sgd_new_rejects_negative_dampening() {
assert!(SGD::new(0.1, 0.0, 0.0, -0.1, false).is_err());
}
#[test]
fn sgd_with_dampening_rejects_nan() {
let res = SGD::vanilla(0.1).and_then(|s| s.with_dampening(f32::NAN));
assert!(res.is_err());
}
#[test]
fn sgd_with_dampening_rejects_inf() {
let res = SGD::vanilla(0.1).and_then(|s| s.with_dampening(f32::INFINITY));
assert!(res.is_err());
}
#[test]
fn sgd_with_dampening_rejects_negative() {
let res = SGD::vanilla(0.1).and_then(|s| s.with_dampening(-0.1));
assert!(res.is_err());
}
#[test]
fn sgd_validate_nesterov_rejects_nan_momentum() {
assert!(SGD::new(0.1, f32::NAN, 0.0, 0.0, true).is_err());
let with_path = SGD::new(0.1, 0.9, 0.0, 0.0, false)
.unwrap()
.with_momentum(f32::NAN);
assert!(with_path.is_err());
}
#[test]
fn sgd_rejects_nan_momentum_even_without_nesterov() {
assert!(SGD::new(0.1, f32::NAN, 0.0, 0.0, false).is_err());
assert!(SGD::new(0.1, f32::INFINITY, 0.0, 0.0, false).is_err());
let with_path = SGD::vanilla(0.1).unwrap().with_momentum(f32::NAN);
assert!(with_path.is_err());
}
#[test]
fn sgd_with_learning_rate_rejects_fixed_nan() {
let res = SGD::vanilla(0.1).and_then(|s| s.with_learning_rate(LearningRate::Fixed(f32::NAN)));
assert!(res.is_err(), "with_learning_rate must reject Fixed(NaN)");
}
#[test]
fn sgd_getters_echo_inputs() -> Result<()> {
let sgd = SGD::new(LearningRate::Fixed(0.25), 0.8, 0.05, 0.1, false)?;
assert!(
sgd.learning_rate_ref().is_fixed(),
"learning_rate_ref must echo the Fixed schedule"
);
assert_eq!(sgd.momentum(), 0.8);
assert_eq!(sgd.weight_decay(), 0.05);
assert_eq!(sgd.dampening(), 0.1);
assert!(!sgd.nesterov());
assert_eq!(sgd.learning_rate(), 0.25);
assert_eq!(sgd.step(), 0);
Ok(())
}
#[test]
fn sgd_nesterov_getter_reflects_true() -> Result<()> {
let sgd = SGD::new(0.1, 0.9, 0.0, 0.0, true)?;
assert!(sgd.nesterov());
Ok(())
}
#[test]
fn sgd_builder_success_paths_echo() -> Result<()> {
let sgd = SGD::vanilla(0.1)?
.with_learning_rate(LearningRate::Fixed(0.05))?
.with_momentum(0.7)?
.with_weight_decay(0.2)?
.with_dampening(0.3)?;
assert_eq!(sgd.learning_rate(), 0.05);
assert!(sgd.learning_rate_ref().is_fixed());
assert_eq!(sgd.momentum(), 0.7);
assert_eq!(sgd.weight_decay(), 0.2);
assert_eq!(sgd.dampening(), 0.3);
Ok(())
}
#[test]
fn sgd_with_nesterov_success_enables_flag() -> Result<()> {
let sgd = SGD::new(0.1, 0.9, 0.0, 0.0, false)?.with_nesterov(true)?;
assert!(sgd.nesterov());
Ok(())
}
#[test]
fn sgd_with_dampening_single_step_matches_python_ref() -> Result<()> {
let mut sgd = SGD::new(0.1, 0.9, 0.0, 0.5, false)?;
let mut params: Weights = HashMap::new();
params.insert("w".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("w".into(), scalar(0.5)?);
sgd.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["w"])?;
assert!((got - 0.975).abs() < 1e-6, "expected 0.975, got {got}");
let v = read_scalar(&sgd.state["w"])?;
assert!((v - 0.25).abs() < 1e-6, "expected v=0.25, got {v}");
Ok(())
}
#[test]
fn sgd_nesterov_single_step_matches_python_ref() -> Result<()> {
let mut sgd = SGD::new(0.1, 0.9, 0.0, 0.0, true)?;
let mut params: Weights = HashMap::new();
params.insert("w".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("w".into(), scalar(0.5)?);
sgd.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["w"])?;
assert!((got - 0.905).abs() < 1e-6, "expected 0.905, got {got}");
Ok(())
}
#[test]
fn sgd_momentum_step_none_state_arm_via_uninit_grad_key() -> Result<()> {
let mut sgd = SGD::new(0.1, 0.9, 0.0, 0.0, false)?;
let mut init_params: Weights = HashMap::new();
init_params.insert("a".into(), scalar(1.0)?);
sgd.init(&init_params)?;
assert!(
!sgd.state.is_empty(),
"explicit init populated state for 'a'"
);
let mut params: Weights = HashMap::new();
params.insert("a".into(), scalar(1.0)?);
params.insert("b".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("a".into(), scalar(0.5)?);
grads.insert("b".into(), scalar(0.5)?);
sgd.apply_gradients(&grads, &mut params)?;
let got_b = read_scalar(¶ms["b"])?;
assert!((got_b - 0.95).abs() < 1e-6, "b got {got_b}");
Ok(())
}
#[test]
fn sgd_skips_grad_key_absent_from_params() -> Result<()> {
let mut sgd = SGD::vanilla(0.1)?;
let mut params: Weights = HashMap::new();
params.insert("present".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("present".into(), scalar(0.5)?);
grads.insert("absent".into(), scalar(0.5)?);
sgd.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["present"])?;
assert!((got - 0.95).abs() < 1e-6, "present got {got}");
assert!(
!params.contains_key("absent"),
"absent grad must not be added to params"
);
Ok(())
}
#[test]
fn sgd_apply_gradients_rejects_schedule_that_goes_nan() -> Result<()> {
let sched: Box<dyn Fn(usize) -> f32> = Box::new(|step| if step == 0 { 0.1 } else { f32::NAN });
let mut sgd = SGD::vanilla(LearningRate::Schedule(sched))?;
let mut params: Weights = HashMap::new();
params.insert("w".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("w".into(), scalar(0.5)?);
sgd.apply_gradients(&grads, &mut params)?;
let w_before = {
let mut c = params["w"].try_clone()?;
c.item::<f32>()?
};
let result = sgd.apply_gradients(&grads, &mut params);
assert!(
result.is_err(),
"apply_gradients must reject schedule NaN at step 1"
);
let w_after = {
let mut c = params["w"].try_clone()?;
c.item::<f32>()?
};
assert!(
(w_before - w_after).abs() < 1e-9,
"params must not be mutated when LR goes NaN: before={w_before}, after={w_after}"
);
Ok(())
}