use dfdx::{optim::Sgd, prelude::*, tensor::AutoDevice};
fn classification_train<
Inp: Trace<E, D>,
Lbl,
Model: ModuleMut<Inp::Traced, Error = D::Err> + TensorCollection<E, D>,
Opt: Optimizer<Model, D, E>,
Data: Iterator<Item = (Inp, Lbl)>,
Criterion: FnMut(Model::Output, Lbl) -> Loss,
Loss: Backward<E, D, Err = D::Err> + AsArray<Array = E>,
E: Dtype,
D: Device<E>,
>(
model: &mut Model,
opt: &mut Opt,
mut criterion: Criterion,
data: Data,
batch_accum: usize,
) -> Result<(), D::Err> {
let mut grads = model.try_alloc_grads()?;
for (i, (inp, lbl)) in data.enumerate() {
let y = model.try_forward_mut(inp.traced(grads))?;
let loss = criterion(y, lbl);
let loss_value = loss.array();
grads = loss.try_backward()?;
if i % batch_accum == 0 {
opt.update(model, &grads).unwrap();
model.try_zero_grads(&mut grads)?;
}
println!("batch {i} | loss = {loss_value:?}");
}
Ok(())
}
fn main() {
let dev = AutoDevice::default();
type Model = Linear<10, 2>;
type Dtype = f32;
let mut model = dev.build_module::<Model, Dtype>();
let mut opt = Sgd::new(&model, Default::default());
let mut data = Vec::new();
for _ in 0..100 {
let inp = dev.sample_normal::<Rank2<5, 10>>();
let lbl = dev.tensor([[0.0, 1.0]; 5]);
data.push((inp, lbl));
}
classification_train(
&mut model,
&mut opt,
cross_entropy_with_logits_loss,
data.into_iter(),
1,
)
.unwrap();
}