use etensor_core::tensor::Tensor;
use etensor_core::buffer::Buffer;
use etensor_core::shape::Shape;
use etensor_core::device::Device;
use etensor_core::dtypes::DType;
use etensor_core::autograd::engine::backward;
use etensor_core::autograd::tape;
use etensor_core::dispatch::Dispatcher;
fn make_scalar(val: f32, requires_grad: bool) -> Tensor {
Tensor::new(
Buffer::from_f32_vec(vec![val]),
Shape::new(vec![1]),
Device::Cpu,
DType::F32,
requires_grad,
)
}
fn make_matrix(data: Vec<f32>, dims: Vec<usize>, requires_grad: bool) -> Tensor {
Tensor::new(
Buffer::from_f32_vec(data),
Shape::new(dims),
Device::Cpu,
DType::F32,
requires_grad,
)
}
#[test]
fn test_nested_autograd_history() {
let _ = tape::take();
let x = make_scalar(2.0, true).with_name("x");
let y = make_scalar(3.0, true).with_name("y");
let z = make_scalar(4.0, true).with_name("z");
let a = &x + &y;
let out = &a * &z;
let grads = backward(&out).expect("Backward pass failed!");
let dz = grads.get(&z.id).unwrap().as_f32_slice().unwrap()[0];
assert_eq!(dz, 5.0, "Gradient for z is incorrect");
let dx = grads.get(&x.id).unwrap().as_f32_slice().unwrap()[0];
assert_eq!(dx, 4.0, "Gradient for x is incorrect");
let dy = grads.get(&y.id).unwrap().as_f32_slice().unwrap()[0];
assert_eq!(dy, 4.0, "Gradient for y is incorrect");
}
#[test]
fn test_gradient_accumulation_chain_rule() {
let _ = tape::take();
let x = make_scalar(3.0, true);
let w = &x * &x;
let grads = backward(&w).unwrap();
let dx = grads.get(&x.id).unwrap().as_f32_slice().unwrap()[0];
assert_eq!(dx, 6.0, "Gradient accumulation failed on identical variables!");
}
#[test]
fn test_mini_neural_network_backward() {
let _ = tape::take();
let x = make_matrix(vec![1.0, 2.0], vec![1, 2], false);
let w = make_matrix(vec![0.5, -0.5, 1.0, 2.0], vec![2, 2], true);
let h = Dispatcher::matmul(&x, &w).unwrap(); let act = Dispatcher::relu(&h).unwrap(); let loss = Dispatcher::sum_all(&act).unwrap();
let grads = backward(&loss).unwrap();
let dw = grads.get(&w.id).unwrap().as_f32_slice().unwrap();
assert_eq!(dw, &[1.0, 1.0, 2.0, 2.0], "Matrix Chain Rule failed!");
assert!(grads.get(&x.id).is_none(), "Computed gradients for an input that didn't request them!");
}