use super::*;
fn read_scalar(a: &Array) -> Result<f32> {
let mut clone = a.try_clone()?;
clone.item::<f32>()
}
fn p_g(p: f32, g: f32) -> Result<(Weights, Weights)> {
let mut params: Weights = HashMap::new();
params.insert("w".into(), scalar(p)?);
let mut grads: Weights = HashMap::new();
grads.insert("w".into(), scalar(g)?);
Ok((params, grads))
}
#[test]
fn adam_single_step_no_bias_correction_matches_python_ref() -> Result<()> {
let mut adam = Adam::default_with_lr(0.001)?;
let (mut params, grads) = p_g(1.0, 0.5)?;
adam.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["w"])?;
assert!((got - 0.996_837_7).abs() < 1e-5, "got {got}");
Ok(())
}
#[test]
fn adam_bias_correction_step1_matches_python_ref() -> Result<()> {
let mut adam = Adam::new(0.001, (0.9, 0.999), 1e-8, true)?;
let (mut params, grads) = p_g(1.0, 0.5)?;
adam.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["w"])?;
assert!((got - 0.999).abs() < 1e-4, "got {got}");
Ok(())
}
#[test]
fn adamw_decoupled_weight_decay_applies_before_step() -> Result<()> {
let mut adamw = AdamW::new(0.001, (0.9, 0.999), 1e-8, 0.01, false)?;
let (mut params, grads) = p_g(1.0, 0.5)?;
adamw.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["w"])?;
assert!((got - 0.996_827_7).abs() < 1e-5, "got {got}");
Ok(())
}
#[test]
fn adamax_single_step_matches_python_ref() -> Result<()> {
let mut adamax = Adamax::default_with_lr(0.001)?;
let (mut params, grads) = p_g(1.0, 0.5)?;
adamax.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["w"])?;
assert!((got - 0.9999).abs() < 1e-5, "got {got}");
Ok(())
}
#[test]
fn adam_two_consecutive_steps_advance_state() -> Result<()> {
let mut adam = Adam::default_with_lr(0.001)?;
let (mut params, grads) = p_g(1.0, 0.5)?;
adam.apply_gradients(&grads, &mut params)?;
let after_one = read_scalar(¶ms["w"])?;
adam.apply_gradients(&grads, &mut params)?;
let after_two = read_scalar(¶ms["w"])?;
assert!(after_two < after_one, "weight should keep decreasing");
assert_eq!(adam.step(), 2);
Ok(())
}
#[test]
fn adamax_builder_with_eps_rejects_negative() {
let res = Adamax::default_with_lr(0.001).and_then(|a| a.with_eps(-1e-8));
assert!(res.is_err());
}
#[test]
fn adam_new_rejects_betas_above_one() {
assert!(Adam::new(0.001, (0.9, 1.1), 1e-8, false).is_err());
assert!(Adam::new(0.001, (1.0, 0.999), 1e-8, false).is_err());
}
#[test]
fn adam_new_rejects_betas_negative() {
assert!(Adam::new(0.001, (-0.1, 0.999), 1e-8, false).is_err());
assert!(Adam::new(0.001, (0.9, -0.1), 1e-8, false).is_err());
}
#[test]
fn adam_new_rejects_non_finite_betas() {
assert!(Adam::new(0.001, (f32::NAN, 0.999), 1e-8, false).is_err());
assert!(Adam::new(0.001, (0.9, f32::INFINITY), 1e-8, false).is_err());
}
#[test]
fn adam_with_betas_rejects_above_one() {
let res = Adam::default_with_lr(0.001).and_then(|a| a.with_betas((0.9, 1.1)));
assert!(res.is_err());
}
#[test]
fn adam_with_betas_rejects_non_finite() {
let res = Adam::default_with_lr(0.001).and_then(|a| a.with_betas((f32::NAN, 0.999)));
assert!(res.is_err());
}
#[test]
fn adam_with_eps_rejects_negative() {
let res = Adam::default_with_lr(0.001).and_then(|a| a.with_eps(-1e-8));
assert!(res.is_err());
}
#[test]
fn adam_with_eps_rejects_non_finite() {
let res = Adam::default_with_lr(0.001).and_then(|a| a.with_eps(f32::NAN));
assert!(res.is_err());
}
#[test]
fn adamw_new_rejects_negative_weight_decay() {
assert!(AdamW::new(0.001, (0.9, 0.999), 1e-8, -0.01, false).is_err());
}
#[test]
fn adamw_new_rejects_non_finite_weight_decay() {
assert!(AdamW::new(0.001, (0.9, 0.999), 1e-8, f32::NAN, false).is_err());
}
#[test]
fn adamw_with_weight_decay_rejects_negative() {
let res = AdamW::default_with_lr(0.001).and_then(|a| a.with_weight_decay(-0.1));
assert!(res.is_err());
}
#[test]
fn adamw_with_weight_decay_rejects_non_finite() {
let res = AdamW::default_with_lr(0.001).and_then(|a| a.with_weight_decay(f32::INFINITY));
assert!(res.is_err());
}
#[test]
fn adamax_new_rejects_betas_above_one() {
assert!(Adamax::new(0.001, (0.9, 1.1), 1e-8).is_err());
assert!(Adamax::new(0.001, (1.0, 0.999), 1e-8).is_err());
}
#[test]
fn adamax_new_rejects_non_finite_betas() {
assert!(Adamax::new(0.001, (f32::NAN, 0.999), 1e-8).is_err());
assert!(Adamax::new(0.001, (0.9, f32::INFINITY), 1e-8).is_err());
}
#[test]
fn adamax_with_betas_rejects_above_one() {
let res = Adamax::default_with_lr(0.001).and_then(|a| a.with_betas((0.9, 1.1)));
assert!(res.is_err());
}
#[test]
fn adamax_with_betas_rejects_non_finite() {
let res = Adamax::default_with_lr(0.001).and_then(|a| a.with_betas((f32::NAN, 0.999)));
assert!(res.is_err());
}
#[test]
fn adamax_with_eps_rejects_non_finite() {
let res = Adamax::default_with_lr(0.001).and_then(|a| a.with_eps(f32::NAN));
assert!(res.is_err());
}
#[test]
fn adam_with_learning_rate_rejects_fixed_nan() {
let res =
Adam::default_with_lr(0.001).and_then(|a| a.with_learning_rate(LearningRate::Fixed(f32::NAN)));
assert!(
res.is_err(),
"Adam::with_learning_rate must reject Fixed(NaN)"
);
}
#[test]
fn adamw_with_learning_rate_rejects_fixed_nan() {
let res =
AdamW::default_with_lr(0.001).and_then(|a| a.with_learning_rate(LearningRate::Fixed(f32::NAN)));
assert!(
res.is_err(),
"AdamW::with_learning_rate must reject Fixed(NaN)"
);
}
#[test]
fn adamax_with_learning_rate_rejects_fixed_nan() {
let res = Adamax::default_with_lr(0.001)
.and_then(|a| a.with_learning_rate(LearningRate::Fixed(f32::NAN)));
assert!(
res.is_err(),
"Adamax::with_learning_rate must reject Fixed(NaN)"
);
}
#[test]
fn adam_getters_echo_inputs() -> Result<()> {
let adam = Adam::new(LearningRate::Fixed(0.25), (0.8, 0.99), 1e-6, true)?;
assert!(
adam.learning_rate_ref().is_fixed(),
"learning_rate_ref must echo the Fixed schedule"
);
assert_eq!(adam.betas(), (0.8, 0.99));
assert_eq!(adam.eps(), 1e-6);
assert!(adam.bias_correction());
assert_eq!(adam.learning_rate(), 0.25);
assert_eq!(adam.step(), 0);
Ok(())
}
#[test]
fn adam_default_with_lr_getters() -> Result<()> {
let adam = Adam::default_with_lr(0.001)?;
assert_eq!(adam.betas(), (0.9, 0.999));
assert_eq!(adam.eps(), 1e-8);
assert!(!adam.bias_correction());
assert!(adam.learning_rate_ref().is_fixed());
Ok(())
}
#[test]
fn adam_builder_success_paths_echo() -> Result<()> {
let adam = Adam::default_with_lr(0.001)?
.with_learning_rate(LearningRate::Fixed(0.05))?
.with_betas((0.7, 0.95))?
.with_eps(2e-7)?
.with_bias_correction(true);
assert_eq!(adam.learning_rate(), 0.05);
assert!(adam.learning_rate_ref().is_fixed());
assert_eq!(adam.betas(), (0.7, 0.95));
assert_eq!(adam.eps(), 2e-7);
assert!(adam.bias_correction());
Ok(())
}
#[test]
fn adam_with_bias_correction_toggles_off() -> Result<()> {
let adam = Adam::new(0.001, (0.9, 0.999), 1e-8, true)?.with_bias_correction(false);
assert!(!adam.bias_correction());
Ok(())
}
#[test]
fn adam_step_none_state_arm_via_uninit_grad_key() -> Result<()> {
let mut adam = Adam::default_with_lr(0.001)?;
let mut init_params: Weights = HashMap::new();
init_params.insert("a".into(), scalar(1.0)?);
adam.init(&init_params)?;
assert!(
!adam.state.is_empty(),
"explicit init populated state for 'a'"
);
let mut params: Weights = HashMap::new();
params.insert("a".into(), scalar(1.0)?);
params.insert("b".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("a".into(), scalar(0.5)?);
grads.insert("b".into(), scalar(0.5)?);
adam.apply_gradients(&grads, &mut params)?;
let got_b = read_scalar(¶ms["b"])?;
assert!((got_b - 0.996_837_7).abs() < 1e-5, "b got {got_b}");
let got_a = read_scalar(¶ms["a"])?;
assert!((got_a - 0.996_837_7).abs() < 1e-5, "a got {got_a}");
Ok(())
}
#[test]
fn adam_skips_grad_key_absent_from_params() -> Result<()> {
let mut adam = Adam::default_with_lr(0.001)?;
let mut params: Weights = HashMap::new();
params.insert("present".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("present".into(), scalar(0.5)?);
grads.insert("absent".into(), scalar(0.5)?);
adam.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["present"])?;
assert!((got - 0.996_837_7).abs() < 1e-5, "present got {got}");
assert!(
!params.contains_key("absent"),
"absent grad must not be added to params"
);
Ok(())
}
#[test]
fn adamw_getters_echo_inputs() -> Result<()> {
let adamw = AdamW::new(LearningRate::Fixed(0.02), (0.8, 0.99), 1e-7, 0.05, true)?;
assert_eq!(adamw.weight_decay(), 0.05);
assert_eq!(adamw.learning_rate(), 0.02);
assert_eq!(adamw.step(), 0);
Ok(())
}
#[test]
fn adamw_default_with_lr_weight_decay_default() -> Result<()> {
let adamw = AdamW::default_with_lr(0.001)?;
assert_eq!(adamw.weight_decay(), 0.01);
Ok(())
}
#[test]
fn adamw_builder_success_paths_echo() -> Result<()> {
let adamw = AdamW::default_with_lr(0.001)?
.with_learning_rate(LearningRate::Fixed(0.03))?
.with_weight_decay(0.2)?;
assert_eq!(adamw.learning_rate(), 0.03);
assert_eq!(adamw.weight_decay(), 0.2);
Ok(())
}
#[test]
fn adamw_init_then_step_advances_trait_methods() -> Result<()> {
let mut adamw = AdamW::default_with_lr(0.001)?;
let (mut params, grads) = p_g(1.0, 0.5)?;
adamw.init(¶ms)?;
assert_eq!(adamw.step(), 0);
adamw.apply_gradients(&grads, &mut params)?;
assert_eq!(adamw.step(), 1);
assert_eq!(adamw.learning_rate(), 0.001);
let got = read_scalar(¶ms["w"])?;
assert!((got - 0.996_827_7).abs() < 1e-5, "got {got}");
Ok(())
}
#[test]
fn adamw_skips_grad_key_absent_from_params() -> Result<()> {
let mut adamw = AdamW::new(0.001, (0.9, 0.999), 1e-8, 0.01, false)?;
let mut params: Weights = HashMap::new();
params.insert("present".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("present".into(), scalar(0.5)?);
grads.insert("absent".into(), scalar(0.5)?);
adamw.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["present"])?;
assert!((got - 0.996_827_7).abs() < 1e-5, "present got {got}");
assert!(
!params.contains_key("absent"),
"absent grad must not be added to params"
);
Ok(())
}
#[test]
fn adamw_two_steps_advance_state() -> Result<()> {
let mut adamw = AdamW::default_with_lr(0.001)?;
let (mut params, grads) = p_g(1.0, 0.5)?;
adamw.apply_gradients(&grads, &mut params)?;
let after_one = read_scalar(¶ms["w"])?;
adamw.apply_gradients(&grads, &mut params)?;
let after_two = read_scalar(¶ms["w"])?;
assert!(after_two < after_one, "weight should keep decreasing");
assert_eq!(adamw.step(), 2);
Ok(())
}
#[test]
fn adamax_getters_echo_inputs() -> Result<()> {
let adamax = Adamax::new(LearningRate::Fixed(0.02), (0.8, 0.99), 1e-7)?;
assert!(adamax.learning_rate_ref().is_fixed());
assert_eq!(adamax.betas(), (0.8, 0.99));
assert_eq!(adamax.eps(), 1e-7);
assert_eq!(adamax.learning_rate(), 0.02);
assert_eq!(adamax.step(), 0);
Ok(())
}
#[test]
fn adamax_default_with_lr_getters() -> Result<()> {
let adamax = Adamax::default_with_lr(0.001)?;
assert_eq!(adamax.betas(), (0.9, 0.999));
assert_eq!(adamax.eps(), 1e-8);
assert!(adamax.learning_rate_ref().is_fixed());
Ok(())
}
#[test]
fn adamax_builder_success_paths_echo() -> Result<()> {
let adamax = Adamax::default_with_lr(0.001)?
.with_learning_rate(LearningRate::Fixed(0.05))?
.with_betas((0.7, 0.95))?
.with_eps(2e-7)?;
assert_eq!(adamax.learning_rate(), 0.05);
assert!(adamax.learning_rate_ref().is_fixed());
assert_eq!(adamax.betas(), (0.7, 0.95));
assert_eq!(adamax.eps(), 2e-7);
Ok(())
}
#[test]
fn adamax_init_then_two_steps_preflight_re_resolves() -> Result<()> {
let mut adamax = Adamax::default_with_lr(0.001)?;
let (mut params, grads) = p_g(1.0, 0.5)?;
adamax.init(¶ms)?;
assert_eq!(adamax.step(), 0);
adamax.apply_gradients(&grads, &mut params)?;
let after_one = read_scalar(¶ms["w"])?;
assert_eq!(adamax.step(), 1);
adamax.apply_gradients(&grads, &mut params)?;
let after_two = read_scalar(¶ms["w"])?;
assert_eq!(adamax.step(), 2);
assert_eq!(adamax.learning_rate(), 0.001);
assert!(after_two < after_one, "weight should keep decreasing");
Ok(())
}
#[test]
fn adamax_step_none_state_arm_via_uninit_grad_key() -> Result<()> {
let mut adamax = Adamax::default_with_lr(0.001)?;
let mut init_params: Weights = HashMap::new();
init_params.insert("a".into(), scalar(1.0)?);
adamax.init(&init_params)?;
let mut params: Weights = HashMap::new();
params.insert("a".into(), scalar(1.0)?);
params.insert("b".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("a".into(), scalar(0.5)?);
grads.insert("b".into(), scalar(0.5)?);
adamax.apply_gradients(&grads, &mut params)?;
let got_b = read_scalar(¶ms["b"])?;
assert!((got_b - 0.9999).abs() < 1e-5, "b got {got_b}");
Ok(())
}
#[test]
fn adamax_skips_grad_key_absent_from_params() -> Result<()> {
let mut adamax = Adamax::default_with_lr(0.001)?;
let mut params: Weights = HashMap::new();
params.insert("present".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("present".into(), scalar(0.5)?);
grads.insert("absent".into(), scalar(0.5)?);
adamax.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["present"])?;
assert!((got - 0.9999).abs() < 1e-5, "present got {got}");
assert!(
!params.contains_key("absent"),
"absent grad must not be added to params"
);
Ok(())
}