Skip to main content

entrenar/train/callback/
monitor.rs

1//! Monitor callback that integrates with entrenar's monitoring system
2
3use super::traits::{CallbackAction, CallbackContext, TrainerCallback};
4
5/// Callback that integrates with entrenar's monitoring system
6#[derive(Debug)]
7pub struct MonitorCallback {
8    collector: crate::monitor::MetricsCollector,
9    andon: crate::monitor::AndonSystem,
10}
11
12impl MonitorCallback {
13    /// Create a new monitor callback
14    pub fn new() -> Self {
15        Self {
16            collector: crate::monitor::MetricsCollector::new(),
17            andon: crate::monitor::AndonSystem::new(),
18        }
19    }
20
21    /// Get the metrics collector
22    pub fn collector(&self) -> &crate::monitor::MetricsCollector {
23        &self.collector
24    }
25
26    /// Get summary as JSON
27    pub fn summary_json(&self) -> Result<String, serde_json::Error> {
28        // Convert summary to string keys for JSON
29        let summary: std::collections::HashMap<String, _> = self
30            .collector
31            .summary()
32            .into_iter()
33            .map(|(k, v)| (k.as_str().to_string(), v))
34            .collect();
35        serde_json::to_string_pretty(&summary)
36    }
37}
38
39impl Default for MonitorCallback {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl TrainerCallback for MonitorCallback {
46    fn on_step_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
47        // Record loss at each step
48        self.collector.record(crate::monitor::Metric::Loss, f64::from(ctx.loss));
49        self.collector.record(crate::monitor::Metric::LearningRate, f64::from(ctx.lr));
50        CallbackAction::Continue
51    }
52
53    fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
54        self.collector.record(crate::monitor::Metric::Epoch, ctx.epoch as f64);
55
56        // Check for NaN/Inf loss
57        if ctx.loss.is_nan() {
58            self.andon.critical("NaN loss detected");
59        } else if ctx.loss.is_infinite() {
60            self.andon.critical("Infinite loss detected");
61        }
62
63        // Check if andon suggests stopping
64        if self.andon.should_stop() {
65            CallbackAction::Stop
66        } else {
67            CallbackAction::Continue
68        }
69    }
70
71    fn name(&self) -> &'static str {
72        "MonitorCallback"
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn test_monitor_callback() {
82        let mut monitor = MonitorCallback::new();
83        let ctx = CallbackContext { epoch: 0, step: 0, loss: 0.5, lr: 0.001, ..Default::default() };
84
85        assert_eq!(monitor.on_step_end(&ctx), CallbackAction::Continue);
86        assert_eq!(monitor.on_epoch_end(&ctx), CallbackAction::Continue);
87
88        // Verify metrics were recorded
89        let summary = monitor.collector().summary();
90        assert!(summary.contains_key(&crate::monitor::Metric::Loss));
91    }
92
93    #[test]
94    fn test_monitor_callback_nan_detection() {
95        let mut monitor = MonitorCallback::new();
96        let ctx = CallbackContext { loss: f32::NAN, ..Default::default() };
97
98        // NaN should trigger stop via andon
99        assert_eq!(monitor.on_epoch_end(&ctx), CallbackAction::Stop);
100    }
101
102    #[test]
103    fn test_monitor_callback_default() {
104        let mc = MonitorCallback::default();
105        assert_eq!(mc.name(), "MonitorCallback");
106    }
107
108    #[test]
109    fn test_monitor_callback_summary_json() {
110        let mut mc = MonitorCallback::new();
111        let ctx = CallbackContext { loss: 0.5, lr: 0.001, ..Default::default() };
112        mc.on_step_end(&ctx);
113
114        let json = mc.summary_json();
115        assert!(json.is_ok());
116    }
117
118    #[test]
119    fn test_monitor_callback_inf_detection() {
120        let mut mc = MonitorCallback::new();
121        let ctx = CallbackContext { loss: f32::INFINITY, ..Default::default() };
122        assert_eq!(mc.on_epoch_end(&ctx), CallbackAction::Stop);
123    }
124
125    #[test]
126    fn test_monitor_callback_nan_loss() {
127        let mut cb = MonitorCallback::new();
128        let mut ctx = CallbackContext::default();
129        ctx.loss = f32::NAN;
130
131        let action = cb.on_epoch_end(&ctx);
132        // Should detect NaN and potentially stop
133        assert!(action == CallbackAction::Stop || action == CallbackAction::Continue);
134    }
135
136    #[test]
137    fn test_monitor_callback_infinite_loss() {
138        let mut cb = MonitorCallback::new();
139        let mut ctx = CallbackContext::default();
140        ctx.loss = f32::INFINITY;
141
142        cb.on_epoch_end(&ctx);
143        // Should detect infinite loss
144    }
145}
146
147#[cfg(test)]
148mod proptests {
149    use super::*;
150    use proptest::prelude::*;
151
152    proptest! {
153        /// Monitor callback should detect NaN/Inf
154        #[test]
155        fn monitor_callback_detects_nan_inf(
156            normal_loss in -100.0f32..100.0,
157        ) {
158            // Normal loss should continue
159            let mut monitor = MonitorCallback::new();
160            let ctx = CallbackContext {
161                loss: normal_loss,
162                ..Default::default()
163            };
164            prop_assert_eq!(monitor.on_epoch_end(&ctx), CallbackAction::Continue);
165
166            // NaN should stop
167            let mut monitor_nan = MonitorCallback::new();
168            let ctx_nan = CallbackContext {
169                loss: f32::NAN,
170                ..Default::default()
171            };
172            prop_assert_eq!(monitor_nan.on_epoch_end(&ctx_nan), CallbackAction::Stop);
173
174            // Inf should stop
175            let mut monitor_inf = MonitorCallback::new();
176            let ctx_inf = CallbackContext {
177                loss: f32::INFINITY,
178                ..Default::default()
179            };
180            prop_assert_eq!(monitor_inf.on_epoch_end(&ctx_inf), CallbackAction::Stop);
181        }
182    }
183}