use crate::{TrainResult, TrainingState};
pub trait Callback {
fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()> {
Ok(())
}
fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
Ok(())
}
fn on_epoch_begin(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
Ok(())
}
fn on_epoch_end(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
Ok(())
}
fn on_batch_begin(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
Ok(())
}
fn on_batch_end(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
Ok(())
}
fn on_validation_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
Ok(())
}
fn should_stop(&self) -> bool {
false
}
}
pub struct CallbackList {
callbacks: Vec<Box<dyn Callback>>,
}
impl CallbackList {
pub fn new() -> Self {
Self {
callbacks: Vec::new(),
}
}
pub fn add(&mut self, callback: Box<dyn Callback>) {
self.callbacks.push(callback);
}
pub fn on_train_begin(&mut self, state: &TrainingState) -> TrainResult<()> {
for callback in &mut self.callbacks {
callback.on_train_begin(state)?;
}
Ok(())
}
pub fn on_train_end(&mut self, state: &TrainingState) -> TrainResult<()> {
for callback in &mut self.callbacks {
callback.on_train_end(state)?;
}
Ok(())
}
pub fn on_epoch_begin(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
for callback in &mut self.callbacks {
callback.on_epoch_begin(epoch, state)?;
}
Ok(())
}
pub fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
for callback in &mut self.callbacks {
callback.on_epoch_end(epoch, state)?;
}
Ok(())
}
pub fn on_batch_begin(&mut self, batch: usize, state: &TrainingState) -> TrainResult<()> {
for callback in &mut self.callbacks {
callback.on_batch_begin(batch, state)?;
}
Ok(())
}
pub fn on_batch_end(&mut self, batch: usize, state: &TrainingState) -> TrainResult<()> {
for callback in &mut self.callbacks {
callback.on_batch_end(batch, state)?;
}
Ok(())
}
pub fn on_validation_end(&mut self, state: &TrainingState) -> TrainResult<()> {
for callback in &mut self.callbacks {
callback.on_validation_end(state)?;
}
Ok(())
}
pub fn should_stop(&self) -> bool {
self.callbacks.iter().any(|cb| cb.should_stop())
}
}
impl Default for CallbackList {
fn default() -> Self {
Self::new()
}
}
pub struct EpochCallback {
pub verbose: bool,
}
impl EpochCallback {
pub fn new(verbose: bool) -> Self {
Self { verbose }
}
}
impl Callback for EpochCallback {
fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
if self.verbose {
println!(
"Epoch {}: loss={:.6}, val_loss={:.6}",
epoch,
state.train_loss,
state.val_loss.unwrap_or(f64::NAN)
);
}
Ok(())
}
}
pub struct BatchCallback {
pub log_frequency: usize,
}
impl BatchCallback {
pub fn new(log_frequency: usize) -> Self {
Self { log_frequency }
}
}
impl Callback for BatchCallback {
fn on_batch_end(&mut self, batch: usize, state: &TrainingState) -> TrainResult<()> {
if batch.is_multiple_of(self.log_frequency) {
println!("Batch {}: loss={:.6}", batch, state.batch_loss);
}
Ok(())
}
}
pub struct ValidationCallback {
pub validation_frequency: usize,
}
impl ValidationCallback {
pub fn new(validation_frequency: usize) -> Self {
Self {
validation_frequency,
}
}
}
impl Callback for ValidationCallback {
fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
if epoch.is_multiple_of(self.validation_frequency) {
if let Some(val_loss) = state.val_loss {
println!("Validation at epoch {}: val_loss={:.6}", epoch, val_loss);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn create_test_state() -> TrainingState {
TrainingState {
epoch: 0,
batch: 0,
train_loss: 1.0,
val_loss: Some(0.8),
batch_loss: 0.5,
learning_rate: 0.001,
metrics: HashMap::new(),
}
}
#[test]
fn test_callback_list() {
let mut callbacks = CallbackList::new();
callbacks.add(Box::new(EpochCallback::new(false)));
let state = create_test_state();
callbacks.on_train_begin(&state).expect("unwrap");
callbacks.on_epoch_begin(0, &state).expect("unwrap");
callbacks.on_epoch_end(0, &state).expect("unwrap");
callbacks.on_train_end(&state).expect("unwrap");
}
}