entrenar/train/callback/
monitor.rs1use super::traits::{CallbackAction, CallbackContext, TrainerCallback};
4
5#[derive(Debug)]
7pub struct MonitorCallback {
8 collector: crate::monitor::MetricsCollector,
9 andon: crate::monitor::AndonSystem,
10}
11
12impl MonitorCallback {
13 pub fn new() -> Self {
15 Self {
16 collector: crate::monitor::MetricsCollector::new(),
17 andon: crate::monitor::AndonSystem::new(),
18 }
19 }
20
21 pub fn collector(&self) -> &crate::monitor::MetricsCollector {
23 &self.collector
24 }
25
26 pub fn summary_json(&self) -> Result<String, serde_json::Error> {
28 let summary: std::collections::HashMap<String, _> = self
30 .collector
31 .summary()
32 .into_iter()
33 .map(|(k, v)| (k.as_str().to_string(), v))
34 .collect();
35 serde_json::to_string_pretty(&summary)
36 }
37}
38
39impl Default for MonitorCallback {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl TrainerCallback for MonitorCallback {
46 fn on_step_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
47 self.collector.record(crate::monitor::Metric::Loss, f64::from(ctx.loss));
49 self.collector.record(crate::monitor::Metric::LearningRate, f64::from(ctx.lr));
50 CallbackAction::Continue
51 }
52
53 fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
54 self.collector.record(crate::monitor::Metric::Epoch, ctx.epoch as f64);
55
56 if ctx.loss.is_nan() {
58 self.andon.critical("NaN loss detected");
59 } else if ctx.loss.is_infinite() {
60 self.andon.critical("Infinite loss detected");
61 }
62
63 if self.andon.should_stop() {
65 CallbackAction::Stop
66 } else {
67 CallbackAction::Continue
68 }
69 }
70
71 fn name(&self) -> &'static str {
72 "MonitorCallback"
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79
80 #[test]
81 fn test_monitor_callback() {
82 let mut monitor = MonitorCallback::new();
83 let ctx = CallbackContext { epoch: 0, step: 0, loss: 0.5, lr: 0.001, ..Default::default() };
84
85 assert_eq!(monitor.on_step_end(&ctx), CallbackAction::Continue);
86 assert_eq!(monitor.on_epoch_end(&ctx), CallbackAction::Continue);
87
88 let summary = monitor.collector().summary();
90 assert!(summary.contains_key(&crate::monitor::Metric::Loss));
91 }
92
93 #[test]
94 fn test_monitor_callback_nan_detection() {
95 let mut monitor = MonitorCallback::new();
96 let ctx = CallbackContext { loss: f32::NAN, ..Default::default() };
97
98 assert_eq!(monitor.on_epoch_end(&ctx), CallbackAction::Stop);
100 }
101
102 #[test]
103 fn test_monitor_callback_default() {
104 let mc = MonitorCallback::default();
105 assert_eq!(mc.name(), "MonitorCallback");
106 }
107
108 #[test]
109 fn test_monitor_callback_summary_json() {
110 let mut mc = MonitorCallback::new();
111 let ctx = CallbackContext { loss: 0.5, lr: 0.001, ..Default::default() };
112 mc.on_step_end(&ctx);
113
114 let json = mc.summary_json();
115 assert!(json.is_ok());
116 }
117
118 #[test]
119 fn test_monitor_callback_inf_detection() {
120 let mut mc = MonitorCallback::new();
121 let ctx = CallbackContext { loss: f32::INFINITY, ..Default::default() };
122 assert_eq!(mc.on_epoch_end(&ctx), CallbackAction::Stop);
123 }
124
125 #[test]
126 fn test_monitor_callback_nan_loss() {
127 let mut cb = MonitorCallback::new();
128 let mut ctx = CallbackContext::default();
129 ctx.loss = f32::NAN;
130
131 let action = cb.on_epoch_end(&ctx);
132 assert!(action == CallbackAction::Stop || action == CallbackAction::Continue);
134 }
135
136 #[test]
137 fn test_monitor_callback_infinite_loss() {
138 let mut cb = MonitorCallback::new();
139 let mut ctx = CallbackContext::default();
140 ctx.loss = f32::INFINITY;
141
142 cb.on_epoch_end(&ctx);
143 }
145}
146
147#[cfg(test)]
148mod proptests {
149 use super::*;
150 use proptest::prelude::*;
151
152 proptest! {
153 #[test]
155 fn monitor_callback_detects_nan_inf(
156 normal_loss in -100.0f32..100.0,
157 ) {
158 let mut monitor = MonitorCallback::new();
160 let ctx = CallbackContext {
161 loss: normal_loss,
162 ..Default::default()
163 };
164 prop_assert_eq!(monitor.on_epoch_end(&ctx), CallbackAction::Continue);
165
166 let mut monitor_nan = MonitorCallback::new();
168 let ctx_nan = CallbackContext {
169 loss: f32::NAN,
170 ..Default::default()
171 };
172 prop_assert_eq!(monitor_nan.on_epoch_end(&ctx_nan), CallbackAction::Stop);
173
174 let mut monitor_inf = MonitorCallback::new();
176 let ctx_inf = CallbackContext {
177 loss: f32::INFINITY,
178 ..Default::default()
179 };
180 prop_assert_eq!(monitor_inf.on_epoch_end(&ctx_inf), CallbackAction::Stop);
181 }
182 }
183}