Skip to main content

entrenar/train/callback/
manager.rs

1//! Callback manager for dispatching events to multiple callbacks
2
3use super::traits::{CallbackAction, CallbackContext, TrainerCallback};
4
5/// Manages multiple callbacks and dispatches events
6pub struct CallbackManager {
7    callbacks: Vec<Box<dyn TrainerCallback>>,
8}
9
10impl CallbackManager {
11    /// Create new callback manager
12    pub fn new() -> Self {
13        Self { callbacks: Vec::new() }
14    }
15
16    /// Add a callback
17    pub fn add<C: TrainerCallback + 'static>(&mut self, callback: C) {
18        self.callbacks.push(Box::new(callback));
19    }
20
21    /// Check if no callbacks are registered
22    pub fn is_empty(&self) -> bool {
23        self.callbacks.is_empty()
24    }
25
26    /// Get number of callbacks
27    pub fn len(&self) -> usize {
28        self.callbacks.len()
29    }
30
31    /// Fire train begin event
32    pub fn on_train_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
33        for cb in &mut self.callbacks {
34            if cb.on_train_begin(ctx) == CallbackAction::Stop {
35                return CallbackAction::Stop;
36            }
37        }
38        CallbackAction::Continue
39    }
40
41    /// Fire train end event
42    pub fn on_train_end(&mut self, ctx: &CallbackContext) {
43        for cb in &mut self.callbacks {
44            cb.on_train_end(ctx);
45        }
46    }
47
48    /// Fire epoch begin event
49    pub fn on_epoch_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
50        for cb in &mut self.callbacks {
51            match cb.on_epoch_begin(ctx) {
52                CallbackAction::Stop => return CallbackAction::Stop,
53                CallbackAction::SkipEpoch => return CallbackAction::SkipEpoch,
54                CallbackAction::Continue => {}
55            }
56        }
57        CallbackAction::Continue
58    }
59
60    /// Fire epoch end event
61    pub fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
62        for cb in &mut self.callbacks {
63            if cb.on_epoch_end(ctx) == CallbackAction::Stop {
64                return CallbackAction::Stop;
65            }
66        }
67        CallbackAction::Continue
68    }
69
70    /// Fire step begin event
71    pub fn on_step_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
72        for cb in &mut self.callbacks {
73            if cb.on_step_begin(ctx) == CallbackAction::Stop {
74                return CallbackAction::Stop;
75            }
76        }
77        CallbackAction::Continue
78    }
79
80    /// Fire step end event
81    pub fn on_step_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
82        for cb in &mut self.callbacks {
83            if cb.on_step_end(ctx) == CallbackAction::Stop {
84                return CallbackAction::Stop;
85            }
86        }
87        CallbackAction::Continue
88    }
89}
90
91impl Default for CallbackManager {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use crate::train::callback::{EarlyStopping, ProgressCallback};
101
102    #[test]
103    fn test_callback_manager_dispatch() {
104        let mut manager = CallbackManager::new();
105
106        // Add early stopping that triggers after 1 epoch without improvement
107        let es = EarlyStopping::new(1, 0.001);
108        manager.add(es);
109
110        let mut ctx = CallbackContext::default();
111        ctx.loss = 1.0;
112
113        // First epoch
114        assert_eq!(manager.on_epoch_end(&ctx), CallbackAction::Continue);
115
116        // Second epoch - no improvement, should stop
117        ctx.epoch = 1;
118        assert_eq!(manager.on_epoch_end(&ctx), CallbackAction::Stop);
119    }
120
121    #[test]
122    fn test_callback_manager_len_and_empty() {
123        let mut manager = CallbackManager::new();
124        assert!(manager.is_empty());
125        assert_eq!(manager.len(), 0);
126
127        manager.add(ProgressCallback::new(10));
128        assert!(!manager.is_empty());
129        assert_eq!(manager.len(), 1);
130    }
131
132    #[test]
133    fn test_callback_manager_on_train_begin_stop() {
134        struct StopCallback;
135        impl TrainerCallback for StopCallback {
136            fn on_train_begin(&mut self, _: &CallbackContext) -> CallbackAction {
137                CallbackAction::Stop
138            }
139            fn name(&self) -> &'static str {
140                "StopCallback"
141            }
142        }
143
144        let mut manager = CallbackManager::new();
145        manager.add(StopCallback);
146        assert_eq!(manager.on_train_begin(&CallbackContext::default()), CallbackAction::Stop);
147    }
148
149    #[test]
150    fn test_callback_manager_on_train_end() {
151        use std::sync::{
152            atomic::{AtomicBool, Ordering},
153            Arc,
154        };
155
156        struct EndCallback {
157            called: Arc<AtomicBool>,
158        }
159        impl TrainerCallback for EndCallback {
160            fn on_train_end(&mut self, _: &CallbackContext) {
161                self.called.store(true, Ordering::SeqCst);
162            }
163            fn name(&self) -> &'static str {
164                "EndCallback"
165            }
166        }
167
168        let called = Arc::new(AtomicBool::new(false));
169        let mut manager = CallbackManager::new();
170        manager.add(EndCallback { called: called.clone() });
171        manager.on_train_end(&CallbackContext::default());
172        assert!(called.load(Ordering::SeqCst));
173    }
174
175    #[test]
176    fn test_callback_manager_on_epoch_begin_skip() {
177        struct SkipCallback;
178        impl TrainerCallback for SkipCallback {
179            fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction {
180                CallbackAction::SkipEpoch
181            }
182            fn name(&self) -> &'static str {
183                "SkipCallback"
184            }
185        }
186
187        let mut manager = CallbackManager::new();
188        manager.add(SkipCallback);
189        assert_eq!(manager.on_epoch_begin(&CallbackContext::default()), CallbackAction::SkipEpoch);
190    }
191
192    #[test]
193    fn test_callback_manager_on_epoch_begin_stop() {
194        struct StopCallback;
195        impl TrainerCallback for StopCallback {
196            fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction {
197                CallbackAction::Stop
198            }
199            fn name(&self) -> &'static str {
200                "StopCallback"
201            }
202        }
203
204        let mut manager = CallbackManager::new();
205        manager.add(StopCallback);
206        assert_eq!(manager.on_epoch_begin(&CallbackContext::default()), CallbackAction::Stop);
207    }
208
209    #[test]
210    fn test_callback_manager_on_step_begin_stop() {
211        struct StopCallback;
212        impl TrainerCallback for StopCallback {
213            fn on_step_begin(&mut self, _: &CallbackContext) -> CallbackAction {
214                CallbackAction::Stop
215            }
216            fn name(&self) -> &'static str {
217                "StopCallback"
218            }
219        }
220
221        let mut manager = CallbackManager::new();
222        manager.add(StopCallback);
223        assert_eq!(manager.on_step_begin(&CallbackContext::default()), CallbackAction::Stop);
224    }
225
226    #[test]
227    fn test_callback_manager_default() {
228        let manager = CallbackManager::default();
229        assert!(manager.is_empty());
230    }
231
232    #[test]
233    fn test_callback_manager_stop_propagation() {
234        // Create a callback that always returns Stop
235        struct StopCallback;
236        impl TrainerCallback for StopCallback {
237            fn on_epoch_end(&mut self, _: &CallbackContext) -> CallbackAction {
238                CallbackAction::Stop
239            }
240            fn name(&self) -> &'static str {
241                "StopCallback"
242            }
243        }
244
245        let mut manager = CallbackManager::new();
246        manager.add(StopCallback);
247        manager.add(ProgressCallback::new(10));
248
249        let ctx = CallbackContext::default();
250        let action = manager.on_epoch_end(&ctx);
251        // Stop should propagate
252        assert_eq!(action, CallbackAction::Stop);
253    }
254
255    #[test]
256    fn test_callback_manager_on_step_end_stop() {
257        struct StopCallback;
258        impl TrainerCallback for StopCallback {
259            fn on_step_end(&mut self, _: &CallbackContext) -> CallbackAction {
260                CallbackAction::Stop
261            }
262            fn name(&self) -> &'static str {
263                "StopCallback"
264            }
265        }
266
267        let mut manager = CallbackManager::new();
268        manager.add(StopCallback);
269        assert_eq!(manager.on_step_end(&CallbackContext::default()), CallbackAction::Stop);
270    }
271
272    #[test]
273    fn test_callback_manager_all_continue() {
274        // Test that all callbacks continue properly
275        struct ContinueCallback;
276        impl TrainerCallback for ContinueCallback {
277            fn on_train_begin(&mut self, _: &CallbackContext) -> CallbackAction {
278                CallbackAction::Continue
279            }
280            fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction {
281                CallbackAction::Continue
282            }
283            fn on_epoch_end(&mut self, _: &CallbackContext) -> CallbackAction {
284                CallbackAction::Continue
285            }
286            fn on_step_begin(&mut self, _: &CallbackContext) -> CallbackAction {
287                CallbackAction::Continue
288            }
289            fn on_step_end(&mut self, _: &CallbackContext) -> CallbackAction {
290                CallbackAction::Continue
291            }
292            fn name(&self) -> &'static str {
293                "ContinueCallback"
294            }
295        }
296
297        let mut manager = CallbackManager::new();
298        manager.add(ContinueCallback);
299        manager.add(ContinueCallback);
300
301        let ctx = CallbackContext::default();
302        assert_eq!(manager.on_train_begin(&ctx), CallbackAction::Continue);
303        assert_eq!(manager.on_epoch_begin(&ctx), CallbackAction::Continue);
304        assert_eq!(manager.on_epoch_end(&ctx), CallbackAction::Continue);
305        assert_eq!(manager.on_step_begin(&ctx), CallbackAction::Continue);
306        assert_eq!(manager.on_step_end(&ctx), CallbackAction::Continue);
307    }
308
309    #[test]
310    fn test_callback_manager_multiple_train_end() {
311        use std::sync::{
312            atomic::{AtomicUsize, Ordering},
313            Arc,
314        };
315
316        struct CountingEndCallback {
317            count: Arc<AtomicUsize>,
318        }
319
320        impl TrainerCallback for CountingEndCallback {
321            fn on_train_end(&mut self, _: &CallbackContext) {
322                self.count.fetch_add(1, Ordering::SeqCst);
323            }
324            fn name(&self) -> &'static str {
325                "CountingEndCallback"
326            }
327        }
328
329        let count = Arc::new(AtomicUsize::new(0));
330        let mut manager = CallbackManager::new();
331        manager.add(CountingEndCallback { count: count.clone() });
332        manager.add(CountingEndCallback { count: count.clone() });
333        manager.add(CountingEndCallback { count: count.clone() });
334
335        manager.on_train_end(&CallbackContext::default());
336        assert_eq!(count.load(Ordering::SeqCst), 3);
337    }
338
339    #[test]
340    fn test_callback_manager_stop_after_first() {
341        use std::sync::{
342            atomic::{AtomicUsize, Ordering},
343            Arc,
344        };
345
346        struct CountingStopCallback {
347            count: Arc<AtomicUsize>,
348        }
349
350        impl TrainerCallback for CountingStopCallback {
351            fn on_train_begin(&mut self, _: &CallbackContext) -> CallbackAction {
352                self.count.fetch_add(1, Ordering::SeqCst);
353                CallbackAction::Stop
354            }
355            fn name(&self) -> &'static str {
356                "CountingStopCallback"
357            }
358        }
359
360        struct CountingContinueCallback {
361            count: Arc<AtomicUsize>,
362        }
363
364        impl TrainerCallback for CountingContinueCallback {
365            fn on_train_begin(&mut self, _: &CallbackContext) -> CallbackAction {
366                self.count.fetch_add(1, Ordering::SeqCst);
367                CallbackAction::Continue
368            }
369            fn name(&self) -> &'static str {
370                "CountingContinueCallback"
371            }
372        }
373
374        let count = Arc::new(AtomicUsize::new(0));
375        let mut manager = CallbackManager::new();
376        manager.add(CountingStopCallback { count: count.clone() });
377        manager.add(CountingContinueCallback { count: count.clone() });
378
379        // First callback stops, second should not be called
380        let action = manager.on_train_begin(&CallbackContext::default());
381        assert_eq!(action, CallbackAction::Stop);
382        assert_eq!(count.load(Ordering::SeqCst), 1);
383    }
384
385    // ── Additional coverage tests ─────────────────────────────────
386
387    #[test]
388    fn test_callback_manager_on_train_begin_continue() {
389        let mut manager = CallbackManager::new();
390        // No callbacks → should return Continue
391        assert_eq!(manager.on_train_begin(&CallbackContext::default()), CallbackAction::Continue);
392    }
393
394    #[test]
395    fn test_callback_manager_on_epoch_end_continue() {
396        let mut manager = CallbackManager::new();
397        // No callbacks → should return Continue
398        assert_eq!(manager.on_epoch_end(&CallbackContext::default()), CallbackAction::Continue);
399    }
400
401    #[test]
402    fn test_callback_manager_on_step_begin_continue() {
403        let mut manager = CallbackManager::new();
404        assert_eq!(manager.on_step_begin(&CallbackContext::default()), CallbackAction::Continue);
405    }
406
407    #[test]
408    fn test_callback_manager_on_step_end_continue() {
409        let mut manager = CallbackManager::new();
410        assert_eq!(manager.on_step_end(&CallbackContext::default()), CallbackAction::Continue);
411    }
412
413    #[test]
414    fn test_callback_manager_on_epoch_begin_continue() {
415        let mut manager = CallbackManager::new();
416        assert_eq!(manager.on_epoch_begin(&CallbackContext::default()), CallbackAction::Continue);
417    }
418
419    #[test]
420    fn test_callback_manager_stop_epoch_begin_does_not_call_second() {
421        use std::sync::{
422            atomic::{AtomicUsize, Ordering},
423            Arc,
424        };
425
426        struct StopEpochBegin {
427            count: Arc<AtomicUsize>,
428        }
429        impl TrainerCallback for StopEpochBegin {
430            fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction {
431                self.count.fetch_add(1, Ordering::SeqCst);
432                CallbackAction::Stop
433            }
434            fn name(&self) -> &'static str {
435                "StopEpochBegin"
436            }
437        }
438
439        struct CountEpochBegin {
440            count: Arc<AtomicUsize>,
441        }
442        impl TrainerCallback for CountEpochBegin {
443            fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction {
444                self.count.fetch_add(1, Ordering::SeqCst);
445                CallbackAction::Continue
446            }
447            fn name(&self) -> &'static str {
448                "CountEpochBegin"
449            }
450        }
451
452        let count = Arc::new(AtomicUsize::new(0));
453        let mut manager = CallbackManager::new();
454        manager.add(StopEpochBegin { count: count.clone() });
455        manager.add(CountEpochBegin { count: count.clone() });
456
457        let action = manager.on_epoch_begin(&CallbackContext::default());
458        assert_eq!(action, CallbackAction::Stop);
459        assert_eq!(count.load(Ordering::SeqCst), 1); // second never called
460    }
461
462    #[test]
463    fn test_callback_manager_stop_epoch_end_does_not_call_second() {
464        use std::sync::{
465            atomic::{AtomicUsize, Ordering},
466            Arc,
467        };
468
469        struct StopEpochEnd {
470            count: Arc<AtomicUsize>,
471        }
472        impl TrainerCallback for StopEpochEnd {
473            fn on_epoch_end(&mut self, _: &CallbackContext) -> CallbackAction {
474                self.count.fetch_add(1, Ordering::SeqCst);
475                CallbackAction::Stop
476            }
477            fn name(&self) -> &'static str {
478                "StopEpochEnd"
479            }
480        }
481
482        struct CountEpochEnd {
483            count: Arc<AtomicUsize>,
484        }
485        impl TrainerCallback for CountEpochEnd {
486            fn on_epoch_end(&mut self, _: &CallbackContext) -> CallbackAction {
487                self.count.fetch_add(1, Ordering::SeqCst);
488                CallbackAction::Continue
489            }
490            fn name(&self) -> &'static str {
491                "CountEpochEnd"
492            }
493        }
494
495        let count = Arc::new(AtomicUsize::new(0));
496        let mut manager = CallbackManager::new();
497        manager.add(StopEpochEnd { count: count.clone() });
498        manager.add(CountEpochEnd { count: count.clone() });
499
500        let action = manager.on_epoch_end(&CallbackContext::default());
501        assert_eq!(action, CallbackAction::Stop);
502        assert_eq!(count.load(Ordering::SeqCst), 1);
503    }
504
505    #[test]
506    fn test_callback_manager_stop_step_begin_does_not_call_second() {
507        use std::sync::{
508            atomic::{AtomicUsize, Ordering},
509            Arc,
510        };
511
512        struct StopStepBegin {
513            count: Arc<AtomicUsize>,
514        }
515        impl TrainerCallback for StopStepBegin {
516            fn on_step_begin(&mut self, _: &CallbackContext) -> CallbackAction {
517                self.count.fetch_add(1, Ordering::SeqCst);
518                CallbackAction::Stop
519            }
520            fn name(&self) -> &'static str {
521                "StopStepBegin"
522            }
523        }
524
525        struct CountStepBegin {
526            count: Arc<AtomicUsize>,
527        }
528        impl TrainerCallback for CountStepBegin {
529            fn on_step_begin(&mut self, _: &CallbackContext) -> CallbackAction {
530                self.count.fetch_add(1, Ordering::SeqCst);
531                CallbackAction::Continue
532            }
533            fn name(&self) -> &'static str {
534                "CountStepBegin"
535            }
536        }
537
538        let count = Arc::new(AtomicUsize::new(0));
539        let mut manager = CallbackManager::new();
540        manager.add(StopStepBegin { count: count.clone() });
541        manager.add(CountStepBegin { count: count.clone() });
542
543        let action = manager.on_step_begin(&CallbackContext::default());
544        assert_eq!(action, CallbackAction::Stop);
545        assert_eq!(count.load(Ordering::SeqCst), 1);
546    }
547
548    #[test]
549    fn test_callback_manager_stop_step_end_does_not_call_second() {
550        use std::sync::{
551            atomic::{AtomicUsize, Ordering},
552            Arc,
553        };
554
555        struct StopStepEnd {
556            count: Arc<AtomicUsize>,
557        }
558        impl TrainerCallback for StopStepEnd {
559            fn on_step_end(&mut self, _: &CallbackContext) -> CallbackAction {
560                self.count.fetch_add(1, Ordering::SeqCst);
561                CallbackAction::Stop
562            }
563            fn name(&self) -> &'static str {
564                "StopStepEnd"
565            }
566        }
567
568        struct CountStepEnd {
569            count: Arc<AtomicUsize>,
570        }
571        impl TrainerCallback for CountStepEnd {
572            fn on_step_end(&mut self, _: &CallbackContext) -> CallbackAction {
573                self.count.fetch_add(1, Ordering::SeqCst);
574                CallbackAction::Continue
575            }
576            fn name(&self) -> &'static str {
577                "CountStepEnd"
578            }
579        }
580
581        let count = Arc::new(AtomicUsize::new(0));
582        let mut manager = CallbackManager::new();
583        manager.add(StopStepEnd { count: count.clone() });
584        manager.add(CountStepEnd { count: count.clone() });
585
586        let action = manager.on_step_end(&CallbackContext::default());
587        assert_eq!(action, CallbackAction::Stop);
588        assert_eq!(count.load(Ordering::SeqCst), 1);
589    }
590
591    #[test]
592    fn test_callback_manager_skip_epoch_does_not_call_second() {
593        use std::sync::{
594            atomic::{AtomicUsize, Ordering},
595            Arc,
596        };
597
598        struct SkipCallback {
599            count: Arc<AtomicUsize>,
600        }
601        impl TrainerCallback for SkipCallback {
602            fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction {
603                self.count.fetch_add(1, Ordering::SeqCst);
604                CallbackAction::SkipEpoch
605            }
606            fn name(&self) -> &'static str {
607                "SkipCallback"
608            }
609        }
610
611        struct ContinueCallback {
612            count: Arc<AtomicUsize>,
613        }
614        impl TrainerCallback for ContinueCallback {
615            fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction {
616                self.count.fetch_add(1, Ordering::SeqCst);
617                CallbackAction::Continue
618            }
619            fn name(&self) -> &'static str {
620                "ContinueCallback"
621            }
622        }
623
624        let count = Arc::new(AtomicUsize::new(0));
625        let mut manager = CallbackManager::new();
626        manager.add(SkipCallback { count: count.clone() });
627        manager.add(ContinueCallback { count: count.clone() });
628
629        let action = manager.on_epoch_begin(&CallbackContext::default());
630        assert_eq!(action, CallbackAction::SkipEpoch);
631        assert_eq!(count.load(Ordering::SeqCst), 1);
632    }
633
634    #[test]
635    fn test_callback_manager_with_context_values() {
636        use std::sync::{
637            atomic::{AtomicUsize, Ordering},
638            Arc,
639        };
640
641        struct EpochTracker {
642            last_epoch: Arc<AtomicUsize>,
643        }
644        impl TrainerCallback for EpochTracker {
645            fn on_epoch_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
646                self.last_epoch.store(ctx.epoch, Ordering::SeqCst);
647                CallbackAction::Continue
648            }
649            fn name(&self) -> &'static str {
650                "EpochTracker"
651            }
652        }
653
654        let last_epoch = Arc::new(AtomicUsize::new(999));
655        let mut manager = CallbackManager::new();
656        manager.add(EpochTracker { last_epoch: last_epoch.clone() });
657
658        let mut ctx = CallbackContext::default();
659        ctx.epoch = 42;
660        manager.on_epoch_begin(&ctx);
661        assert_eq!(last_epoch.load(Ordering::SeqCst), 42);
662    }
663
664    #[test]
665    fn test_callback_manager_train_end_all_called() {
666        use std::sync::{
667            atomic::{AtomicUsize, Ordering},
668            Arc,
669        };
670
671        struct CountCallback {
672            count: Arc<AtomicUsize>,
673        }
674        impl TrainerCallback for CountCallback {
675            fn on_train_end(&mut self, _: &CallbackContext) {
676                self.count.fetch_add(1, Ordering::SeqCst);
677            }
678            fn name(&self) -> &'static str {
679                "CountCallback"
680            }
681        }
682
683        let count = Arc::new(AtomicUsize::new(0));
684        let mut manager = CallbackManager::new();
685        for _ in 0..5 {
686            manager.add(CountCallback { count: count.clone() });
687        }
688        assert_eq!(manager.len(), 5);
689
690        manager.on_train_end(&CallbackContext::default());
691        assert_eq!(count.load(Ordering::SeqCst), 5);
692    }
693
694    // ── test_cov4 additional coverage tests ────────────────────────
695
696    #[test]
697    fn test_cov4_manager_full_lifecycle() {
698        // Exercise complete train begin→step begin→step end→epoch begin→epoch end→train end flow
699        use std::sync::Arc;
700
701        struct LifecycleCallback {
702            events: Arc<std::sync::Mutex<Vec<String>>>,
703        }
704        impl TrainerCallback for LifecycleCallback {
705            fn on_train_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
706                self.events.lock().unwrap().push(format!("train_begin:{}", ctx.epoch));
707                CallbackAction::Continue
708            }
709            fn on_train_end(&mut self, ctx: &CallbackContext) {
710                self.events.lock().unwrap().push(format!("train_end:{}", ctx.epoch));
711            }
712            fn on_epoch_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
713                self.events.lock().unwrap().push(format!("epoch_begin:{}", ctx.epoch));
714                CallbackAction::Continue
715            }
716            fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
717                self.events.lock().unwrap().push(format!("epoch_end:{}", ctx.epoch));
718                CallbackAction::Continue
719            }
720            fn on_step_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
721                self.events.lock().unwrap().push(format!("step_begin:{}", ctx.step));
722                CallbackAction::Continue
723            }
724            fn on_step_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
725                self.events.lock().unwrap().push(format!("step_end:{}", ctx.step));
726                CallbackAction::Continue
727            }
728            fn name(&self) -> &'static str {
729                "LifecycleCallback"
730            }
731        }
732
733        let events = Arc::new(std::sync::Mutex::new(Vec::new()));
734        let mut manager = CallbackManager::new();
735        manager.add(LifecycleCallback { events: events.clone() });
736
737        let mut ctx = CallbackContext::default();
738        ctx.max_epochs = 2;
739        ctx.steps_per_epoch = 3;
740
741        manager.on_train_begin(&ctx);
742        for epoch in 0..2 {
743            ctx.epoch = epoch;
744            manager.on_epoch_begin(&ctx);
745            for step in 0..3 {
746                ctx.step = step;
747                manager.on_step_begin(&ctx);
748                manager.on_step_end(&ctx);
749            }
750            manager.on_epoch_end(&ctx);
751        }
752        manager.on_train_end(&ctx);
753
754        let ev = events.lock().unwrap();
755        assert_eq!(ev[0], "train_begin:0");
756        assert_eq!(ev[1], "epoch_begin:0");
757        assert_eq!(ev[2], "step_begin:0");
758        assert_eq!(ev[3], "step_end:0");
759        assert!(ev.len() >= 16); // 1+2*(1+3*2+1)+1 = 18
760        assert_eq!(*ev.last().unwrap(), "train_end:1");
761    }
762
763    #[test]
764    fn test_cov4_manager_mixed_callbacks_epoch_end() {
765        // Mix callbacks with different epoch_end behaviors: first continues, second stops
766        use std::sync::{
767            atomic::{AtomicUsize, Ordering},
768            Arc,
769        };
770
771        struct ContinueTracker {
772            count: Arc<AtomicUsize>,
773        }
774        impl TrainerCallback for ContinueTracker {
775            fn on_epoch_end(&mut self, _: &CallbackContext) -> CallbackAction {
776                self.count.fetch_add(1, Ordering::SeqCst);
777                CallbackAction::Continue
778            }
779            fn name(&self) -> &'static str {
780                "ContinueTracker"
781            }
782        }
783
784        struct StopTracker {
785            count: Arc<AtomicUsize>,
786        }
787        impl TrainerCallback for StopTracker {
788            fn on_epoch_end(&mut self, _: &CallbackContext) -> CallbackAction {
789                self.count.fetch_add(1, Ordering::SeqCst);
790                CallbackAction::Stop
791            }
792            fn name(&self) -> &'static str {
793                "StopTracker"
794            }
795        }
796
797        let count = Arc::new(AtomicUsize::new(0));
798        let mut manager = CallbackManager::new();
799        // First callback continues, second stops → both called, second triggers stop
800        manager.add(ContinueTracker { count: count.clone() });
801        manager.add(StopTracker { count: count.clone() });
802
803        let action = manager.on_epoch_end(&CallbackContext::default());
804        assert_eq!(action, CallbackAction::Stop);
805        assert_eq!(count.load(Ordering::SeqCst), 2); // both were called
806    }
807
808    #[test]
809    fn test_cov4_manager_mixed_callbacks_step_end() {
810        // Two continue then stop: all three called
811        use std::sync::{
812            atomic::{AtomicUsize, Ordering},
813            Arc,
814        };
815
816        struct ContinueCb {
817            count: Arc<AtomicUsize>,
818        }
819        impl TrainerCallback for ContinueCb {
820            fn on_step_end(&mut self, _: &CallbackContext) -> CallbackAction {
821                self.count.fetch_add(1, Ordering::SeqCst);
822                CallbackAction::Continue
823            }
824            fn name(&self) -> &'static str {
825                "ContinueCb"
826            }
827        }
828
829        struct StopCb {
830            count: Arc<AtomicUsize>,
831        }
832        impl TrainerCallback for StopCb {
833            fn on_step_end(&mut self, _: &CallbackContext) -> CallbackAction {
834                self.count.fetch_add(1, Ordering::SeqCst);
835                CallbackAction::Stop
836            }
837            fn name(&self) -> &'static str {
838                "StopCb"
839            }
840        }
841
842        let count = Arc::new(AtomicUsize::new(0));
843        let mut manager = CallbackManager::new();
844        manager.add(ContinueCb { count: count.clone() });
845        manager.add(ContinueCb { count: count.clone() });
846        manager.add(StopCb { count: count.clone() });
847
848        let action = manager.on_step_end(&CallbackContext::default());
849        assert_eq!(action, CallbackAction::Stop);
850        assert_eq!(count.load(Ordering::SeqCst), 3);
851    }
852
853    #[test]
854    fn test_cov4_manager_ctx_with_rich_fields() {
855        // Use a context with all fields populated
856        use std::sync::Arc;
857
858        struct FieldChecker {
859            verified: Arc<std::sync::atomic::AtomicBool>,
860        }
861        impl TrainerCallback for FieldChecker {
862            fn on_step_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
863                if ctx.epoch == 3
864                    && ctx.max_epochs == 10
865                    && ctx.step == 7
866                    && ctx.steps_per_epoch == 100
867                    && ctx.global_step == 307
868                    && (ctx.loss - 0.42).abs() < 1e-5
869                    && (ctx.lr - 1e-4).abs() < 1e-8
870                    && ctx.best_loss == Some(0.30)
871                    && ctx.val_loss == Some(0.50)
872                    && (ctx.elapsed_secs - 123.4).abs() < 0.1
873                {
874                    self.verified.store(true, std::sync::atomic::Ordering::SeqCst);
875                }
876                CallbackAction::Continue
877            }
878            fn name(&self) -> &'static str {
879                "FieldChecker"
880            }
881        }
882
883        let verified = Arc::new(std::sync::atomic::AtomicBool::new(false));
884        let mut manager = CallbackManager::new();
885        manager.add(FieldChecker { verified: verified.clone() });
886
887        let ctx = CallbackContext {
888            epoch: 3,
889            max_epochs: 10,
890            step: 7,
891            steps_per_epoch: 100,
892            global_step: 307,
893            loss: 0.42,
894            lr: 1e-4,
895            best_loss: Some(0.30),
896            val_loss: Some(0.50),
897            elapsed_secs: 123.4,
898        };
899
900        manager.on_step_end(&ctx);
901        assert!(verified.load(std::sync::atomic::Ordering::SeqCst));
902    }
903
904    #[test]
905    fn test_cov4_manager_multiple_adds() {
906        let mut manager = CallbackManager::new();
907        assert_eq!(manager.len(), 0);
908        assert!(manager.is_empty());
909
910        manager.add(ProgressCallback::new(10));
911        manager.add(ProgressCallback::new(20));
912        manager.add(EarlyStopping::new(5, 0.001));
913        assert_eq!(manager.len(), 3);
914        assert!(!manager.is_empty());
915    }
916
917    #[test]
918    fn test_cov4_manager_train_begin_multiple_continue() {
919        use std::sync::{
920            atomic::{AtomicUsize, Ordering},
921            Arc,
922        };
923
924        struct CountCb {
925            count: Arc<AtomicUsize>,
926        }
927        impl TrainerCallback for CountCb {
928            fn on_train_begin(&mut self, _: &CallbackContext) -> CallbackAction {
929                self.count.fetch_add(1, Ordering::SeqCst);
930                CallbackAction::Continue
931            }
932            fn name(&self) -> &'static str {
933                "CountCb"
934            }
935        }
936
937        let count = Arc::new(AtomicUsize::new(0));
938        let mut manager = CallbackManager::new();
939        manager.add(CountCb { count: count.clone() });
940        manager.add(CountCb { count: count.clone() });
941        manager.add(CountCb { count: count.clone() });
942
943        let action = manager.on_train_begin(&CallbackContext::default());
944        assert_eq!(action, CallbackAction::Continue);
945        assert_eq!(count.load(Ordering::SeqCst), 3);
946    }
947
948    #[test]
949    fn test_cov4_manager_step_begin_multiple_continue() {
950        use std::sync::{
951            atomic::{AtomicUsize, Ordering},
952            Arc,
953        };
954
955        struct CountCb {
956            count: Arc<AtomicUsize>,
957        }
958        impl TrainerCallback for CountCb {
959            fn on_step_begin(&mut self, _: &CallbackContext) -> CallbackAction {
960                self.count.fetch_add(1, Ordering::SeqCst);
961                CallbackAction::Continue
962            }
963            fn name(&self) -> &'static str {
964                "CountCb"
965            }
966        }
967
968        let count = Arc::new(AtomicUsize::new(0));
969        let mut manager = CallbackManager::new();
970        manager.add(CountCb { count: count.clone() });
971        manager.add(CountCb { count: count.clone() });
972
973        let action = manager.on_step_begin(&CallbackContext::default());
974        assert_eq!(action, CallbackAction::Continue);
975        assert_eq!(count.load(Ordering::SeqCst), 2);
976    }
977
978    #[test]
979    fn test_cov4_manager_epoch_begin_multiple_continue() {
980        use std::sync::{
981            atomic::{AtomicUsize, Ordering},
982            Arc,
983        };
984
985        struct CountCb {
986            count: Arc<AtomicUsize>,
987        }
988        impl TrainerCallback for CountCb {
989            fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction {
990                self.count.fetch_add(1, Ordering::SeqCst);
991                CallbackAction::Continue
992            }
993            fn name(&self) -> &'static str {
994                "CountCb"
995            }
996        }
997
998        let count = Arc::new(AtomicUsize::new(0));
999        let mut manager = CallbackManager::new();
1000        manager.add(CountCb { count: count.clone() });
1001        manager.add(CountCb { count: count.clone() });
1002        manager.add(CountCb { count: count.clone() });
1003
1004        let action = manager.on_epoch_begin(&CallbackContext::default());
1005        assert_eq!(action, CallbackAction::Continue);
1006        assert_eq!(count.load(Ordering::SeqCst), 3);
1007    }
1008
1009    #[test]
1010    fn test_cov4_manager_train_end_empty() {
1011        let mut manager = CallbackManager::new();
1012        // Should not panic with no callbacks
1013        manager.on_train_end(&CallbackContext::default());
1014    }
1015
1016    #[test]
1017    fn test_cov4_manager_early_stopping_with_improvement() {
1018        let mut manager = CallbackManager::new();
1019        manager.add(EarlyStopping::new(3, 0.001));
1020
1021        let mut ctx = CallbackContext::default();
1022
1023        // Epoch 0: loss=1.0
1024        ctx.epoch = 0;
1025        ctx.loss = 1.0;
1026        assert_eq!(manager.on_epoch_end(&ctx), CallbackAction::Continue);
1027
1028        // Epoch 1: loss improves to 0.5
1029        ctx.epoch = 1;
1030        ctx.loss = 0.5;
1031        assert_eq!(manager.on_epoch_end(&ctx), CallbackAction::Continue);
1032
1033        // Epoch 2: loss worsens to 0.6 (1 epoch no improvement)
1034        ctx.epoch = 2;
1035        ctx.loss = 0.6;
1036        assert_eq!(manager.on_epoch_end(&ctx), CallbackAction::Continue);
1037
1038        // Epoch 3: loss improves again to 0.3 — resets patience
1039        ctx.epoch = 3;
1040        ctx.loss = 0.3;
1041        assert_eq!(manager.on_epoch_end(&ctx), CallbackAction::Continue);
1042
1043        // Epoch 4-6: no improvement
1044        for i in 4..7 {
1045            ctx.epoch = i;
1046            ctx.loss = 0.35;
1047            let action = manager.on_epoch_end(&ctx);
1048            if i == 6 {
1049                assert_eq!(action, CallbackAction::Stop);
1050            } else {
1051                assert_eq!(action, CallbackAction::Continue);
1052            }
1053        }
1054    }
1055
1056    #[test]
1057    fn test_cov4_manager_default_new_equivalent() {
1058        let m1 = CallbackManager::new();
1059        let m2 = CallbackManager::default();
1060        assert_eq!(m1.len(), m2.len());
1061        assert_eq!(m1.is_empty(), m2.is_empty());
1062    }
1063}
1064
1065#[cfg(test)]
1066mod proptests {
1067    use super::*;
1068    use crate::train::callback::EarlyStopping;
1069    use proptest::prelude::*;
1070
1071    proptest! {
1072        /// Callback manager should propagate stop action
1073        #[test]
1074        fn callback_manager_propagates_stop(
1075            patience in 1usize..5,
1076        ) {
1077            let mut manager = CallbackManager::new();
1078            manager.add(EarlyStopping::new(patience, 0.001));
1079
1080            let mut ctx = CallbackContext::default();
1081            ctx.loss = 1.0;
1082
1083            // Should continue until patience exhausted
1084            for epoch in 0..patience {
1085                ctx.epoch = epoch;
1086                let action = manager.on_epoch_end(&ctx);
1087                if epoch < patience - 1 {
1088                    prop_assert_eq!(action, CallbackAction::Continue);
1089                }
1090            }
1091
1092            // Final epoch should stop
1093            ctx.epoch = patience;
1094            prop_assert_eq!(manager.on_epoch_end(&ctx), CallbackAction::Stop);
1095        }
1096
1097        /// Multiple callbacks should all fire
1098        #[test]
1099        fn multiple_callbacks_fire(
1100            num_callbacks in 1usize..5,
1101        ) {
1102            use std::sync::{Arc, atomic::{AtomicUsize, Ordering}};
1103
1104            struct CounterCallback {
1105                counter: Arc<AtomicUsize>,
1106            }
1107
1108            impl TrainerCallback for CounterCallback {
1109                fn on_train_begin(&mut self, _: &CallbackContext) -> CallbackAction {
1110                    self.counter.fetch_add(1, Ordering::SeqCst);
1111                    CallbackAction::Continue
1112                }
1113                fn on_train_end(&mut self, _: &CallbackContext) {}
1114                fn on_epoch_begin(&mut self, _: &CallbackContext) -> CallbackAction { CallbackAction::Continue }
1115                fn on_epoch_end(&mut self, _: &CallbackContext) -> CallbackAction { CallbackAction::Continue }
1116                fn on_step_begin(&mut self, _: &CallbackContext) -> CallbackAction { CallbackAction::Continue }
1117                fn on_step_end(&mut self, _: &CallbackContext) -> CallbackAction { CallbackAction::Continue }
1118                fn name(&self) -> &'static str { "CounterCallback" }
1119            }
1120
1121            let counter = Arc::new(AtomicUsize::new(0));
1122            let mut manager = CallbackManager::new();
1123
1124            for _ in 0..num_callbacks {
1125                manager.add(CounterCallback { counter: counter.clone() });
1126            }
1127
1128            let ctx = CallbackContext::default();
1129            manager.on_train_begin(&ctx);
1130
1131            prop_assert_eq!(counter.load(Ordering::SeqCst), num_callbacks);
1132        }
1133    }
1134}