#[derive(Clone, Debug)]
pub struct CallbackContext {
pub epoch: usize,
pub max_epochs: usize,
pub step: usize,
pub steps_per_epoch: usize,
pub global_step: usize,
pub loss: f32,
pub lr: f32,
pub best_loss: Option<f32>,
pub val_loss: Option<f32>,
pub elapsed_secs: f64,
}
impl Default for CallbackContext {
fn default() -> Self {
Self {
epoch: 0,
max_epochs: 0,
step: 0,
steps_per_epoch: 0,
global_step: 0,
loss: 0.0,
lr: 0.0,
best_loss: None,
val_loss: None,
elapsed_secs: 0.0,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CallbackAction {
Continue,
Stop,
SkipEpoch,
}
pub trait TrainerCallback: Send {
fn on_train_begin(&mut self, _ctx: &CallbackContext) -> CallbackAction {
CallbackAction::Continue
}
fn on_train_end(&mut self, _ctx: &CallbackContext) {}
fn on_epoch_begin(&mut self, _ctx: &CallbackContext) -> CallbackAction {
CallbackAction::Continue
}
fn on_epoch_end(&mut self, _ctx: &CallbackContext) -> CallbackAction {
CallbackAction::Continue
}
fn on_step_begin(&mut self, _ctx: &CallbackContext) -> CallbackAction {
CallbackAction::Continue
}
fn on_step_end(&mut self, _ctx: &CallbackContext) -> CallbackAction {
CallbackAction::Continue
}
fn on_validation(&mut self, _ctx: &CallbackContext) -> CallbackAction {
CallbackAction::Continue
}
fn name(&self) -> &'static str {
"TrainerCallback"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_callback_context_default() {
let ctx = CallbackContext::default();
assert_eq!(ctx.epoch, 0);
assert_eq!(ctx.loss, 0.0);
assert!(ctx.best_loss.is_none());
}
#[test]
fn test_callback_action_clone_copy() {
let action = CallbackAction::Continue;
let cloned = action;
assert_eq!(action, cloned);
assert_ne!(CallbackAction::Stop, CallbackAction::SkipEpoch);
}
#[test]
fn test_callback_context_clone() {
let ctx = CallbackContext {
epoch: 5,
max_epochs: 10,
step: 50,
steps_per_epoch: 100,
global_step: 550,
loss: 0.5,
lr: 0.001,
best_loss: Some(0.4),
val_loss: Some(0.6),
elapsed_secs: 100.0,
};
let cloned = ctx.clone();
assert_eq!(ctx.epoch, cloned.epoch);
}
#[test]
fn test_default_trainer_callback_impl() {
struct MinimalCallback;
impl TrainerCallback for MinimalCallback {
fn name(&self) -> &'static str {
"MinimalCallback"
}
}
let mut cb = MinimalCallback;
let ctx = CallbackContext::default();
assert_eq!(cb.on_train_begin(&ctx), CallbackAction::Continue);
assert_eq!(cb.on_epoch_begin(&ctx), CallbackAction::Continue);
assert_eq!(cb.on_epoch_end(&ctx), CallbackAction::Continue);
assert_eq!(cb.on_step_begin(&ctx), CallbackAction::Continue);
assert_eq!(cb.on_step_end(&ctx), CallbackAction::Continue);
assert_eq!(cb.on_validation(&ctx), CallbackAction::Continue);
cb.on_train_end(&ctx);
}
}