1use super::traits::{CallbackAction, CallbackContext, TrainerCallback};
4
5pub struct CallbackManager {
7 callbacks: Vec<Box<dyn TrainerCallback>>,
8}
9
10impl CallbackManager {
11 pub fn new() -> Self {
13 Self { callbacks: Vec::new() }
14 }
15
16 pub fn add<C: TrainerCallback + 'static>(&mut self, callback: C) {
18 self.callbacks.push(Box::new(callback));
19 }
20
21 pub fn is_empty(&self) -> bool {
23 self.callbacks.is_empty()
24 }
25
26 pub fn len(&self) -> usize {
28 self.callbacks.len()
29 }
30
31 pub fn on_train_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
33 for cb in &mut self.callbacks {
34 if cb.on_train_begin(ctx) == CallbackAction::Stop {
35 return CallbackAction::Stop;
36 }
37 }
38 CallbackAction::Continue
39 }
40
41 pub fn on_train_end(&mut self, ctx: &CallbackContext) {
43 for cb in &mut self.callbacks {
44 cb.on_train_end(ctx);
45 }
46 }
47
48 pub fn on_epoch_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
50 for cb in &mut self.callbacks {
51 match cb.on_epoch_begin(ctx) {
52 CallbackAction::Stop => return CallbackAction::Stop,
53 CallbackAction::SkipEpoch => return CallbackAction::SkipEpoch,
54 CallbackAction::Continue => {}
55 }
56 }
57 CallbackAction::Continue
58 }
59
60 pub fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
62 for cb in &mut self.callbacks {
63 if cb.on_epoch_end(ctx) == CallbackAction::Stop {
64 return CallbackAction::Stop;
65 }
66 }
67 CallbackAction::Continue
68 }
69
70 pub fn on_step_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
72 for cb in &mut self.callbacks {
73 if cb.on_step_begin(ctx) == CallbackAction::Stop {
74 return CallbackAction::Stop;
75 }
76 }
77 CallbackAction::Continue
78 }
79
80 pub fn on_step_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
82 for cb in &mut self.callbacks {
83 if cb.on_step_end(ctx) == CallbackAction::Stop {
84 return CallbackAction::Stop;
85 }
86 }
87 CallbackAction::Continue
88 }
89}
90
91impl Default for CallbackManager {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use crate::train::callback::{EarlyStopping, ProgressCallback};
101
102 #[test]
103 fn test_callback_manager_dispatch() {
104 let mut manager = CallbackManager::new();
105
106 let es = EarlyStopping::new(1, 0.001);
108 manager.add(es);
109
110 let mut ctx = CallbackContext::default();
111 ctx.loss = 1.0;
112
113 assert_eq!(manager.on_epoch_end(&ctx), CallbackAction::Continue);
115
116 ctx.epoch = 1;
118 assert_eq!(manager.on_epoch_end(&ctx), CallbackAction::Stop);
119 }
120
121 #[test]
122 fn test_callback_manager_len_and_empty() {
123 let mut manager = CallbackManager::new();
124 assert!(manager.is_empty());
125 assert_eq!(manager.len(), 0);
126
127 manager.add(ProgressCallback::new(10));
128 assert!(!manager.is_empty());
129 assert_eq!(manager.len(), 1);
130 }
131
132 #[test]
133 fn test_callback_manager_on_train_begin_stop() {
134 struct StopCallback;
135 impl TrainerCallback for StopCallback {
136 fn on_train_begin(&mut self, _: &CallbackContext) -> CallbackAction {
137 CallbackAction::Stop
138 }
139 fn name(&self) -> &'static str {
140 "StopCallback"
141 }
142 }
143
144 let mut manager = CallbackManager::new();
145 manager.add(StopCallback);
146 assert_eq!(manager.on_train_begin(&CallbackContext::default()), CallbackAction::Stop);
147 }
148
149 #[test]
150 fn test_callback_manager_on_train_end() {
151 use std::sync::{
152 atomic::{AtomicBool, Ordering},
153 Arc,
154 };
155
156 struct EndCallback {
157 called: Arc<AtomicBool>,
158 }
159 impl TrainerCallback for EndCallback {
160 fn on_train_end(&mut self, _: &CallbackContext) {
161 self.called.store(true, Ordering::SeqCst);
162 }
163 fn name(&self) -> &'static str {
164 "EndCallback"
165 }
166 }
167
168 let called = Arc::new(AtomicBool::new(false));
169 let mut manager = CallbackManager::new();
170 manager.add(EndCallback { called: called.clone() });
171 manager.on_train_end(&CallbackContext::default());
172 assert!(called.load(Ordering::SeqCst));
173 }
174
175 #[test]
176 fn test_callback_manager_on_epoch_begin_skip() {
177 struct SkipCallback;
178 impl TrainerCallback for SkipCallback {
179 fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction {
180 CallbackAction::SkipEpoch
181 }
182 fn name(&self) -> &'static str {
183 "SkipCallback"
184 }
185 }
186
187 let mut manager = CallbackManager::new();
188 manager.add(SkipCallback);
189 assert_eq!(manager.on_epoch_begin(&CallbackContext::default()), CallbackAction::SkipEpoch);
190 }
191
192 #[test]
193 fn test_callback_manager_on_epoch_begin_stop() {
194 struct StopCallback;
195 impl TrainerCallback for StopCallback {
196 fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction {
197 CallbackAction::Stop
198 }
199 fn name(&self) -> &'static str {
200 "StopCallback"
201 }
202 }
203
204 let mut manager = CallbackManager::new();
205 manager.add(StopCallback);
206 assert_eq!(manager.on_epoch_begin(&CallbackContext::default()), CallbackAction::Stop);
207 }
208
209 #[test]
210 fn test_callback_manager_on_step_begin_stop() {
211 struct StopCallback;
212 impl TrainerCallback for StopCallback {
213 fn on_step_begin(&mut self, _: &CallbackContext) -> CallbackAction {
214 CallbackAction::Stop
215 }
216 fn name(&self) -> &'static str {
217 "StopCallback"
218 }
219 }
220
221 let mut manager = CallbackManager::new();
222 manager.add(StopCallback);
223 assert_eq!(manager.on_step_begin(&CallbackContext::default()), CallbackAction::Stop);
224 }
225
226 #[test]
227 fn test_callback_manager_default() {
228 let manager = CallbackManager::default();
229 assert!(manager.is_empty());
230 }
231
232 #[test]
233 fn test_callback_manager_stop_propagation() {
234 struct StopCallback;
236 impl TrainerCallback for StopCallback {
237 fn on_epoch_end(&mut self, _: &CallbackContext) -> CallbackAction {
238 CallbackAction::Stop
239 }
240 fn name(&self) -> &'static str {
241 "StopCallback"
242 }
243 }
244
245 let mut manager = CallbackManager::new();
246 manager.add(StopCallback);
247 manager.add(ProgressCallback::new(10));
248
249 let ctx = CallbackContext::default();
250 let action = manager.on_epoch_end(&ctx);
251 assert_eq!(action, CallbackAction::Stop);
253 }
254
255 #[test]
256 fn test_callback_manager_on_step_end_stop() {
257 struct StopCallback;
258 impl TrainerCallback for StopCallback {
259 fn on_step_end(&mut self, _: &CallbackContext) -> CallbackAction {
260 CallbackAction::Stop
261 }
262 fn name(&self) -> &'static str {
263 "StopCallback"
264 }
265 }
266
267 let mut manager = CallbackManager::new();
268 manager.add(StopCallback);
269 assert_eq!(manager.on_step_end(&CallbackContext::default()), CallbackAction::Stop);
270 }
271
272 #[test]
273 fn test_callback_manager_all_continue() {
274 struct ContinueCallback;
276 impl TrainerCallback for ContinueCallback {
277 fn on_train_begin(&mut self, _: &CallbackContext) -> CallbackAction {
278 CallbackAction::Continue
279 }
280 fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction {
281 CallbackAction::Continue
282 }
283 fn on_epoch_end(&mut self, _: &CallbackContext) -> CallbackAction {
284 CallbackAction::Continue
285 }
286 fn on_step_begin(&mut self, _: &CallbackContext) -> CallbackAction {
287 CallbackAction::Continue
288 }
289 fn on_step_end(&mut self, _: &CallbackContext) -> CallbackAction {
290 CallbackAction::Continue
291 }
292 fn name(&self) -> &'static str {
293 "ContinueCallback"
294 }
295 }
296
297 let mut manager = CallbackManager::new();
298 manager.add(ContinueCallback);
299 manager.add(ContinueCallback);
300
301 let ctx = CallbackContext::default();
302 assert_eq!(manager.on_train_begin(&ctx), CallbackAction::Continue);
303 assert_eq!(manager.on_epoch_begin(&ctx), CallbackAction::Continue);
304 assert_eq!(manager.on_epoch_end(&ctx), CallbackAction::Continue);
305 assert_eq!(manager.on_step_begin(&ctx), CallbackAction::Continue);
306 assert_eq!(manager.on_step_end(&ctx), CallbackAction::Continue);
307 }
308
309 #[test]
310 fn test_callback_manager_multiple_train_end() {
311 use std::sync::{
312 atomic::{AtomicUsize, Ordering},
313 Arc,
314 };
315
316 struct CountingEndCallback {
317 count: Arc<AtomicUsize>,
318 }
319
320 impl TrainerCallback for CountingEndCallback {
321 fn on_train_end(&mut self, _: &CallbackContext) {
322 self.count.fetch_add(1, Ordering::SeqCst);
323 }
324 fn name(&self) -> &'static str {
325 "CountingEndCallback"
326 }
327 }
328
329 let count = Arc::new(AtomicUsize::new(0));
330 let mut manager = CallbackManager::new();
331 manager.add(CountingEndCallback { count: count.clone() });
332 manager.add(CountingEndCallback { count: count.clone() });
333 manager.add(CountingEndCallback { count: count.clone() });
334
335 manager.on_train_end(&CallbackContext::default());
336 assert_eq!(count.load(Ordering::SeqCst), 3);
337 }
338
339 #[test]
340 fn test_callback_manager_stop_after_first() {
341 use std::sync::{
342 atomic::{AtomicUsize, Ordering},
343 Arc,
344 };
345
346 struct CountingStopCallback {
347 count: Arc<AtomicUsize>,
348 }
349
350 impl TrainerCallback for CountingStopCallback {
351 fn on_train_begin(&mut self, _: &CallbackContext) -> CallbackAction {
352 self.count.fetch_add(1, Ordering::SeqCst);
353 CallbackAction::Stop
354 }
355 fn name(&self) -> &'static str {
356 "CountingStopCallback"
357 }
358 }
359
360 struct CountingContinueCallback {
361 count: Arc<AtomicUsize>,
362 }
363
364 impl TrainerCallback for CountingContinueCallback {
365 fn on_train_begin(&mut self, _: &CallbackContext) -> CallbackAction {
366 self.count.fetch_add(1, Ordering::SeqCst);
367 CallbackAction::Continue
368 }
369 fn name(&self) -> &'static str {
370 "CountingContinueCallback"
371 }
372 }
373
374 let count = Arc::new(AtomicUsize::new(0));
375 let mut manager = CallbackManager::new();
376 manager.add(CountingStopCallback { count: count.clone() });
377 manager.add(CountingContinueCallback { count: count.clone() });
378
379 let action = manager.on_train_begin(&CallbackContext::default());
381 assert_eq!(action, CallbackAction::Stop);
382 assert_eq!(count.load(Ordering::SeqCst), 1);
383 }
384
385 #[test]
388 fn test_callback_manager_on_train_begin_continue() {
389 let mut manager = CallbackManager::new();
390 assert_eq!(manager.on_train_begin(&CallbackContext::default()), CallbackAction::Continue);
392 }
393
394 #[test]
395 fn test_callback_manager_on_epoch_end_continue() {
396 let mut manager = CallbackManager::new();
397 assert_eq!(manager.on_epoch_end(&CallbackContext::default()), CallbackAction::Continue);
399 }
400
401 #[test]
402 fn test_callback_manager_on_step_begin_continue() {
403 let mut manager = CallbackManager::new();
404 assert_eq!(manager.on_step_begin(&CallbackContext::default()), CallbackAction::Continue);
405 }
406
407 #[test]
408 fn test_callback_manager_on_step_end_continue() {
409 let mut manager = CallbackManager::new();
410 assert_eq!(manager.on_step_end(&CallbackContext::default()), CallbackAction::Continue);
411 }
412
413 #[test]
414 fn test_callback_manager_on_epoch_begin_continue() {
415 let mut manager = CallbackManager::new();
416 assert_eq!(manager.on_epoch_begin(&CallbackContext::default()), CallbackAction::Continue);
417 }
418
419 #[test]
420 fn test_callback_manager_stop_epoch_begin_does_not_call_second() {
421 use std::sync::{
422 atomic::{AtomicUsize, Ordering},
423 Arc,
424 };
425
426 struct StopEpochBegin {
427 count: Arc<AtomicUsize>,
428 }
429 impl TrainerCallback for StopEpochBegin {
430 fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction {
431 self.count.fetch_add(1, Ordering::SeqCst);
432 CallbackAction::Stop
433 }
434 fn name(&self) -> &'static str {
435 "StopEpochBegin"
436 }
437 }
438
439 struct CountEpochBegin {
440 count: Arc<AtomicUsize>,
441 }
442 impl TrainerCallback for CountEpochBegin {
443 fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction {
444 self.count.fetch_add(1, Ordering::SeqCst);
445 CallbackAction::Continue
446 }
447 fn name(&self) -> &'static str {
448 "CountEpochBegin"
449 }
450 }
451
452 let count = Arc::new(AtomicUsize::new(0));
453 let mut manager = CallbackManager::new();
454 manager.add(StopEpochBegin { count: count.clone() });
455 manager.add(CountEpochBegin { count: count.clone() });
456
457 let action = manager.on_epoch_begin(&CallbackContext::default());
458 assert_eq!(action, CallbackAction::Stop);
459 assert_eq!(count.load(Ordering::SeqCst), 1); }
461
462 #[test]
463 fn test_callback_manager_stop_epoch_end_does_not_call_second() {
464 use std::sync::{
465 atomic::{AtomicUsize, Ordering},
466 Arc,
467 };
468
469 struct StopEpochEnd {
470 count: Arc<AtomicUsize>,
471 }
472 impl TrainerCallback for StopEpochEnd {
473 fn on_epoch_end(&mut self, _: &CallbackContext) -> CallbackAction {
474 self.count.fetch_add(1, Ordering::SeqCst);
475 CallbackAction::Stop
476 }
477 fn name(&self) -> &'static str {
478 "StopEpochEnd"
479 }
480 }
481
482 struct CountEpochEnd {
483 count: Arc<AtomicUsize>,
484 }
485 impl TrainerCallback for CountEpochEnd {
486 fn on_epoch_end(&mut self, _: &CallbackContext) -> CallbackAction {
487 self.count.fetch_add(1, Ordering::SeqCst);
488 CallbackAction::Continue
489 }
490 fn name(&self) -> &'static str {
491 "CountEpochEnd"
492 }
493 }
494
495 let count = Arc::new(AtomicUsize::new(0));
496 let mut manager = CallbackManager::new();
497 manager.add(StopEpochEnd { count: count.clone() });
498 manager.add(CountEpochEnd { count: count.clone() });
499
500 let action = manager.on_epoch_end(&CallbackContext::default());
501 assert_eq!(action, CallbackAction::Stop);
502 assert_eq!(count.load(Ordering::SeqCst), 1);
503 }
504
505 #[test]
506 fn test_callback_manager_stop_step_begin_does_not_call_second() {
507 use std::sync::{
508 atomic::{AtomicUsize, Ordering},
509 Arc,
510 };
511
512 struct StopStepBegin {
513 count: Arc<AtomicUsize>,
514 }
515 impl TrainerCallback for StopStepBegin {
516 fn on_step_begin(&mut self, _: &CallbackContext) -> CallbackAction {
517 self.count.fetch_add(1, Ordering::SeqCst);
518 CallbackAction::Stop
519 }
520 fn name(&self) -> &'static str {
521 "StopStepBegin"
522 }
523 }
524
525 struct CountStepBegin {
526 count: Arc<AtomicUsize>,
527 }
528 impl TrainerCallback for CountStepBegin {
529 fn on_step_begin(&mut self, _: &CallbackContext) -> CallbackAction {
530 self.count.fetch_add(1, Ordering::SeqCst);
531 CallbackAction::Continue
532 }
533 fn name(&self) -> &'static str {
534 "CountStepBegin"
535 }
536 }
537
538 let count = Arc::new(AtomicUsize::new(0));
539 let mut manager = CallbackManager::new();
540 manager.add(StopStepBegin { count: count.clone() });
541 manager.add(CountStepBegin { count: count.clone() });
542
543 let action = manager.on_step_begin(&CallbackContext::default());
544 assert_eq!(action, CallbackAction::Stop);
545 assert_eq!(count.load(Ordering::SeqCst), 1);
546 }
547
548 #[test]
549 fn test_callback_manager_stop_step_end_does_not_call_second() {
550 use std::sync::{
551 atomic::{AtomicUsize, Ordering},
552 Arc,
553 };
554
555 struct StopStepEnd {
556 count: Arc<AtomicUsize>,
557 }
558 impl TrainerCallback for StopStepEnd {
559 fn on_step_end(&mut self, _: &CallbackContext) -> CallbackAction {
560 self.count.fetch_add(1, Ordering::SeqCst);
561 CallbackAction::Stop
562 }
563 fn name(&self) -> &'static str {
564 "StopStepEnd"
565 }
566 }
567
568 struct CountStepEnd {
569 count: Arc<AtomicUsize>,
570 }
571 impl TrainerCallback for CountStepEnd {
572 fn on_step_end(&mut self, _: &CallbackContext) -> CallbackAction {
573 self.count.fetch_add(1, Ordering::SeqCst);
574 CallbackAction::Continue
575 }
576 fn name(&self) -> &'static str {
577 "CountStepEnd"
578 }
579 }
580
581 let count = Arc::new(AtomicUsize::new(0));
582 let mut manager = CallbackManager::new();
583 manager.add(StopStepEnd { count: count.clone() });
584 manager.add(CountStepEnd { count: count.clone() });
585
586 let action = manager.on_step_end(&CallbackContext::default());
587 assert_eq!(action, CallbackAction::Stop);
588 assert_eq!(count.load(Ordering::SeqCst), 1);
589 }
590
591 #[test]
592 fn test_callback_manager_skip_epoch_does_not_call_second() {
593 use std::sync::{
594 atomic::{AtomicUsize, Ordering},
595 Arc,
596 };
597
598 struct SkipCallback {
599 count: Arc<AtomicUsize>,
600 }
601 impl TrainerCallback for SkipCallback {
602 fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction {
603 self.count.fetch_add(1, Ordering::SeqCst);
604 CallbackAction::SkipEpoch
605 }
606 fn name(&self) -> &'static str {
607 "SkipCallback"
608 }
609 }
610
611 struct ContinueCallback {
612 count: Arc<AtomicUsize>,
613 }
614 impl TrainerCallback for ContinueCallback {
615 fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction {
616 self.count.fetch_add(1, Ordering::SeqCst);
617 CallbackAction::Continue
618 }
619 fn name(&self) -> &'static str {
620 "ContinueCallback"
621 }
622 }
623
624 let count = Arc::new(AtomicUsize::new(0));
625 let mut manager = CallbackManager::new();
626 manager.add(SkipCallback { count: count.clone() });
627 manager.add(ContinueCallback { count: count.clone() });
628
629 let action = manager.on_epoch_begin(&CallbackContext::default());
630 assert_eq!(action, CallbackAction::SkipEpoch);
631 assert_eq!(count.load(Ordering::SeqCst), 1);
632 }
633
634 #[test]
635 fn test_callback_manager_with_context_values() {
636 use std::sync::{
637 atomic::{AtomicUsize, Ordering},
638 Arc,
639 };
640
641 struct EpochTracker {
642 last_epoch: Arc<AtomicUsize>,
643 }
644 impl TrainerCallback for EpochTracker {
645 fn on_epoch_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
646 self.last_epoch.store(ctx.epoch, Ordering::SeqCst);
647 CallbackAction::Continue
648 }
649 fn name(&self) -> &'static str {
650 "EpochTracker"
651 }
652 }
653
654 let last_epoch = Arc::new(AtomicUsize::new(999));
655 let mut manager = CallbackManager::new();
656 manager.add(EpochTracker { last_epoch: last_epoch.clone() });
657
658 let mut ctx = CallbackContext::default();
659 ctx.epoch = 42;
660 manager.on_epoch_begin(&ctx);
661 assert_eq!(last_epoch.load(Ordering::SeqCst), 42);
662 }
663
664 #[test]
665 fn test_callback_manager_train_end_all_called() {
666 use std::sync::{
667 atomic::{AtomicUsize, Ordering},
668 Arc,
669 };
670
671 struct CountCallback {
672 count: Arc<AtomicUsize>,
673 }
674 impl TrainerCallback for CountCallback {
675 fn on_train_end(&mut self, _: &CallbackContext) {
676 self.count.fetch_add(1, Ordering::SeqCst);
677 }
678 fn name(&self) -> &'static str {
679 "CountCallback"
680 }
681 }
682
683 let count = Arc::new(AtomicUsize::new(0));
684 let mut manager = CallbackManager::new();
685 for _ in 0..5 {
686 manager.add(CountCallback { count: count.clone() });
687 }
688 assert_eq!(manager.len(), 5);
689
690 manager.on_train_end(&CallbackContext::default());
691 assert_eq!(count.load(Ordering::SeqCst), 5);
692 }
693
694 #[test]
697 fn test_cov4_manager_full_lifecycle() {
698 use std::sync::Arc;
700
701 struct LifecycleCallback {
702 events: Arc<std::sync::Mutex<Vec<String>>>,
703 }
704 impl TrainerCallback for LifecycleCallback {
705 fn on_train_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
706 self.events.lock().unwrap().push(format!("train_begin:{}", ctx.epoch));
707 CallbackAction::Continue
708 }
709 fn on_train_end(&mut self, ctx: &CallbackContext) {
710 self.events.lock().unwrap().push(format!("train_end:{}", ctx.epoch));
711 }
712 fn on_epoch_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
713 self.events.lock().unwrap().push(format!("epoch_begin:{}", ctx.epoch));
714 CallbackAction::Continue
715 }
716 fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
717 self.events.lock().unwrap().push(format!("epoch_end:{}", ctx.epoch));
718 CallbackAction::Continue
719 }
720 fn on_step_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
721 self.events.lock().unwrap().push(format!("step_begin:{}", ctx.step));
722 CallbackAction::Continue
723 }
724 fn on_step_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
725 self.events.lock().unwrap().push(format!("step_end:{}", ctx.step));
726 CallbackAction::Continue
727 }
728 fn name(&self) -> &'static str {
729 "LifecycleCallback"
730 }
731 }
732
733 let events = Arc::new(std::sync::Mutex::new(Vec::new()));
734 let mut manager = CallbackManager::new();
735 manager.add(LifecycleCallback { events: events.clone() });
736
737 let mut ctx = CallbackContext::default();
738 ctx.max_epochs = 2;
739 ctx.steps_per_epoch = 3;
740
741 manager.on_train_begin(&ctx);
742 for epoch in 0..2 {
743 ctx.epoch = epoch;
744 manager.on_epoch_begin(&ctx);
745 for step in 0..3 {
746 ctx.step = step;
747 manager.on_step_begin(&ctx);
748 manager.on_step_end(&ctx);
749 }
750 manager.on_epoch_end(&ctx);
751 }
752 manager.on_train_end(&ctx);
753
754 let ev = events.lock().unwrap();
755 assert_eq!(ev[0], "train_begin:0");
756 assert_eq!(ev[1], "epoch_begin:0");
757 assert_eq!(ev[2], "step_begin:0");
758 assert_eq!(ev[3], "step_end:0");
759 assert!(ev.len() >= 16); assert_eq!(*ev.last().unwrap(), "train_end:1");
761 }
762
763 #[test]
764 fn test_cov4_manager_mixed_callbacks_epoch_end() {
765 use std::sync::{
767 atomic::{AtomicUsize, Ordering},
768 Arc,
769 };
770
771 struct ContinueTracker {
772 count: Arc<AtomicUsize>,
773 }
774 impl TrainerCallback for ContinueTracker {
775 fn on_epoch_end(&mut self, _: &CallbackContext) -> CallbackAction {
776 self.count.fetch_add(1, Ordering::SeqCst);
777 CallbackAction::Continue
778 }
779 fn name(&self) -> &'static str {
780 "ContinueTracker"
781 }
782 }
783
784 struct StopTracker {
785 count: Arc<AtomicUsize>,
786 }
787 impl TrainerCallback for StopTracker {
788 fn on_epoch_end(&mut self, _: &CallbackContext) -> CallbackAction {
789 self.count.fetch_add(1, Ordering::SeqCst);
790 CallbackAction::Stop
791 }
792 fn name(&self) -> &'static str {
793 "StopTracker"
794 }
795 }
796
797 let count = Arc::new(AtomicUsize::new(0));
798 let mut manager = CallbackManager::new();
799 manager.add(ContinueTracker { count: count.clone() });
801 manager.add(StopTracker { count: count.clone() });
802
803 let action = manager.on_epoch_end(&CallbackContext::default());
804 assert_eq!(action, CallbackAction::Stop);
805 assert_eq!(count.load(Ordering::SeqCst), 2); }
807
808 #[test]
809 fn test_cov4_manager_mixed_callbacks_step_end() {
810 use std::sync::{
812 atomic::{AtomicUsize, Ordering},
813 Arc,
814 };
815
816 struct ContinueCb {
817 count: Arc<AtomicUsize>,
818 }
819 impl TrainerCallback for ContinueCb {
820 fn on_step_end(&mut self, _: &CallbackContext) -> CallbackAction {
821 self.count.fetch_add(1, Ordering::SeqCst);
822 CallbackAction::Continue
823 }
824 fn name(&self) -> &'static str {
825 "ContinueCb"
826 }
827 }
828
829 struct StopCb {
830 count: Arc<AtomicUsize>,
831 }
832 impl TrainerCallback for StopCb {
833 fn on_step_end(&mut self, _: &CallbackContext) -> CallbackAction {
834 self.count.fetch_add(1, Ordering::SeqCst);
835 CallbackAction::Stop
836 }
837 fn name(&self) -> &'static str {
838 "StopCb"
839 }
840 }
841
842 let count = Arc::new(AtomicUsize::new(0));
843 let mut manager = CallbackManager::new();
844 manager.add(ContinueCb { count: count.clone() });
845 manager.add(ContinueCb { count: count.clone() });
846 manager.add(StopCb { count: count.clone() });
847
848 let action = manager.on_step_end(&CallbackContext::default());
849 assert_eq!(action, CallbackAction::Stop);
850 assert_eq!(count.load(Ordering::SeqCst), 3);
851 }
852
853 #[test]
854 fn test_cov4_manager_ctx_with_rich_fields() {
855 use std::sync::Arc;
857
858 struct FieldChecker {
859 verified: Arc<std::sync::atomic::AtomicBool>,
860 }
861 impl TrainerCallback for FieldChecker {
862 fn on_step_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
863 if ctx.epoch == 3
864 && ctx.max_epochs == 10
865 && ctx.step == 7
866 && ctx.steps_per_epoch == 100
867 && ctx.global_step == 307
868 && (ctx.loss - 0.42).abs() < 1e-5
869 && (ctx.lr - 1e-4).abs() < 1e-8
870 && ctx.best_loss == Some(0.30)
871 && ctx.val_loss == Some(0.50)
872 && (ctx.elapsed_secs - 123.4).abs() < 0.1
873 {
874 self.verified.store(true, std::sync::atomic::Ordering::SeqCst);
875 }
876 CallbackAction::Continue
877 }
878 fn name(&self) -> &'static str {
879 "FieldChecker"
880 }
881 }
882
883 let verified = Arc::new(std::sync::atomic::AtomicBool::new(false));
884 let mut manager = CallbackManager::new();
885 manager.add(FieldChecker { verified: verified.clone() });
886
887 let ctx = CallbackContext {
888 epoch: 3,
889 max_epochs: 10,
890 step: 7,
891 steps_per_epoch: 100,
892 global_step: 307,
893 loss: 0.42,
894 lr: 1e-4,
895 best_loss: Some(0.30),
896 val_loss: Some(0.50),
897 elapsed_secs: 123.4,
898 };
899
900 manager.on_step_end(&ctx);
901 assert!(verified.load(std::sync::atomic::Ordering::SeqCst));
902 }
903
904 #[test]
905 fn test_cov4_manager_multiple_adds() {
906 let mut manager = CallbackManager::new();
907 assert_eq!(manager.len(), 0);
908 assert!(manager.is_empty());
909
910 manager.add(ProgressCallback::new(10));
911 manager.add(ProgressCallback::new(20));
912 manager.add(EarlyStopping::new(5, 0.001));
913 assert_eq!(manager.len(), 3);
914 assert!(!manager.is_empty());
915 }
916
917 #[test]
918 fn test_cov4_manager_train_begin_multiple_continue() {
919 use std::sync::{
920 atomic::{AtomicUsize, Ordering},
921 Arc,
922 };
923
924 struct CountCb {
925 count: Arc<AtomicUsize>,
926 }
927 impl TrainerCallback for CountCb {
928 fn on_train_begin(&mut self, _: &CallbackContext) -> CallbackAction {
929 self.count.fetch_add(1, Ordering::SeqCst);
930 CallbackAction::Continue
931 }
932 fn name(&self) -> &'static str {
933 "CountCb"
934 }
935 }
936
937 let count = Arc::new(AtomicUsize::new(0));
938 let mut manager = CallbackManager::new();
939 manager.add(CountCb { count: count.clone() });
940 manager.add(CountCb { count: count.clone() });
941 manager.add(CountCb { count: count.clone() });
942
943 let action = manager.on_train_begin(&CallbackContext::default());
944 assert_eq!(action, CallbackAction::Continue);
945 assert_eq!(count.load(Ordering::SeqCst), 3);
946 }
947
948 #[test]
949 fn test_cov4_manager_step_begin_multiple_continue() {
950 use std::sync::{
951 atomic::{AtomicUsize, Ordering},
952 Arc,
953 };
954
955 struct CountCb {
956 count: Arc<AtomicUsize>,
957 }
958 impl TrainerCallback for CountCb {
959 fn on_step_begin(&mut self, _: &CallbackContext) -> CallbackAction {
960 self.count.fetch_add(1, Ordering::SeqCst);
961 CallbackAction::Continue
962 }
963 fn name(&self) -> &'static str {
964 "CountCb"
965 }
966 }
967
968 let count = Arc::new(AtomicUsize::new(0));
969 let mut manager = CallbackManager::new();
970 manager.add(CountCb { count: count.clone() });
971 manager.add(CountCb { count: count.clone() });
972
973 let action = manager.on_step_begin(&CallbackContext::default());
974 assert_eq!(action, CallbackAction::Continue);
975 assert_eq!(count.load(Ordering::SeqCst), 2);
976 }
977
978 #[test]
979 fn test_cov4_manager_epoch_begin_multiple_continue() {
980 use std::sync::{
981 atomic::{AtomicUsize, Ordering},
982 Arc,
983 };
984
985 struct CountCb {
986 count: Arc<AtomicUsize>,
987 }
988 impl TrainerCallback for CountCb {
989 fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction {
990 self.count.fetch_add(1, Ordering::SeqCst);
991 CallbackAction::Continue
992 }
993 fn name(&self) -> &'static str {
994 "CountCb"
995 }
996 }
997
998 let count = Arc::new(AtomicUsize::new(0));
999 let mut manager = CallbackManager::new();
1000 manager.add(CountCb { count: count.clone() });
1001 manager.add(CountCb { count: count.clone() });
1002 manager.add(CountCb { count: count.clone() });
1003
1004 let action = manager.on_epoch_begin(&CallbackContext::default());
1005 assert_eq!(action, CallbackAction::Continue);
1006 assert_eq!(count.load(Ordering::SeqCst), 3);
1007 }
1008
1009 #[test]
1010 fn test_cov4_manager_train_end_empty() {
1011 let mut manager = CallbackManager::new();
1012 manager.on_train_end(&CallbackContext::default());
1014 }
1015
1016 #[test]
1017 fn test_cov4_manager_early_stopping_with_improvement() {
1018 let mut manager = CallbackManager::new();
1019 manager.add(EarlyStopping::new(3, 0.001));
1020
1021 let mut ctx = CallbackContext::default();
1022
1023 ctx.epoch = 0;
1025 ctx.loss = 1.0;
1026 assert_eq!(manager.on_epoch_end(&ctx), CallbackAction::Continue);
1027
1028 ctx.epoch = 1;
1030 ctx.loss = 0.5;
1031 assert_eq!(manager.on_epoch_end(&ctx), CallbackAction::Continue);
1032
1033 ctx.epoch = 2;
1035 ctx.loss = 0.6;
1036 assert_eq!(manager.on_epoch_end(&ctx), CallbackAction::Continue);
1037
1038 ctx.epoch = 3;
1040 ctx.loss = 0.3;
1041 assert_eq!(manager.on_epoch_end(&ctx), CallbackAction::Continue);
1042
1043 for i in 4..7 {
1045 ctx.epoch = i;
1046 ctx.loss = 0.35;
1047 let action = manager.on_epoch_end(&ctx);
1048 if i == 6 {
1049 assert_eq!(action, CallbackAction::Stop);
1050 } else {
1051 assert_eq!(action, CallbackAction::Continue);
1052 }
1053 }
1054 }
1055
1056 #[test]
1057 fn test_cov4_manager_default_new_equivalent() {
1058 let m1 = CallbackManager::new();
1059 let m2 = CallbackManager::default();
1060 assert_eq!(m1.len(), m2.len());
1061 assert_eq!(m1.is_empty(), m2.is_empty());
1062 }
1063}
1064
1065#[cfg(test)]
1066mod proptests {
1067 use super::*;
1068 use crate::train::callback::EarlyStopping;
1069 use proptest::prelude::*;
1070
1071 proptest! {
1072 #[test]
1074 fn callback_manager_propagates_stop(
1075 patience in 1usize..5,
1076 ) {
1077 let mut manager = CallbackManager::new();
1078 manager.add(EarlyStopping::new(patience, 0.001));
1079
1080 let mut ctx = CallbackContext::default();
1081 ctx.loss = 1.0;
1082
1083 for epoch in 0..patience {
1085 ctx.epoch = epoch;
1086 let action = manager.on_epoch_end(&ctx);
1087 if epoch < patience - 1 {
1088 prop_assert_eq!(action, CallbackAction::Continue);
1089 }
1090 }
1091
1092 ctx.epoch = patience;
1094 prop_assert_eq!(manager.on_epoch_end(&ctx), CallbackAction::Stop);
1095 }
1096
1097 #[test]
1099 fn multiple_callbacks_fire(
1100 num_callbacks in 1usize..5,
1101 ) {
1102 use std::sync::{Arc, atomic::{AtomicUsize, Ordering}};
1103
1104 struct CounterCallback {
1105 counter: Arc<AtomicUsize>,
1106 }
1107
1108 impl TrainerCallback for CounterCallback {
1109 fn on_train_begin(&mut self, _: &CallbackContext) -> CallbackAction {
1110 self.counter.fetch_add(1, Ordering::SeqCst);
1111 CallbackAction::Continue
1112 }
1113 fn on_train_end(&mut self, _: &CallbackContext) {}
1114 fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction { CallbackAction::Continue }
1115 fn on_epoch_end(&mut self, _: &CallbackContext) -> CallbackAction { CallbackAction::Continue }
1116 fn on_step_begin(&mut self, _: &CallbackContext) -> CallbackAction { CallbackAction::Continue }
1117 fn on_step_end(&mut self, _: &CallbackContext) -> CallbackAction { CallbackAction::Continue }
1118 fn name(&self) -> &'static str { "CounterCallback" }
1119 }
1120
1121 let counter = Arc::new(AtomicUsize::new(0));
1122 let mut manager = CallbackManager::new();
1123
1124 for _ in 0..num_callbacks {
1125 manager.add(CounterCallback { counter: counter.clone() });
1126 }
1127
1128 let ctx = CallbackContext::default();
1129 manager.on_train_begin(&ctx);
1130
1131 prop_assert_eq!(counter.load(Ordering::SeqCst), num_callbacks);
1132 }
1133 }
1134}