use scivex_core::Float;
use crate::error::Result;
use crate::training::callbacks::{Callback, CallbackAction};
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Clone)]
pub struct TrainingHistory<T: Float> {
pub losses: Vec<T>,
pub stopped_early: bool,
pub best_epoch: usize,
}
pub struct Trainer<T: Float> {
epochs: usize,
callbacks: Vec<Box<dyn Callback<T>>>,
}
impl<T: Float> Trainer<T> {
pub fn new(epochs: usize) -> Self {
Self {
epochs,
callbacks: Vec::new(),
}
}
pub fn add_callback(&mut self, cb: Box<dyn Callback<T>>) -> &mut Self {
self.callbacks.push(cb);
self
}
pub fn fit<F>(&mut self, mut train_fn: F) -> Result<TrainingHistory<T>>
where
F: FnMut(usize) -> Result<T>,
{
for cb in &mut self.callbacks {
cb.on_train_begin();
}
let mut losses: Vec<T> = Vec::with_capacity(self.epochs);
let mut stopped_early = false;
let mut best_epoch: usize = 0;
let mut best_loss: Option<T> = None;
for epoch in 0..self.epochs {
let loss = train_fn(epoch)?;
losses.push(loss);
let is_best = match best_loss {
None => true,
Some(prev) => loss < prev,
};
if is_best {
best_loss = Some(loss);
best_epoch = epoch;
}
let mut should_stop = false;
for cb in &mut self.callbacks {
if cb.on_epoch_end(epoch, loss) == CallbackAction::Stop {
should_stop = true;
}
}
if should_stop {
stopped_early = true;
break;
}
}
for cb in &mut self.callbacks {
cb.on_train_end();
}
Ok(TrainingHistory {
losses,
stopped_early,
best_epoch,
})
}
}