use crate::adam::{Adam, AdamConfig, AdamW, AdamWConfig};
use crate::lookahead::Lookahead;
use crate::sgd::SGD;
use trustformers_core::tensor::Tensor;
use trustformers_core::traits::Optimizer;
fn lcg_next(s: &mut u64) -> f32 {
*s = s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
(*s % 1000) as f32 / 1000.0
}
fn make_adam() -> Adam {
Adam::from_config(AdamConfig::default())
}
fn make_adamw() -> AdamW {
AdamW::from_config(AdamWConfig::default())
}
fn make_sgd() -> SGD {
SGD::new(0.01, 0.0, 0.0, false)
}
#[test]
fn test_lookahead_new_get_lr_delegates_to_base() {
let base = make_adam();
let la = Lookahead::new(base, 5, 0.5);
assert!((la.get_lr() - AdamConfig::default().lr).abs() < 1e-9);
}
#[test]
fn test_lookahead_set_lr_delegates_to_base() {
let base = make_adam();
let mut la = Lookahead::new(base, 5, 0.5);
la.set_lr(0.05);
assert!((la.get_lr() - 0.05).abs() < 1e-9);
}
#[test]
fn test_lookahead_zero_grad_no_panic() {
let base = make_adam();
let mut la = Lookahead::new(base, 5, 0.5);
la.zero_grad(); }
#[test]
fn test_lookahead_step_no_panic() {
let base = make_adam();
let mut la = Lookahead::new(base, 5, 0.5);
la.step();
la.step();
}
#[test]
fn test_lookahead_update_positive_grad_decreases_param() {
let base = make_adam();
let mut la = Lookahead::new(base, 5, 0.5);
let mut param = Tensor::new(vec![1.0_f32]).unwrap_or_else(|_| panic!("tensor failed"));
let grad = Tensor::new(vec![1.0_f32]).unwrap_or_else(|_| panic!("tensor failed"));
let before = match ¶m {
Tensor::F32(a) => a[0],
_ => panic!("wrong type"),
};
la.update(&mut param, &grad).unwrap_or_else(|e| panic!("update failed: {e}"));
let after = match ¶m {
Tensor::F32(a) => a[0],
_ => panic!("wrong type"),
};
assert!(
after < before,
"param should decrease with positive grad: before={before} after={after}"
);
}
#[test]
fn test_lookahead_update_with_sgd_base() {
let base = make_sgd();
let mut la = Lookahead::new(base, 5, 0.5);
let mut param = Tensor::new(vec![2.0_f32]).unwrap_or_else(|_| panic!("tensor failed"));
let grad = Tensor::new(vec![1.0_f32]).unwrap_or_else(|_| panic!("tensor failed"));
la.update(&mut param, &grad).unwrap_or_else(|e| panic!("update failed: {e}"));
let val = match ¶m {
Tensor::F32(a) => a[0],
_ => panic!("wrong type"),
};
assert!(val < 2.0, "param should decrease: {val}");
}
#[test]
fn test_lookahead_update_with_adamw_base() {
let base = make_adamw();
let mut la = Lookahead::new(base, 5, 0.5);
let mut param = Tensor::new(vec![1.5_f32]).unwrap_or_else(|_| panic!("tensor failed"));
let grad = Tensor::new(vec![0.5_f32]).unwrap_or_else(|_| panic!("tensor failed"));
la.update(&mut param, &grad).unwrap_or_else(|e| panic!("update failed: {e}"));
let val = match ¶m {
Tensor::F32(a) => a[0],
_ => panic!("wrong type"),
};
assert!(val < 1.5, "param should decrease: {val}");
}
#[test]
fn test_lookahead_base_optimizer_accessor() {
let base = Adam::from_config(AdamConfig {
lr: 0.042,
..AdamConfig::default()
});
let la = Lookahead::new(base, 5, 0.5);
assert!((la.base_optimizer().get_lr() - 0.042).abs() < 1e-9);
}
#[test]
fn test_lookahead_base_optimizer_mut_accessor() {
let base = make_adam();
let mut la = Lookahead::new(base, 5, 0.5);
la.base_optimizer_mut().set_lr(0.99);
assert!((la.get_lr() - 0.99).abs() < 1e-9);
}
#[test]
fn test_lookahead_multiple_step_calls() {
let base = make_adam();
let mut la = Lookahead::new(base, 3, 0.5);
for _ in 0..10 {
la.step();
}
}
#[test]
fn test_lookahead_lcg_gradient_updates() {
let base = make_adam();
let mut la = Lookahead::new(base, 5, 0.5);
let mut s = 77u64;
let mut param = Tensor::new(vec![lcg_next(&mut s), lcg_next(&mut s), lcg_next(&mut s)])
.unwrap_or_else(|_| panic!("tensor failed"));
for _ in 0..5 {
let grads: Vec<f32> = (0..3).map(|_| lcg_next(&mut s)).collect();
let grad = Tensor::new(grads).unwrap_or_else(|_| panic!("tensor failed"));
la.update(&mut param, &grad).unwrap_or_else(|e| panic!("update failed: {e}"));
la.step();
}
let data = match ¶m {
Tensor::F32(a) => a.iter().cloned().collect::<Vec<_>>(),
_ => panic!("wrong type"),
};
for (i, v) in data.iter().enumerate() {
assert!(v.is_finite(), "param[{i}] should be finite: {v}");
}
}
#[test]
fn test_lookahead_slow_step_after_k_steps() {
let base = make_sgd();
let mut la = Lookahead::new(base, 3, 0.5);
let mut param = Tensor::new(vec![10.0_f32]).unwrap_or_else(|_| panic!("tensor failed"));
for _ in 0..3 {
let grad = Tensor::new(vec![1.0_f32]).unwrap_or_else(|_| panic!("tensor failed"));
la.update(&mut param, &grad).unwrap_or_else(|e| panic!("update failed: {e}"));
la.step();
}
let result = la.slow_step(&mut param);
assert!(
result.is_ok(),
"slow_step should succeed: {:?}",
result.err()
);
}
#[test]
fn test_lookahead_k_equals_1() {
let base = make_adam();
let mut la = Lookahead::new(base, 1, 0.5);
let mut param = Tensor::new(vec![5.0_f32]).unwrap_or_else(|_| panic!("tensor failed"));
let grad = Tensor::new(vec![1.0_f32]).unwrap_or_else(|_| panic!("tensor failed"));
la.update(&mut param, &grad).unwrap_or_else(|e| panic!("update failed: {e}"));
la.step();
let val = match ¶m {
Tensor::F32(a) => a[0],
_ => panic!("wrong type"),
};
assert!(val.is_finite(), "param should be finite: {val}");
}
#[test]
fn test_lookahead_alpha_one_slow_equals_fast() {
let base = make_sgd();
let mut la = Lookahead::new(base, 3, 1.0);
let mut param = Tensor::new(vec![1.0_f32]).unwrap_or_else(|_| panic!("tensor failed"));
for _ in 0..3 {
let grad = Tensor::new(vec![0.5_f32]).unwrap_or_else(|_| panic!("tensor failed"));
la.update(&mut param, &grad).unwrap_or_else(|e| panic!("update failed: {e}"));
la.step();
}
let result = la.slow_step(&mut param);
assert!(
result.is_ok(),
"slow_step should succeed: {:?}",
result.err()
);
}
#[test]
fn test_lookahead_convergence_toward_zero() {
let base = SGD::new(0.1, 0.0, 0.0, false);
let mut la = Lookahead::new(base, 5, 0.5);
let mut param = Tensor::new(vec![4.0_f32]).unwrap_or_else(|_| panic!("tensor failed"));
for _ in 0..100 {
let val = match ¶m {
Tensor::F32(a) => a[0],
_ => panic!("wrong type"),
};
let grad_val = 2.0 * val;
let grad = Tensor::new(vec![grad_val]).unwrap_or_else(|_| panic!("tensor failed"));
la.update(&mut param, &grad).unwrap_or_else(|e| panic!("update failed: {e}"));
la.step();
}
let final_val = match ¶m {
Tensor::F32(a) => a[0],
_ => panic!("wrong type"),
};
assert!(
final_val.abs() < 2.0,
"Lookahead should converge, got {final_val}"
);
}
#[test]
fn test_lookahead_multivariate_update() {
let base = make_adam();
let mut la = Lookahead::new(base, 5, 0.5);
let mut param =
Tensor::new(vec![1.0_f32, 2.0, 3.0, 4.0]).unwrap_or_else(|_| panic!("tensor failed"));
let grad =
Tensor::new(vec![0.1_f32, 0.2, 0.3, 0.4]).unwrap_or_else(|_| panic!("tensor failed"));
la.update(&mut param, &grad).unwrap_or_else(|e| panic!("update failed: {e}"));
let data = match ¶m {
Tensor::F32(a) => a.iter().cloned().collect::<Vec<_>>(),
_ => panic!("wrong type"),
};
for (i, &v) in data.iter().enumerate() {
assert!(v.is_finite(), "param[{i}] should be finite: {v}");
}
}
#[test]
fn test_lookahead_large_tensor_update() {
let base = make_sgd();
let mut la = Lookahead::new(base, 5, 0.5);
let mut s = 55u64;
let n = 128_usize;
let params: Vec<f32> = (0..n).map(|_| lcg_next(&mut s)).collect();
let grads: Vec<f32> = (0..n).map(|_| lcg_next(&mut s)).collect();
let mut param = Tensor::new(params).unwrap_or_else(|_| panic!("tensor failed"));
let grad = Tensor::new(grads).unwrap_or_else(|_| panic!("tensor failed"));
la.update(&mut param, &grad).unwrap_or_else(|e| panic!("update failed: {e}"));
let data = match ¶m {
Tensor::F32(a) => a.iter().cloned().collect::<Vec<_>>(),
_ => panic!("wrong type"),
};
for (i, &v) in data.iter().enumerate() {
assert!(v.is_finite(), "param[{i}] should be finite: {v}");
}
}