entrenar/train/tui/
andon.rs1use std::time::Instant;
9
10use super::buffer::MetricsBuffer;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum AlertLevel {
15 Info,
17 Warning,
19 Critical,
21}
22
23#[derive(Debug, Clone)]
25pub struct Alert {
26 pub level: AlertLevel,
28 pub message: String,
30 pub timestamp: Instant,
32}
33
34#[derive(Debug)]
41pub struct AndonSystem {
42 alerts: Vec<Alert>,
44 stop_on_critical: bool,
46 loss_history: MetricsBuffer,
48 loss_ema: f32,
50 ema_alpha: f32,
52 sigma_threshold: f32,
54 stall_counter: usize,
56 best_loss: f32,
58 stall_threshold: usize,
60}
61
62impl Default for AndonSystem {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl AndonSystem {
69 pub fn new() -> Self {
71 Self {
72 alerts: Vec::new(),
73 stop_on_critical: true,
74 loss_history: MetricsBuffer::new(100),
75 loss_ema: 0.0,
76 ema_alpha: 0.1,
77 sigma_threshold: 3.0,
78 stall_counter: 0,
79 best_loss: f32::INFINITY,
80 stall_threshold: 1000,
81 }
82 }
83
84 pub fn with_sigma_threshold(mut self, sigma: f32) -> Self {
86 self.sigma_threshold = sigma;
87 self
88 }
89
90 pub fn with_stall_threshold(mut self, steps: usize) -> Self {
92 self.stall_threshold = steps;
93 self
94 }
95
96 pub fn with_stop_on_critical(mut self, stop: bool) -> Self {
98 self.stop_on_critical = stop;
99 self
100 }
101
102 pub fn check_loss(&mut self, loss: f32) -> bool {
106 if loss.is_nan() {
108 self.critical("NaN loss detected - training diverged");
109 return self.stop_on_critical;
110 }
111
112 if loss.is_infinite() {
113 self.critical("Infinite loss detected - training diverged");
114 return self.stop_on_critical;
115 }
116
117 if self.loss_history.is_empty() {
119 self.loss_ema = loss;
120 } else {
121 self.loss_ema = self.ema_alpha * loss + (1.0 - self.ema_alpha) * self.loss_ema;
122 }
123
124 if self.loss_history.len() > 10 {
126 if let (Some(mean), Some(std)) = (self.loss_history.mean(), self.loss_std()) {
127 let z_score = (loss - mean) / std.max(f32::EPSILON);
128 if z_score > self.sigma_threshold {
129 self.warning(format!(
130 "Loss spike detected: {loss:.4} ({z_score:.1}σ above mean)"
131 ));
132 }
133 }
134 }
135
136 if loss < self.best_loss {
138 self.best_loss = loss;
139 self.stall_counter = 0;
140 } else {
141 self.stall_counter += 1;
142 if self.stall_counter >= self.stall_threshold {
143 self.warning(format!(
144 "Training stalled: no improvement for {} steps",
145 self.stall_counter
146 ));
147 }
148 }
149
150 self.loss_history.push(loss);
151 false
152 }
153
154 fn loss_std(&self) -> Option<f32> {
156 let values = self.loss_history.values();
157 if values.len() < 2 {
158 return None;
159 }
160 let mean = values.iter().sum::<f32>() / values.len().max(1) as f32;
161 let variance =
162 values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / values.len().max(1) as f32;
163 Some(variance.sqrt())
164 }
165
166 pub fn info(&mut self, message: impl Into<String>) {
168 self.alerts.push(Alert {
169 level: AlertLevel::Info,
170 message: message.into(),
171 timestamp: Instant::now(),
172 });
173 }
174
175 pub fn warning(&mut self, message: impl Into<String>) {
177 self.alerts.push(Alert {
178 level: AlertLevel::Warning,
179 message: message.into(),
180 timestamp: Instant::now(),
181 });
182 }
183
184 pub fn critical(&mut self, message: impl Into<String>) {
186 self.alerts.push(Alert {
187 level: AlertLevel::Critical,
188 message: message.into(),
189 timestamp: Instant::now(),
190 });
191 }
192
193 pub fn has_critical(&self) -> bool {
195 self.alerts.iter().any(|a| a.level == AlertLevel::Critical)
196 }
197
198 pub fn should_stop(&self) -> bool {
200 self.stop_on_critical && self.has_critical()
201 }
202
203 pub fn recent_alerts(&self, count: usize) -> &[Alert] {
205 let start = self.alerts.len().saturating_sub(count);
206 &self.alerts[start..]
207 }
208
209 pub fn clear_alerts(&mut self) {
211 self.alerts.clear();
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn test_andon_system_new() {
221 let andon = AndonSystem::new();
222 assert!(!andon.has_critical());
223 assert!(!andon.should_stop());
224 }
225
226 #[test]
227 fn test_andon_system_nan_detection() {
228 let mut andon = AndonSystem::new();
229 let should_stop = andon.check_loss(f32::NAN);
230 assert!(should_stop);
231 assert!(andon.has_critical());
232 }
233
234 #[test]
235 fn test_andon_system_inf_detection() {
236 let mut andon = AndonSystem::new();
237 let should_stop = andon.check_loss(f32::INFINITY);
238 assert!(should_stop);
239 assert!(andon.has_critical());
240 }
241
242 #[test]
243 fn test_andon_system_neg_inf_detection() {
244 let mut andon = AndonSystem::new();
245 let should_stop = andon.check_loss(f32::NEG_INFINITY);
246 assert!(should_stop);
247 assert!(andon.has_critical());
248 }
249
250 #[test]
251 fn test_andon_system_normal_loss() {
252 let mut andon = AndonSystem::new();
253 for i in 0..20 {
254 let should_stop = andon.check_loss(1.0 - i as f32 * 0.01);
255 assert!(!should_stop);
256 }
257 assert!(!andon.has_critical());
258 }
259
260 #[test]
261 fn test_andon_system_alerts() {
262 let mut andon = AndonSystem::new();
263 andon.info("Test info");
264 andon.warning("Test warning");
265 andon.critical("Test critical");
266
267 let alerts = andon.recent_alerts(10);
268 assert_eq!(alerts.len(), 3);
269 assert_eq!(alerts[0].level, AlertLevel::Info);
270 assert_eq!(alerts[1].level, AlertLevel::Warning);
271 assert_eq!(alerts[2].level, AlertLevel::Critical);
272 }
273
274 #[test]
275 fn test_andon_system_clear_alerts() {
276 let mut andon = AndonSystem::new();
277 andon.warning("Test");
278 andon.clear_alerts();
279 assert!(andon.recent_alerts(10).is_empty());
280 }
281
282 #[test]
283 fn test_andon_system_builders() {
284 let andon = AndonSystem::new()
285 .with_sigma_threshold(5.0)
286 .with_stall_threshold(500)
287 .with_stop_on_critical(false);
288
289 assert_eq!(andon.sigma_threshold, 5.0);
290 assert_eq!(andon.stall_threshold, 500);
291 assert!(!andon.stop_on_critical);
292 }
293
294 #[test]
295 fn test_andon_system_no_stop_on_critical() {
296 let mut andon = AndonSystem::new().with_stop_on_critical(false);
297 let should_stop = andon.check_loss(f32::NAN);
298 assert!(!should_stop);
299 assert!(andon.has_critical());
300 }
301}