Skip to main content

axonml_optim/
health.rs

1//! Training Health Monitor - Real-time Training Diagnostics
2//!
3//! Provides automated detection of pathological training behavior including
4//! gradient explosion/vanishing, loss divergence, dead neurons, and convergence
5//! analysis. Attaches to any optimizer for self-monitoring training loops.
6//!
7//! This is a genuinely novel feature -- PyTorch has no built-in training
8//! diagnostics. AxonML optimizers can self-monitor and detect problems
9//! before they waste hours of compute.
10//!
11//! # Example
12//!
13//! ```ignore
14//! use axonml_optim::health::{TrainingMonitor, MonitorConfig};
15//!
16//! let mut monitor = TrainingMonitor::new();
17//!
18//! for step in 0..1000 {
19//!     // ... training step ...
20//!     let grad_norms = vec![("layer1.weight", 0.5), ("layer2.weight", 0.3)];
21//!     monitor.record_step(loss_val, &grad_norms, lr);
22//!
23//!     if !monitor.is_healthy() {
24//!         eprintln!("{}", monitor.summary());
25//!         if let Some(suggested) = monitor.suggest_lr() {
26//!             optimizer.set_lr(suggested);
27//!         }
28//!     }
29//! }
30//! ```
31//!
32//! @version 0.1.0
33//! @author `AutomataNexus` Development Team
34
35use std::collections::HashMap;
36use std::fmt;
37
38// =============================================================================
39// Configuration
40// =============================================================================
41
42/// Configuration for the training monitor.
43#[derive(Debug, Clone)]
44pub struct MonitorConfig {
45    /// Rolling window size for statistics (default 100).
46    pub window_size: usize,
47    /// Alert if gradient norm exceeds this value (default 100.0).
48    pub grad_norm_threshold: f32,
49    /// Alert if loss exceeds factor * moving average (default 10.0).
50    pub loss_divergence_factor: f32,
51    /// Steps with zero gradient before alerting on dead neurons (default 50).
52    pub dead_neuron_threshold: usize,
53    /// Whether to check for NaN/Inf values (default true).
54    pub nan_check: bool,
55    /// Loss change below this threshold indicates convergence (default 1e-6).
56    pub convergence_threshold: f32,
57    /// Maximum history entries to retain in memory (default 1000).
58    pub max_history: usize,
59}
60
61impl Default for MonitorConfig {
62    fn default() -> Self {
63        Self {
64            window_size: 100,
65            grad_norm_threshold: 100.0,
66            loss_divergence_factor: 10.0,
67            dead_neuron_threshold: 50,
68            nan_check: true,
69            convergence_threshold: 1e-6,
70            max_history: 1000,
71        }
72    }
73}
74
75// =============================================================================
76// Alert Types
77// =============================================================================
78
79/// A training alert generated by the monitor.
80#[derive(Debug, Clone)]
81pub struct TrainingAlert {
82    /// The training step at which this alert was generated.
83    pub step: usize,
84    /// Severity level of the alert.
85    pub severity: AlertSeverity,
86    /// Category of the alert.
87    pub kind: AlertKind,
88    /// Human-readable description of the issue.
89    pub message: String,
90}
91
92impl fmt::Display for TrainingAlert {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        write!(
95            f,
96            "[step {}] {:?} {:?}: {}",
97            self.step, self.severity, self.kind, self.message
98        )
99    }
100}
101
102/// Severity level for training alerts.
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub enum AlertSeverity {
105    /// Informational -- no action required.
106    Info,
107    /// Warning -- training may be suboptimal.
108    Warning,
109    /// Critical -- training is likely broken, immediate action needed.
110    Critical,
111}
112
113/// Category of training alert.
114#[derive(Debug, Clone, Copy, PartialEq, Eq)]
115pub enum AlertKind {
116    /// NaN detected in loss or gradients.
117    NaNDetected,
118    /// Infinity detected in loss or gradients.
119    InfDetected,
120    /// Gradient norms are exploding (exceeding threshold).
121    GradientExplosion,
122    /// Gradient norms are vanishing (near zero for many steps).
123    GradientVanishing,
124    /// Loss is diverging from its moving average.
125    LossDivergence,
126    /// Loss has stagnated (not improving).
127    LossStagnation,
128    /// A parameter has had zero gradient for many steps (dead neuron).
129    DeadNeuron,
130    /// Learning rate appears too high based on loss behavior.
131    LearningRateTooHigh,
132    /// Learning rate appears too low based on loss behavior.
133    LearningRateTooLow,
134    /// Training has converged (loss is no longer changing).
135    Converged,
136}
137
138// =============================================================================
139// Health Report
140// =============================================================================
141
142/// Snapshot of current training health.
143#[derive(Debug, Clone)]
144pub struct HealthReport {
145    /// Whether training appears healthy overall.
146    pub is_healthy: bool,
147    /// Current training step.
148    pub step: usize,
149    /// Most recent loss value.
150    pub current_loss: f32,
151    /// Current loss trajectory.
152    pub loss_trend: LossTrend,
153    /// Mean gradient norm over the recent window.
154    pub mean_grad_norm: f32,
155    /// Maximum gradient norm over the recent window.
156    pub max_grad_norm: f32,
157    /// Convergence score from 0.0 (actively changing) to 1.0 (fully converged).
158    pub convergence_score: f32,
159    /// Currently active alerts.
160    pub active_alerts: Vec<TrainingAlert>,
161    /// Number of parameters detected as dead neurons.
162    pub dead_neurons: usize,
163}
164
165/// Loss trajectory classification.
166#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167pub enum LossTrend {
168    /// Loss is decreasing -- training is progressing.
169    Decreasing,
170    /// Loss is stable -- may be converged or stuck.
171    Stable,
172    /// Loss is increasing -- training may be diverging.
173    Increasing,
174    /// Loss is oscillating -- learning rate may be too high.
175    Oscillating,
176    /// Not enough data to determine trend.
177    Unknown,
178}
179
180// =============================================================================
181// Training Monitor
182// =============================================================================
183
184/// Real-time training health monitor.
185///
186/// Attaches to any optimizer to detect pathological training behavior
187/// including gradient explosion/vanishing, loss divergence, dead neurons,
188/// and convergence. Generates alerts with severity levels and provides
189/// actionable suggestions.
190pub struct TrainingMonitor {
191    /// Current training step count.
192    step_count: usize,
193    /// History of loss values.
194    loss_history: Vec<f32>,
195    /// History of overall gradient norms (max across all parameters per step).
196    grad_norm_history: Vec<f32>,
197    /// History of learning rate values.
198    lr_history: Vec<f32>,
199    /// Accumulated training alerts.
200    alerts: Vec<TrainingAlert>,
201    /// Monitor configuration.
202    config: MonitorConfig,
203    /// Per-parameter count of consecutive steps with zero gradient.
204    zero_grad_counts: HashMap<String, usize>,
205    /// Count of consecutive steps with vanishing gradients.
206    vanishing_streak: usize,
207}
208
209impl TrainingMonitor {
210    /// Creates a new training monitor with default configuration.
211    pub fn new() -> Self {
212        Self::with_config(MonitorConfig::default())
213    }
214
215    /// Creates a new training monitor with the given configuration.
216    pub fn with_config(config: MonitorConfig) -> Self {
217        Self {
218            step_count: 0,
219            loss_history: Vec::new(),
220            grad_norm_history: Vec::new(),
221            lr_history: Vec::new(),
222            alerts: Vec::new(),
223            config,
224            zero_grad_counts: HashMap::new(),
225            vanishing_streak: 0,
226        }
227    }
228
229    /// Records a single training step.
230    ///
231    /// # Arguments
232    ///
233    /// * `loss` - The loss value for this step.
234    /// * `grad_norms` - Slice of (parameter_name, gradient_norm) pairs.
235    /// * `lr` - The current learning rate.
236    pub fn record_step(&mut self, loss: f32, grad_norms: &[(&str, f32)], lr: f32) {
237        self.step_count += 1;
238        let step = self.step_count;
239
240        // Append to histories, respecting max_history
241        self.push_bounded(&mut self.loss_history.clone(), loss);
242        self.loss_history.push(loss);
243        if self.loss_history.len() > self.config.max_history {
244            self.loss_history.remove(0);
245        }
246
247        // Compute max gradient norm across all parameters
248        let max_grad_norm = grad_norms
249            .iter()
250            .map(|(_, n)| *n)
251            .fold(0.0_f32, f32::max);
252
253        self.grad_norm_history.push(max_grad_norm);
254        if self.grad_norm_history.len() > self.config.max_history {
255            self.grad_norm_history.remove(0);
256        }
257
258        self.lr_history.push(lr);
259        if self.lr_history.len() > self.config.max_history {
260            self.lr_history.remove(0);
261        }
262
263        // --- NaN/Inf checks ---
264        if self.config.nan_check {
265            if loss.is_nan() {
266                self.emit_alert(step, AlertSeverity::Critical, AlertKind::NaNDetected,
267                    "NaN detected in loss value".to_string());
268            } else if loss.is_infinite() {
269                self.emit_alert(step, AlertSeverity::Critical, AlertKind::InfDetected,
270                    "Infinity detected in loss value".to_string());
271            }
272
273            for (name, norm) in grad_norms {
274                if norm.is_nan() {
275                    self.emit_alert(step, AlertSeverity::Critical, AlertKind::NaNDetected,
276                        format!("NaN detected in gradient norm for '{}'", name));
277                } else if norm.is_infinite() {
278                    self.emit_alert(step, AlertSeverity::Critical, AlertKind::InfDetected,
279                        format!("Infinity detected in gradient norm for '{}'", name));
280                }
281            }
282        }
283
284        // --- Gradient explosion ---
285        if max_grad_norm > self.config.grad_norm_threshold && max_grad_norm.is_finite() {
286            self.emit_alert(step, AlertSeverity::Warning, AlertKind::GradientExplosion,
287                format!("Gradient norm {:.4} exceeds threshold {:.4}",
288                    max_grad_norm, self.config.grad_norm_threshold));
289        }
290
291        // --- Gradient vanishing ---
292        if max_grad_norm < 1e-8 && max_grad_norm.is_finite() {
293            self.vanishing_streak += 1;
294            if self.vanishing_streak >= 10 {
295                self.emit_alert(step, AlertSeverity::Warning, AlertKind::GradientVanishing,
296                    format!("Gradient norms near zero for {} consecutive steps",
297                        self.vanishing_streak));
298            }
299        } else {
300            self.vanishing_streak = 0;
301        }
302
303        // --- Dead neuron tracking ---
304        let dead_threshold = self.config.dead_neuron_threshold;
305        let mut new_dead_alerts: Vec<(String, usize)> = Vec::new();
306        for (name, norm) in grad_norms {
307            let count = self.zero_grad_counts.entry(name.to_string()).or_insert(0);
308            if *norm == 0.0 {
309                *count += 1;
310                if *count == dead_threshold {
311                    new_dead_alerts.push((name.to_string(), *count));
312                }
313            } else {
314                *count = 0;
315            }
316        }
317        for (name, count) in new_dead_alerts {
318            self.emit_alert(step, AlertSeverity::Warning, AlertKind::DeadNeuron,
319                format!("Parameter '{}' has had zero gradient for {} steps (dead neuron)",
320                    name, count));
321        }
322
323        // --- Loss divergence ---
324        if self.loss_history.len() >= self.config.window_size && loss.is_finite() {
325            let window_start = self.loss_history.len().saturating_sub(self.config.window_size);
326            let window = &self.loss_history[window_start..self.loss_history.len() - 1];
327            let finite_vals: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
328            if !finite_vals.is_empty() {
329                let avg: f32 = finite_vals.iter().sum::<f32>() / finite_vals.len() as f32;
330                if avg > 0.0 && loss > avg * self.config.loss_divergence_factor {
331                    self.emit_alert(step, AlertSeverity::Warning, AlertKind::LossDivergence,
332                        format!("Loss {:.6} diverged from moving average {:.6} (factor {:.1}x)",
333                            loss, avg, loss / avg));
334                }
335            }
336        }
337
338        // --- Convergence check ---
339        if self.loss_history.len() >= self.config.window_size {
340            let window_start = self.loss_history.len() - self.config.window_size;
341            let window = &self.loss_history[window_start..];
342            let finite_vals: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
343            if finite_vals.len() >= 2 {
344                let max_val = finite_vals.iter().copied().fold(f32::NEG_INFINITY, f32::max);
345                let min_val = finite_vals.iter().copied().fold(f32::INFINITY, f32::min);
346                let range = max_val - min_val;
347                if range < self.config.convergence_threshold {
348                    self.emit_alert(step, AlertSeverity::Info, AlertKind::Converged,
349                        format!("Training converged: loss range {:.2e} over last {} steps",
350                            range, self.config.window_size));
351                }
352            }
353        }
354    }
355
356    /// Returns a full health report for the current training state.
357    pub fn check_health(&self) -> HealthReport {
358        let (mean_gn, _std_gn, max_gn) = self.grad_norm_stats();
359        let trend = self.loss_trend();
360        let conv_score = self.convergence_score();
361
362        let current_loss = self.loss_history.last().copied().unwrap_or(f32::NAN);
363
364        // Count dead neurons
365        let dead_neurons = self.zero_grad_counts.values()
366            .filter(|c| **c >= self.config.dead_neuron_threshold)
367            .count();
368
369        // Determine overall health
370        let has_critical = self.alerts.iter().any(|a| a.severity == AlertSeverity::Critical);
371        let is_healthy = !has_critical
372            && trend != LossTrend::Increasing
373            && !current_loss.is_nan()
374            && !current_loss.is_infinite();
375
376        // Collect recent alerts (last window_size steps)
377        let min_step = self.step_count.saturating_sub(self.config.window_size);
378        let active_alerts: Vec<TrainingAlert> = self.alerts.iter()
379            .filter(|a| a.step > min_step)
380            .cloned()
381            .collect();
382
383        HealthReport {
384            is_healthy,
385            step: self.step_count,
386            current_loss,
387            loss_trend: trend,
388            mean_grad_norm: mean_gn,
389            max_grad_norm: max_gn,
390            convergence_score: conv_score,
391            active_alerts,
392            dead_neurons,
393        }
394    }
395
396    /// Returns true if training appears healthy (no critical alerts, loss not diverging).
397    pub fn is_healthy(&self) -> bool {
398        self.check_health().is_healthy
399    }
400
401    /// Returns the accumulated alerts.
402    pub fn alerts(&self) -> &[TrainingAlert] {
403        &self.alerts
404    }
405
406    /// Clears all accumulated alerts.
407    pub fn clear_alerts(&mut self) {
408        self.alerts.clear();
409    }
410
411    /// Analyzes the loss trajectory over the recent window.
412    ///
413    /// Compares the rolling average of the most recent window to the
414    /// rolling average of the previous window to classify the trend.
415    pub fn loss_trend(&self) -> LossTrend {
416        let w = self.config.window_size;
417        if self.loss_history.len() < w * 2 {
418            return LossTrend::Unknown;
419        }
420
421        let len = self.loss_history.len();
422        let recent = &self.loss_history[len - w..];
423        let previous = &self.loss_history[len - 2 * w..len - w];
424
425        let recent_finite: Vec<f32> = recent.iter().copied().filter(|v| v.is_finite()).collect();
426        let prev_finite: Vec<f32> = previous.iter().copied().filter(|v| v.is_finite()).collect();
427
428        if recent_finite.is_empty() || prev_finite.is_empty() {
429            return LossTrend::Unknown;
430        }
431
432        let recent_avg = recent_finite.iter().sum::<f32>() / recent_finite.len() as f32;
433        let prev_avg = prev_finite.iter().sum::<f32>() / prev_finite.len() as f32;
434
435        if prev_avg == 0.0 {
436            return LossTrend::Unknown;
437        }
438
439        let ratio = recent_avg / prev_avg;
440
441        // Check for oscillation: high variance in recent window
442        let recent_mean = recent_avg;
443        let recent_var = recent_finite.iter()
444            .map(|v| (v - recent_mean).powi(2))
445            .sum::<f32>() / recent_finite.len() as f32;
446        let recent_std = recent_var.sqrt();
447        let cv = if recent_mean.abs() > 1e-12 { recent_std / recent_mean.abs() } else { 0.0 };
448
449        if ratio < 0.95 {
450            LossTrend::Decreasing
451        } else if ratio > 1.05 {
452            LossTrend::Increasing
453        } else if cv > 0.1 {
454            LossTrend::Oscillating
455        } else {
456            LossTrend::Stable
457        }
458    }
459
460    /// Suggests a learning rate adjustment based on current training dynamics.
461    ///
462    /// Returns `None` if no adjustment is needed or training has converged.
463    pub fn suggest_lr(&self) -> Option<f32> {
464        let current_lr = self.lr_history.last().copied()?;
465        let trend = self.loss_trend();
466
467        // Check for gradient explosion
468        let (_, _, max_gn) = self.grad_norm_stats();
469        if max_gn > self.config.grad_norm_threshold && max_gn.is_finite() {
470            return Some(current_lr * 0.1);
471        }
472
473        match trend {
474            LossTrend::Oscillating => Some(current_lr * 0.5),
475            LossTrend::Stable => {
476                // Check if truly converged or just stagnating
477                let conv = self.convergence_score();
478                if conv > 0.99 {
479                    None // Converged, no adjustment needed
480                } else {
481                    Some(current_lr * 2.0) // Stagnating, try higher LR
482                }
483            }
484            LossTrend::Increasing => Some(current_lr * 0.1),
485            _ => None,
486        }
487    }
488
489    /// Returns (mean, std, max) of gradient norms over the recent window.
490    pub fn grad_norm_stats(&self) -> (f32, f32, f32) {
491        if self.grad_norm_history.is_empty() {
492            return (0.0, 0.0, 0.0);
493        }
494
495        let w = self.config.window_size.min(self.grad_norm_history.len());
496        let start = self.grad_norm_history.len() - w;
497        let window = &self.grad_norm_history[start..];
498
499        let finite: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
500        if finite.is_empty() {
501            return (0.0, 0.0, 0.0);
502        }
503
504        let n = finite.len() as f32;
505        let mean = finite.iter().sum::<f32>() / n;
506        let variance = finite.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n;
507        let std = variance.sqrt();
508        let max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
509
510        (mean, std, max)
511    }
512
513    /// Returns a convergence score between 0.0 and 1.0.
514    ///
515    /// 1.0 indicates full convergence (no loss change over the window).
516    /// 0.0 indicates the loss is still actively changing.
517    pub fn convergence_score(&self) -> f32 {
518        let w = self.config.window_size;
519        if self.loss_history.len() < w {
520            return 0.0;
521        }
522
523        let start = self.loss_history.len() - w;
524        let window = &self.loss_history[start..];
525
526        let finite: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
527        if finite.len() < 2 {
528            return 0.0;
529        }
530
531        let max_val = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
532        let min_val = finite.iter().copied().fold(f32::INFINITY, f32::min);
533        let range = max_val - min_val;
534
535        let mean = finite.iter().sum::<f32>() / finite.len() as f32;
536        if mean.abs() < 1e-12 {
537            // If mean is essentially zero, check absolute range
538            if range < self.config.convergence_threshold {
539                return 1.0;
540            }
541            return 0.0;
542        }
543
544        // Relative range (normalized by mean)
545        let relative_range = range / mean.abs();
546
547        // Map relative_range to [0, 1] score: 0 range -> 1.0, large range -> 0.0
548        // Use exponential decay for smooth mapping
549        let score = (-relative_range * 100.0).exp();
550        score.clamp(0.0, 1.0)
551    }
552
553    /// Returns a human-readable summary of the current training state.
554    pub fn summary(&self) -> String {
555        let report = self.check_health();
556        let (mean_gn, std_gn, max_gn) = self.grad_norm_stats();
557
558        let mut s = String::new();
559        s.push_str(&format!("=== Training Health Report (step {}) ===\n", report.step));
560        s.push_str(&format!("Status: {}\n", if report.is_healthy { "HEALTHY" } else { "UNHEALTHY" }));
561        s.push_str(&format!("Loss: {:.6} (trend: {:?})\n", report.current_loss, report.loss_trend));
562        s.push_str(&format!("Grad norms: mean={:.4}, std={:.4}, max={:.4}\n", mean_gn, std_gn, max_gn));
563        s.push_str(&format!("Convergence: {:.2}%\n", report.convergence_score * 100.0));
564        s.push_str(&format!("Dead neurons: {}\n", report.dead_neurons));
565
566        if !report.active_alerts.is_empty() {
567            s.push_str(&format!("Active alerts ({}):\n", report.active_alerts.len()));
568            for alert in &report.active_alerts {
569                s.push_str(&format!("  {}\n", alert));
570            }
571        }
572
573        if let Some(lr) = self.suggest_lr() {
574            s.push_str(&format!("Suggested LR: {:.6}\n", lr));
575        }
576
577        s
578    }
579
580    // =========================================================================
581    // Private helpers
582    // =========================================================================
583
584    fn push_bounded(&self, _history: &mut Vec<f32>, _value: f32) {
585        // Intentionally left as no-op: actual push + trim is done inline
586        // in record_step for each history vector individually.
587    }
588
589    fn emit_alert(&mut self, step: usize, severity: AlertSeverity, kind: AlertKind, message: String) {
590        self.alerts.push(TrainingAlert {
591            step,
592            severity,
593            kind,
594            message,
595        });
596    }
597}
598
599impl Default for TrainingMonitor {
600    fn default() -> Self {
601        Self::new()
602    }
603}
604
605// =============================================================================
606// Tests
607// =============================================================================
608
609#[cfg(test)]
610mod tests {
611    use super::*;
612
613    // -------------------------------------------------------------------------
614    // Construction
615    // -------------------------------------------------------------------------
616
617    #[test]
618    fn test_monitor_creation_defaults() {
619        let monitor = TrainingMonitor::new();
620        assert_eq!(monitor.step_count, 0);
621        assert!(monitor.loss_history.is_empty());
622        assert!(monitor.alerts.is_empty());
623        assert_eq!(monitor.config.window_size, 100);
624        assert!((monitor.config.grad_norm_threshold - 100.0).abs() < 1e-6);
625        assert!(monitor.config.nan_check);
626    }
627
628    #[test]
629    fn test_monitor_with_custom_config() {
630        let config = MonitorConfig {
631            window_size: 50,
632            grad_norm_threshold: 50.0,
633            loss_divergence_factor: 5.0,
634            dead_neuron_threshold: 20,
635            nan_check: false,
636            convergence_threshold: 1e-5,
637            max_history: 500,
638        };
639        let monitor = TrainingMonitor::with_config(config);
640        assert_eq!(monitor.config.window_size, 50);
641        assert!((monitor.config.grad_norm_threshold - 50.0).abs() < 1e-6);
642        assert!(!monitor.config.nan_check);
643        assert_eq!(monitor.config.max_history, 500);
644    }
645
646    // -------------------------------------------------------------------------
647    // Recording steps
648    // -------------------------------------------------------------------------
649
650    #[test]
651    fn test_record_step_updates_state() {
652        let mut monitor = TrainingMonitor::new();
653        monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
654
655        assert_eq!(monitor.step_count, 1);
656        assert_eq!(monitor.loss_history.len(), 1);
657        assert_eq!(monitor.grad_norm_history.len(), 1);
658        assert_eq!(monitor.lr_history.len(), 1);
659        assert!((monitor.loss_history[0] - 0.5).abs() < 1e-6);
660    }
661
662    // -------------------------------------------------------------------------
663    // NaN/Inf detection
664    // -------------------------------------------------------------------------
665
666    #[test]
667    fn test_nan_detection_generates_critical_alert() {
668        let mut monitor = TrainingMonitor::new();
669        monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
670
671        assert_eq!(monitor.alerts.len(), 1);
672        assert_eq!(monitor.alerts[0].severity, AlertSeverity::Critical);
673        assert_eq!(monitor.alerts[0].kind, AlertKind::NaNDetected);
674        assert!(monitor.alerts[0].message.contains("NaN"));
675    }
676
677    #[test]
678    fn test_inf_detection_generates_critical_alert() {
679        let mut monitor = TrainingMonitor::new();
680        monitor.record_step(f32::INFINITY, &[("w1", 1.0)], 0.001);
681
682        assert_eq!(monitor.alerts.len(), 1);
683        assert_eq!(monitor.alerts[0].severity, AlertSeverity::Critical);
684        assert_eq!(monitor.alerts[0].kind, AlertKind::InfDetected);
685        assert!(monitor.alerts[0].message.contains("Infinity"));
686    }
687
688    #[test]
689    fn test_nan_in_grad_norm_detected() {
690        let mut monitor = TrainingMonitor::new();
691        monitor.record_step(0.5, &[("w1", f32::NAN)], 0.001);
692
693        let nan_alerts: Vec<_> = monitor.alerts.iter()
694            .filter(|a| a.kind == AlertKind::NaNDetected)
695            .collect();
696        assert_eq!(nan_alerts.len(), 1);
697        assert!(nan_alerts[0].message.contains("w1"));
698    }
699
700    #[test]
701    fn test_nan_check_disabled() {
702        let config = MonitorConfig {
703            nan_check: false,
704            ..MonitorConfig::default()
705        };
706        let mut monitor = TrainingMonitor::with_config(config);
707        monitor.record_step(f32::NAN, &[("w1", f32::NAN)], 0.001);
708
709        assert!(monitor.alerts.is_empty());
710    }
711
712    // -------------------------------------------------------------------------
713    // Gradient explosion/vanishing
714    // -------------------------------------------------------------------------
715
716    #[test]
717    fn test_gradient_explosion_detection() {
718        let mut monitor = TrainingMonitor::new();
719        monitor.record_step(0.5, &[("w1", 200.0)], 0.001);
720
721        let explosion_alerts: Vec<_> = monitor.alerts.iter()
722            .filter(|a| a.kind == AlertKind::GradientExplosion)
723            .collect();
724        assert_eq!(explosion_alerts.len(), 1);
725        assert_eq!(explosion_alerts[0].severity, AlertSeverity::Warning);
726    }
727
728    #[test]
729    fn test_gradient_vanishing_detection() {
730        let config = MonitorConfig {
731            window_size: 5,
732            ..MonitorConfig::default()
733        };
734        let mut monitor = TrainingMonitor::with_config(config);
735
736        // Record 10 steps with near-zero gradients to trigger vanishing alert
737        for _ in 0..10 {
738            monitor.record_step(0.5, &[("w1", 1e-10)], 0.001);
739        }
740
741        let vanishing_alerts: Vec<_> = monitor.alerts.iter()
742            .filter(|a| a.kind == AlertKind::GradientVanishing)
743            .collect();
744        assert!(!vanishing_alerts.is_empty());
745    }
746
747    #[test]
748    fn test_gradient_vanishing_resets_on_normal_grad() {
749        let mut monitor = TrainingMonitor::new();
750
751        // 5 vanishing steps
752        for _ in 0..5 {
753            monitor.record_step(0.5, &[("w1", 1e-10)], 0.001);
754        }
755        // Normal gradient resets the streak
756        monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
757        assert_eq!(monitor.vanishing_streak, 0);
758    }
759
760    // -------------------------------------------------------------------------
761    // Loss divergence
762    // -------------------------------------------------------------------------
763
764    #[test]
765    fn test_loss_divergence_detection() {
766        let config = MonitorConfig {
767            window_size: 10,
768            loss_divergence_factor: 2.0,
769            ..MonitorConfig::default()
770        };
771        let mut monitor = TrainingMonitor::with_config(config);
772
773        // Build up a history of stable loss
774        for _ in 0..10 {
775            monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
776        }
777
778        // Spike the loss to trigger divergence
779        monitor.record_step(100.0, &[("w1", 0.5)], 0.001);
780
781        let divergence_alerts: Vec<_> = monitor.alerts.iter()
782            .filter(|a| a.kind == AlertKind::LossDivergence)
783            .collect();
784        assert!(!divergence_alerts.is_empty());
785    }
786
787    // -------------------------------------------------------------------------
788    // Dead neuron tracking
789    // -------------------------------------------------------------------------
790
791    #[test]
792    fn test_dead_neuron_tracking() {
793        let config = MonitorConfig {
794            dead_neuron_threshold: 5,
795            ..MonitorConfig::default()
796        };
797        let mut monitor = TrainingMonitor::with_config(config);
798
799        for _ in 0..5 {
800            monitor.record_step(0.5, &[("dead_layer", 0.0), ("alive_layer", 0.5)], 0.001);
801        }
802
803        let dead_alerts: Vec<_> = monitor.alerts.iter()
804            .filter(|a| a.kind == AlertKind::DeadNeuron)
805            .collect();
806        assert_eq!(dead_alerts.len(), 1);
807        assert!(dead_alerts[0].message.contains("dead_layer"));
808    }
809
810    #[test]
811    fn test_dead_neuron_resets_on_nonzero_grad() {
812        let config = MonitorConfig {
813            dead_neuron_threshold: 10,
814            ..MonitorConfig::default()
815        };
816        let mut monitor = TrainingMonitor::with_config(config);
817
818        // 5 steps with zero grad
819        for _ in 0..5 {
820            monitor.record_step(0.5, &[("layer", 0.0)], 0.001);
821        }
822        // One step with nonzero grad resets the counter
823        monitor.record_step(0.5, &[("layer", 1.0)], 0.001);
824
825        assert_eq!(*monitor.zero_grad_counts.get("layer").unwrap(), 0);
826    }
827
828    // -------------------------------------------------------------------------
829    // Convergence detection
830    // -------------------------------------------------------------------------
831
832    #[test]
833    fn test_convergence_detection() {
834        let config = MonitorConfig {
835            window_size: 10,
836            convergence_threshold: 1e-4,
837            ..MonitorConfig::default()
838        };
839        let mut monitor = TrainingMonitor::with_config(config);
840
841        // Record identical losses
842        for _ in 0..10 {
843            monitor.record_step(0.001, &[("w1", 0.1)], 0.001);
844        }
845
846        let converged_alerts: Vec<_> = monitor.alerts.iter()
847            .filter(|a| a.kind == AlertKind::Converged)
848            .collect();
849        assert!(!converged_alerts.is_empty());
850    }
851
852    // -------------------------------------------------------------------------
853    // Loss trend analysis
854    // -------------------------------------------------------------------------
855
856    #[test]
857    fn test_loss_trend_decreasing() {
858        let config = MonitorConfig {
859            window_size: 5,
860            ..MonitorConfig::default()
861        };
862        let mut monitor = TrainingMonitor::with_config(config);
863
864        // First window: higher losses
865        for i in 0..5 {
866            monitor.record_step(2.0 - i as f32 * 0.01, &[("w1", 0.5)], 0.001);
867        }
868        // Second window: much lower losses
869        for i in 0..5 {
870            monitor.record_step(1.0 - i as f32 * 0.01, &[("w1", 0.5)], 0.001);
871        }
872
873        assert_eq!(monitor.loss_trend(), LossTrend::Decreasing);
874    }
875
876    #[test]
877    fn test_loss_trend_increasing() {
878        let config = MonitorConfig {
879            window_size: 5,
880            ..MonitorConfig::default()
881        };
882        let mut monitor = TrainingMonitor::with_config(config);
883
884        // First window: lower losses
885        for i in 0..5 {
886            monitor.record_step(1.0 + i as f32 * 0.01, &[("w1", 0.5)], 0.001);
887        }
888        // Second window: much higher losses
889        for i in 0..5 {
890            monitor.record_step(2.0 + i as f32 * 0.01, &[("w1", 0.5)], 0.001);
891        }
892
893        assert_eq!(monitor.loss_trend(), LossTrend::Increasing);
894    }
895
896    #[test]
897    fn test_loss_trend_oscillating() {
898        let config = MonitorConfig {
899            window_size: 10,
900            ..MonitorConfig::default()
901        };
902        let mut monitor = TrainingMonitor::with_config(config);
903
904        // First window: stable
905        for _ in 0..10 {
906            monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
907        }
908        // Second window: oscillating around same mean but with high variance
909        for i in 0..10 {
910            let loss = if i % 2 == 0 { 1.3 } else { 0.7 };
911            monitor.record_step(loss, &[("w1", 0.5)], 0.001);
912        }
913
914        assert_eq!(monitor.loss_trend(), LossTrend::Oscillating);
915    }
916
917    #[test]
918    fn test_loss_trend_stable() {
919        let config = MonitorConfig {
920            window_size: 5,
921            ..MonitorConfig::default()
922        };
923        let mut monitor = TrainingMonitor::with_config(config);
924
925        // Two windows with nearly identical means and low variance
926        for _ in 0..10 {
927            monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
928        }
929
930        assert_eq!(monitor.loss_trend(), LossTrend::Stable);
931    }
932
933    #[test]
934    fn test_loss_trend_unknown_insufficient_data() {
935        let config = MonitorConfig {
936            window_size: 100,
937            ..MonitorConfig::default()
938        };
939        let mut monitor = TrainingMonitor::with_config(config);
940
941        monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
942
943        assert_eq!(monitor.loss_trend(), LossTrend::Unknown);
944    }
945
946    // -------------------------------------------------------------------------
947    // Health report
948    // -------------------------------------------------------------------------
949
950    #[test]
951    fn test_health_report_healthy_normal_training() {
952        let config = MonitorConfig {
953            window_size: 5,
954            ..MonitorConfig::default()
955        };
956        let mut monitor = TrainingMonitor::with_config(config);
957
958        for i in 0..10 {
959            monitor.record_step(1.0 - i as f32 * 0.05, &[("w1", 0.5)], 0.001);
960        }
961
962        let report = monitor.check_health();
963        assert!(report.is_healthy);
964        assert_eq!(report.step, 10);
965        assert_eq!(report.dead_neurons, 0);
966    }
967
968    #[test]
969    fn test_health_report_not_healthy_with_nan() {
970        let mut monitor = TrainingMonitor::new();
971        monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
972
973        let report = monitor.check_health();
974        assert!(!report.is_healthy);
975    }
976
977    // -------------------------------------------------------------------------
978    // Learning rate suggestions
979    // -------------------------------------------------------------------------
980
981    #[test]
982    fn test_suggest_lr_exploding_gradients() {
983        let config = MonitorConfig {
984            window_size: 5,
985            grad_norm_threshold: 10.0,
986            ..MonitorConfig::default()
987        };
988        let mut monitor = TrainingMonitor::with_config(config);
989
990        // Record steps with exploding gradients
991        for _ in 0..5 {
992            monitor.record_step(1.0, &[("w1", 50.0)], 0.01);
993        }
994
995        let suggested = monitor.suggest_lr();
996        assert!(suggested.is_some());
997        assert!((suggested.unwrap() - 0.001).abs() < 1e-6); // 0.01 * 0.1
998    }
999
1000    #[test]
1001    fn test_suggest_lr_oscillating_loss() {
1002        let config = MonitorConfig {
1003            window_size: 10,
1004            ..MonitorConfig::default()
1005        };
1006        let mut monitor = TrainingMonitor::with_config(config);
1007
1008        // First window stable
1009        for _ in 0..10 {
1010            monitor.record_step(1.0, &[("w1", 0.5)], 0.01);
1011        }
1012        // Second window oscillating
1013        for i in 0..10 {
1014            let loss = if i % 2 == 0 { 1.3 } else { 0.7 };
1015            monitor.record_step(loss, &[("w1", 0.5)], 0.01);
1016        }
1017
1018        let suggested = monitor.suggest_lr();
1019        assert!(suggested.is_some());
1020        assert!((suggested.unwrap() - 0.005).abs() < 1e-6); // 0.01 * 0.5
1021    }
1022
1023    #[test]
1024    fn test_suggest_lr_converged_returns_none() {
1025        let config = MonitorConfig {
1026            window_size: 5,
1027            convergence_threshold: 1e-4,
1028            ..MonitorConfig::default()
1029        };
1030        let mut monitor = TrainingMonitor::with_config(config);
1031
1032        // Two windows of identical loss (converged + stable)
1033        for _ in 0..10 {
1034            monitor.record_step(0.001, &[("w1", 0.01)], 0.001);
1035        }
1036
1037        // Stable trend with high convergence score -> should return None
1038        let trend = monitor.loss_trend();
1039        let conv = monitor.convergence_score();
1040        assert_eq!(trend, LossTrend::Stable);
1041        assert!(conv > 0.99);
1042        assert!(monitor.suggest_lr().is_none());
1043    }
1044
1045    // -------------------------------------------------------------------------
1046    // Convergence score
1047    // -------------------------------------------------------------------------
1048
1049    #[test]
1050    fn test_convergence_score_fully_converged() {
1051        let config = MonitorConfig {
1052            window_size: 10,
1053            ..MonitorConfig::default()
1054        };
1055        let mut monitor = TrainingMonitor::with_config(config);
1056
1057        for _ in 0..10 {
1058            monitor.record_step(0.5, &[("w1", 0.1)], 0.001);
1059        }
1060
1061        let score = monitor.convergence_score();
1062        assert!((score - 1.0).abs() < 1e-3, "Expected ~1.0, got {}", score);
1063    }
1064
1065    #[test]
1066    fn test_convergence_score_actively_changing() {
1067        let config = MonitorConfig {
1068            window_size: 10,
1069            ..MonitorConfig::default()
1070        };
1071        let mut monitor = TrainingMonitor::with_config(config);
1072
1073        for i in 0..10 {
1074            monitor.record_step(10.0 - i as f32 * 1.0, &[("w1", 0.5)], 0.001);
1075        }
1076
1077        let score = monitor.convergence_score();
1078        assert!(score < 0.5, "Expected low score, got {}", score);
1079    }
1080
1081    #[test]
1082    fn test_convergence_score_insufficient_data() {
1083        let monitor = TrainingMonitor::new();
1084        assert!((monitor.convergence_score() - 0.0).abs() < 1e-6);
1085    }
1086
1087    // -------------------------------------------------------------------------
1088    // Summary output
1089    // -------------------------------------------------------------------------
1090
1091    #[test]
1092    fn test_summary_contains_key_metrics() {
1093        let config = MonitorConfig {
1094            window_size: 5,
1095            ..MonitorConfig::default()
1096        };
1097        let mut monitor = TrainingMonitor::with_config(config);
1098
1099        for _ in 0..5 {
1100            monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
1101        }
1102
1103        let summary = monitor.summary();
1104        assert!(summary.contains("Training Health Report"));
1105        assert!(summary.contains("HEALTHY"));
1106        assert!(summary.contains("Loss:"));
1107        assert!(summary.contains("Grad norms:"));
1108        assert!(summary.contains("Convergence:"));
1109        assert!(summary.contains("Dead neurons:"));
1110    }
1111
1112    // -------------------------------------------------------------------------
1113    // Alert management
1114    // -------------------------------------------------------------------------
1115
1116    #[test]
1117    fn test_clear_alerts_empties_list() {
1118        let mut monitor = TrainingMonitor::new();
1119        monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
1120        assert!(!monitor.alerts().is_empty());
1121
1122        monitor.clear_alerts();
1123        assert!(monitor.alerts().is_empty());
1124    }
1125
1126    // -------------------------------------------------------------------------
1127    // Memory bounds
1128    // -------------------------------------------------------------------------
1129
1130    #[test]
1131    fn test_max_history_bounds_memory() {
1132        let config = MonitorConfig {
1133            max_history: 20,
1134            window_size: 5,
1135            ..MonitorConfig::default()
1136        };
1137        let mut monitor = TrainingMonitor::with_config(config);
1138
1139        for i in 0..50 {
1140            monitor.record_step(i as f32, &[("w1", 0.5)], 0.001);
1141        }
1142
1143        assert!(monitor.loss_history.len() <= 20);
1144        assert!(monitor.grad_norm_history.len() <= 20);
1145        assert!(monitor.lr_history.len() <= 20);
1146    }
1147
1148    // -------------------------------------------------------------------------
1149    // Grad norm stats
1150    // -------------------------------------------------------------------------
1151
1152    #[test]
1153    fn test_grad_norm_stats_computation() {
1154        let config = MonitorConfig {
1155            window_size: 4,
1156            ..MonitorConfig::default()
1157        };
1158        let mut monitor = TrainingMonitor::with_config(config);
1159
1160        // Record known gradient norms: 1, 2, 3, 4
1161        monitor.record_step(1.0, &[("w1", 1.0)], 0.001);
1162        monitor.record_step(1.0, &[("w1", 2.0)], 0.001);
1163        monitor.record_step(1.0, &[("w1", 3.0)], 0.001);
1164        monitor.record_step(1.0, &[("w1", 4.0)], 0.001);
1165
1166        let (mean, std, max) = monitor.grad_norm_stats();
1167        assert!((mean - 2.5).abs() < 1e-4, "Expected mean ~2.5, got {}", mean);
1168        assert!((max - 4.0).abs() < 1e-4, "Expected max 4.0, got {}", max);
1169        // std of [1,2,3,4] = sqrt(1.25) ~= 1.118
1170        assert!((std - 1.118).abs() < 0.01, "Expected std ~1.118, got {}", std);
1171    }
1172
1173    #[test]
1174    fn test_grad_norm_stats_empty() {
1175        let monitor = TrainingMonitor::new();
1176        let (mean, std, max) = monitor.grad_norm_stats();
1177        assert!((mean - 0.0).abs() < 1e-6);
1178        assert!((std - 0.0).abs() < 1e-6);
1179        assert!((max - 0.0).abs() < 1e-6);
1180    }
1181
1182    // -------------------------------------------------------------------------
1183    // Integration test
1184    // -------------------------------------------------------------------------
1185
1186    #[test]
1187    fn test_integration_100_step_improving_training() {
1188        let config = MonitorConfig {
1189            window_size: 20,
1190            ..MonitorConfig::default()
1191        };
1192        let mut monitor = TrainingMonitor::with_config(config);
1193
1194        // Simulate 100 steps of steadily improving loss
1195        for i in 0..100 {
1196            let loss = 2.0 * (-0.03 * i as f32).exp(); // Exponential decay
1197            let grad_norm = 1.0 * (-0.01 * i as f32).exp();
1198            let lr = 0.001;
1199            monitor.record_step(loss, &[("layer1.weight", grad_norm), ("layer2.weight", grad_norm * 0.5)], lr);
1200        }
1201
1202        assert_eq!(monitor.step_count, 100);
1203        assert!(monitor.is_healthy());
1204
1205        let report = monitor.check_health();
1206        assert!(report.is_healthy);
1207        assert_eq!(report.step, 100);
1208        assert!(report.current_loss < 0.2); // Should have decayed significantly
1209        assert_eq!(report.dead_neurons, 0);
1210
1211        // Should detect decreasing trend
1212        let trend = monitor.loss_trend();
1213        assert_eq!(trend, LossTrend::Decreasing);
1214
1215        // No critical alerts should exist
1216        let critical_count = monitor.alerts.iter()
1217            .filter(|a| a.severity == AlertSeverity::Critical)
1218            .count();
1219        assert_eq!(critical_count, 0);
1220
1221        // Summary should be valid
1222        let summary = monitor.summary();
1223        assert!(summary.contains("HEALTHY"));
1224        assert!(summary.contains("step 100"));
1225    }
1226
1227    #[test]
1228    fn test_default_trait() {
1229        let monitor = TrainingMonitor::default();
1230        assert_eq!(monitor.step_count, 0);
1231        assert_eq!(monitor.config.window_size, 100);
1232    }
1233
1234    #[test]
1235    fn test_alert_display() {
1236        let alert = TrainingAlert {
1237            step: 42,
1238            severity: AlertSeverity::Critical,
1239            kind: AlertKind::NaNDetected,
1240            message: "NaN detected".to_string(),
1241        };
1242        let display = format!("{}", alert);
1243        assert!(display.contains("42"));
1244        assert!(display.contains("Critical"));
1245        assert!(display.contains("NaN"));
1246    }
1247
1248    #[test]
1249    fn test_multiple_parameters_grad_norms() {
1250        let mut monitor = TrainingMonitor::new();
1251
1252        // The max grad norm should be tracked (not mean)
1253        monitor.record_step(1.0, &[("w1", 5.0), ("w2", 10.0), ("w3", 3.0)], 0.001);
1254
1255        assert_eq!(monitor.grad_norm_history.len(), 1);
1256        assert!((monitor.grad_norm_history[0] - 10.0).abs() < 1e-6);
1257    }
1258}