use crate::autograd::{self, AutogradError};
use crate::optim::Optimizer;
use crate::tensor::Tensor;
pub struct Trainer<O: Optimizer> {
optimizer: O,
epoch_loss: f32,
epoch_steps: usize,
}
impl<O: Optimizer> Trainer<O> {
pub fn new(optimizer: O) -> Self {
Self {
optimizer,
epoch_loss: 0.0,
epoch_steps: 0,
}
}
pub fn train_step<F>(&mut self, forward_fn: F) -> Result<f32, AutogradError>
where
F: FnOnce() -> Tensor,
{
let loss = forward_fn();
let loss_val = {
let g = loss.data();
g[0]
};
let mut grads = autograd::backward(&loss)?;
self.optimizer.step(&mut grads)?;
self.epoch_loss += loss_val;
self.epoch_steps += 1;
Ok(loss_val)
}
pub fn epoch_avg_loss(&self) -> f32 {
if self.epoch_steps == 0 {
0.0
} else {
self.epoch_loss / self.epoch_steps as f32
}
}
pub fn reset_epoch(&mut self) {
self.epoch_loss = 0.0;
self.epoch_steps = 0;
}
pub fn optimizer_mut(&mut self) -> &mut O {
&mut self.optimizer
}
}