Skip to main content

axonml_optim/
health.rs

1//! Training Health Monitor - Real-time Training Diagnostics
2//!
3//! # File
4//! `crates/axonml-optim/src/health.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18use std::fmt;
19
20// =============================================================================
21// Configuration
22// =============================================================================
23
24/// Configuration for the training monitor.
25#[derive(Debug, Clone)]
26pub struct MonitorConfig {
27    /// Rolling window size for statistics (default 100).
28    pub window_size: usize,
29    /// Alert if gradient norm exceeds this value (default 100.0).
30    pub grad_norm_threshold: f32,
31    /// Alert if loss exceeds factor * moving average (default 10.0).
32    pub loss_divergence_factor: f32,
33    /// Steps with zero gradient before alerting on dead neurons (default 50).
34    pub dead_neuron_threshold: usize,
35    /// Whether to check for NaN/Inf values (default true).
36    pub nan_check: bool,
37    /// Loss change below this threshold indicates convergence (default 1e-6).
38    pub convergence_threshold: f32,
39    /// Maximum history entries to retain in memory (default 1000).
40    pub max_history: usize,
41}
42
43impl Default for MonitorConfig {
44    fn default() -> Self {
45        Self {
46            window_size: 100,
47            grad_norm_threshold: 100.0,
48            loss_divergence_factor: 10.0,
49            dead_neuron_threshold: 50,
50            nan_check: true,
51            convergence_threshold: 1e-6,
52            max_history: 1000,
53        }
54    }
55}
56
57// =============================================================================
58// Alert Types
59// =============================================================================
60
61/// A training alert generated by the monitor.
62#[derive(Debug, Clone)]
63pub struct TrainingAlert {
64    /// The training step at which this alert was generated.
65    pub step: usize,
66    /// Severity level of the alert.
67    pub severity: AlertSeverity,
68    /// Category of the alert.
69    pub kind: AlertKind,
70    /// Human-readable description of the issue.
71    pub message: String,
72}
73
74impl fmt::Display for TrainingAlert {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        write!(
77            f,
78            "[step {}] {:?} {:?}: {}",
79            self.step, self.severity, self.kind, self.message
80        )
81    }
82}
83
84/// Severity level for training alerts.
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum AlertSeverity {
87    /// Informational -- no action required.
88    Info,
89    /// Warning -- training may be suboptimal.
90    Warning,
91    /// Critical -- training is likely broken, immediate action needed.
92    Critical,
93}
94
95/// Category of training alert.
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub enum AlertKind {
98    /// NaN detected in loss or gradients.
99    NaNDetected,
100    /// Infinity detected in loss or gradients.
101    InfDetected,
102    /// Gradient norms are exploding (exceeding threshold).
103    GradientExplosion,
104    /// Gradient norms are vanishing (near zero for many steps).
105    GradientVanishing,
106    /// Loss is diverging from its moving average.
107    LossDivergence,
108    /// Loss has stagnated (not improving).
109    LossStagnation,
110    /// A parameter has had zero gradient for many steps (dead neuron).
111    DeadNeuron,
112    /// Learning rate appears too high based on loss behavior.
113    LearningRateTooHigh,
114    /// Learning rate appears too low based on loss behavior.
115    LearningRateTooLow,
116    /// Training has converged (loss is no longer changing).
117    Converged,
118}
119
120// =============================================================================
121// Health Report
122// =============================================================================
123
124/// Snapshot of current training health.
125#[derive(Debug, Clone)]
126pub struct HealthReport {
127    /// Whether training appears healthy overall.
128    pub is_healthy: bool,
129    /// Current training step.
130    pub step: usize,
131    /// Most recent loss value.
132    pub current_loss: f32,
133    /// Current loss trajectory.
134    pub loss_trend: LossTrend,
135    /// Mean gradient norm over the recent window.
136    pub mean_grad_norm: f32,
137    /// Maximum gradient norm over the recent window.
138    pub max_grad_norm: f32,
139    /// Convergence score from 0.0 (actively changing) to 1.0 (fully converged).
140    pub convergence_score: f32,
141    /// Currently active alerts.
142    pub active_alerts: Vec<TrainingAlert>,
143    /// Number of parameters detected as dead neurons.
144    pub dead_neurons: usize,
145}
146
147/// Loss trajectory classification.
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
149pub enum LossTrend {
150    /// Loss is decreasing -- training is progressing.
151    Decreasing,
152    /// Loss is stable -- may be converged or stuck.
153    Stable,
154    /// Loss is increasing -- training may be diverging.
155    Increasing,
156    /// Loss is oscillating -- learning rate may be too high.
157    Oscillating,
158    /// Not enough data to determine trend.
159    Unknown,
160}
161
162// =============================================================================
163// Training Monitor
164// =============================================================================
165
166/// Real-time training health monitor.
167///
168/// Attaches to any optimizer to detect pathological training behavior
169/// including gradient explosion/vanishing, loss divergence, dead neurons,
170/// and convergence. Generates alerts with severity levels and provides
171/// actionable suggestions.
172pub struct TrainingMonitor {
173    /// Current training step count.
174    step_count: usize,
175    /// History of loss values.
176    loss_history: Vec<f32>,
177    /// History of overall gradient norms (max across all parameters per step).
178    grad_norm_history: Vec<f32>,
179    /// History of learning rate values.
180    lr_history: Vec<f32>,
181    /// Accumulated training alerts.
182    alerts: Vec<TrainingAlert>,
183    /// Monitor configuration.
184    config: MonitorConfig,
185    /// Per-parameter count of consecutive steps with zero gradient.
186    zero_grad_counts: HashMap<String, usize>,
187    /// Count of consecutive steps with vanishing gradients.
188    vanishing_streak: usize,
189}
190
191impl TrainingMonitor {
192    /// Creates a new training monitor with default configuration.
193    pub fn new() -> Self {
194        Self::with_config(MonitorConfig::default())
195    }
196
197    /// Creates a new training monitor with the given configuration.
198    pub fn with_config(config: MonitorConfig) -> Self {
199        Self {
200            step_count: 0,
201            loss_history: Vec::new(),
202            grad_norm_history: Vec::new(),
203            lr_history: Vec::new(),
204            alerts: Vec::new(),
205            config,
206            zero_grad_counts: HashMap::new(),
207            vanishing_streak: 0,
208        }
209    }
210
211    /// Records a single training step.
212    ///
213    /// # Arguments
214    ///
215    /// * `loss` - The loss value for this step.
216    /// * `grad_norms` - Slice of (parameter_name, gradient_norm) pairs.
217    /// * `lr` - The current learning rate.
218    pub fn record_step(&mut self, loss: f32, grad_norms: &[(&str, f32)], lr: f32) {
219        self.step_count += 1;
220        let step = self.step_count;
221
222        // Append to histories, respecting max_history
223        self.push_bounded(&mut self.loss_history.clone(), loss);
224        self.loss_history.push(loss);
225        if self.loss_history.len() > self.config.max_history {
226            self.loss_history.remove(0);
227        }
228
229        // Compute max gradient norm across all parameters
230        let max_grad_norm = grad_norms.iter().map(|(_, n)| *n).fold(0.0_f32, f32::max);
231
232        self.grad_norm_history.push(max_grad_norm);
233        if self.grad_norm_history.len() > self.config.max_history {
234            self.grad_norm_history.remove(0);
235        }
236
237        self.lr_history.push(lr);
238        if self.lr_history.len() > self.config.max_history {
239            self.lr_history.remove(0);
240        }
241
242        // --- NaN/Inf checks ---
243        if self.config.nan_check {
244            if loss.is_nan() {
245                self.emit_alert(
246                    step,
247                    AlertSeverity::Critical,
248                    AlertKind::NaNDetected,
249                    "NaN detected in loss value".to_string(),
250                );
251            } else if loss.is_infinite() {
252                self.emit_alert(
253                    step,
254                    AlertSeverity::Critical,
255                    AlertKind::InfDetected,
256                    "Infinity detected in loss value".to_string(),
257                );
258            }
259
260            for (name, norm) in grad_norms {
261                if norm.is_nan() {
262                    self.emit_alert(
263                        step,
264                        AlertSeverity::Critical,
265                        AlertKind::NaNDetected,
266                        format!("NaN detected in gradient norm for '{}'", name),
267                    );
268                } else if norm.is_infinite() {
269                    self.emit_alert(
270                        step,
271                        AlertSeverity::Critical,
272                        AlertKind::InfDetected,
273                        format!("Infinity detected in gradient norm for '{}'", name),
274                    );
275                }
276            }
277        }
278
279        // --- Gradient explosion ---
280        if max_grad_norm > self.config.grad_norm_threshold && max_grad_norm.is_finite() {
281            self.emit_alert(
282                step,
283                AlertSeverity::Warning,
284                AlertKind::GradientExplosion,
285                format!(
286                    "Gradient norm {:.4} exceeds threshold {:.4}",
287                    max_grad_norm, self.config.grad_norm_threshold
288                ),
289            );
290        }
291
292        // --- Gradient vanishing ---
293        if max_grad_norm < 1e-8 && max_grad_norm.is_finite() {
294            self.vanishing_streak += 1;
295            if self.vanishing_streak >= 10 {
296                self.emit_alert(
297                    step,
298                    AlertSeverity::Warning,
299                    AlertKind::GradientVanishing,
300                    format!(
301                        "Gradient norms near zero for {} consecutive steps",
302                        self.vanishing_streak
303                    ),
304                );
305            }
306        } else {
307            self.vanishing_streak = 0;
308        }
309
310        // --- Dead neuron tracking ---
311        let dead_threshold = self.config.dead_neuron_threshold;
312        let mut new_dead_alerts: Vec<(String, usize)> = Vec::new();
313        for (name, norm) in grad_norms {
314            let count = self
315                .zero_grad_counts
316                .entry((*name).to_string())
317                .or_insert(0);
318            if *norm == 0.0 {
319                *count += 1;
320                if *count == dead_threshold {
321                    new_dead_alerts.push(((*name).to_string(), *count));
322                }
323            } else {
324                *count = 0;
325            }
326        }
327        for (name, count) in new_dead_alerts {
328            self.emit_alert(
329                step,
330                AlertSeverity::Warning,
331                AlertKind::DeadNeuron,
332                format!(
333                    "Parameter '{}' has had zero gradient for {} steps (dead neuron)",
334                    name, count
335                ),
336            );
337        }
338
339        // --- Loss divergence ---
340        if self.loss_history.len() >= self.config.window_size && loss.is_finite() {
341            let window_start = self
342                .loss_history
343                .len()
344                .saturating_sub(self.config.window_size);
345            let window = &self.loss_history[window_start..self.loss_history.len() - 1];
346            let finite_vals: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
347            if !finite_vals.is_empty() {
348                let avg: f32 = finite_vals.iter().sum::<f32>() / finite_vals.len() as f32;
349                if avg > 0.0 && loss > avg * self.config.loss_divergence_factor {
350                    self.emit_alert(
351                        step,
352                        AlertSeverity::Warning,
353                        AlertKind::LossDivergence,
354                        format!(
355                            "Loss {:.6} diverged from moving average {:.6} (factor {:.1}x)",
356                            loss,
357                            avg,
358                            loss / avg
359                        ),
360                    );
361                }
362            }
363        }
364
365        // --- Convergence check ---
366        if self.loss_history.len() >= self.config.window_size {
367            let window_start = self.loss_history.len() - self.config.window_size;
368            let window = &self.loss_history[window_start..];
369            let finite_vals: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
370            if finite_vals.len() >= 2 {
371                let max_val = finite_vals
372                    .iter()
373                    .copied()
374                    .fold(f32::NEG_INFINITY, f32::max);
375                let min_val = finite_vals.iter().copied().fold(f32::INFINITY, f32::min);
376                let range = max_val - min_val;
377                if range < self.config.convergence_threshold {
378                    self.emit_alert(
379                        step,
380                        AlertSeverity::Info,
381                        AlertKind::Converged,
382                        format!(
383                            "Training converged: loss range {:.2e} over last {} steps",
384                            range, self.config.window_size
385                        ),
386                    );
387                }
388            }
389        }
390    }
391
392    /// Returns a full health report for the current training state.
393    pub fn check_health(&self) -> HealthReport {
394        let (mean_gn, _std_gn, max_gn) = self.grad_norm_stats();
395        let trend = self.loss_trend();
396        let conv_score = self.convergence_score();
397
398        let current_loss = self.loss_history.last().copied().unwrap_or(f32::NAN);
399
400        // Count dead neurons
401        let dead_neurons = self
402            .zero_grad_counts
403            .values()
404            .filter(|c| **c >= self.config.dead_neuron_threshold)
405            .count();
406
407        // Determine overall health
408        let has_critical = self
409            .alerts
410            .iter()
411            .any(|a| a.severity == AlertSeverity::Critical);
412        let is_healthy = !has_critical
413            && trend != LossTrend::Increasing
414            && !current_loss.is_nan()
415            && !current_loss.is_infinite();
416
417        // Collect recent alerts (last window_size steps)
418        let min_step = self.step_count.saturating_sub(self.config.window_size);
419        let active_alerts: Vec<TrainingAlert> = self
420            .alerts
421            .iter()
422            .filter(|a| a.step > min_step)
423            .cloned()
424            .collect();
425
426        HealthReport {
427            is_healthy,
428            step: self.step_count,
429            current_loss,
430            loss_trend: trend,
431            mean_grad_norm: mean_gn,
432            max_grad_norm: max_gn,
433            convergence_score: conv_score,
434            active_alerts,
435            dead_neurons,
436        }
437    }
438
439    /// Returns true if training appears healthy (no critical alerts, loss not diverging).
440    pub fn is_healthy(&self) -> bool {
441        self.check_health().is_healthy
442    }
443
444    /// Returns the accumulated alerts.
445    pub fn alerts(&self) -> &[TrainingAlert] {
446        &self.alerts
447    }
448
449    /// Clears all accumulated alerts.
450    pub fn clear_alerts(&mut self) {
451        self.alerts.clear();
452    }
453
454    /// Analyzes the loss trajectory over the recent window.
455    ///
456    /// Compares the rolling average of the most recent window to the
457    /// rolling average of the previous window to classify the trend.
458    pub fn loss_trend(&self) -> LossTrend {
459        let w = self.config.window_size;
460        if self.loss_history.len() < w * 2 {
461            return LossTrend::Unknown;
462        }
463
464        let len = self.loss_history.len();
465        let recent = &self.loss_history[len - w..];
466        let previous = &self.loss_history[len - 2 * w..len - w];
467
468        let recent_finite: Vec<f32> = recent.iter().copied().filter(|v| v.is_finite()).collect();
469        let prev_finite: Vec<f32> = previous.iter().copied().filter(|v| v.is_finite()).collect();
470
471        if recent_finite.is_empty() || prev_finite.is_empty() {
472            return LossTrend::Unknown;
473        }
474
475        let recent_avg = recent_finite.iter().sum::<f32>() / recent_finite.len() as f32;
476        let prev_avg = prev_finite.iter().sum::<f32>() / prev_finite.len() as f32;
477
478        if prev_avg == 0.0 {
479            return LossTrend::Unknown;
480        }
481
482        let ratio = recent_avg / prev_avg;
483
484        // Check for oscillation: high variance in recent window
485        let recent_mean = recent_avg;
486        let recent_var = recent_finite
487            .iter()
488            .map(|v| (v - recent_mean).powi(2))
489            .sum::<f32>()
490            / recent_finite.len() as f32;
491        let recent_std = recent_var.sqrt();
492        let cv = if recent_mean.abs() > 1e-12 {
493            recent_std / recent_mean.abs()
494        } else {
495            0.0
496        };
497
498        if ratio < 0.95 {
499            LossTrend::Decreasing
500        } else if ratio > 1.05 {
501            LossTrend::Increasing
502        } else if cv > 0.1 {
503            LossTrend::Oscillating
504        } else {
505            LossTrend::Stable
506        }
507    }
508
509    /// Suggests a learning rate adjustment based on current training dynamics.
510    ///
511    /// Returns `None` if no adjustment is needed or training has converged.
512    pub fn suggest_lr(&self) -> Option<f32> {
513        let current_lr = self.lr_history.last().copied()?;
514        let trend = self.loss_trend();
515
516        // Check for gradient explosion
517        let (_, _, max_gn) = self.grad_norm_stats();
518        if max_gn > self.config.grad_norm_threshold && max_gn.is_finite() {
519            return Some(current_lr * 0.1);
520        }
521
522        match trend {
523            LossTrend::Oscillating => Some(current_lr * 0.5),
524            LossTrend::Stable => {
525                // Check if truly converged or just stagnating
526                let conv = self.convergence_score();
527                if conv > 0.99 {
528                    None // Converged, no adjustment needed
529                } else {
530                    Some(current_lr * 2.0) // Stagnating, try higher LR
531                }
532            }
533            LossTrend::Increasing => Some(current_lr * 0.1),
534            _ => None,
535        }
536    }
537
538    /// Returns (mean, std, max) of gradient norms over the recent window.
539    pub fn grad_norm_stats(&self) -> (f32, f32, f32) {
540        if self.grad_norm_history.is_empty() {
541            return (0.0, 0.0, 0.0);
542        }
543
544        let w = self.config.window_size.min(self.grad_norm_history.len());
545        let start = self.grad_norm_history.len() - w;
546        let window = &self.grad_norm_history[start..];
547
548        let finite: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
549        if finite.is_empty() {
550            return (0.0, 0.0, 0.0);
551        }
552
553        let n = finite.len() as f32;
554        let mean = finite.iter().sum::<f32>() / n;
555        let variance = finite.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n;
556        let std = variance.sqrt();
557        let max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
558
559        (mean, std, max)
560    }
561
562    /// Returns a convergence score between 0.0 and 1.0.
563    ///
564    /// 1.0 indicates full convergence (no loss change over the window).
565    /// 0.0 indicates the loss is still actively changing.
566    pub fn convergence_score(&self) -> f32 {
567        let w = self.config.window_size;
568        if self.loss_history.len() < w {
569            return 0.0;
570        }
571
572        let start = self.loss_history.len() - w;
573        let window = &self.loss_history[start..];
574
575        let finite: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
576        if finite.len() < 2 {
577            return 0.0;
578        }
579
580        let max_val = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
581        let min_val = finite.iter().copied().fold(f32::INFINITY, f32::min);
582        let range = max_val - min_val;
583
584        let mean = finite.iter().sum::<f32>() / finite.len() as f32;
585        if mean.abs() < 1e-12 {
586            // If mean is essentially zero, check absolute range
587            if range < self.config.convergence_threshold {
588                return 1.0;
589            }
590            return 0.0;
591        }
592
593        // Relative range (normalized by mean)
594        let relative_range = range / mean.abs();
595
596        // Map relative_range to [0, 1] score: 0 range -> 1.0, large range -> 0.0
597        // Use exponential decay for smooth mapping
598        let score = (-relative_range * 100.0).exp();
599        score.clamp(0.0, 1.0)
600    }
601
602    /// Returns a human-readable summary of the current training state.
603    pub fn summary(&self) -> String {
604        let report = self.check_health();
605        let (mean_gn, std_gn, max_gn) = self.grad_norm_stats();
606
607        let mut s = String::new();
608        s.push_str(&format!(
609            "=== Training Health Report (step {}) ===\n",
610            report.step
611        ));
612        s.push_str(&format!(
613            "Status: {}\n",
614            if report.is_healthy {
615                "HEALTHY"
616            } else {
617                "UNHEALTHY"
618            }
619        ));
620        s.push_str(&format!(
621            "Loss: {:.6} (trend: {:?})\n",
622            report.current_loss, report.loss_trend
623        ));
624        s.push_str(&format!(
625            "Grad norms: mean={:.4}, std={:.4}, max={:.4}\n",
626            mean_gn, std_gn, max_gn
627        ));
628        s.push_str(&format!(
629            "Convergence: {:.2}%\n",
630            report.convergence_score * 100.0
631        ));
632        s.push_str(&format!("Dead neurons: {}\n", report.dead_neurons));
633
634        if !report.active_alerts.is_empty() {
635            s.push_str(&format!(
636                "Active alerts ({}):\n",
637                report.active_alerts.len()
638            ));
639            for alert in &report.active_alerts {
640                s.push_str(&format!("  {}\n", alert));
641            }
642        }
643
644        if let Some(lr) = self.suggest_lr() {
645            s.push_str(&format!("Suggested LR: {:.6}\n", lr));
646        }
647
648        s
649    }
650
651    // =========================================================================
652    // Private helpers
653    // =========================================================================
654
655    fn push_bounded(&self, _history: &mut Vec<f32>, _value: f32) {
656        // Intentionally left as no-op: actual push + trim is done inline
657        // in record_step for each history vector individually.
658    }
659
660    fn emit_alert(
661        &mut self,
662        step: usize,
663        severity: AlertSeverity,
664        kind: AlertKind,
665        message: String,
666    ) {
667        self.alerts.push(TrainingAlert {
668            step,
669            severity,
670            kind,
671            message,
672        });
673    }
674}
675
676impl Default for TrainingMonitor {
677    fn default() -> Self {
678        Self::new()
679    }
680}
681
682// =============================================================================
683// Tests
684// =============================================================================
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689
690    // -------------------------------------------------------------------------
691    // Construction
692    // -------------------------------------------------------------------------
693
694    #[test]
695    fn test_monitor_creation_defaults() {
696        let monitor = TrainingMonitor::new();
697        assert_eq!(monitor.step_count, 0);
698        assert!(monitor.loss_history.is_empty());
699        assert!(monitor.alerts.is_empty());
700        assert_eq!(monitor.config.window_size, 100);
701        assert!((monitor.config.grad_norm_threshold - 100.0).abs() < 1e-6);
702        assert!(monitor.config.nan_check);
703    }
704
705    #[test]
706    fn test_monitor_with_custom_config() {
707        let config = MonitorConfig {
708            window_size: 50,
709            grad_norm_threshold: 50.0,
710            loss_divergence_factor: 5.0,
711            dead_neuron_threshold: 20,
712            nan_check: false,
713            convergence_threshold: 1e-5,
714            max_history: 500,
715        };
716        let monitor = TrainingMonitor::with_config(config);
717        assert_eq!(monitor.config.window_size, 50);
718        assert!((monitor.config.grad_norm_threshold - 50.0).abs() < 1e-6);
719        assert!(!monitor.config.nan_check);
720        assert_eq!(monitor.config.max_history, 500);
721    }
722
723    // -------------------------------------------------------------------------
724    // Recording steps
725    // -------------------------------------------------------------------------
726
727    #[test]
728    fn test_record_step_updates_state() {
729        let mut monitor = TrainingMonitor::new();
730        monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
731
732        assert_eq!(monitor.step_count, 1);
733        assert_eq!(monitor.loss_history.len(), 1);
734        assert_eq!(monitor.grad_norm_history.len(), 1);
735        assert_eq!(monitor.lr_history.len(), 1);
736        assert!((monitor.loss_history[0] - 0.5).abs() < 1e-6);
737    }
738
739    // -------------------------------------------------------------------------
740    // NaN/Inf detection
741    // -------------------------------------------------------------------------
742
743    #[test]
744    fn test_nan_detection_generates_critical_alert() {
745        let mut monitor = TrainingMonitor::new();
746        monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
747
748        assert_eq!(monitor.alerts.len(), 1);
749        assert_eq!(monitor.alerts[0].severity, AlertSeverity::Critical);
750        assert_eq!(monitor.alerts[0].kind, AlertKind::NaNDetected);
751        assert!(monitor.alerts[0].message.contains("NaN"));
752    }
753
754    #[test]
755    fn test_inf_detection_generates_critical_alert() {
756        let mut monitor = TrainingMonitor::new();
757        monitor.record_step(f32::INFINITY, &[("w1", 1.0)], 0.001);
758
759        assert_eq!(monitor.alerts.len(), 1);
760        assert_eq!(monitor.alerts[0].severity, AlertSeverity::Critical);
761        assert_eq!(monitor.alerts[0].kind, AlertKind::InfDetected);
762        assert!(monitor.alerts[0].message.contains("Infinity"));
763    }
764
765    #[test]
766    fn test_nan_in_grad_norm_detected() {
767        let mut monitor = TrainingMonitor::new();
768        monitor.record_step(0.5, &[("w1", f32::NAN)], 0.001);
769
770        let nan_alerts: Vec<_> = monitor
771            .alerts
772            .iter()
773            .filter(|a| a.kind == AlertKind::NaNDetected)
774            .collect();
775        assert_eq!(nan_alerts.len(), 1);
776        assert!(nan_alerts[0].message.contains("w1"));
777    }
778
779    #[test]
780    fn test_nan_check_disabled() {
781        let config = MonitorConfig {
782            nan_check: false,
783            ..MonitorConfig::default()
784        };
785        let mut monitor = TrainingMonitor::with_config(config);
786        monitor.record_step(f32::NAN, &[("w1", f32::NAN)], 0.001);
787
788        assert!(monitor.alerts.is_empty());
789    }
790
791    // -------------------------------------------------------------------------
792    // Gradient explosion/vanishing
793    // -------------------------------------------------------------------------
794
795    #[test]
796    fn test_gradient_explosion_detection() {
797        let mut monitor = TrainingMonitor::new();
798        monitor.record_step(0.5, &[("w1", 200.0)], 0.001);
799
800        let explosion_alerts: Vec<_> = monitor
801            .alerts
802            .iter()
803            .filter(|a| a.kind == AlertKind::GradientExplosion)
804            .collect();
805        assert_eq!(explosion_alerts.len(), 1);
806        assert_eq!(explosion_alerts[0].severity, AlertSeverity::Warning);
807    }
808
809    #[test]
810    fn test_gradient_vanishing_detection() {
811        let config = MonitorConfig {
812            window_size: 5,
813            ..MonitorConfig::default()
814        };
815        let mut monitor = TrainingMonitor::with_config(config);
816
817        // Record 10 steps with near-zero gradients to trigger vanishing alert
818        for _ in 0..10 {
819            monitor.record_step(0.5, &[("w1", 1e-10)], 0.001);
820        }
821
822        let vanishing_alerts: Vec<_> = monitor
823            .alerts
824            .iter()
825            .filter(|a| a.kind == AlertKind::GradientVanishing)
826            .collect();
827        assert!(!vanishing_alerts.is_empty());
828    }
829
830    #[test]
831    fn test_gradient_vanishing_resets_on_normal_grad() {
832        let mut monitor = TrainingMonitor::new();
833
834        // 5 vanishing steps
835        for _ in 0..5 {
836            monitor.record_step(0.5, &[("w1", 1e-10)], 0.001);
837        }
838        // Normal gradient resets the streak
839        monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
840        assert_eq!(monitor.vanishing_streak, 0);
841    }
842
843    // -------------------------------------------------------------------------
844    // Loss divergence
845    // -------------------------------------------------------------------------
846
847    #[test]
848    fn test_loss_divergence_detection() {
849        let config = MonitorConfig {
850            window_size: 10,
851            loss_divergence_factor: 2.0,
852            ..MonitorConfig::default()
853        };
854        let mut monitor = TrainingMonitor::with_config(config);
855
856        // Build up a history of stable loss
857        for _ in 0..10 {
858            monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
859        }
860
861        // Spike the loss to trigger divergence
862        monitor.record_step(100.0, &[("w1", 0.5)], 0.001);
863
864        let divergence_alerts: Vec<_> = monitor
865            .alerts
866            .iter()
867            .filter(|a| a.kind == AlertKind::LossDivergence)
868            .collect();
869        assert!(!divergence_alerts.is_empty());
870    }
871
872    // -------------------------------------------------------------------------
873    // Dead neuron tracking
874    // -------------------------------------------------------------------------
875
876    #[test]
877    fn test_dead_neuron_tracking() {
878        let config = MonitorConfig {
879            dead_neuron_threshold: 5,
880            ..MonitorConfig::default()
881        };
882        let mut monitor = TrainingMonitor::with_config(config);
883
884        for _ in 0..5 {
885            monitor.record_step(0.5, &[("dead_layer", 0.0), ("alive_layer", 0.5)], 0.001);
886        }
887
888        let dead_alerts: Vec<_> = monitor
889            .alerts
890            .iter()
891            .filter(|a| a.kind == AlertKind::DeadNeuron)
892            .collect();
893        assert_eq!(dead_alerts.len(), 1);
894        assert!(dead_alerts[0].message.contains("dead_layer"));
895    }
896
897    #[test]
898    fn test_dead_neuron_resets_on_nonzero_grad() {
899        let config = MonitorConfig {
900            dead_neuron_threshold: 10,
901            ..MonitorConfig::default()
902        };
903        let mut monitor = TrainingMonitor::with_config(config);
904
905        // 5 steps with zero grad
906        for _ in 0..5 {
907            monitor.record_step(0.5, &[("layer", 0.0)], 0.001);
908        }
909        // One step with nonzero grad resets the counter
910        monitor.record_step(0.5, &[("layer", 1.0)], 0.001);
911
912        assert_eq!(*monitor.zero_grad_counts.get("layer").unwrap(), 0);
913    }
914
915    // -------------------------------------------------------------------------
916    // Convergence detection
917    // -------------------------------------------------------------------------
918
919    #[test]
920    fn test_convergence_detection() {
921        let config = MonitorConfig {
922            window_size: 10,
923            convergence_threshold: 1e-4,
924            ..MonitorConfig::default()
925        };
926        let mut monitor = TrainingMonitor::with_config(config);
927
928        // Record identical losses
929        for _ in 0..10 {
930            monitor.record_step(0.001, &[("w1", 0.1)], 0.001);
931        }
932
933        let converged_alerts: Vec<_> = monitor
934            .alerts
935            .iter()
936            .filter(|a| a.kind == AlertKind::Converged)
937            .collect();
938        assert!(!converged_alerts.is_empty());
939    }
940
941    // -------------------------------------------------------------------------
942    // Loss trend analysis
943    // -------------------------------------------------------------------------
944
945    #[test]
946    fn test_loss_trend_decreasing() {
947        let config = MonitorConfig {
948            window_size: 5,
949            ..MonitorConfig::default()
950        };
951        let mut monitor = TrainingMonitor::with_config(config);
952
953        // First window: higher losses
954        for i in 0..5 {
955            monitor.record_step(2.0 - i as f32 * 0.01, &[("w1", 0.5)], 0.001);
956        }
957        // Second window: much lower losses
958        for i in 0..5 {
959            monitor.record_step(1.0 - i as f32 * 0.01, &[("w1", 0.5)], 0.001);
960        }
961
962        assert_eq!(monitor.loss_trend(), LossTrend::Decreasing);
963    }
964
965    #[test]
966    fn test_loss_trend_increasing() {
967        let config = MonitorConfig {
968            window_size: 5,
969            ..MonitorConfig::default()
970        };
971        let mut monitor = TrainingMonitor::with_config(config);
972
973        // First window: lower losses
974        for i in 0..5 {
975            monitor.record_step(1.0 + i as f32 * 0.01, &[("w1", 0.5)], 0.001);
976        }
977        // Second window: much higher losses
978        for i in 0..5 {
979            monitor.record_step(2.0 + i as f32 * 0.01, &[("w1", 0.5)], 0.001);
980        }
981
982        assert_eq!(monitor.loss_trend(), LossTrend::Increasing);
983    }
984
985    #[test]
986    fn test_loss_trend_oscillating() {
987        let config = MonitorConfig {
988            window_size: 10,
989            ..MonitorConfig::default()
990        };
991        let mut monitor = TrainingMonitor::with_config(config);
992
993        // First window: stable
994        for _ in 0..10 {
995            monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
996        }
997        // Second window: oscillating around same mean but with high variance
998        for i in 0..10 {
999            let loss = if i % 2 == 0 { 1.3 } else { 0.7 };
1000            monitor.record_step(loss, &[("w1", 0.5)], 0.001);
1001        }
1002
1003        assert_eq!(monitor.loss_trend(), LossTrend::Oscillating);
1004    }
1005
1006    #[test]
1007    fn test_loss_trend_stable() {
1008        let config = MonitorConfig {
1009            window_size: 5,
1010            ..MonitorConfig::default()
1011        };
1012        let mut monitor = TrainingMonitor::with_config(config);
1013
1014        // Two windows with nearly identical means and low variance
1015        for _ in 0..10 {
1016            monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
1017        }
1018
1019        assert_eq!(monitor.loss_trend(), LossTrend::Stable);
1020    }
1021
1022    #[test]
1023    fn test_loss_trend_unknown_insufficient_data() {
1024        let config = MonitorConfig {
1025            window_size: 100,
1026            ..MonitorConfig::default()
1027        };
1028        let mut monitor = TrainingMonitor::with_config(config);
1029
1030        monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
1031
1032        assert_eq!(monitor.loss_trend(), LossTrend::Unknown);
1033    }
1034
1035    // -------------------------------------------------------------------------
1036    // Health report
1037    // -------------------------------------------------------------------------
1038
1039    #[test]
1040    fn test_health_report_healthy_normal_training() {
1041        let config = MonitorConfig {
1042            window_size: 5,
1043            ..MonitorConfig::default()
1044        };
1045        let mut monitor = TrainingMonitor::with_config(config);
1046
1047        for i in 0..10 {
1048            monitor.record_step(1.0 - i as f32 * 0.05, &[("w1", 0.5)], 0.001);
1049        }
1050
1051        let report = monitor.check_health();
1052        assert!(report.is_healthy);
1053        assert_eq!(report.step, 10);
1054        assert_eq!(report.dead_neurons, 0);
1055    }
1056
1057    #[test]
1058    fn test_health_report_not_healthy_with_nan() {
1059        let mut monitor = TrainingMonitor::new();
1060        monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
1061
1062        let report = monitor.check_health();
1063        assert!(!report.is_healthy);
1064    }
1065
1066    // -------------------------------------------------------------------------
1067    // Learning rate suggestions
1068    // -------------------------------------------------------------------------
1069
1070    #[test]
1071    fn test_suggest_lr_exploding_gradients() {
1072        let config = MonitorConfig {
1073            window_size: 5,
1074            grad_norm_threshold: 10.0,
1075            ..MonitorConfig::default()
1076        };
1077        let mut monitor = TrainingMonitor::with_config(config);
1078
1079        // Record steps with exploding gradients
1080        for _ in 0..5 {
1081            monitor.record_step(1.0, &[("w1", 50.0)], 0.01);
1082        }
1083
1084        let suggested = monitor.suggest_lr();
1085        assert!(suggested.is_some());
1086        assert!((suggested.unwrap() - 0.001).abs() < 1e-6); // 0.01 * 0.1
1087    }
1088
1089    #[test]
1090    fn test_suggest_lr_oscillating_loss() {
1091        let config = MonitorConfig {
1092            window_size: 10,
1093            ..MonitorConfig::default()
1094        };
1095        let mut monitor = TrainingMonitor::with_config(config);
1096
1097        // First window stable
1098        for _ in 0..10 {
1099            monitor.record_step(1.0, &[("w1", 0.5)], 0.01);
1100        }
1101        // Second window oscillating
1102        for i in 0..10 {
1103            let loss = if i % 2 == 0 { 1.3 } else { 0.7 };
1104            monitor.record_step(loss, &[("w1", 0.5)], 0.01);
1105        }
1106
1107        let suggested = monitor.suggest_lr();
1108        assert!(suggested.is_some());
1109        assert!((suggested.unwrap() - 0.005).abs() < 1e-6); // 0.01 * 0.5
1110    }
1111
1112    #[test]
1113    fn test_suggest_lr_converged_returns_none() {
1114        let config = MonitorConfig {
1115            window_size: 5,
1116            convergence_threshold: 1e-4,
1117            ..MonitorConfig::default()
1118        };
1119        let mut monitor = TrainingMonitor::with_config(config);
1120
1121        // Two windows of identical loss (converged + stable)
1122        for _ in 0..10 {
1123            monitor.record_step(0.001, &[("w1", 0.01)], 0.001);
1124        }
1125
1126        // Stable trend with high convergence score -> should return None
1127        let trend = monitor.loss_trend();
1128        let conv = monitor.convergence_score();
1129        assert_eq!(trend, LossTrend::Stable);
1130        assert!(conv > 0.99);
1131        assert!(monitor.suggest_lr().is_none());
1132    }
1133
1134    // -------------------------------------------------------------------------
1135    // Convergence score
1136    // -------------------------------------------------------------------------
1137
1138    #[test]
1139    fn test_convergence_score_fully_converged() {
1140        let config = MonitorConfig {
1141            window_size: 10,
1142            ..MonitorConfig::default()
1143        };
1144        let mut monitor = TrainingMonitor::with_config(config);
1145
1146        for _ in 0..10 {
1147            monitor.record_step(0.5, &[("w1", 0.1)], 0.001);
1148        }
1149
1150        let score = monitor.convergence_score();
1151        assert!((score - 1.0).abs() < 1e-3, "Expected ~1.0, got {}", score);
1152    }
1153
1154    #[test]
1155    fn test_convergence_score_actively_changing() {
1156        let config = MonitorConfig {
1157            window_size: 10,
1158            ..MonitorConfig::default()
1159        };
1160        let mut monitor = TrainingMonitor::with_config(config);
1161
1162        for i in 0..10 {
1163            monitor.record_step(10.0 - i as f32 * 1.0, &[("w1", 0.5)], 0.001);
1164        }
1165
1166        let score = monitor.convergence_score();
1167        assert!(score < 0.5, "Expected low score, got {}", score);
1168    }
1169
1170    #[test]
1171    fn test_convergence_score_insufficient_data() {
1172        let monitor = TrainingMonitor::new();
1173        assert!((monitor.convergence_score() - 0.0).abs() < 1e-6);
1174    }
1175
1176    // -------------------------------------------------------------------------
1177    // Summary output
1178    // -------------------------------------------------------------------------
1179
1180    #[test]
1181    fn test_summary_contains_key_metrics() {
1182        let config = MonitorConfig {
1183            window_size: 5,
1184            ..MonitorConfig::default()
1185        };
1186        let mut monitor = TrainingMonitor::with_config(config);
1187
1188        for _ in 0..5 {
1189            monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
1190        }
1191
1192        let summary = monitor.summary();
1193        assert!(summary.contains("Training Health Report"));
1194        assert!(summary.contains("HEALTHY"));
1195        assert!(summary.contains("Loss:"));
1196        assert!(summary.contains("Grad norms:"));
1197        assert!(summary.contains("Convergence:"));
1198        assert!(summary.contains("Dead neurons:"));
1199    }
1200
1201    // -------------------------------------------------------------------------
1202    // Alert management
1203    // -------------------------------------------------------------------------
1204
1205    #[test]
1206    fn test_clear_alerts_empties_list() {
1207        let mut monitor = TrainingMonitor::new();
1208        monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
1209        assert!(!monitor.alerts().is_empty());
1210
1211        monitor.clear_alerts();
1212        assert!(monitor.alerts().is_empty());
1213    }
1214
1215    // -------------------------------------------------------------------------
1216    // Memory bounds
1217    // -------------------------------------------------------------------------
1218
1219    #[test]
1220    fn test_max_history_bounds_memory() {
1221        let config = MonitorConfig {
1222            max_history: 20,
1223            window_size: 5,
1224            ..MonitorConfig::default()
1225        };
1226        let mut monitor = TrainingMonitor::with_config(config);
1227
1228        for i in 0..50 {
1229            monitor.record_step(i as f32, &[("w1", 0.5)], 0.001);
1230        }
1231
1232        assert!(monitor.loss_history.len() <= 20);
1233        assert!(monitor.grad_norm_history.len() <= 20);
1234        assert!(monitor.lr_history.len() <= 20);
1235    }
1236
1237    // -------------------------------------------------------------------------
1238    // Grad norm stats
1239    // -------------------------------------------------------------------------
1240
1241    #[test]
1242    fn test_grad_norm_stats_computation() {
1243        let config = MonitorConfig {
1244            window_size: 4,
1245            ..MonitorConfig::default()
1246        };
1247        let mut monitor = TrainingMonitor::with_config(config);
1248
1249        // Record known gradient norms: 1, 2, 3, 4
1250        monitor.record_step(1.0, &[("w1", 1.0)], 0.001);
1251        monitor.record_step(1.0, &[("w1", 2.0)], 0.001);
1252        monitor.record_step(1.0, &[("w1", 3.0)], 0.001);
1253        monitor.record_step(1.0, &[("w1", 4.0)], 0.001);
1254
1255        let (mean, std, max) = monitor.grad_norm_stats();
1256        assert!(
1257            (mean - 2.5).abs() < 1e-4,
1258            "Expected mean ~2.5, got {}",
1259            mean
1260        );
1261        assert!((max - 4.0).abs() < 1e-4, "Expected max 4.0, got {}", max);
1262        // std of [1,2,3,4] = sqrt(1.25) ~= 1.118
1263        assert!(
1264            (std - 1.118).abs() < 0.01,
1265            "Expected std ~1.118, got {}",
1266            std
1267        );
1268    }
1269
1270    #[test]
1271    fn test_grad_norm_stats_empty() {
1272        let monitor = TrainingMonitor::new();
1273        let (mean, std, max) = monitor.grad_norm_stats();
1274        assert!((mean - 0.0).abs() < 1e-6);
1275        assert!((std - 0.0).abs() < 1e-6);
1276        assert!((max - 0.0).abs() < 1e-6);
1277    }
1278
1279    // -------------------------------------------------------------------------
1280    // Integration test
1281    // -------------------------------------------------------------------------
1282
1283    #[test]
1284    fn test_integration_100_step_improving_training() {
1285        let config = MonitorConfig {
1286            window_size: 20,
1287            ..MonitorConfig::default()
1288        };
1289        let mut monitor = TrainingMonitor::with_config(config);
1290
1291        // Simulate 100 steps of steadily improving loss
1292        for i in 0..100 {
1293            let loss = 2.0 * (-0.03 * i as f32).exp(); // Exponential decay
1294            let grad_norm = 1.0 * (-0.01 * i as f32).exp();
1295            let lr = 0.001;
1296            monitor.record_step(
1297                loss,
1298                &[
1299                    ("layer1.weight", grad_norm),
1300                    ("layer2.weight", grad_norm * 0.5),
1301                ],
1302                lr,
1303            );
1304        }
1305
1306        assert_eq!(monitor.step_count, 100);
1307        assert!(monitor.is_healthy());
1308
1309        let report = monitor.check_health();
1310        assert!(report.is_healthy);
1311        assert_eq!(report.step, 100);
1312        assert!(report.current_loss < 0.2); // Should have decayed significantly
1313        assert_eq!(report.dead_neurons, 0);
1314
1315        // Should detect decreasing trend
1316        let trend = monitor.loss_trend();
1317        assert_eq!(trend, LossTrend::Decreasing);
1318
1319        // No critical alerts should exist
1320        let critical_count = monitor
1321            .alerts
1322            .iter()
1323            .filter(|a| a.severity == AlertSeverity::Critical)
1324            .count();
1325        assert_eq!(critical_count, 0);
1326
1327        // Summary should be valid
1328        let summary = monitor.summary();
1329        assert!(summary.contains("HEALTHY"));
1330        assert!(summary.contains("step 100"));
1331    }
1332
1333    #[test]
1334    fn test_default_trait() {
1335        let monitor = TrainingMonitor::default();
1336        assert_eq!(monitor.step_count, 0);
1337        assert_eq!(monitor.config.window_size, 100);
1338    }
1339
1340    #[test]
1341    fn test_alert_display() {
1342        let alert = TrainingAlert {
1343            step: 42,
1344            severity: AlertSeverity::Critical,
1345            kind: AlertKind::NaNDetected,
1346            message: "NaN detected".to_string(),
1347        };
1348        let display = format!("{}", alert);
1349        assert!(display.contains("42"));
1350        assert!(display.contains("Critical"));
1351        assert!(display.contains("NaN"));
1352    }
1353
1354    #[test]
1355    fn test_multiple_parameters_grad_norms() {
1356        let mut monitor = TrainingMonitor::new();
1357
1358        // The max grad norm should be tracked (not mean)
1359        monitor.record_step(1.0, &[("w1", 5.0), ("w2", 10.0), ("w3", 3.0)], 0.001);
1360
1361        assert_eq!(monitor.grad_norm_history.len(), 1);
1362        assert!((monitor.grad_norm_history[0] - 10.0).abs() < 1e-6);
1363    }
1364}