use super::checkpoint::Checkpoint;
use super::metrics::{EpochMetrics, TrainingMetrics};
use crate::train::config::TrainingConfig;
pub trait TrainingCallback: Send + Sync {
fn on_train_start(&mut self, _config: &TrainingConfig) {}
fn on_train_end(&mut self, _metrics: &TrainingMetrics) {}
fn on_epoch_start(&mut self, _epoch: usize) {}
fn on_epoch_end(&mut self, _epoch: usize, _metrics: &EpochMetrics) {}
fn on_batch_start(&mut self, _batch_idx: usize) {}
fn on_batch_end(&mut self, _batch_idx: usize, _loss: f32) {}
fn on_checkpoint(&mut self, _checkpoint: &Checkpoint) {}
}
#[derive(Default)]
pub struct NoOpCallback;
impl TrainingCallback for NoOpCallback {}
pub struct LoggingCallback {
log_interval: usize,
}
impl LoggingCallback {
pub fn new(log_interval: usize) -> Self {
Self { log_interval }
}
}
impl TrainingCallback for LoggingCallback {
fn on_epoch_end(&mut self, epoch: usize, metrics: &EpochMetrics) {
let val_info = metrics
.val_loss
.map(|v| format!(", val_loss: {v:.4}"))
.unwrap_or_default();
println!(
"Epoch {}: train_loss: {:.4}, train_acc: {:.2}%{} [{:.1}s]",
epoch,
metrics.train_loss,
metrics.train_accuracy * 100.0,
val_info,
metrics.duration_secs
);
}
fn on_batch_end(&mut self, batch_idx: usize, loss: f32) {
if batch_idx % self.log_interval == 0 {
println!(" batch {batch_idx}: loss {loss:.4}");
}
}
}