entrenar/train/callback/
traits.rs1#[derive(Clone, Debug)]
10pub struct CallbackContext {
11 pub epoch: usize,
13 pub max_epochs: usize,
15 pub step: usize,
17 pub steps_per_epoch: usize,
19 pub global_step: usize,
21 pub loss: f32,
23 pub lr: f32,
25 pub best_loss: Option<f32>,
27 pub val_loss: Option<f32>,
29 pub elapsed_secs: f64,
31}
32
33impl Default for CallbackContext {
34 fn default() -> Self {
35 Self {
36 epoch: 0,
37 max_epochs: 0,
38 step: 0,
39 steps_per_epoch: 0,
40 global_step: 0,
41 loss: 0.0,
42 lr: 0.0,
43 best_loss: None,
44 val_loss: None,
45 elapsed_secs: 0.0,
46 }
47 }
48}
49
50#[derive(Clone, Copy, Debug, PartialEq, Eq)]
52pub enum CallbackAction {
53 Continue,
55 Stop,
57 SkipEpoch,
59}
60
61pub trait TrainerCallback: Send {
67 fn on_train_begin(&mut self, _ctx: &CallbackContext) -> CallbackAction {
69 CallbackAction::Continue
70 }
71
72 fn on_train_end(&mut self, _ctx: &CallbackContext) {}
74
75 fn on_epoch_begin(&mut self, _ctx: &CallbackContext) -> CallbackAction {
77 CallbackAction::Continue
78 }
79
80 fn on_epoch_end(&mut self, _ctx: &CallbackContext) -> CallbackAction {
82 CallbackAction::Continue
83 }
84
85 fn on_step_begin(&mut self, _ctx: &CallbackContext) -> CallbackAction {
87 CallbackAction::Continue
88 }
89
90 fn on_step_end(&mut self, _ctx: &CallbackContext) -> CallbackAction {
92 CallbackAction::Continue
93 }
94
95 fn on_validation(&mut self, _ctx: &CallbackContext) -> CallbackAction {
97 CallbackAction::Continue
98 }
99
100 fn name(&self) -> &'static str {
102 "TrainerCallback"
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 #[test]
111 fn test_callback_context_default() {
112 let ctx = CallbackContext::default();
113 assert_eq!(ctx.epoch, 0);
114 assert_eq!(ctx.loss, 0.0);
115 assert!(ctx.best_loss.is_none());
116 }
117
118 #[test]
119 fn test_callback_action_clone_copy() {
120 let action = CallbackAction::Continue;
121 let cloned = action;
122 assert_eq!(action, cloned);
123 assert_ne!(CallbackAction::Stop, CallbackAction::SkipEpoch);
124 }
125
126 #[test]
127 fn test_callback_context_clone() {
128 let ctx = CallbackContext {
129 epoch: 5,
130 max_epochs: 10,
131 step: 50,
132 steps_per_epoch: 100,
133 global_step: 550,
134 loss: 0.5,
135 lr: 0.001,
136 best_loss: Some(0.4),
137 val_loss: Some(0.6),
138 elapsed_secs: 100.0,
139 };
140 let cloned = ctx.clone();
141 assert_eq!(ctx.epoch, cloned.epoch);
142 }
143
144 #[test]
145 fn test_default_trainer_callback_impl() {
146 struct MinimalCallback;
147 impl TrainerCallback for MinimalCallback {
148 fn name(&self) -> &'static str {
149 "MinimalCallback"
150 }
151 }
152
153 let mut cb = MinimalCallback;
154 let ctx = CallbackContext::default();
155 assert_eq!(cb.on_train_begin(&ctx), CallbackAction::Continue);
156 assert_eq!(cb.on_epoch_begin(&ctx), CallbackAction::Continue);
157 assert_eq!(cb.on_epoch_end(&ctx), CallbackAction::Continue);
158 assert_eq!(cb.on_step_begin(&ctx), CallbackAction::Continue);
159 assert_eq!(cb.on_step_end(&ctx), CallbackAction::Continue);
160 assert_eq!(cb.on_validation(&ctx), CallbackAction::Continue);
161 cb.on_train_end(&ctx);
162 }
163}