entrenar/train/callback/
progress.rs1use super::traits::{CallbackAction, CallbackContext, TrainerCallback};
4
5#[derive(Clone, Debug)]
7pub struct ProgressCallback {
8 log_interval: usize,
10}
11
12impl ProgressCallback {
13 pub fn new(log_interval: usize) -> Self {
15 Self { log_interval }
16 }
17}
18
19impl Default for ProgressCallback {
20 fn default() -> Self {
21 Self { log_interval: 10 }
22 }
23}
24
25impl TrainerCallback for ProgressCallback {
26 fn on_epoch_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
27 println!("Epoch {}/{} starting (lr: {:.2e})", ctx.epoch + 1, ctx.max_epochs, ctx.lr);
28 CallbackAction::Continue
29 }
30
31 fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
32 let val_str = ctx.val_loss.map(|v| format!(", val_loss: {v:.4}")).unwrap_or_default();
33
34 println!(
35 "Epoch {}/{}: loss: {:.4}{} ({:.1}s)",
36 ctx.epoch + 1,
37 ctx.max_epochs,
38 ctx.loss,
39 val_str,
40 ctx.elapsed_secs
41 );
42 CallbackAction::Continue
43 }
44
45 fn on_step_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
46 if ctx.step > 0 && ctx.step.is_multiple_of(self.log_interval) {
47 println!(" Step {}/{}: loss: {:.4}", ctx.step, ctx.steps_per_epoch, ctx.loss);
48 }
49 CallbackAction::Continue
50 }
51
52 fn name(&self) -> &'static str {
53 "ProgressCallback"
54 }
55}
56
57#[cfg(test)]
58mod tests {
59 use super::*;
60
61 #[test]
62 fn test_progress_callback() {
63 let mut progress = ProgressCallback::new(5);
64 let ctx = CallbackContext {
65 epoch: 0,
66 max_epochs: 10,
67 step: 5,
68 steps_per_epoch: 100,
69 loss: 0.5,
70 lr: 0.001,
71 ..Default::default()
72 };
73
74 assert_eq!(progress.on_epoch_begin(&ctx), CallbackAction::Continue);
76 assert_eq!(progress.on_step_end(&ctx), CallbackAction::Continue);
77 assert_eq!(progress.on_epoch_end(&ctx), CallbackAction::Continue);
78 }
79
80 #[test]
81 fn test_progress_callback_default() {
82 let pc = ProgressCallback::default();
83 assert_eq!(pc.log_interval, 10);
84 }
85
86 #[test]
87 fn test_progress_callback_name() {
88 let pc = ProgressCallback::new(5);
89 assert_eq!(pc.name(), "ProgressCallback");
90 }
91
92 #[test]
93 fn test_progress_callback_with_val_loss() {
94 let mut pc = ProgressCallback::new(5);
95 let ctx = CallbackContext {
96 epoch: 0,
97 max_epochs: 10,
98 loss: 0.5,
99 val_loss: Some(0.6),
100 lr: 0.001,
101 elapsed_secs: 1.0,
102 ..Default::default()
103 };
104 assert_eq!(pc.on_epoch_end(&ctx), CallbackAction::Continue);
105 }
106
107 #[test]
108 fn test_progress_callback_clone() {
109 let pc = ProgressCallback::new(5);
110 let cloned = pc.clone();
111 assert_eq!(pc.log_interval, cloned.log_interval);
112 }
113
114 #[test]
115 fn test_progress_callback_on_step_end_at_interval() {
116 let mut cb = ProgressCallback::new(5);
117 let mut ctx = CallbackContext::default();
118 ctx.step = 5;
119 ctx.steps_per_epoch = 10;
120
121 let action = cb.on_step_end(&ctx);
122 assert_eq!(action, CallbackAction::Continue);
123 }
124}
125
126#[cfg(test)]
127mod proptests {
128 use super::*;
129 use proptest::prelude::*;
130
131 proptest! {
132 #[test]
134 fn progress_callback_never_stops(
135 epoch in 0usize..100,
136 step in 0usize..1000,
137 loss in -100.0f32..100.0,
138 ) {
139 let mut progress = ProgressCallback::new(10);
140 let ctx = CallbackContext {
141 epoch,
142 max_epochs: 100,
143 step,
144 steps_per_epoch: 100,
145 loss,
146 lr: 0.001,
147 ..Default::default()
148 };
149
150 prop_assert_eq!(progress.on_train_begin(&ctx), CallbackAction::Continue);
151 prop_assert_eq!(progress.on_epoch_begin(&ctx), CallbackAction::Continue);
152 prop_assert_eq!(progress.on_step_begin(&ctx), CallbackAction::Continue);
153 prop_assert_eq!(progress.on_step_end(&ctx), CallbackAction::Continue);
154 prop_assert_eq!(progress.on_epoch_end(&ctx), CallbackAction::Continue);
155 }
156 }
157}