Skip to main content

axonml_optim/
health.rs

1//! Training Health Monitor — a novel AxonML feature for real-time diagnostics.
2//!
3//! Attaches to any optimizer and monitors gradient norms, loss trends, and
4//! parameter statistics each step. Detects: NaN/Inf gradients, gradient
5//! explosion (norm > threshold), vanishing gradients (norm < epsilon), dead
6//! neurons (zero activation percentage), loss trend classification
7//! (decreasing/stable/increasing/oscillating), and convergence scoring.
8//! Provides automatic learning rate suggestions based on the detected issues.
9//!
10//! # File
11//! `crates/axonml-optim/src/health.rs`
12//!
13//! # Author
14//! Andrew Jewell Sr. — AutomataNexus LLC
15//! ORCID: 0009-0005-2158-7060
16//!
17//! # Updated
18//! April 14, 2026 11:15 PM EST
19//!
20//! # Disclaimer
21//! Use at own risk. This software is provided "as is", without warranty of any
22//! kind, express or implied. The author and AutomataNexus shall not be held
23//! liable for any damages arising from the use of this software.
24
25use std::collections::HashMap;
26use std::fmt;
27
28// =============================================================================
29// Configuration
30// =============================================================================
31
32/// Configuration for the training monitor.
33#[derive(Debug, Clone)]
34pub struct MonitorConfig {
35    /// Rolling window size for statistics (default 100).
36    pub window_size: usize,
37    /// Alert if gradient norm exceeds this value (default 100.0).
38    pub grad_norm_threshold: f32,
39    /// Alert if loss exceeds factor * moving average (default 10.0).
40    pub loss_divergence_factor: f32,
41    /// Steps with zero gradient before alerting on dead neurons (default 50).
42    pub dead_neuron_threshold: usize,
43    /// Whether to check for NaN/Inf values (default true).
44    pub nan_check: bool,
45    /// Loss change below this threshold indicates convergence (default 1e-6).
46    pub convergence_threshold: f32,
47    /// Maximum history entries to retain in memory (default 1000).
48    pub max_history: usize,
49}
50
51impl Default for MonitorConfig {
52    fn default() -> Self {
53        Self {
54            window_size: 100,
55            grad_norm_threshold: 100.0,
56            loss_divergence_factor: 10.0,
57            dead_neuron_threshold: 50,
58            nan_check: true,
59            convergence_threshold: 1e-6,
60            max_history: 1000,
61        }
62    }
63}
64
65// =============================================================================
66// Alert Types
67// =============================================================================
68
69/// A training alert generated by the monitor.
70#[derive(Debug, Clone)]
71pub struct TrainingAlert {
72    /// The training step at which this alert was generated.
73    pub step: usize,
74    /// Severity level of the alert.
75    pub severity: AlertSeverity,
76    /// Category of the alert.
77    pub kind: AlertKind,
78    /// Human-readable description of the issue.
79    pub message: String,
80}
81
82impl fmt::Display for TrainingAlert {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        write!(
85            f,
86            "[step {}] {:?} {:?}: {}",
87            self.step, self.severity, self.kind, self.message
88        )
89    }
90}
91
92/// Severity level for training alerts.
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub enum AlertSeverity {
95    /// Informational -- no action required.
96    Info,
97    /// Warning -- training may be suboptimal.
98    Warning,
99    /// Critical -- training is likely broken, immediate action needed.
100    Critical,
101}
102
103/// Category of training alert.
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
105pub enum AlertKind {
106    /// NaN detected in loss or gradients.
107    NaNDetected,
108    /// Infinity detected in loss or gradients.
109    InfDetected,
110    /// Gradient norms are exploding (exceeding threshold).
111    GradientExplosion,
112    /// Gradient norms are vanishing (near zero for many steps).
113    GradientVanishing,
114    /// Loss is diverging from its moving average.
115    LossDivergence,
116    /// Loss has stagnated (not improving).
117    LossStagnation,
118    /// A parameter has had zero gradient for many steps (dead neuron).
119    DeadNeuron,
120    /// Learning rate appears too high based on loss behavior.
121    LearningRateTooHigh,
122    /// Learning rate appears too low based on loss behavior.
123    LearningRateTooLow,
124    /// Training has converged (loss is no longer changing).
125    Converged,
126}
127
128// =============================================================================
129// Health Report
130// =============================================================================
131
132/// Snapshot of current training health.
133#[derive(Debug, Clone)]
134pub struct HealthReport {
135    /// Whether training appears healthy overall.
136    pub is_healthy: bool,
137    /// Current training step.
138    pub step: usize,
139    /// Most recent loss value.
140    pub current_loss: f32,
141    /// Current loss trajectory.
142    pub loss_trend: LossTrend,
143    /// Mean gradient norm over the recent window.
144    pub mean_grad_norm: f32,
145    /// Maximum gradient norm over the recent window.
146    pub max_grad_norm: f32,
147    /// Convergence score from 0.0 (actively changing) to 1.0 (fully converged).
148    pub convergence_score: f32,
149    /// Currently active alerts.
150    pub active_alerts: Vec<TrainingAlert>,
151    /// Number of parameters detected as dead neurons.
152    pub dead_neurons: usize,
153}
154
155/// Loss trajectory classification.
156#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub enum LossTrend {
158    /// Loss is decreasing -- training is progressing.
159    Decreasing,
160    /// Loss is stable -- may be converged or stuck.
161    Stable,
162    /// Loss is increasing -- training may be diverging.
163    Increasing,
164    /// Loss is oscillating -- learning rate may be too high.
165    Oscillating,
166    /// Not enough data to determine trend.
167    Unknown,
168}
169
170// =============================================================================
171// Training Monitor
172// =============================================================================
173
174/// Real-time training health monitor.
175///
176/// Attaches to any optimizer to detect pathological training behavior
177/// including gradient explosion/vanishing, loss divergence, dead neurons,
178/// and convergence. Generates alerts with severity levels and provides
179/// actionable suggestions.
180pub struct TrainingMonitor {
181    /// Current training step count.
182    step_count: usize,
183    /// History of loss values.
184    loss_history: Vec<f32>,
185    /// History of overall gradient norms (max across all parameters per step).
186    grad_norm_history: Vec<f32>,
187    /// History of learning rate values.
188    lr_history: Vec<f32>,
189    /// Accumulated training alerts.
190    alerts: Vec<TrainingAlert>,
191    /// Monitor configuration.
192    config: MonitorConfig,
193    /// Per-parameter count of consecutive steps with zero gradient.
194    zero_grad_counts: HashMap<String, usize>,
195    /// Count of consecutive steps with vanishing gradients.
196    vanishing_streak: usize,
197}
198
199impl TrainingMonitor {
200    /// Creates a new training monitor with default configuration.
201    pub fn new() -> Self {
202        Self::with_config(MonitorConfig::default())
203    }
204
205    /// Creates a new training monitor with the given configuration.
206    pub fn with_config(config: MonitorConfig) -> Self {
207        Self {
208            step_count: 0,
209            loss_history: Vec::new(),
210            grad_norm_history: Vec::new(),
211            lr_history: Vec::new(),
212            alerts: Vec::new(),
213            config,
214            zero_grad_counts: HashMap::new(),
215            vanishing_streak: 0,
216        }
217    }
218
219    /// Records a single training step.
220    ///
221    /// # Arguments
222    ///
223    /// * `loss` - The loss value for this step.
224    /// * `grad_norms` - Slice of (parameter_name, gradient_norm) pairs.
225    /// * `lr` - The current learning rate.
226    pub fn record_step(&mut self, loss: f32, grad_norms: &[(&str, f32)], lr: f32) {
227        self.step_count += 1;
228        let step = self.step_count;
229
230        // Append to histories, respecting max_history
231        self.push_bounded(&mut self.loss_history.clone(), loss);
232        self.loss_history.push(loss);
233        if self.loss_history.len() > self.config.max_history {
234            self.loss_history.remove(0);
235        }
236
237        // Compute max gradient norm across all parameters
238        let max_grad_norm = grad_norms.iter().map(|(_, n)| *n).fold(0.0_f32, f32::max);
239
240        self.grad_norm_history.push(max_grad_norm);
241        if self.grad_norm_history.len() > self.config.max_history {
242            self.grad_norm_history.remove(0);
243        }
244
245        self.lr_history.push(lr);
246        if self.lr_history.len() > self.config.max_history {
247            self.lr_history.remove(0);
248        }
249
250        // --- NaN/Inf checks ---
251        if self.config.nan_check {
252            if loss.is_nan() {
253                self.emit_alert(
254                    step,
255                    AlertSeverity::Critical,
256                    AlertKind::NaNDetected,
257                    "NaN detected in loss value".to_string(),
258                );
259            } else if loss.is_infinite() {
260                self.emit_alert(
261                    step,
262                    AlertSeverity::Critical,
263                    AlertKind::InfDetected,
264                    "Infinity detected in loss value".to_string(),
265                );
266            }
267
268            for (name, norm) in grad_norms {
269                if norm.is_nan() {
270                    self.emit_alert(
271                        step,
272                        AlertSeverity::Critical,
273                        AlertKind::NaNDetected,
274                        format!("NaN detected in gradient norm for '{}'", name),
275                    );
276                } else if norm.is_infinite() {
277                    self.emit_alert(
278                        step,
279                        AlertSeverity::Critical,
280                        AlertKind::InfDetected,
281                        format!("Infinity detected in gradient norm for '{}'", name),
282                    );
283                }
284            }
285        }
286
287        // --- Gradient explosion ---
288        if max_grad_norm > self.config.grad_norm_threshold && max_grad_norm.is_finite() {
289            self.emit_alert(
290                step,
291                AlertSeverity::Warning,
292                AlertKind::GradientExplosion,
293                format!(
294                    "Gradient norm {:.4} exceeds threshold {:.4}",
295                    max_grad_norm, self.config.grad_norm_threshold
296                ),
297            );
298        }
299
300        // --- Gradient vanishing ---
301        if max_grad_norm < 1e-8 && max_grad_norm.is_finite() {
302            self.vanishing_streak += 1;
303            if self.vanishing_streak >= 10 {
304                self.emit_alert(
305                    step,
306                    AlertSeverity::Warning,
307                    AlertKind::GradientVanishing,
308                    format!(
309                        "Gradient norms near zero for {} consecutive steps",
310                        self.vanishing_streak
311                    ),
312                );
313            }
314        } else {
315            self.vanishing_streak = 0;
316        }
317
318        // --- Dead neuron tracking ---
319        let dead_threshold = self.config.dead_neuron_threshold;
320        let mut new_dead_alerts: Vec<(String, usize)> = Vec::new();
321        for (name, norm) in grad_norms {
322            let count = self
323                .zero_grad_counts
324                .entry((*name).to_string())
325                .or_insert(0);
326            if *norm == 0.0 {
327                *count += 1;
328                if *count == dead_threshold {
329                    new_dead_alerts.push(((*name).to_string(), *count));
330                }
331            } else {
332                *count = 0;
333            }
334        }
335        for (name, count) in new_dead_alerts {
336            self.emit_alert(
337                step,
338                AlertSeverity::Warning,
339                AlertKind::DeadNeuron,
340                format!(
341                    "Parameter '{}' has had zero gradient for {} steps (dead neuron)",
342                    name, count
343                ),
344            );
345        }
346
347        // --- Loss divergence ---
348        if self.loss_history.len() >= self.config.window_size && loss.is_finite() {
349            let window_start = self
350                .loss_history
351                .len()
352                .saturating_sub(self.config.window_size);
353            let window = &self.loss_history[window_start..self.loss_history.len() - 1];
354            let finite_vals: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
355            if !finite_vals.is_empty() {
356                let avg: f32 = finite_vals.iter().sum::<f32>() / finite_vals.len() as f32;
357                if avg > 0.0 && loss > avg * self.config.loss_divergence_factor {
358                    self.emit_alert(
359                        step,
360                        AlertSeverity::Warning,
361                        AlertKind::LossDivergence,
362                        format!(
363                            "Loss {:.6} diverged from moving average {:.6} (factor {:.1}x)",
364                            loss,
365                            avg,
366                            loss / avg
367                        ),
368                    );
369                }
370            }
371        }
372
373        // --- Convergence check ---
374        if self.loss_history.len() >= self.config.window_size {
375            let window_start = self.loss_history.len() - self.config.window_size;
376            let window = &self.loss_history[window_start..];
377            let finite_vals: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
378            if finite_vals.len() >= 2 {
379                let max_val = finite_vals
380                    .iter()
381                    .copied()
382                    .fold(f32::NEG_INFINITY, f32::max);
383                let min_val = finite_vals.iter().copied().fold(f32::INFINITY, f32::min);
384                let range = max_val - min_val;
385                if range < self.config.convergence_threshold {
386                    self.emit_alert(
387                        step,
388                        AlertSeverity::Info,
389                        AlertKind::Converged,
390                        format!(
391                            "Training converged: loss range {:.2e} over last {} steps",
392                            range, self.config.window_size
393                        ),
394                    );
395                }
396            }
397        }
398    }
399
400    /// Returns a full health report for the current training state.
401    pub fn check_health(&self) -> HealthReport {
402        let (mean_gn, _std_gn, max_gn) = self.grad_norm_stats();
403        let trend = self.loss_trend();
404        let conv_score = self.convergence_score();
405
406        let current_loss = self.loss_history.last().copied().unwrap_or(f32::NAN);
407
408        // Count dead neurons
409        let dead_neurons = self
410            .zero_grad_counts
411            .values()
412            .filter(|c| **c >= self.config.dead_neuron_threshold)
413            .count();
414
415        // Determine overall health
416        let has_critical = self
417            .alerts
418            .iter()
419            .any(|a| a.severity == AlertSeverity::Critical);
420        let is_healthy = !has_critical
421            && trend != LossTrend::Increasing
422            && !current_loss.is_nan()
423            && !current_loss.is_infinite();
424
425        // Collect recent alerts (last window_size steps)
426        let min_step = self.step_count.saturating_sub(self.config.window_size);
427        let active_alerts: Vec<TrainingAlert> = self
428            .alerts
429            .iter()
430            .filter(|a| a.step > min_step)
431            .cloned()
432            .collect();
433
434        HealthReport {
435            is_healthy,
436            step: self.step_count,
437            current_loss,
438            loss_trend: trend,
439            mean_grad_norm: mean_gn,
440            max_grad_norm: max_gn,
441            convergence_score: conv_score,
442            active_alerts,
443            dead_neurons,
444        }
445    }
446
447    /// Returns true if training appears healthy (no critical alerts, loss not diverging).
448    pub fn is_healthy(&self) -> bool {
449        self.check_health().is_healthy
450    }
451
452    /// Returns the accumulated alerts.
453    pub fn alerts(&self) -> &[TrainingAlert] {
454        &self.alerts
455    }
456
457    /// Clears all accumulated alerts.
458    pub fn clear_alerts(&mut self) {
459        self.alerts.clear();
460    }
461
462    /// Analyzes the loss trajectory over the recent window.
463    ///
464    /// Compares the rolling average of the most recent window to the
465    /// rolling average of the previous window to classify the trend.
466    pub fn loss_trend(&self) -> LossTrend {
467        let w = self.config.window_size;
468        if self.loss_history.len() < w * 2 {
469            return LossTrend::Unknown;
470        }
471
472        let len = self.loss_history.len();
473        let recent = &self.loss_history[len - w..];
474        let previous = &self.loss_history[len - 2 * w..len - w];
475
476        let recent_finite: Vec<f32> = recent.iter().copied().filter(|v| v.is_finite()).collect();
477        let prev_finite: Vec<f32> = previous.iter().copied().filter(|v| v.is_finite()).collect();
478
479        if recent_finite.is_empty() || prev_finite.is_empty() {
480            return LossTrend::Unknown;
481        }
482
483        let recent_avg = recent_finite.iter().sum::<f32>() / recent_finite.len() as f32;
484        let prev_avg = prev_finite.iter().sum::<f32>() / prev_finite.len() as f32;
485
486        if prev_avg == 0.0 {
487            return LossTrend::Unknown;
488        }
489
490        let ratio = recent_avg / prev_avg;
491
492        // Check for oscillation: high variance in recent window
493        let recent_mean = recent_avg;
494        let recent_var = recent_finite
495            .iter()
496            .map(|v| (v - recent_mean).powi(2))
497            .sum::<f32>()
498            / recent_finite.len() as f32;
499        let recent_std = recent_var.sqrt();
500        let cv = if recent_mean.abs() > 1e-12 {
501            recent_std / recent_mean.abs()
502        } else {
503            0.0
504        };
505
506        if ratio < 0.95 {
507            LossTrend::Decreasing
508        } else if ratio > 1.05 {
509            LossTrend::Increasing
510        } else if cv > 0.1 {
511            LossTrend::Oscillating
512        } else {
513            LossTrend::Stable
514        }
515    }
516
517    /// Suggests a learning rate adjustment based on current training dynamics.
518    ///
519    /// Returns `None` if no adjustment is needed or training has converged.
520    pub fn suggest_lr(&self) -> Option<f32> {
521        let current_lr = self.lr_history.last().copied()?;
522        let trend = self.loss_trend();
523
524        // Check for gradient explosion
525        let (_, _, max_gn) = self.grad_norm_stats();
526        if max_gn > self.config.grad_norm_threshold && max_gn.is_finite() {
527            return Some(current_lr * 0.1);
528        }
529
530        match trend {
531            LossTrend::Oscillating => Some(current_lr * 0.5),
532            LossTrend::Stable => {
533                // Check if truly converged or just stagnating
534                let conv = self.convergence_score();
535                if conv > 0.99 {
536                    None // Converged, no adjustment needed
537                } else {
538                    Some(current_lr * 2.0) // Stagnating, try higher LR
539                }
540            }
541            LossTrend::Increasing => Some(current_lr * 0.1),
542            _ => None,
543        }
544    }
545
546    /// Returns (mean, std, max) of gradient norms over the recent window.
547    pub fn grad_norm_stats(&self) -> (f32, f32, f32) {
548        if self.grad_norm_history.is_empty() {
549            return (0.0, 0.0, 0.0);
550        }
551
552        let w = self.config.window_size.min(self.grad_norm_history.len());
553        let start = self.grad_norm_history.len() - w;
554        let window = &self.grad_norm_history[start..];
555
556        let finite: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
557        if finite.is_empty() {
558            return (0.0, 0.0, 0.0);
559        }
560
561        let n = finite.len() as f32;
562        let mean = finite.iter().sum::<f32>() / n;
563        let variance = finite.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n;
564        let std = variance.sqrt();
565        let max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
566
567        (mean, std, max)
568    }
569
570    /// Returns a convergence score between 0.0 and 1.0.
571    ///
572    /// 1.0 indicates full convergence (no loss change over the window).
573    /// 0.0 indicates the loss is still actively changing.
574    pub fn convergence_score(&self) -> f32 {
575        let w = self.config.window_size;
576        if self.loss_history.len() < w {
577            return 0.0;
578        }
579
580        let start = self.loss_history.len() - w;
581        let window = &self.loss_history[start..];
582
583        let finite: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
584        if finite.len() < 2 {
585            return 0.0;
586        }
587
588        let max_val = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
589        let min_val = finite.iter().copied().fold(f32::INFINITY, f32::min);
590        let range = max_val - min_val;
591
592        let mean = finite.iter().sum::<f32>() / finite.len() as f32;
593        if mean.abs() < 1e-12 {
594            // If mean is essentially zero, check absolute range
595            if range < self.config.convergence_threshold {
596                return 1.0;
597            }
598            return 0.0;
599        }
600
601        // Relative range (normalized by mean)
602        let relative_range = range / mean.abs();
603
604        // Map relative_range to [0, 1] score: 0 range -> 1.0, large range -> 0.0
605        // Use exponential decay for smooth mapping
606        let score = (-relative_range * 100.0).exp();
607        score.clamp(0.0, 1.0)
608    }
609
610    /// Returns a human-readable summary of the current training state.
611    pub fn summary(&self) -> String {
612        let report = self.check_health();
613        let (mean_gn, std_gn, max_gn) = self.grad_norm_stats();
614
615        let mut s = String::new();
616        s.push_str(&format!(
617            "=== Training Health Report (step {}) ===\n",
618            report.step
619        ));
620        s.push_str(&format!(
621            "Status: {}\n",
622            if report.is_healthy {
623                "HEALTHY"
624            } else {
625                "UNHEALTHY"
626            }
627        ));
628        s.push_str(&format!(
629            "Loss: {:.6} (trend: {:?})\n",
630            report.current_loss, report.loss_trend
631        ));
632        s.push_str(&format!(
633            "Grad norms: mean={:.4}, std={:.4}, max={:.4}\n",
634            mean_gn, std_gn, max_gn
635        ));
636        s.push_str(&format!(
637            "Convergence: {:.2}%\n",
638            report.convergence_score * 100.0
639        ));
640        s.push_str(&format!("Dead neurons: {}\n", report.dead_neurons));
641
642        if !report.active_alerts.is_empty() {
643            s.push_str(&format!(
644                "Active alerts ({}):\n",
645                report.active_alerts.len()
646            ));
647            for alert in &report.active_alerts {
648                s.push_str(&format!("  {}\n", alert));
649            }
650        }
651
652        if let Some(lr) = self.suggest_lr() {
653            s.push_str(&format!("Suggested LR: {:.6}\n", lr));
654        }
655
656        s
657    }
658
659    // =========================================================================
660    // Private helpers
661    // =========================================================================
662
663    fn push_bounded(&self, _history: &mut Vec<f32>, _value: f32) {
664        // Intentionally left as no-op: actual push + trim is done inline
665        // in record_step for each history vector individually.
666    }
667
668    fn emit_alert(
669        &mut self,
670        step: usize,
671        severity: AlertSeverity,
672        kind: AlertKind,
673        message: String,
674    ) {
675        self.alerts.push(TrainingAlert {
676            step,
677            severity,
678            kind,
679            message,
680        });
681    }
682}
683
684impl Default for TrainingMonitor {
685    fn default() -> Self {
686        Self::new()
687    }
688}
689
690// =============================================================================
691// Tests
692// =============================================================================
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697
698    // -------------------------------------------------------------------------
699    // Construction
700    // -------------------------------------------------------------------------
701
702    #[test]
703    fn test_monitor_creation_defaults() {
704        let monitor = TrainingMonitor::new();
705        assert_eq!(monitor.step_count, 0);
706        assert!(monitor.loss_history.is_empty());
707        assert!(monitor.alerts.is_empty());
708        assert_eq!(monitor.config.window_size, 100);
709        assert!((monitor.config.grad_norm_threshold - 100.0).abs() < 1e-6);
710        assert!(monitor.config.nan_check);
711    }
712
713    #[test]
714    fn test_monitor_with_custom_config() {
715        let config = MonitorConfig {
716            window_size: 50,
717            grad_norm_threshold: 50.0,
718            loss_divergence_factor: 5.0,
719            dead_neuron_threshold: 20,
720            nan_check: false,
721            convergence_threshold: 1e-5,
722            max_history: 500,
723        };
724        let monitor = TrainingMonitor::with_config(config);
725        assert_eq!(monitor.config.window_size, 50);
726        assert!((monitor.config.grad_norm_threshold - 50.0).abs() < 1e-6);
727        assert!(!monitor.config.nan_check);
728        assert_eq!(monitor.config.max_history, 500);
729    }
730
731    // -------------------------------------------------------------------------
732    // Recording steps
733    // -------------------------------------------------------------------------
734
735    #[test]
736    fn test_record_step_updates_state() {
737        let mut monitor = TrainingMonitor::new();
738        monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
739
740        assert_eq!(monitor.step_count, 1);
741        assert_eq!(monitor.loss_history.len(), 1);
742        assert_eq!(monitor.grad_norm_history.len(), 1);
743        assert_eq!(monitor.lr_history.len(), 1);
744        assert!((monitor.loss_history[0] - 0.5).abs() < 1e-6);
745    }
746
747    // -------------------------------------------------------------------------
748    // NaN/Inf detection
749    // -------------------------------------------------------------------------
750
751    #[test]
752    fn test_nan_detection_generates_critical_alert() {
753        let mut monitor = TrainingMonitor::new();
754        monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
755
756        assert_eq!(monitor.alerts.len(), 1);
757        assert_eq!(monitor.alerts[0].severity, AlertSeverity::Critical);
758        assert_eq!(monitor.alerts[0].kind, AlertKind::NaNDetected);
759        assert!(monitor.alerts[0].message.contains("NaN"));
760    }
761
762    #[test]
763    fn test_inf_detection_generates_critical_alert() {
764        let mut monitor = TrainingMonitor::new();
765        monitor.record_step(f32::INFINITY, &[("w1", 1.0)], 0.001);
766
767        assert_eq!(monitor.alerts.len(), 1);
768        assert_eq!(monitor.alerts[0].severity, AlertSeverity::Critical);
769        assert_eq!(monitor.alerts[0].kind, AlertKind::InfDetected);
770        assert!(monitor.alerts[0].message.contains("Infinity"));
771    }
772
773    #[test]
774    fn test_nan_in_grad_norm_detected() {
775        let mut monitor = TrainingMonitor::new();
776        monitor.record_step(0.5, &[("w1", f32::NAN)], 0.001);
777
778        let nan_alerts: Vec<_> = monitor
779            .alerts
780            .iter()
781            .filter(|a| a.kind == AlertKind::NaNDetected)
782            .collect();
783        assert_eq!(nan_alerts.len(), 1);
784        assert!(nan_alerts[0].message.contains("w1"));
785    }
786
787    #[test]
788    fn test_nan_check_disabled() {
789        let config = MonitorConfig {
790            nan_check: false,
791            ..MonitorConfig::default()
792        };
793        let mut monitor = TrainingMonitor::with_config(config);
794        monitor.record_step(f32::NAN, &[("w1", f32::NAN)], 0.001);
795
796        assert!(monitor.alerts.is_empty());
797    }
798
799    // -------------------------------------------------------------------------
800    // Gradient explosion/vanishing
801    // -------------------------------------------------------------------------
802
803    #[test]
804    fn test_gradient_explosion_detection() {
805        let mut monitor = TrainingMonitor::new();
806        monitor.record_step(0.5, &[("w1", 200.0)], 0.001);
807
808        let explosion_alerts: Vec<_> = monitor
809            .alerts
810            .iter()
811            .filter(|a| a.kind == AlertKind::GradientExplosion)
812            .collect();
813        assert_eq!(explosion_alerts.len(), 1);
814        assert_eq!(explosion_alerts[0].severity, AlertSeverity::Warning);
815    }
816
817    #[test]
818    fn test_gradient_vanishing_detection() {
819        let config = MonitorConfig {
820            window_size: 5,
821            ..MonitorConfig::default()
822        };
823        let mut monitor = TrainingMonitor::with_config(config);
824
825        // Record 10 steps with near-zero gradients to trigger vanishing alert
826        for _ in 0..10 {
827            monitor.record_step(0.5, &[("w1", 1e-10)], 0.001);
828        }
829
830        let vanishing_alerts: Vec<_> = monitor
831            .alerts
832            .iter()
833            .filter(|a| a.kind == AlertKind::GradientVanishing)
834            .collect();
835        assert!(!vanishing_alerts.is_empty());
836    }
837
838    #[test]
839    fn test_gradient_vanishing_resets_on_normal_grad() {
840        let mut monitor = TrainingMonitor::new();
841
842        // 5 vanishing steps
843        for _ in 0..5 {
844            monitor.record_step(0.5, &[("w1", 1e-10)], 0.001);
845        }
846        // Normal gradient resets the streak
847        monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
848        assert_eq!(monitor.vanishing_streak, 0);
849    }
850
851    // -------------------------------------------------------------------------
852    // Loss divergence
853    // -------------------------------------------------------------------------
854
855    #[test]
856    fn test_loss_divergence_detection() {
857        let config = MonitorConfig {
858            window_size: 10,
859            loss_divergence_factor: 2.0,
860            ..MonitorConfig::default()
861        };
862        let mut monitor = TrainingMonitor::with_config(config);
863
864        // Build up a history of stable loss
865        for _ in 0..10 {
866            monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
867        }
868
869        // Spike the loss to trigger divergence
870        monitor.record_step(100.0, &[("w1", 0.5)], 0.001);
871
872        let divergence_alerts: Vec<_> = monitor
873            .alerts
874            .iter()
875            .filter(|a| a.kind == AlertKind::LossDivergence)
876            .collect();
877        assert!(!divergence_alerts.is_empty());
878    }
879
880    // -------------------------------------------------------------------------
881    // Dead neuron tracking
882    // -------------------------------------------------------------------------
883
884    #[test]
885    fn test_dead_neuron_tracking() {
886        let config = MonitorConfig {
887            dead_neuron_threshold: 5,
888            ..MonitorConfig::default()
889        };
890        let mut monitor = TrainingMonitor::with_config(config);
891
892        for _ in 0..5 {
893            monitor.record_step(0.5, &[("dead_layer", 0.0), ("alive_layer", 0.5)], 0.001);
894        }
895
896        let dead_alerts: Vec<_> = monitor
897            .alerts
898            .iter()
899            .filter(|a| a.kind == AlertKind::DeadNeuron)
900            .collect();
901        assert_eq!(dead_alerts.len(), 1);
902        assert!(dead_alerts[0].message.contains("dead_layer"));
903    }
904
905    #[test]
906    fn test_dead_neuron_resets_on_nonzero_grad() {
907        let config = MonitorConfig {
908            dead_neuron_threshold: 10,
909            ..MonitorConfig::default()
910        };
911        let mut monitor = TrainingMonitor::with_config(config);
912
913        // 5 steps with zero grad
914        for _ in 0..5 {
915            monitor.record_step(0.5, &[("layer", 0.0)], 0.001);
916        }
917        // One step with nonzero grad resets the counter
918        monitor.record_step(0.5, &[("layer", 1.0)], 0.001);
919
920        assert_eq!(*monitor.zero_grad_counts.get("layer").unwrap(), 0);
921    }
922
923    // -------------------------------------------------------------------------
924    // Convergence detection
925    // -------------------------------------------------------------------------
926
927    #[test]
928    fn test_convergence_detection() {
929        let config = MonitorConfig {
930            window_size: 10,
931            convergence_threshold: 1e-4,
932            ..MonitorConfig::default()
933        };
934        let mut monitor = TrainingMonitor::with_config(config);
935
936        // Record identical losses
937        for _ in 0..10 {
938            monitor.record_step(0.001, &[("w1", 0.1)], 0.001);
939        }
940
941        let converged_alerts: Vec<_> = monitor
942            .alerts
943            .iter()
944            .filter(|a| a.kind == AlertKind::Converged)
945            .collect();
946        assert!(!converged_alerts.is_empty());
947    }
948
949    // -------------------------------------------------------------------------
950    // Loss trend analysis
951    // -------------------------------------------------------------------------
952
953    #[test]
954    fn test_loss_trend_decreasing() {
955        let config = MonitorConfig {
956            window_size: 5,
957            ..MonitorConfig::default()
958        };
959        let mut monitor = TrainingMonitor::with_config(config);
960
961        // First window: higher losses
962        for i in 0..5 {
963            monitor.record_step(2.0 - i as f32 * 0.01, &[("w1", 0.5)], 0.001);
964        }
965        // Second window: much lower losses
966        for i in 0..5 {
967            monitor.record_step(1.0 - i as f32 * 0.01, &[("w1", 0.5)], 0.001);
968        }
969
970        assert_eq!(monitor.loss_trend(), LossTrend::Decreasing);
971    }
972
973    #[test]
974    fn test_loss_trend_increasing() {
975        let config = MonitorConfig {
976            window_size: 5,
977            ..MonitorConfig::default()
978        };
979        let mut monitor = TrainingMonitor::with_config(config);
980
981        // First window: lower losses
982        for i in 0..5 {
983            monitor.record_step(1.0 + i as f32 * 0.01, &[("w1", 0.5)], 0.001);
984        }
985        // Second window: much higher losses
986        for i in 0..5 {
987            monitor.record_step(2.0 + i as f32 * 0.01, &[("w1", 0.5)], 0.001);
988        }
989
990        assert_eq!(monitor.loss_trend(), LossTrend::Increasing);
991    }
992
993    #[test]
994    fn test_loss_trend_oscillating() {
995        let config = MonitorConfig {
996            window_size: 10,
997            ..MonitorConfig::default()
998        };
999        let mut monitor = TrainingMonitor::with_config(config);
1000
1001        // First window: stable
1002        for _ in 0..10 {
1003            monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
1004        }
1005        // Second window: oscillating around same mean but with high variance
1006        for i in 0..10 {
1007            let loss = if i % 2 == 0 { 1.3 } else { 0.7 };
1008            monitor.record_step(loss, &[("w1", 0.5)], 0.001);
1009        }
1010
1011        assert_eq!(monitor.loss_trend(), LossTrend::Oscillating);
1012    }
1013
1014    #[test]
1015    fn test_loss_trend_stable() {
1016        let config = MonitorConfig {
1017            window_size: 5,
1018            ..MonitorConfig::default()
1019        };
1020        let mut monitor = TrainingMonitor::with_config(config);
1021
1022        // Two windows with nearly identical means and low variance
1023        for _ in 0..10 {
1024            monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
1025        }
1026
1027        assert_eq!(monitor.loss_trend(), LossTrend::Stable);
1028    }
1029
1030    #[test]
1031    fn test_loss_trend_unknown_insufficient_data() {
1032        let config = MonitorConfig {
1033            window_size: 100,
1034            ..MonitorConfig::default()
1035        };
1036        let mut monitor = TrainingMonitor::with_config(config);
1037
1038        monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
1039
1040        assert_eq!(monitor.loss_trend(), LossTrend::Unknown);
1041    }
1042
1043    // -------------------------------------------------------------------------
1044    // Health report
1045    // -------------------------------------------------------------------------
1046
1047    #[test]
1048    fn test_health_report_healthy_normal_training() {
1049        let config = MonitorConfig {
1050            window_size: 5,
1051            ..MonitorConfig::default()
1052        };
1053        let mut monitor = TrainingMonitor::with_config(config);
1054
1055        for i in 0..10 {
1056            monitor.record_step(1.0 - i as f32 * 0.05, &[("w1", 0.5)], 0.001);
1057        }
1058
1059        let report = monitor.check_health();
1060        assert!(report.is_healthy);
1061        assert_eq!(report.step, 10);
1062        assert_eq!(report.dead_neurons, 0);
1063    }
1064
1065    #[test]
1066    fn test_health_report_not_healthy_with_nan() {
1067        let mut monitor = TrainingMonitor::new();
1068        monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
1069
1070        let report = monitor.check_health();
1071        assert!(!report.is_healthy);
1072    }
1073
1074    // -------------------------------------------------------------------------
1075    // Learning rate suggestions
1076    // -------------------------------------------------------------------------
1077
1078    #[test]
1079    fn test_suggest_lr_exploding_gradients() {
1080        let config = MonitorConfig {
1081            window_size: 5,
1082            grad_norm_threshold: 10.0,
1083            ..MonitorConfig::default()
1084        };
1085        let mut monitor = TrainingMonitor::with_config(config);
1086
1087        // Record steps with exploding gradients
1088        for _ in 0..5 {
1089            monitor.record_step(1.0, &[("w1", 50.0)], 0.01);
1090        }
1091
1092        let suggested = monitor.suggest_lr();
1093        assert!(suggested.is_some());
1094        assert!((suggested.unwrap() - 0.001).abs() < 1e-6); // 0.01 * 0.1
1095    }
1096
1097    #[test]
1098    fn test_suggest_lr_oscillating_loss() {
1099        let config = MonitorConfig {
1100            window_size: 10,
1101            ..MonitorConfig::default()
1102        };
1103        let mut monitor = TrainingMonitor::with_config(config);
1104
1105        // First window stable
1106        for _ in 0..10 {
1107            monitor.record_step(1.0, &[("w1", 0.5)], 0.01);
1108        }
1109        // Second window oscillating
1110        for i in 0..10 {
1111            let loss = if i % 2 == 0 { 1.3 } else { 0.7 };
1112            monitor.record_step(loss, &[("w1", 0.5)], 0.01);
1113        }
1114
1115        let suggested = monitor.suggest_lr();
1116        assert!(suggested.is_some());
1117        assert!((suggested.unwrap() - 0.005).abs() < 1e-6); // 0.01 * 0.5
1118    }
1119
1120    #[test]
1121    fn test_suggest_lr_converged_returns_none() {
1122        let config = MonitorConfig {
1123            window_size: 5,
1124            convergence_threshold: 1e-4,
1125            ..MonitorConfig::default()
1126        };
1127        let mut monitor = TrainingMonitor::with_config(config);
1128
1129        // Two windows of identical loss (converged + stable)
1130        for _ in 0..10 {
1131            monitor.record_step(0.001, &[("w1", 0.01)], 0.001);
1132        }
1133
1134        // Stable trend with high convergence score -> should return None
1135        let trend = monitor.loss_trend();
1136        let conv = monitor.convergence_score();
1137        assert_eq!(trend, LossTrend::Stable);
1138        assert!(conv > 0.99);
1139        assert!(monitor.suggest_lr().is_none());
1140    }
1141
1142    // -------------------------------------------------------------------------
1143    // Convergence score
1144    // -------------------------------------------------------------------------
1145
1146    #[test]
1147    fn test_convergence_score_fully_converged() {
1148        let config = MonitorConfig {
1149            window_size: 10,
1150            ..MonitorConfig::default()
1151        };
1152        let mut monitor = TrainingMonitor::with_config(config);
1153
1154        for _ in 0..10 {
1155            monitor.record_step(0.5, &[("w1", 0.1)], 0.001);
1156        }
1157
1158        let score = monitor.convergence_score();
1159        assert!((score - 1.0).abs() < 1e-3, "Expected ~1.0, got {}", score);
1160    }
1161
1162    #[test]
1163    fn test_convergence_score_actively_changing() {
1164        let config = MonitorConfig {
1165            window_size: 10,
1166            ..MonitorConfig::default()
1167        };
1168        let mut monitor = TrainingMonitor::with_config(config);
1169
1170        for i in 0..10 {
1171            monitor.record_step(10.0 - i as f32 * 1.0, &[("w1", 0.5)], 0.001);
1172        }
1173
1174        let score = monitor.convergence_score();
1175        assert!(score < 0.5, "Expected low score, got {}", score);
1176    }
1177
1178    #[test]
1179    fn test_convergence_score_insufficient_data() {
1180        let monitor = TrainingMonitor::new();
1181        assert!((monitor.convergence_score() - 0.0).abs() < 1e-6);
1182    }
1183
1184    // -------------------------------------------------------------------------
1185    // Summary output
1186    // -------------------------------------------------------------------------
1187
1188    #[test]
1189    fn test_summary_contains_key_metrics() {
1190        let config = MonitorConfig {
1191            window_size: 5,
1192            ..MonitorConfig::default()
1193        };
1194        let mut monitor = TrainingMonitor::with_config(config);
1195
1196        for _ in 0..5 {
1197            monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
1198        }
1199
1200        let summary = monitor.summary();
1201        assert!(summary.contains("Training Health Report"));
1202        assert!(summary.contains("HEALTHY"));
1203        assert!(summary.contains("Loss:"));
1204        assert!(summary.contains("Grad norms:"));
1205        assert!(summary.contains("Convergence:"));
1206        assert!(summary.contains("Dead neurons:"));
1207    }
1208
1209    // -------------------------------------------------------------------------
1210    // Alert management
1211    // -------------------------------------------------------------------------
1212
1213    #[test]
1214    fn test_clear_alerts_empties_list() {
1215        let mut monitor = TrainingMonitor::new();
1216        monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
1217        assert!(!monitor.alerts().is_empty());
1218
1219        monitor.clear_alerts();
1220        assert!(monitor.alerts().is_empty());
1221    }
1222
1223    // -------------------------------------------------------------------------
1224    // Memory bounds
1225    // -------------------------------------------------------------------------
1226
1227    #[test]
1228    fn test_max_history_bounds_memory() {
1229        let config = MonitorConfig {
1230            max_history: 20,
1231            window_size: 5,
1232            ..MonitorConfig::default()
1233        };
1234        let mut monitor = TrainingMonitor::with_config(config);
1235
1236        for i in 0..50 {
1237            monitor.record_step(i as f32, &[("w1", 0.5)], 0.001);
1238        }
1239
1240        assert!(monitor.loss_history.len() <= 20);
1241        assert!(monitor.grad_norm_history.len() <= 20);
1242        assert!(monitor.lr_history.len() <= 20);
1243    }
1244
1245    // -------------------------------------------------------------------------
1246    // Grad norm stats
1247    // -------------------------------------------------------------------------
1248
1249    #[test]
1250    fn test_grad_norm_stats_computation() {
1251        let config = MonitorConfig {
1252            window_size: 4,
1253            ..MonitorConfig::default()
1254        };
1255        let mut monitor = TrainingMonitor::with_config(config);
1256
1257        // Record known gradient norms: 1, 2, 3, 4
1258        monitor.record_step(1.0, &[("w1", 1.0)], 0.001);
1259        monitor.record_step(1.0, &[("w1", 2.0)], 0.001);
1260        monitor.record_step(1.0, &[("w1", 3.0)], 0.001);
1261        monitor.record_step(1.0, &[("w1", 4.0)], 0.001);
1262
1263        let (mean, std, max) = monitor.grad_norm_stats();
1264        assert!(
1265            (mean - 2.5).abs() < 1e-4,
1266            "Expected mean ~2.5, got {}",
1267            mean
1268        );
1269        assert!((max - 4.0).abs() < 1e-4, "Expected max 4.0, got {}", max);
1270        // std of [1,2,3,4] = sqrt(1.25) ~= 1.118
1271        assert!(
1272            (std - 1.118).abs() < 0.01,
1273            "Expected std ~1.118, got {}",
1274            std
1275        );
1276    }
1277
1278    #[test]
1279    fn test_grad_norm_stats_empty() {
1280        let monitor = TrainingMonitor::new();
1281        let (mean, std, max) = monitor.grad_norm_stats();
1282        assert!((mean - 0.0).abs() < 1e-6);
1283        assert!((std - 0.0).abs() < 1e-6);
1284        assert!((max - 0.0).abs() < 1e-6);
1285    }
1286
1287    // -------------------------------------------------------------------------
1288    // Integration test
1289    // -------------------------------------------------------------------------
1290
1291    #[test]
1292    fn test_integration_100_step_improving_training() {
1293        let config = MonitorConfig {
1294            window_size: 20,
1295            ..MonitorConfig::default()
1296        };
1297        let mut monitor = TrainingMonitor::with_config(config);
1298
1299        // Simulate 100 steps of steadily improving loss
1300        for i in 0..100 {
1301            let loss = 2.0 * (-0.03 * i as f32).exp(); // Exponential decay
1302            let grad_norm = 1.0 * (-0.01 * i as f32).exp();
1303            let lr = 0.001;
1304            monitor.record_step(
1305                loss,
1306                &[
1307                    ("layer1.weight", grad_norm),
1308                    ("layer2.weight", grad_norm * 0.5),
1309                ],
1310                lr,
1311            );
1312        }
1313
1314        assert_eq!(monitor.step_count, 100);
1315        assert!(monitor.is_healthy());
1316
1317        let report = monitor.check_health();
1318        assert!(report.is_healthy);
1319        assert_eq!(report.step, 100);
1320        assert!(report.current_loss < 0.2); // Should have decayed significantly
1321        assert_eq!(report.dead_neurons, 0);
1322
1323        // Should detect decreasing trend
1324        let trend = monitor.loss_trend();
1325        assert_eq!(trend, LossTrend::Decreasing);
1326
1327        // No critical alerts should exist
1328        let critical_count = monitor
1329            .alerts
1330            .iter()
1331            .filter(|a| a.severity == AlertSeverity::Critical)
1332            .count();
1333        assert_eq!(critical_count, 0);
1334
1335        // Summary should be valid
1336        let summary = monitor.summary();
1337        assert!(summary.contains("HEALTHY"));
1338        assert!(summary.contains("step 100"));
1339    }
1340
1341    #[test]
1342    fn test_default_trait() {
1343        let monitor = TrainingMonitor::default();
1344        assert_eq!(monitor.step_count, 0);
1345        assert_eq!(monitor.config.window_size, 100);
1346    }
1347
1348    #[test]
1349    fn test_alert_display() {
1350        let alert = TrainingAlert {
1351            step: 42,
1352            severity: AlertSeverity::Critical,
1353            kind: AlertKind::NaNDetected,
1354            message: "NaN detected".to_string(),
1355        };
1356        let display = format!("{}", alert);
1357        assert!(display.contains("42"));
1358        assert!(display.contains("Critical"));
1359        assert!(display.contains("NaN"));
1360    }
1361
1362    #[test]
1363    fn test_multiple_parameters_grad_norms() {
1364        let mut monitor = TrainingMonitor::new();
1365
1366        // The max grad norm should be tracked (not mean)
1367        monitor.record_step(1.0, &[("w1", 5.0), ("w2", 10.0), ("w3", 3.0)], 0.001);
1368
1369        assert_eq!(monitor.grad_norm_history.len(), 1);
1370        assert!((monitor.grad_norm_history[0] - 10.0).abs() < 1e-6);
1371    }
1372}