Skip to main content

entrenar/train/callback/
progress.rs

1//! Progress callback for logging training progress
2
3use super::traits::{CallbackAction, CallbackContext, TrainerCallback};
4
5/// Progress callback for logging training progress
6#[derive(Clone, Debug)]
7pub struct ProgressCallback {
8    /// Log every N steps
9    log_interval: usize,
10}
11
12impl ProgressCallback {
13    /// Create progress callback
14    pub fn new(log_interval: usize) -> Self {
15        Self { log_interval }
16    }
17}
18
19impl Default for ProgressCallback {
20    fn default() -> Self {
21        Self { log_interval: 10 }
22    }
23}
24
25impl TrainerCallback for ProgressCallback {
26    fn on_epoch_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
27        println!("Epoch {}/{} starting (lr: {:.2e})", ctx.epoch + 1, ctx.max_epochs, ctx.lr);
28        CallbackAction::Continue
29    }
30
31    fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
32        let val_str = ctx.val_loss.map(|v| format!(", val_loss: {v:.4}")).unwrap_or_default();
33
34        println!(
35            "Epoch {}/{}: loss: {:.4}{} ({:.1}s)",
36            ctx.epoch + 1,
37            ctx.max_epochs,
38            ctx.loss,
39            val_str,
40            ctx.elapsed_secs
41        );
42        CallbackAction::Continue
43    }
44
45    fn on_step_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
46        if ctx.step > 0 && ctx.step.is_multiple_of(self.log_interval) {
47            println!("  Step {}/{}: loss: {:.4}", ctx.step, ctx.steps_per_epoch, ctx.loss);
48        }
49        CallbackAction::Continue
50    }
51
52    fn name(&self) -> &'static str {
53        "ProgressCallback"
54    }
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60
61    #[test]
62    fn test_progress_callback() {
63        let mut progress = ProgressCallback::new(5);
64        let ctx = CallbackContext {
65            epoch: 0,
66            max_epochs: 10,
67            step: 5,
68            steps_per_epoch: 100,
69            loss: 0.5,
70            lr: 0.001,
71            ..Default::default()
72        };
73
74        // Should not panic
75        assert_eq!(progress.on_epoch_begin(&ctx), CallbackAction::Continue);
76        assert_eq!(progress.on_step_end(&ctx), CallbackAction::Continue);
77        assert_eq!(progress.on_epoch_end(&ctx), CallbackAction::Continue);
78    }
79
80    #[test]
81    fn test_progress_callback_default() {
82        let pc = ProgressCallback::default();
83        assert_eq!(pc.log_interval, 10);
84    }
85
86    #[test]
87    fn test_progress_callback_name() {
88        let pc = ProgressCallback::new(5);
89        assert_eq!(pc.name(), "ProgressCallback");
90    }
91
92    #[test]
93    fn test_progress_callback_with_val_loss() {
94        let mut pc = ProgressCallback::new(5);
95        let ctx = CallbackContext {
96            epoch: 0,
97            max_epochs: 10,
98            loss: 0.5,
99            val_loss: Some(0.6),
100            lr: 0.001,
101            elapsed_secs: 1.0,
102            ..Default::default()
103        };
104        assert_eq!(pc.on_epoch_end(&ctx), CallbackAction::Continue);
105    }
106
107    #[test]
108    fn test_progress_callback_clone() {
109        let pc = ProgressCallback::new(5);
110        let cloned = pc.clone();
111        assert_eq!(pc.log_interval, cloned.log_interval);
112    }
113
114    #[test]
115    fn test_progress_callback_on_step_end_at_interval() {
116        let mut cb = ProgressCallback::new(5);
117        let mut ctx = CallbackContext::default();
118        ctx.step = 5;
119        ctx.steps_per_epoch = 10;
120
121        let action = cb.on_step_end(&ctx);
122        assert_eq!(action, CallbackAction::Continue);
123    }
124}
125
126#[cfg(test)]
127mod proptests {
128    use super::*;
129    use proptest::prelude::*;
130
131    proptest! {
132        /// Progress callback should always continue
133        #[test]
134        fn progress_callback_never_stops(
135            epoch in 0usize..100,
136            step in 0usize..1000,
137            loss in -100.0f32..100.0,
138        ) {
139            let mut progress = ProgressCallback::new(10);
140            let ctx = CallbackContext {
141                epoch,
142                max_epochs: 100,
143                step,
144                steps_per_epoch: 100,
145                loss,
146                lr: 0.001,
147                ..Default::default()
148            };
149
150            prop_assert_eq!(progress.on_train_begin(&ctx), CallbackAction::Continue);
151            prop_assert_eq!(progress.on_epoch_begin(&ctx), CallbackAction::Continue);
152            prop_assert_eq!(progress.on_step_begin(&ctx), CallbackAction::Continue);
153            prop_assert_eq!(progress.on_step_end(&ctx), CallbackAction::Continue);
154            prop_assert_eq!(progress.on_epoch_end(&ctx), CallbackAction::Continue);
155        }
156    }
157}