use dfdx::{
losses::mse_loss,
nn::builders::*,
optim::{Momentum, Optimizer, Sgd, SgdConfig},
shapes::Rank2,
tensor::{AsArray, AutoDevice, SampleTensor, Tensor, Trace},
tensor_ops::Backward,
};
type Mlp = (
(Linear<5, 32>, ReLU),
(Linear<32, 32>, ReLU),
(Linear<32, 2>, Tanh),
);
fn main() {
let dev = AutoDevice::default();
let mut mlp = dev.build_module::<Mlp, f32>();
let mut grads = mlp.alloc_grads();
let mut sgd = Sgd::new(
&mlp,
SgdConfig {
lr: 1e-1,
momentum: Some(Momentum::Nesterov(0.9)),
weight_decay: None,
},
);
let x: Tensor<Rank2<3, 5>, f32, _> = dev.sample_normal();
let y: Tensor<Rank2<3, 2>, f32, _> = dev.sample_normal();
let prediction = mlp.forward_mut(x.trace(grads));
let loss = mse_loss(prediction, y.clone());
dbg!(loss.array());
grads = loss.backward();
sgd.update(&mut mlp, &grads)
.expect("Oops, there were some unused params");
mlp.zero_grads(&mut grads);
for i in 0..5 {
let prediction = mlp.forward_mut(x.trace(grads));
let loss = mse_loss(prediction, y.clone());
println!("Loss after update {i}: {:?}", loss.array());
grads = loss.backward();
sgd.update(&mut mlp, &grads)
.expect("Oops, there were some unused params");
mlp.zero_grads(&mut grads);
}
}