Skip to main content

entrenar/train/callback/
traits.rs

1//! Core traits and types for the callback system
2//!
3//! This module provides the foundational types for training callbacks:
4//! - `CallbackContext` - State passed to callbacks
5//! - `CallbackAction` - Actions a callback can request
6//! - `TrainerCallback` - The trait all callbacks implement
7
8/// Context passed to callbacks with current training state
9#[derive(Clone, Debug)]
10pub struct CallbackContext {
11    /// Current epoch (0-indexed)
12    pub epoch: usize,
13    /// Total epochs planned
14    pub max_epochs: usize,
15    /// Current step within epoch
16    pub step: usize,
17    /// Total steps in epoch
18    pub steps_per_epoch: usize,
19    /// Global step count
20    pub global_step: usize,
21    /// Current loss value
22    pub loss: f32,
23    /// Current learning rate
24    pub lr: f32,
25    /// Best loss seen so far
26    pub best_loss: Option<f32>,
27    /// Validation loss (if available)
28    pub val_loss: Option<f32>,
29    /// Training duration in seconds
30    pub elapsed_secs: f64,
31}
32
33impl Default for CallbackContext {
34    fn default() -> Self {
35        Self {
36            epoch: 0,
37            max_epochs: 0,
38            step: 0,
39            steps_per_epoch: 0,
40            global_step: 0,
41            loss: 0.0,
42            lr: 0.0,
43            best_loss: None,
44            val_loss: None,
45            elapsed_secs: 0.0,
46        }
47    }
48}
49
50/// Action to take after a callback
51#[derive(Clone, Copy, Debug, PartialEq, Eq)]
52pub enum CallbackAction {
53    /// Continue training normally
54    Continue,
55    /// Stop training (early stopping)
56    Stop,
57    /// Skip rest of current epoch
58    SkipEpoch,
59}
60
61/// Trait for training callbacks
62///
63/// Implement this trait to hook into training events. All methods have
64/// default no-op implementations, so you only need to implement the
65/// events you care about.
66pub trait TrainerCallback: Send {
67    /// Called before training starts
68    fn on_train_begin(&mut self, _ctx: &CallbackContext) -> CallbackAction {
69        CallbackAction::Continue
70    }
71
72    /// Called after training ends
73    fn on_train_end(&mut self, _ctx: &CallbackContext) {}
74
75    /// Called before each epoch
76    fn on_epoch_begin(&mut self, _ctx: &CallbackContext) -> CallbackAction {
77        CallbackAction::Continue
78    }
79
80    /// Called after each epoch
81    fn on_epoch_end(&mut self, _ctx: &CallbackContext) -> CallbackAction {
82        CallbackAction::Continue
83    }
84
85    /// Called before each training step
86    fn on_step_begin(&mut self, _ctx: &CallbackContext) -> CallbackAction {
87        CallbackAction::Continue
88    }
89
90    /// Called after each training step
91    fn on_step_end(&mut self, _ctx: &CallbackContext) -> CallbackAction {
92        CallbackAction::Continue
93    }
94
95    /// Called when validation is performed
96    fn on_validation(&mut self, _ctx: &CallbackContext) -> CallbackAction {
97        CallbackAction::Continue
98    }
99
100    /// Get callback name for logging
101    fn name(&self) -> &'static str {
102        "TrainerCallback"
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    #[test]
111    fn test_callback_context_default() {
112        let ctx = CallbackContext::default();
113        assert_eq!(ctx.epoch, 0);
114        assert_eq!(ctx.loss, 0.0);
115        assert!(ctx.best_loss.is_none());
116    }
117
118    #[test]
119    fn test_callback_action_clone_copy() {
120        let action = CallbackAction::Continue;
121        let cloned = action;
122        assert_eq!(action, cloned);
123        assert_ne!(CallbackAction::Stop, CallbackAction::SkipEpoch);
124    }
125
126    #[test]
127    fn test_callback_context_clone() {
128        let ctx = CallbackContext {
129            epoch: 5,
130            max_epochs: 10,
131            step: 50,
132            steps_per_epoch: 100,
133            global_step: 550,
134            loss: 0.5,
135            lr: 0.001,
136            best_loss: Some(0.4),
137            val_loss: Some(0.6),
138            elapsed_secs: 100.0,
139        };
140        let cloned = ctx.clone();
141        assert_eq!(ctx.epoch, cloned.epoch);
142    }
143
144    #[test]
145    fn test_default_trainer_callback_impl() {
146        struct MinimalCallback;
147        impl TrainerCallback for MinimalCallback {
148            fn name(&self) -> &'static str {
149                "MinimalCallback"
150            }
151        }
152
153        let mut cb = MinimalCallback;
154        let ctx = CallbackContext::default();
155        assert_eq!(cb.on_train_begin(&ctx), CallbackAction::Continue);
156        assert_eq!(cb.on_epoch_begin(&ctx), CallbackAction::Continue);
157        assert_eq!(cb.on_epoch_end(&ctx), CallbackAction::Continue);
158        assert_eq!(cb.on_step_begin(&ctx), CallbackAction::Continue);
159        assert_eq!(cb.on_step_end(&ctx), CallbackAction::Continue);
160        assert_eq!(cb.on_validation(&ctx), CallbackAction::Continue);
161        cb.on_train_end(&ctx);
162    }
163}