use super::core::Trainer;
use crate::optim::clip_grad_norm;
use crate::train::Batch;
use crate::Tensor;
use provable_contracts_macros::ensures;
impl Trainer {
#[ensures(ret.is_finite())]
pub fn train_step<F>(&mut self, batch: &Batch, forward_fn: F) -> f32
where
F: FnOnce(&Tensor) -> Tensor,
{
assert!(self.loss_fn.is_some(), "Loss function must be set before training");
self.optimizer.zero_grad(&mut self.params);
let predictions = forward_fn(&batch.inputs);
let loss = self
.loss_fn
.as_ref()
.expect("loss function must be set before training")
.forward(&predictions, &batch.targets);
let loss_val = loss.data()[0];
if let Some(backward_op) = loss.backward_op() {
backward_op.backward();
}
if let Some(max_norm) = self.config.max_grad_norm {
clip_grad_norm(&mut self.params, max_norm);
}
self.optimizer.step(&mut self.params);
self.metrics.increment_step();
loss_val
}
pub(crate) fn accumulate_gradients<F>(&mut self, batch: &Batch, forward_fn: F) -> f32
where
F: FnOnce(&Tensor) -> Tensor,
{
assert!(self.loss_fn.is_some(), "Loss function must be set before training");
let predictions = forward_fn(&batch.inputs);
let loss = self
.loss_fn
.as_ref()
.expect("loss function must be set before training")
.forward(&predictions, &batch.targets);
let loss_val = loss.data()[0];
if let Some(backward_op) = loss.backward_op() {
backward_op.backward();
}
loss_val
}
}
#[cfg(test)]
mod tests {
use crate::optim::Adam;
use crate::train::{Batch, MSELoss, TrainConfig, Trainer};
use crate::Tensor;
#[test]
fn test_train_step() {
let params = vec![Tensor::from_vec(vec![1.0, 2.0, 3.0], true)];
let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
let config = TrainConfig::default();
let mut trainer = Trainer::new(params, Box::new(optimizer), config);
trainer.set_loss(Box::new(MSELoss));
let inputs = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
let targets = Tensor::from_vec(vec![2.0, 3.0, 4.0], false);
let batch = Batch::new(inputs, targets);
let loss = trainer.train_step(&batch, std::clone::Clone::clone);
assert!(loss > 0.0);
assert!(loss.is_finite());
assert_eq!(trainer.metrics.steps, 1);
}
#[test]
#[should_panic(expected = "Loss function must be set")]
fn test_train_step_without_loss() {
let params = vec![Tensor::zeros(10, true)];
let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
let config = TrainConfig::default();
let mut trainer = Trainer::new(params, Box::new(optimizer), config);
let batch = Batch::new(Tensor::zeros(10, false), Tensor::zeros(10, false));
trainer.train_step(&batch, std::clone::Clone::clone);
}
}