pub(crate) use super::*;
pub(crate) use crate::autograd::clear_graph;
#[test]
fn test_sgd_basic() {
clear_graph();
let mut param = Tensor::from_slice(&[1.0, 2.0, 3.0]).requires_grad();
let param_id = param.id();
let loss = param.pow(2.0).sum();
loss.backward();
let grad = get_grad(param_id).expect("Should have gradient");
assert_eq!(grad.data(), &[2.0, 4.0, 6.0]);
let mut sgd = SGD::new(vec![&mut param], 0.1);
sgd.step_with_params(&mut [&mut param]);
let expected = [0.8, 1.6, 2.4];
for (p, e) in param.data().iter().zip(expected.iter()) {
assert!((p - e).abs() < 1e-5, "Expected {e}, got {p}");
}
}
#[test]
fn test_sgd_with_momentum() {
clear_graph();
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut sgd = SGD::with_momentum(vec![&mut param], 0.1, 0.9);
sgd.step_with_params(&mut [&mut param]);
assert!((param.data()[0] - 0.8).abs() < 1e-5);
clear_graph();
let loss = param.pow(2.0).sum();
loss.backward();
sgd.step_with_params(&mut [&mut param]);
assert!((param.data()[0] - 0.46).abs() < 1e-5);
}
#[test]
fn test_adam_basic() {
clear_graph();
let mut param = Tensor::from_slice(&[1.0, 2.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut adam = Adam::new(vec![&mut param], 0.1);
adam.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 1.0);
assert!(param.data()[1] < 2.0);
}
#[test]
fn test_adam_convergence() {
clear_graph();
let mut param = Tensor::from_slice(&[5.0]).requires_grad();
let mut adam = Adam::new(vec![&mut param], 0.5);
for _ in 0..100 {
clear_graph();
let loss = param.pow(2.0).sum();
loss.backward();
adam.step_with_params(&mut [&mut param]);
}
assert!(
param.data()[0].abs() < 0.1,
"Parameter should converge to 0, got {}",
param.data()[0]
);
}
#[test]
fn test_adamw_weight_decay() {
clear_graph();
let mut param = Tensor::from_slice(&[10.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut adamw = AdamW::new(vec![&mut param], 0.1).weight_decay(0.1);
adamw.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 10.0);
}
#[test]
fn test_rmsprop_basic() {
clear_graph();
let mut param = Tensor::from_slice(&[3.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut rmsprop = RMSprop::new(vec![&mut param], 0.1);
rmsprop.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 3.0);
}
#[test]
fn test_zero_grad() {
clear_graph();
let mut param = Tensor::from_slice(&[1.0, 2.0]).requires_grad();
let param_id = param.id();
let loss = param.pow(2.0).sum();
loss.backward();
assert!(get_grad(param_id).is_some());
let mut sgd = SGD::new(vec![&mut param], 0.1);
sgd.zero_grad();
assert!(get_grad(param_id).is_none());
}
#[test]
fn test_learning_rate_change() {
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let mut sgd = SGD::new(vec![&mut param], 0.1);
assert!((sgd.lr() - 0.1).abs() < 1e-6);
sgd.set_lr(0.01);
assert!((sgd.lr() - 0.01).abs() < 1e-6);
}
#[test]
fn test_sgd_nesterov() {
clear_graph();
let mut param = Tensor::from_slice(&[2.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut sgd = SGD::with_momentum(vec![&mut param], 0.1, 0.9).nesterov();
sgd.step_with_params(&mut [&mut param]);
assert!(
(param.data()[0] - 1.24).abs() < 1e-5,
"Nesterov update failed: {}",
param.data()[0]
);
}
#[test]
fn test_sgd_weight_decay() {
clear_graph();
let mut param = Tensor::from_slice(&[5.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut sgd = SGD::new(vec![&mut param], 0.1).weight_decay(0.1);
sgd.step_with_params(&mut [&mut param]);
assert!(
(param.data()[0] - 3.95).abs() < 1e-5,
"Weight decay update failed: {}",
param.data()[0]
);
}
#[test]
fn test_adam_with_custom_betas() {
clear_graph();
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut adam = Adam::new(vec![&mut param], 0.1).betas(0.8, 0.99);
adam.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 1.0);
}
#[test]
fn test_adam_with_eps() {
clear_graph();
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut adam = Adam::new(vec![&mut param], 0.1).eps(1e-6);
adam.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 1.0);
}
#[test]
fn test_adam_with_weight_decay() {
clear_graph();
let mut param = Tensor::from_slice(&[10.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut adam_wd = Adam::new(vec![&mut param], 0.1).weight_decay(0.1);
adam_wd.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 10.0);
}
#[test]
fn test_adamw_with_custom_betas_and_eps() {
clear_graph();
let mut param = Tensor::from_slice(&[3.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut adamw = AdamW::new(vec![&mut param], 0.1)
.betas(0.85, 0.995)
.eps(1e-7);
adamw.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 3.0);
}
#[test]
fn test_adamw_lr_methods() {
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let mut adamw = AdamW::new(vec![&mut param], 0.01);
assert!((adamw.lr() - 0.01).abs() < 1e-6);
adamw.set_lr(0.001);
assert!((adamw.lr() - 0.001).abs() < 1e-6);
}
#[test]
fn test_adamw_zero_grad() {
clear_graph();
let mut param = Tensor::from_slice(&[2.0]).requires_grad();
let param_id = param.id();
let loss = param.pow(2.0).sum();
loss.backward();
assert!(get_grad(param_id).is_some());
let mut adamw = AdamW::new(vec![&mut param], 0.1);
adamw.zero_grad();
assert!(get_grad(param_id).is_none());
}
#[test]
fn test_adamw_step_trait() {
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let mut adamw = AdamW::new(vec![&mut param], 0.1);
adamw.step();
assert!(adamw.initialized);
assert_eq!(adamw.t, 1);
}
#[test]
fn test_rmsprop_with_alpha() {
clear_graph();
let mut param = Tensor::from_slice(&[2.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).alpha(0.9);
rmsprop.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 2.0);
}
#[test]
fn test_rmsprop_with_eps() {
clear_graph();
let mut param = Tensor::from_slice(&[2.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).eps(1e-6);
rmsprop.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 2.0);
}
#[test]
fn test_rmsprop_with_momentum() {
clear_graph();
let mut param = Tensor::from_slice(&[3.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).momentum(0.9);
rmsprop.step_with_params(&mut [&mut param]);
let after_first = param.data()[0];
assert!(after_first < 3.0);
clear_graph();
let loss = param.pow(2.0).sum();
loss.backward();
rmsprop.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < after_first);
}
#[test]
fn test_rmsprop_with_weight_decay() {
clear_graph();
let mut param = Tensor::from_slice(&[5.0]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).weight_decay(0.1);
rmsprop.step_with_params(&mut [&mut param]);
assert!(param.data()[0] < 5.0);
}
#[test]
fn test_rmsprop_lr_methods() {
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let mut rmsprop = RMSprop::new(vec![&mut param], 0.01);
assert!((rmsprop.lr() - 0.01).abs() < 1e-6);
rmsprop.set_lr(0.001);
assert!((rmsprop.lr() - 0.001).abs() < 1e-6);
}
#[test]
fn test_rmsprop_zero_grad() {
clear_graph();
let mut param = Tensor::from_slice(&[2.0]).requires_grad();
let param_id = param.id();
let loss = param.pow(2.0).sum();
loss.backward();
assert!(get_grad(param_id).is_some());
let mut rmsprop = RMSprop::new(vec![&mut param], 0.1);
rmsprop.zero_grad();
assert!(get_grad(param_id).is_none());
}
#[test]
fn test_rmsprop_step_trait() {
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let mut rmsprop = RMSprop::new(vec![&mut param], 0.1);
rmsprop.step();
assert!(rmsprop.initialized);
}
#[test]
fn test_sgd_step_trait() {
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let mut sgd = SGD::new(vec![&mut param], 0.1);
sgd.step();
assert!(sgd.initialized);
}
#[test]
fn test_adam_step_trait() {
let mut param = Tensor::from_slice(&[1.0]).requires_grad();
let mut adam = Adam::new(vec![&mut param], 0.1);
adam.step();
assert!(adam.initialized);
assert_eq!(adam.t, 1);
}
#[path = "tests_adam.rs"]
mod tests_adam;
#[path = "tests_adamw_contract.rs"]
mod tests_adamw_contract;
#[path = "tests_large_tensors.rs"]
mod tests_large_tensors;
#[path = "tests_state_resize.rs"]
mod tests_state_resize;
#[test]
fn nn_linear_backward_populates_weight_grad() {
use crate::autograd::{clear_graph, get_grad, Tensor};
use crate::nn::{Linear, Module};
clear_graph(); let layer = Linear::with_seed(4, 3, Some(1));
let x = Tensor::from_vec(vec![0.5; 8], &[2, 4]);
let out = layer.forward(&x);
out.sum().backward();
assert!(
get_grad(layer.weight().id()).is_some(),
"Linear weight received NO gradient through forward — the autograd path to \
`weight` is broken (cached construction-time transpose wiped by clear_graph)."
);
if let Some(b) = layer.bias() {
assert!(get_grad(b.id()).is_some(), "Linear bias received no gradient either.");
}
}
#[test]
fn nn_mlp_training_converges() {
use crate::autograd::{clear_graph, Tensor};
use crate::nn::{Linear, MSELoss, Module, ReLU, Sequential, SGD};
let (n, din, dh) = (256usize, 16usize, 8usize);
let mut s: u64 = 7;
let mut rng = || {
s ^= s << 13;
s ^= s >> 7;
s ^= s << 17;
((s >> 40) as f32 / (1u64 << 24) as f32) - 0.5
};
let w: Vec<f32> = (0..din).map(|_| rng()).collect();
let (mut xd, mut yd) = (Vec::new(), Vec::new());
for _ in 0..n {
let mut t = 0.0;
for j in 0..din {
let v = rng();
xd.push(v);
t += v * w[j];
}
yd.push(t);
}
let x = Tensor::from_vec(xd, &[n, din]);
let y = Tensor::from_vec(yd, &[n, 1]);
let mut model = Sequential::new()
.add(Linear::with_seed(din, dh, Some(1)))
.add(ReLU)
.add(Linear::with_seed(dh, 1, Some(2)));
let crit = MSELoss::new();
let mut sgd = SGD::new(model.parameters_mut(), 0.1);
let mut first = 0.0;
let mut last = 0.0;
for step in 0..500 {
clear_graph();
let loss = crit.forward(&model.forward(&x), &y);
last = loss.item();
if step == 0 {
first = last;
}
loss.backward();
let mut p = model.parameters_mut();
sgd.step_with_params(&mut p);
}
assert!(
first > 0.05,
"sanity: initial MSE should be non-trivial, got {first}"
);
assert!(
last < first * 0.2,
"MLP training did not converge: MSE {first:.5} -> {last:.5} (need < {:.5}). \
Likely a regression of the Linear live-transpose gradient-path fix.",
first * 0.2
);
}