use std::time::Instant;
use super::buffer::MetricsBuffer;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AlertLevel {
Info,
Warning,
Critical,
}
#[derive(Debug, Clone)]
pub struct Alert {
pub level: AlertLevel,
pub message: String,
pub timestamp: Instant,
}
#[derive(Debug)]
pub struct AndonSystem {
alerts: Vec<Alert>,
stop_on_critical: bool,
loss_history: MetricsBuffer,
loss_ema: f32,
ema_alpha: f32,
sigma_threshold: f32,
stall_counter: usize,
best_loss: f32,
stall_threshold: usize,
}
impl Default for AndonSystem {
fn default() -> Self {
Self::new()
}
}
impl AndonSystem {
pub fn new() -> Self {
Self {
alerts: Vec::new(),
stop_on_critical: true,
loss_history: MetricsBuffer::new(100),
loss_ema: 0.0,
ema_alpha: 0.1,
sigma_threshold: 3.0,
stall_counter: 0,
best_loss: f32::INFINITY,
stall_threshold: 1000,
}
}
pub fn with_sigma_threshold(mut self, sigma: f32) -> Self {
self.sigma_threshold = sigma;
self
}
pub fn with_stall_threshold(mut self, steps: usize) -> Self {
self.stall_threshold = steps;
self
}
pub fn with_stop_on_critical(mut self, stop: bool) -> Self {
self.stop_on_critical = stop;
self
}
pub fn check_loss(&mut self, loss: f32) -> bool {
if loss.is_nan() {
self.critical("NaN loss detected - training diverged");
return self.stop_on_critical;
}
if loss.is_infinite() {
self.critical("Infinite loss detected - training diverged");
return self.stop_on_critical;
}
if self.loss_history.is_empty() {
self.loss_ema = loss;
} else {
self.loss_ema = self.ema_alpha * loss + (1.0 - self.ema_alpha) * self.loss_ema;
}
if self.loss_history.len() > 10 {
if let (Some(mean), Some(std)) = (self.loss_history.mean(), self.loss_std()) {
let z_score = (loss - mean) / std.max(f32::EPSILON);
if z_score > self.sigma_threshold {
self.warning(format!(
"Loss spike detected: {loss:.4} ({z_score:.1}σ above mean)"
));
}
}
}
if loss < self.best_loss {
self.best_loss = loss;
self.stall_counter = 0;
} else {
self.stall_counter += 1;
if self.stall_counter >= self.stall_threshold {
self.warning(format!(
"Training stalled: no improvement for {} steps",
self.stall_counter
));
}
}
self.loss_history.push(loss);
false
}
fn loss_std(&self) -> Option<f32> {
let values = self.loss_history.values();
if values.len() < 2 {
return None;
}
let mean = values.iter().sum::<f32>() / values.len().max(1) as f32;
let variance =
values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / values.len().max(1) as f32;
Some(variance.sqrt())
}
pub fn info(&mut self, message: impl Into<String>) {
self.alerts.push(Alert {
level: AlertLevel::Info,
message: message.into(),
timestamp: Instant::now(),
});
}
pub fn warning(&mut self, message: impl Into<String>) {
self.alerts.push(Alert {
level: AlertLevel::Warning,
message: message.into(),
timestamp: Instant::now(),
});
}
pub fn critical(&mut self, message: impl Into<String>) {
self.alerts.push(Alert {
level: AlertLevel::Critical,
message: message.into(),
timestamp: Instant::now(),
});
}
pub fn has_critical(&self) -> bool {
self.alerts.iter().any(|a| a.level == AlertLevel::Critical)
}
pub fn should_stop(&self) -> bool {
self.stop_on_critical && self.has_critical()
}
pub fn recent_alerts(&self, count: usize) -> &[Alert] {
let start = self.alerts.len().saturating_sub(count);
&self.alerts[start..]
}
pub fn clear_alerts(&mut self) {
self.alerts.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_andon_system_new() {
let andon = AndonSystem::new();
assert!(!andon.has_critical());
assert!(!andon.should_stop());
}
#[test]
fn test_andon_system_nan_detection() {
let mut andon = AndonSystem::new();
let should_stop = andon.check_loss(f32::NAN);
assert!(should_stop);
assert!(andon.has_critical());
}
#[test]
fn test_andon_system_inf_detection() {
let mut andon = AndonSystem::new();
let should_stop = andon.check_loss(f32::INFINITY);
assert!(should_stop);
assert!(andon.has_critical());
}
#[test]
fn test_andon_system_neg_inf_detection() {
let mut andon = AndonSystem::new();
let should_stop = andon.check_loss(f32::NEG_INFINITY);
assert!(should_stop);
assert!(andon.has_critical());
}
#[test]
fn test_andon_system_normal_loss() {
let mut andon = AndonSystem::new();
for i in 0..20 {
let should_stop = andon.check_loss(1.0 - i as f32 * 0.01);
assert!(!should_stop);
}
assert!(!andon.has_critical());
}
#[test]
fn test_andon_system_alerts() {
let mut andon = AndonSystem::new();
andon.info("Test info");
andon.warning("Test warning");
andon.critical("Test critical");
let alerts = andon.recent_alerts(10);
assert_eq!(alerts.len(), 3);
assert_eq!(alerts[0].level, AlertLevel::Info);
assert_eq!(alerts[1].level, AlertLevel::Warning);
assert_eq!(alerts[2].level, AlertLevel::Critical);
}
#[test]
fn test_andon_system_clear_alerts() {
let mut andon = AndonSystem::new();
andon.warning("Test");
andon.clear_alerts();
assert!(andon.recent_alerts(10).is_empty());
}
#[test]
fn test_andon_system_builders() {
let andon = AndonSystem::new()
.with_sigma_threshold(5.0)
.with_stall_threshold(500)
.with_stop_on_critical(false);
assert_eq!(andon.sigma_threshold, 5.0);
assert_eq!(andon.stall_threshold, 500);
assert!(!andon.stop_on_critical);
}
#[test]
fn test_andon_system_no_stop_on_critical() {
let mut andon = AndonSystem::new().with_stop_on_critical(false);
let should_stop = andon.check_loss(f32::NAN);
assert!(!should_stop);
assert!(andon.has_critical());
}
}