use super::*;
#[test]
fn adafactor_new_rejects_negative_eps() {
assert!(
Adafactor::new(
None,
(-1e-30, 1e-3),
1.0,
-0.8,
None,
0.0,
true,
true,
false
)
.is_err()
);
assert!(
Adafactor::new(
None,
(1e-30, -1e-3),
1.0,
-0.8,
None,
0.0,
true,
true,
false
)
.is_err()
);
}
#[test]
fn adafactor_new_rejects_non_finite_eps() {
assert!(
Adafactor::new(
None,
(f32::NAN, 1e-3),
1.0,
-0.8,
None,
0.0,
true,
true,
false
)
.is_err()
);
assert!(
Adafactor::new(
None,
(1e-30, f32::INFINITY),
1.0,
-0.8,
None,
0.0,
true,
true,
false
)
.is_err()
);
}
#[test]
fn adafactor_new_rejects_non_positive_clip_threshold() {
assert!(Adafactor::new(None, (1e-30, 1e-3), 0.0, -0.8, None, 0.0, true, true, false).is_err());
assert!(
Adafactor::new(
None,
(1e-30, 1e-3),
-1.0,
-0.8,
None,
0.0,
true,
true,
false
)
.is_err()
);
}
#[test]
fn adafactor_new_rejects_non_finite_clip_threshold() {
assert!(
Adafactor::new(
None,
(1e-30, 1e-3),
f32::NAN,
-0.8,
None,
0.0,
true,
true,
false
)
.is_err()
);
}
#[test]
fn adafactor_new_rejects_non_finite_decay_rate() {
assert!(
Adafactor::new(
None,
(1e-30, 1e-3),
1.0,
f32::NAN,
None,
0.0,
true,
true,
false
)
.is_err()
);
assert!(
Adafactor::new(
None,
(1e-30, 1e-3),
1.0,
f32::INFINITY,
None,
0.0,
true,
true,
false
)
.is_err()
);
}
#[test]
fn adafactor_new_rejects_negative_weight_decay() {
assert!(
Adafactor::new(
None,
(1e-30, 1e-3),
1.0,
-0.8,
None,
-0.1,
true,
true,
false
)
.is_err()
);
}
#[test]
fn adafactor_new_rejects_non_finite_weight_decay() {
assert!(
Adafactor::new(
None,
(1e-30, 1e-3),
1.0,
-0.8,
None,
f32::NAN,
true,
true,
false
)
.is_err()
);
}
#[test]
fn adafactor_with_eps_rejects_negative() {
let res = Adafactor::default_python().and_then(|a| a.with_eps((-1e-30, 1e-3)));
assert!(res.is_err());
}
#[test]
fn adafactor_with_eps_rejects_non_finite() {
let res = Adafactor::default_python().and_then(|a| a.with_eps((f32::NAN, 1e-3)));
assert!(res.is_err());
}
#[test]
fn adafactor_with_clip_threshold_rejects_non_positive() {
let res = Adafactor::default_python().and_then(|a| a.with_clip_threshold(0.0));
assert!(res.is_err());
let res2 = Adafactor::default_python().and_then(|a| a.with_clip_threshold(-1.0));
assert!(res2.is_err());
}
#[test]
fn adafactor_with_clip_threshold_rejects_non_finite() {
let res = Adafactor::default_python().and_then(|a| a.with_clip_threshold(f32::NAN));
assert!(res.is_err());
}
#[test]
fn adafactor_with_decay_rate_rejects_non_finite() {
let res = Adafactor::default_python().and_then(|a| a.with_decay_rate(f32::NAN));
assert!(res.is_err());
let res2 = Adafactor::default_python().and_then(|a| a.with_decay_rate(f32::INFINITY));
assert!(res2.is_err());
}
#[test]
fn adafactor_with_weight_decay_rejects_negative() {
let res = Adafactor::default_python().and_then(|a| a.with_weight_decay(-0.1));
assert!(res.is_err());
}
#[test]
fn adafactor_with_weight_decay_rejects_non_finite() {
let res = Adafactor::default_python().and_then(|a| a.with_weight_decay(f32::NAN));
assert!(res.is_err());
}
#[test]
fn adafactor_1d_param_runs_one_step_without_error() -> Result<()> {
let mut adafactor = Adafactor::default_python()?;
let mut params: Weights = HashMap::new();
params.insert(
"w".into(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3])?,
);
let mut grads: Weights = HashMap::new();
grads.insert(
"w".into(),
Array::from_slice::<f32>(&[0.1, 0.2, 0.3], &[3])?,
);
adafactor.apply_gradients(&grads, &mut params)?;
let mut got = params["w"].try_clone()?;
let v: Vec<f32> = got.to_vec()?;
assert!(
(v[0] - 1.0).abs() > 1e-8,
"expected w[0] to move, got {}",
v[0]
);
Ok(())
}
#[test]
fn adafactor_2d_param_runs_one_step_without_error() -> Result<()> {
let mut adafactor = Adafactor::default_python()?;
let mut params: Weights = HashMap::new();
params.insert(
"w".into(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2, 2))?,
);
let mut grads: Weights = HashMap::new();
grads.insert(
"w".into(),
Array::from_slice::<f32>(&[0.1, 0.2, 0.3, 0.4], &(2, 2))?,
);
adafactor.apply_gradients(&grads, &mut params)?;
let mut got = params["w"].try_clone()?;
let _: Vec<f32> = got.to_vec()?;
Ok(())
}
#[test]
fn adafactor_new_rejects_nan_beta_1() {
assert!(
Adafactor::new(
None,
(1e-30, 1e-3),
1.0,
-0.8,
Some(f32::NAN),
0.0,
true,
true,
false
)
.is_err()
);
}
#[test]
fn adafactor_with_beta_1_rejects_nan_some() {
let res = Adafactor::default_python().and_then(|a| a.with_beta_1(Some(f32::NAN)));
assert!(res.is_err());
}
#[test]
fn adafactor_with_beta_1_rejects_above_one_some() {
let res = Adafactor::default_python().and_then(|a| a.with_beta_1(Some(1.0)));
assert!(res.is_err());
let res2 = Adafactor::default_python().and_then(|a| a.with_beta_1(Some(1.5)));
assert!(res2.is_err());
}
#[test]
fn adafactor_with_beta_1_accepts_none() -> Result<()> {
let _a = Adafactor::default_python()?.with_beta_1(None)?;
Ok(())
}
#[test]
fn adafactor_try_set_beta_1_rejects_nan_pre_init() {
let mut adafactor = Adafactor::default_python().unwrap();
assert!(adafactor.try_set_beta_1(Some(f32::NAN)).is_err());
}
#[test]
fn adafactor_with_beta_1_rejects_post_init() -> Result<()> {
let adafactor = Adafactor::default_python()?;
let mut params = HashMap::from([(
"w".to_string(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2, 2))?,
)]);
let grads = HashMap::from([(
"w".to_string(),
Array::from_slice::<f32>(&[0.1, 0.2, 0.3, 0.4], &(2, 2))?,
)]);
let mut adafactor = adafactor;
adafactor.apply_gradients(&grads, &mut params)?;
assert!(adafactor.with_beta_1(Some(0.9)).is_err());
Ok(())
}
#[test]
fn adafactor_with_learning_rate_rejects_fixed_nan() {
let res = Adafactor::default_python()
.and_then(|a| a.with_learning_rate(Some(LearningRate::Fixed(f32::NAN))));
assert!(
res.is_err(),
"with_learning_rate must reject Some(Fixed(NaN))"
);
}
#[test]
fn adafactor_try_set_beta_1_preserves_state_on_error() -> Result<()> {
let mut adafactor = Adafactor::default_python()?;
let original_beta_1 = adafactor.beta_1;
let mut params = HashMap::from([(
"w".to_string(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2, 2))?,
)]);
let grads = HashMap::from([(
"w".to_string(),
Array::from_slice::<f32>(&[0.1, 0.2, 0.3, 0.4], &(2, 2))?,
)]);
adafactor.apply_gradients(&grads, &mut params)?;
assert!(adafactor.try_set_beta_1(Some(0.9)).is_err());
assert_eq!(adafactor.beta_1, original_beta_1);
assert!(!adafactor.state.is_empty(), "state preserved on error");
adafactor.apply_gradients(&grads, &mut params)?;
Ok(())
}
fn read_scalar(a: &Array) -> Result<f32> {
let mut clone = a.try_clone()?;
clone.item::<f32>()
}
fn read_vec(a: &Array) -> Result<Vec<f32>> {
let mut clone = a.try_clone()?;
clone.to_vec::<f32>()
}
fn scalar_p_g(p: f32, g: f32) -> Result<(Weights, Weights)> {
let mut params: Weights = HashMap::new();
params.insert("w".into(), Array::full::<f32>(&[0i32; 0], p)?);
let mut grads: Weights = HashMap::new();
grads.insert("w".into(), Array::full::<f32>(&[0i32; 0], g)?);
Ok((params, grads))
}
#[test]
fn adafactor_new_rejects_positive_decay_rate() {
let err = Adafactor::new(None, (1e-30, 1e-3), 1.0, 0.5, None, 0.0, true, true, false)
.err()
.expect("decay_rate > 0 must be rejected");
assert!(
matches!(err, Error::OutOfRange(_)),
"expected Error::OutOfRange, got {err:?}"
);
}
#[test]
fn adafactor_with_decay_rate_rejects_positive() {
let err = Adafactor::default_python()
.and_then(|a| a.with_decay_rate(0.5))
.err()
.expect("with_decay_rate(0.5) must be rejected");
assert!(matches!(err, Error::OutOfRange(_)), "got {err:?}");
}
#[test]
fn adafactor_new_with_fixed_lr_echoes_step0() -> Result<()> {
let af = Adafactor::new(
Some(LearningRate::Fixed(0.1)),
(1e-30, 1e-3),
1.0,
-0.8,
None,
0.0,
true,
true,
false,
)?;
assert_eq!(af.learning_rate(), 0.1);
assert!(af.learning_rate_ref().is_some());
Ok(())
}
#[test]
fn adafactor_new_rejects_fixed_nan_lr() {
let res = Adafactor::new(
Some(LearningRate::Fixed(f32::NAN)),
(1e-30, 1e-3),
1.0,
-0.8,
None,
0.0,
true,
true,
false,
);
assert!(res.is_err(), "Some(Fixed(NaN)) must be rejected at step 0");
}
#[test]
fn adafactor_getters_echo_inputs() -> Result<()> {
let af = Adafactor::new(
Some(LearningRate::Fixed(0.25)),
(3e-30, 4e-3),
2.5,
-0.6,
Some(0.7),
0.125,
false,
false,
true,
)?;
assert!(af.learning_rate_ref().is_some());
assert_eq!(af.eps(), (3e-30, 4e-3));
assert_eq!(af.clip_threshold(), 2.5);
assert_eq!(af.decay_rate(), -0.6);
assert_eq!(af.beta_1(), Some(0.7));
assert_eq!(af.weight_decay(), 0.125);
assert!(!af.scale_parameter());
assert!(!af.relative_step());
assert!(af.warmup_init());
Ok(())
}
#[test]
fn adafactor_default_python_getters() -> Result<()> {
let af = Adafactor::default_python()?;
assert!(af.learning_rate_ref().is_none());
assert_eq!(af.eps(), (1e-30, 1e-3));
assert_eq!(af.clip_threshold(), 1.0);
assert_eq!(af.decay_rate(), -0.8);
assert_eq!(af.beta_1(), None);
assert_eq!(af.weight_decay(), 0.0);
assert!(af.scale_parameter());
assert!(af.relative_step());
assert!(!af.warmup_init());
Ok(())
}
#[test]
fn adafactor_builder_success_paths_echo() -> Result<()> {
let af = Adafactor::default_python()?
.with_learning_rate(Some(LearningRate::Fixed(0.05)))?
.with_eps((2e-30, 5e-3))?
.with_clip_threshold(3.0)?
.with_decay_rate(-0.5)?
.with_weight_decay(0.2)?
.with_scale_parameter(false)
.with_relative_step(false)
.with_warmup_init(true);
assert_eq!(af.learning_rate(), 0.05);
assert!(af.learning_rate_ref().is_some());
assert_eq!(af.eps(), (2e-30, 5e-3));
assert_eq!(af.clip_threshold(), 3.0);
assert_eq!(af.decay_rate(), -0.5);
assert_eq!(af.weight_decay(), 0.2);
assert!(!af.scale_parameter());
assert!(!af.relative_step());
assert!(af.warmup_init());
Ok(())
}
#[test]
fn adafactor_with_learning_rate_none_sets_zero() -> Result<()> {
let af = Adafactor::default_python()?
.with_learning_rate(Some(LearningRate::Fixed(0.3)))?
.with_learning_rate(None)?;
assert_eq!(af.learning_rate(), 0.0);
assert!(af.learning_rate_ref().is_none());
Ok(())
}
#[test]
fn adafactor_try_set_beta_1_pre_init_succeeds() -> Result<()> {
let mut af = Adafactor::default_python()?;
af.try_set_beta_1(Some(0.9))?;
assert_eq!(af.beta_1(), Some(0.9));
af.try_set_beta_1(None)?;
assert_eq!(af.beta_1(), None);
Ok(())
}
#[test]
fn adafactor_nonfactored_scalar_no_beta1_no_wd_step1() -> Result<()> {
let af = Adafactor::new(
Some(LearningRate::Fixed(0.1)),
(0.0, 1e-3),
1.0,
-0.8,
None,
0.0,
false,
false,
false,
)?;
let mut af = af;
let (mut params, grads) = scalar_p_g(1.0, 0.5)?;
af.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["w"])?;
assert!((got - 0.9).abs() < 1e-4, "got {got}");
Ok(())
}
#[test]
fn adafactor_nonfactored_scalar_with_beta1_step1() -> Result<()> {
let af = Adafactor::new(
Some(LearningRate::Fixed(0.1)),
(0.0, 1e-3),
1.0,
-0.8,
Some(0.5),
0.0,
false,
false,
false,
)?;
let mut af = af;
let (mut params, grads) = scalar_p_g(1.0, 0.5)?;
af.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["w"])?;
assert!((got - 0.95).abs() < 1e-4, "got {got}");
Ok(())
}
#[test]
fn adafactor_nonfactored_scalar_with_weight_decay_step1() -> Result<()> {
let af = Adafactor::new(
Some(LearningRate::Fixed(0.1)),
(0.0, 1e-3),
1.0,
-0.8,
None,
0.5,
false,
false,
false,
)?;
let mut af = af;
let (mut params, grads) = scalar_p_g(1.0, 0.5)?;
af.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["w"])?;
assert!((got - 0.85).abs() < 1e-4, "got {got}");
Ok(())
}
#[test]
fn adafactor_factored_2d_with_beta1_step1() -> Result<()> {
let af = Adafactor::new(
Some(LearningRate::Fixed(0.1)),
(0.0, 1e-3),
1.0,
-0.8,
Some(0.5),
0.0,
false,
false,
false,
)?;
let mut af = af;
let mut params: Weights = HashMap::new();
params.insert(
"w".into(),
Array::from_slice::<f32>(&[1.0, 1.0, 1.0, 1.0], &(2, 2))?,
);
let mut grads: Weights = HashMap::new();
grads.insert(
"w".into(),
Array::from_slice::<f32>(&[0.5, 0.5, 0.5, 0.5], &(2, 2))?,
);
af.apply_gradients(&grads, &mut params)?;
let v = read_vec(¶ms["w"])?;
for (i, x) in v.iter().enumerate() {
assert!((x - 0.95).abs() < 1e-4, "w[{i}] = {x}, expected 0.95");
}
Ok(())
}
#[test]
fn adafactor_relative_step_scaled_runs_and_moves() -> Result<()> {
let mut af = Adafactor::default_python()?.with_beta_1(Some(0.9))?;
let mut params: Weights = HashMap::new();
params.insert(
"w".into(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2, 2))?,
);
let mut grads: Weights = HashMap::new();
grads.insert(
"w".into(),
Array::from_slice::<f32>(&[0.1, 0.2, 0.3, 0.4], &(2, 2))?,
);
af.apply_gradients(&grads, &mut params)?;
let v = read_vec(¶ms["w"])?;
assert!(
(v[0] - 1.0).abs() > 1e-8,
"expected w[0] to move, got {}",
v[0]
);
Ok(())
}
#[test]
fn adafactor_two_steps_preflight_some_lr_arm() -> Result<()> {
let af = Adafactor::new(
Some(LearningRate::Fixed(0.1)),
(0.0, 1e-3),
1.0,
-0.8,
None,
0.0,
false,
false,
false,
)?;
let mut af = af;
let (mut params, grads) = scalar_p_g(1.0, 0.5)?;
assert_eq!(af.step(), 0);
af.apply_gradients(&grads, &mut params)?;
assert_eq!(af.step(), 1);
af.apply_gradients(&grads, &mut params)?;
assert_eq!(af.step(), 2);
assert_eq!(af.learning_rate(), 0.1);
Ok(())
}
#[test]
fn adafactor_two_steps_preflight_none_arm() -> Result<()> {
let mut af = Adafactor::default_python()?;
let (mut params, grads) = scalar_p_g(1.0, 0.5)?;
af.apply_gradients(&grads, &mut params)?;
af.apply_gradients(&grads, &mut params)?;
assert_eq!(af.step(), 2);
assert_eq!(af.learning_rate(), 0.0);
Ok(())
}
#[test]
fn adafactor_skips_grad_key_absent_from_params() -> Result<()> {
let af = Adafactor::new(
Some(LearningRate::Fixed(0.1)),
(0.0, 1e-3),
1.0,
-0.8,
None,
0.0,
false,
false,
false,
)?;
let mut af = af;
let mut params: Weights = HashMap::new();
params.insert("present".into(), Array::full::<f32>(&[0i32; 0], 1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("present".into(), Array::full::<f32>(&[0i32; 0], 0.5)?);
grads.insert("absent".into(), Array::full::<f32>(&[0i32; 0], 0.5)?);
af.apply_gradients(&grads, &mut params)?;
let got = read_scalar(¶ms["present"])?;
assert!((got - 0.9).abs() < 1e-4, "got {got}");
assert!(
!params.contains_key("absent"),
"absent grad must not be added"
);
Ok(())
}