pub(crate) use super::*;
pub(crate) use crate::autograd::clear_graph;
#[test]
fn test_sgd_basic() {
clear_graph();
let mut param = Tensor::from_slice(&[1.0, 2.0, 3.0]).requires_grad();
let param_id = param.id();
let loss = param.pow(2.0).sum();
loss.backward();
let grad = get_grad(param_id).expect("Should have gradient");
assert_eq!(grad.data(), &[2.0, 4.0, 6.0]);
let mut sgd = SGD::new(vec![&mut param], 0.1);
sgd.step_with_params(&mut [&mut param]);
let expected = [0.8, 1.6, 2.4];
for (p, e) in param.data().iter().zip(expected.iter()) {
assert!((p - e).abs() < 1e-5, "Expected {e}, got {p}");
}
}
#[test]
fn test_sgd_with_momentum() {
clear_graph();
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut sgd = SGD::with_momentum(vec![&mut param], 0.1, 0.9);
sgd.step_with_params(&mut [&mut param]);
assert!((param.data()[0] - 0.8).abs() < 1e-5);
clear_graph();
let loss = param.pow(2.0).sum();
loss.backward();
sgd.step_with_params(&mut [&mut param]);
assert!((param.data()[0] - 0.46).abs() < 1e-5);
}
#[test]
fn test_adam_basic() {
clear_graph();
let mut param = Tensor::from_slice(&[1.0, 2.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut adam = Adam::new(vec![&mut param], 0.1);
adam.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 1.0);
assert!(param.data()[1] < 2.0);
}
#[test]
fn test_adam_convergence() {
clear_graph();
let mut param = Tensor::from_slice(&[5.0]).requires_grad();
let mut adam = Adam::new(vec![&mut param], 0.5);
for _ in 0..100 {
clear_graph();
let loss = param.pow(2.0).sum();
loss.backward();
adam.step_with_params(&mut [&mut param]);
}
assert!(
param.data()[0].abs() < 0.1,
"Parameter should converge to 0, got {}",
param.data()[0]
);
}
#[test]
fn test_adamw_weight_decay() {
clear_graph();
let mut param = Tensor::from_slice(&[10.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut adamw = AdamW::new(vec![&mut param], 0.1).weight_decay(0.1);
adamw.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 10.0);
}
#[test]
fn test_rmsprop_basic() {
clear_graph();
let mut param = Tensor::from_slice(&[3.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut rmsprop = RMSprop::new(vec![&mut param], 0.1);
rmsprop.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 3.0);
}
#[test]
fn test_zero_grad() {
clear_graph();
let mut param = Tensor::from_slice(&[1.0, 2.0]).requires_grad();
let param_id = param.id();
let loss = param.pow(2.0).sum();
loss.backward();
assert!(get_grad(param_id).is_some());
let mut sgd = SGD::new(vec![&mut param], 0.1);
sgd.zero_grad();
assert!(get_grad(param_id).is_none());
}
#[test]
fn test_learning_rate_change() {
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let mut sgd = SGD::new(vec![&mut param], 0.1);
assert!((sgd.lr() - 0.1).abs() < 1e-6);
sgd.set_lr(0.01);
assert!((sgd.lr() - 0.01).abs() < 1e-6);
}
#[test]
fn test_sgd_nesterov() {
clear_graph();
let mut param = Tensor::from_slice(&[2.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut sgd = SGD::with_momentum(vec![&mut param], 0.1, 0.9).nesterov();
sgd.step_with_params(&mut [&mut param]);
assert!(
(param.data()[0] - 1.24).abs() < 1e-5,
"Nesterov update failed: {}",
param.data()[0]
);
}
#[test]
fn test_sgd_weight_decay() {
clear_graph();
let mut param = Tensor::from_slice(&[5.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut sgd = SGD::new(vec![&mut param], 0.1).weight_decay(0.1);
sgd.step_with_params(&mut [&mut param]);
assert!(
(param.data()[0] - 3.95).abs() < 1e-5,
"Weight decay update failed: {}",
param.data()[0]
);
}
#[test]
fn test_adam_with_custom_betas() {
clear_graph();
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut adam = Adam::new(vec![&mut param], 0.1).betas(0.8, 0.99);
adam.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 1.0);
}
#[test]
fn test_adam_with_eps() {
clear_graph();
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut adam = Adam::new(vec![&mut param], 0.1).eps(1e-6);
adam.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 1.0);
}
#[test]
fn test_adam_with_weight_decay() {
clear_graph();
let mut param = Tensor::from_slice(&[10.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut adam_wd = Adam::new(vec![&mut param], 0.1).weight_decay(0.1);
adam_wd.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 10.0);
}
#[test]
fn test_adamw_with_custom_betas_and_eps() {
clear_graph();
let mut param = Tensor::from_slice(&[3.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut adamw = AdamW::new(vec![&mut param], 0.1)
.betas(0.85, 0.995)
.eps(1e-7);
adamw.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 3.0);
}
#[test]
fn test_adamw_lr_methods() {
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let mut adamw = AdamW::new(vec![&mut param], 0.01);
assert!((adamw.lr() - 0.01).abs() < 1e-6);
adamw.set_lr(0.001);
assert!((adamw.lr() - 0.001).abs() < 1e-6);
}
#[test]
fn test_adamw_zero_grad() {
clear_graph();
let mut param = Tensor::from_slice(&[2.0]).requires_grad();
let param_id = param.id();
let loss = param.pow(2.0).sum();
loss.backward();
assert!(get_grad(param_id).is_some());
let mut adamw = AdamW::new(vec![&mut param], 0.1);
adamw.zero_grad();
assert!(get_grad(param_id).is_none());
}
#[test]
fn test_adamw_step_trait() {
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let mut adamw = AdamW::new(vec![&mut param], 0.1);
adamw.step();
assert!(adamw.initialized);
assert_eq!(adamw.t, 1);
}
#[test]
fn test_rmsprop_with_alpha() {
clear_graph();
let mut param = Tensor::from_slice(&[2.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).alpha(0.9);
rmsprop.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 2.0);
}
#[test]
fn test_rmsprop_with_eps() {
clear_graph();
let mut param = Tensor::from_slice(&[2.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).eps(1e-6);
rmsprop.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 2.0);
}
#[test]
fn test_rmsprop_with_momentum() {
clear_graph();
let mut param = Tensor::from_slice(&[3.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).momentum(0.9);
rmsprop.step_with_params(&mut [&mut param]);
let after_first = param.data()[0];
assert!(after_first < 3.0);
clear_graph();
let loss = param.pow(2.0).sum();
loss.backward();
rmsprop.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < after_first);
}
#[test]
fn test_rmsprop_with_weight_decay() {
clear_graph();
let mut param = Tensor::from_slice(&[5.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).weight_decay(0.1);
rmsprop.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 5.0);
}
#[test]
fn test_rmsprop_lr_methods() {
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let mut rmsprop = RMSprop::new(vec![&mut param], 0.01);
assert!((rmsprop.lr() - 0.01).abs() < 1e-6);
rmsprop.set_lr(0.001);
assert!((rmsprop.lr() - 0.001).abs() < 1e-6);
}
#[test]
fn test_rmsprop_zero_grad() {
clear_graph();
let mut param = Tensor::from_slice(&[2.0]).requires_grad();
let param_id = param.id();
let loss = param.pow(2.0).sum();
loss.backward();
assert!(get_grad(param_id).is_some());
let mut rmsprop = RMSprop::new(vec![&mut param], 0.1);
rmsprop.zero_grad();
assert!(get_grad(param_id).is_none());
}
#[test]
fn test_rmsprop_step_trait() {
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let mut rmsprop = RMSprop::new(vec![&mut param], 0.1);
rmsprop.step();
assert!(rmsprop.initialized);
}
#[test]
fn test_sgd_step_trait() {
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let mut sgd = SGD::new(vec![&mut param], 0.1);
sgd.step();
assert!(sgd.initialized);
}
#[test]
fn test_adam_step_trait() {
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let mut adam = Adam::new(vec![&mut param], 0.1);
adam.step();
assert!(adam.initialized);
assert_eq!(adam.t, 1);
}
#[path = "tests_adam.rs"]
mod tests_adam;
#[path = "tests_adamw_contract.rs"]
mod tests_adamw_contract;
#[path = "tests_large_tensors.rs"]
mod tests_large_tensors;
#[path = "tests_state_resize.rs"]
mod tests_state_resize;