Skip to main content

entrenar/train/tui/
andon.rs

1//! Andon System - Health Monitoring (ENT-066)
2//!
3//! Implements Jidoka (automation with a human touch) principles:
4//! - Detects abnormalities automatically
5//! - Alerts immediately
6//! - Stops training if critical
7
8use std::time::Instant;
9
10use super::buffer::MetricsBuffer;
11
12/// Alert severity level.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum AlertLevel {
15    /// Informational message
16    Info,
17    /// Warning - training may be suboptimal
18    Warning,
19    /// Critical - training should stop
20    Critical,
21}
22
23/// Training health alert.
24#[derive(Debug, Clone)]
25pub struct Alert {
26    /// Alert level
27    pub level: AlertLevel,
28    /// Alert message
29    pub message: String,
30    /// Timestamp
31    pub timestamp: Instant,
32}
33
34/// Andon system for training health monitoring.
35///
36/// Implements Jidoka (automation with a human touch) principles:
37/// - Detects abnormalities automatically
38/// - Alerts immediately
39/// - Stops training if critical
40#[derive(Debug)]
41pub struct AndonSystem {
42    /// Active alerts
43    alerts: Vec<Alert>,
44    /// Whether to stop on critical
45    stop_on_critical: bool,
46    /// Loss history for divergence detection
47    loss_history: MetricsBuffer,
48    /// EMA of loss for divergence detection
49    loss_ema: f32,
50    /// EMA alpha
51    ema_alpha: f32,
52    /// Sigma threshold for divergence
53    sigma_threshold: f32,
54    /// Steps since last improvement
55    stall_counter: usize,
56    /// Best loss seen
57    best_loss: f32,
58    /// Stall threshold (steps)
59    stall_threshold: usize,
60}
61
62impl Default for AndonSystem {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl AndonSystem {
69    /// Create a new Andon system.
70    pub fn new() -> Self {
71        Self {
72            alerts: Vec::new(),
73            stop_on_critical: true,
74            loss_history: MetricsBuffer::new(100),
75            loss_ema: 0.0,
76            ema_alpha: 0.1,
77            sigma_threshold: 3.0,
78            stall_counter: 0,
79            best_loss: f32::INFINITY,
80            stall_threshold: 1000,
81        }
82    }
83
84    /// Configure sigma threshold for divergence detection.
85    pub fn with_sigma_threshold(mut self, sigma: f32) -> Self {
86        self.sigma_threshold = sigma;
87        self
88    }
89
90    /// Configure stall detection threshold.
91    pub fn with_stall_threshold(mut self, steps: usize) -> Self {
92        self.stall_threshold = steps;
93        self
94    }
95
96    /// Configure whether to stop on critical alerts.
97    pub fn with_stop_on_critical(mut self, stop: bool) -> Self {
98        self.stop_on_critical = stop;
99        self
100    }
101
102    /// Check loss value for abnormalities.
103    ///
104    /// Returns `true` if training should stop.
105    pub fn check_loss(&mut self, loss: f32) -> bool {
106        // Check for NaN/Inf
107        if loss.is_nan() {
108            self.critical("NaN loss detected - training diverged");
109            return self.stop_on_critical;
110        }
111
112        if loss.is_infinite() {
113            self.critical("Infinite loss detected - training diverged");
114            return self.stop_on_critical;
115        }
116
117        // Update EMA
118        if self.loss_history.is_empty() {
119            self.loss_ema = loss;
120        } else {
121            self.loss_ema = self.ema_alpha * loss + (1.0 - self.ema_alpha) * self.loss_ema;
122        }
123
124        // Check for divergence (loss >> EMA)
125        if self.loss_history.len() > 10 {
126            if let (Some(mean), Some(std)) = (self.loss_history.mean(), self.loss_std()) {
127                let z_score = (loss - mean) / std.max(f32::EPSILON);
128                if z_score > self.sigma_threshold {
129                    self.warning(format!(
130                        "Loss spike detected: {loss:.4} ({z_score:.1}σ above mean)"
131                    ));
132                }
133            }
134        }
135
136        // Check for stall
137        if loss < self.best_loss {
138            self.best_loss = loss;
139            self.stall_counter = 0;
140        } else {
141            self.stall_counter += 1;
142            if self.stall_counter >= self.stall_threshold {
143                self.warning(format!(
144                    "Training stalled: no improvement for {} steps",
145                    self.stall_counter
146                ));
147            }
148        }
149
150        self.loss_history.push(loss);
151        false
152    }
153
154    /// Calculate standard deviation of loss history.
155    fn loss_std(&self) -> Option<f32> {
156        let values = self.loss_history.values();
157        if values.len() < 2 {
158            return None;
159        }
160        let mean = values.iter().sum::<f32>() / values.len().max(1) as f32;
161        let variance =
162            values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / values.len().max(1) as f32;
163        Some(variance.sqrt())
164    }
165
166    /// Add an info alert.
167    pub fn info(&mut self, message: impl Into<String>) {
168        self.alerts.push(Alert {
169            level: AlertLevel::Info,
170            message: message.into(),
171            timestamp: Instant::now(),
172        });
173    }
174
175    /// Add a warning alert.
176    pub fn warning(&mut self, message: impl Into<String>) {
177        self.alerts.push(Alert {
178            level: AlertLevel::Warning,
179            message: message.into(),
180            timestamp: Instant::now(),
181        });
182    }
183
184    /// Add a critical alert.
185    pub fn critical(&mut self, message: impl Into<String>) {
186        self.alerts.push(Alert {
187            level: AlertLevel::Critical,
188            message: message.into(),
189            timestamp: Instant::now(),
190        });
191    }
192
193    /// Check if there are any critical alerts.
194    pub fn has_critical(&self) -> bool {
195        self.alerts.iter().any(|a| a.level == AlertLevel::Critical)
196    }
197
198    /// Check if training should stop.
199    pub fn should_stop(&self) -> bool {
200        self.stop_on_critical && self.has_critical()
201    }
202
203    /// Get recent alerts.
204    pub fn recent_alerts(&self, count: usize) -> &[Alert] {
205        let start = self.alerts.len().saturating_sub(count);
206        &self.alerts[start..]
207    }
208
209    /// Clear all alerts.
210    pub fn clear_alerts(&mut self) {
211        self.alerts.clear();
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn test_andon_system_new() {
221        let andon = AndonSystem::new();
222        assert!(!andon.has_critical());
223        assert!(!andon.should_stop());
224    }
225
226    #[test]
227    fn test_andon_system_nan_detection() {
228        let mut andon = AndonSystem::new();
229        let should_stop = andon.check_loss(f32::NAN);
230        assert!(should_stop);
231        assert!(andon.has_critical());
232    }
233
234    #[test]
235    fn test_andon_system_inf_detection() {
236        let mut andon = AndonSystem::new();
237        let should_stop = andon.check_loss(f32::INFINITY);
238        assert!(should_stop);
239        assert!(andon.has_critical());
240    }
241
242    #[test]
243    fn test_andon_system_neg_inf_detection() {
244        let mut andon = AndonSystem::new();
245        let should_stop = andon.check_loss(f32::NEG_INFINITY);
246        assert!(should_stop);
247        assert!(andon.has_critical());
248    }
249
250    #[test]
251    fn test_andon_system_normal_loss() {
252        let mut andon = AndonSystem::new();
253        for i in 0..20 {
254            let should_stop = andon.check_loss(1.0 - i as f32 * 0.01);
255            assert!(!should_stop);
256        }
257        assert!(!andon.has_critical());
258    }
259
260    #[test]
261    fn test_andon_system_alerts() {
262        let mut andon = AndonSystem::new();
263        andon.info("Test info");
264        andon.warning("Test warning");
265        andon.critical("Test critical");
266
267        let alerts = andon.recent_alerts(10);
268        assert_eq!(alerts.len(), 3);
269        assert_eq!(alerts[0].level, AlertLevel::Info);
270        assert_eq!(alerts[1].level, AlertLevel::Warning);
271        assert_eq!(alerts[2].level, AlertLevel::Critical);
272    }
273
274    #[test]
275    fn test_andon_system_clear_alerts() {
276        let mut andon = AndonSystem::new();
277        andon.warning("Test");
278        andon.clear_alerts();
279        assert!(andon.recent_alerts(10).is_empty());
280    }
281
282    #[test]
283    fn test_andon_system_builders() {
284        let andon = AndonSystem::new()
285            .with_sigma_threshold(5.0)
286            .with_stall_threshold(500)
287            .with_stop_on_critical(false);
288
289        assert_eq!(andon.sigma_threshold, 5.0);
290        assert_eq!(andon.stall_threshold, 500);
291        assert!(!andon.stop_on_critical);
292    }
293
294    #[test]
295    fn test_andon_system_no_stop_on_critical() {
296        let mut andon = AndonSystem::new().with_stop_on_critical(false);
297        let should_stop = andon.check_loss(f32::NAN);
298        assert!(!should_stop);
299        assert!(andon.has_critical());
300    }
301}