use dfdx::{
nn::ZeroGrads,
prelude::*,
tensor::{AutoDevice, Gradients},
};
fn main() {
let dev = AutoDevice::default();
type Model = (Linear<2, 5>, ReLU, Linear<5, 10>, Tanh, Linear<10, 20>);
let model = dev.build_module::<Model, f32>();
let x: Tensor<Rank2<10, 2>, f32, _> = dev.sample_normal();
let mut grads: Gradients<f32, _> = model.alloc_grads();
grads = model.forward(x.trace(grads)).mean().backward();
grads = model.forward(x.trace(grads)).mean().backward();
grads = model.forward(x.trace(grads)).mean().backward();
model.zero_grads(&mut grads);
assert_eq!(grads.get(&model.0.weight).array(), [[0.0; 2]; 5]);
}