use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone)]
pub struct MonitorConfig {
pub window_size: usize,
pub grad_norm_threshold: f32,
pub loss_divergence_factor: f32,
pub dead_neuron_threshold: usize,
pub nan_check: bool,
pub convergence_threshold: f32,
pub max_history: usize,
}
impl Default for MonitorConfig {
fn default() -> Self {
Self {
window_size: 100,
grad_norm_threshold: 100.0,
loss_divergence_factor: 10.0,
dead_neuron_threshold: 50,
nan_check: true,
convergence_threshold: 1e-6,
max_history: 1000,
}
}
}
#[derive(Debug, Clone)]
pub struct TrainingAlert {
pub step: usize,
pub severity: AlertSeverity,
pub kind: AlertKind,
pub message: String,
}
impl fmt::Display for TrainingAlert {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"[step {}] {:?} {:?}: {}",
self.step, self.severity, self.kind, self.message
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AlertSeverity {
Info,
Warning,
Critical,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AlertKind {
NaNDetected,
InfDetected,
GradientExplosion,
GradientVanishing,
LossDivergence,
LossStagnation,
DeadNeuron,
LearningRateTooHigh,
LearningRateTooLow,
Converged,
}
#[derive(Debug, Clone)]
pub struct HealthReport {
pub is_healthy: bool,
pub step: usize,
pub current_loss: f32,
pub loss_trend: LossTrend,
pub mean_grad_norm: f32,
pub max_grad_norm: f32,
pub convergence_score: f32,
pub active_alerts: Vec<TrainingAlert>,
pub dead_neurons: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LossTrend {
Decreasing,
Stable,
Increasing,
Oscillating,
Unknown,
}
pub struct TrainingMonitor {
step_count: usize,
loss_history: Vec<f32>,
grad_norm_history: Vec<f32>,
lr_history: Vec<f32>,
alerts: Vec<TrainingAlert>,
config: MonitorConfig,
zero_grad_counts: HashMap<String, usize>,
vanishing_streak: usize,
}
impl TrainingMonitor {
pub fn new() -> Self {
Self::with_config(MonitorConfig::default())
}
pub fn with_config(config: MonitorConfig) -> Self {
Self {
step_count: 0,
loss_history: Vec::new(),
grad_norm_history: Vec::new(),
lr_history: Vec::new(),
alerts: Vec::new(),
config,
zero_grad_counts: HashMap::new(),
vanishing_streak: 0,
}
}
pub fn record_step(&mut self, loss: f32, grad_norms: &[(&str, f32)], lr: f32) {
self.step_count += 1;
let step = self.step_count;
self.push_bounded(&mut self.loss_history.clone(), loss);
self.loss_history.push(loss);
if self.loss_history.len() > self.config.max_history {
self.loss_history.remove(0);
}
let max_grad_norm = grad_norms.iter().map(|(_, n)| *n).fold(0.0_f32, f32::max);
self.grad_norm_history.push(max_grad_norm);
if self.grad_norm_history.len() > self.config.max_history {
self.grad_norm_history.remove(0);
}
self.lr_history.push(lr);
if self.lr_history.len() > self.config.max_history {
self.lr_history.remove(0);
}
if self.config.nan_check {
if loss.is_nan() {
self.emit_alert(
step,
AlertSeverity::Critical,
AlertKind::NaNDetected,
"NaN detected in loss value".to_string(),
);
} else if loss.is_infinite() {
self.emit_alert(
step,
AlertSeverity::Critical,
AlertKind::InfDetected,
"Infinity detected in loss value".to_string(),
);
}
for (name, norm) in grad_norms {
if norm.is_nan() {
self.emit_alert(
step,
AlertSeverity::Critical,
AlertKind::NaNDetected,
format!("NaN detected in gradient norm for '{}'", name),
);
} else if norm.is_infinite() {
self.emit_alert(
step,
AlertSeverity::Critical,
AlertKind::InfDetected,
format!("Infinity detected in gradient norm for '{}'", name),
);
}
}
}
if max_grad_norm > self.config.grad_norm_threshold && max_grad_norm.is_finite() {
self.emit_alert(
step,
AlertSeverity::Warning,
AlertKind::GradientExplosion,
format!(
"Gradient norm {:.4} exceeds threshold {:.4}",
max_grad_norm, self.config.grad_norm_threshold
),
);
}
if max_grad_norm < 1e-8 && max_grad_norm.is_finite() {
self.vanishing_streak += 1;
if self.vanishing_streak >= 10 {
self.emit_alert(
step,
AlertSeverity::Warning,
AlertKind::GradientVanishing,
format!(
"Gradient norms near zero for {} consecutive steps",
self.vanishing_streak
),
);
}
} else {
self.vanishing_streak = 0;
}
let dead_threshold = self.config.dead_neuron_threshold;
let mut new_dead_alerts: Vec<(String, usize)> = Vec::new();
for (name, norm) in grad_norms {
let count = self
.zero_grad_counts
.entry((*name).to_string())
.or_insert(0);
if *norm == 0.0 {
*count += 1;
if *count == dead_threshold {
new_dead_alerts.push(((*name).to_string(), *count));
}
} else {
*count = 0;
}
}
for (name, count) in new_dead_alerts {
self.emit_alert(
step,
AlertSeverity::Warning,
AlertKind::DeadNeuron,
format!(
"Parameter '{}' has had zero gradient for {} steps (dead neuron)",
name, count
),
);
}
if self.loss_history.len() >= self.config.window_size && loss.is_finite() {
let window_start = self
.loss_history
.len()
.saturating_sub(self.config.window_size);
let window = &self.loss_history[window_start..self.loss_history.len() - 1];
let finite_vals: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
if !finite_vals.is_empty() {
let avg: f32 = finite_vals.iter().sum::<f32>() / finite_vals.len() as f32;
if avg > 0.0 && loss > avg * self.config.loss_divergence_factor {
self.emit_alert(
step,
AlertSeverity::Warning,
AlertKind::LossDivergence,
format!(
"Loss {:.6} diverged from moving average {:.6} (factor {:.1}x)",
loss,
avg,
loss / avg
),
);
}
}
}
if self.loss_history.len() >= self.config.window_size {
let window_start = self.loss_history.len() - self.config.window_size;
let window = &self.loss_history[window_start..];
let finite_vals: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
if finite_vals.len() >= 2 {
let max_val = finite_vals
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max);
let min_val = finite_vals.iter().copied().fold(f32::INFINITY, f32::min);
let range = max_val - min_val;
if range < self.config.convergence_threshold {
self.emit_alert(
step,
AlertSeverity::Info,
AlertKind::Converged,
format!(
"Training converged: loss range {:.2e} over last {} steps",
range, self.config.window_size
),
);
}
}
}
}
pub fn check_health(&self) -> HealthReport {
let (mean_gn, _std_gn, max_gn) = self.grad_norm_stats();
let trend = self.loss_trend();
let conv_score = self.convergence_score();
let current_loss = self.loss_history.last().copied().unwrap_or(f32::NAN);
let dead_neurons = self
.zero_grad_counts
.values()
.filter(|c| **c >= self.config.dead_neuron_threshold)
.count();
let has_critical = self
.alerts
.iter()
.any(|a| a.severity == AlertSeverity::Critical);
let is_healthy = !has_critical
&& trend != LossTrend::Increasing
&& !current_loss.is_nan()
&& !current_loss.is_infinite();
let min_step = self.step_count.saturating_sub(self.config.window_size);
let active_alerts: Vec<TrainingAlert> = self
.alerts
.iter()
.filter(|a| a.step > min_step)
.cloned()
.collect();
HealthReport {
is_healthy,
step: self.step_count,
current_loss,
loss_trend: trend,
mean_grad_norm: mean_gn,
max_grad_norm: max_gn,
convergence_score: conv_score,
active_alerts,
dead_neurons,
}
}
pub fn is_healthy(&self) -> bool {
self.check_health().is_healthy
}
pub fn alerts(&self) -> &[TrainingAlert] {
&self.alerts
}
pub fn clear_alerts(&mut self) {
self.alerts.clear();
}
pub fn loss_trend(&self) -> LossTrend {
let w = self.config.window_size;
if self.loss_history.len() < w * 2 {
return LossTrend::Unknown;
}
let len = self.loss_history.len();
let recent = &self.loss_history[len - w..];
let previous = &self.loss_history[len - 2 * w..len - w];
let recent_finite: Vec<f32> = recent.iter().copied().filter(|v| v.is_finite()).collect();
let prev_finite: Vec<f32> = previous.iter().copied().filter(|v| v.is_finite()).collect();
if recent_finite.is_empty() || prev_finite.is_empty() {
return LossTrend::Unknown;
}
let recent_avg = recent_finite.iter().sum::<f32>() / recent_finite.len() as f32;
let prev_avg = prev_finite.iter().sum::<f32>() / prev_finite.len() as f32;
if prev_avg == 0.0 {
return LossTrend::Unknown;
}
let ratio = recent_avg / prev_avg;
let recent_mean = recent_avg;
let recent_var = recent_finite
.iter()
.map(|v| (v - recent_mean).powi(2))
.sum::<f32>()
/ recent_finite.len() as f32;
let recent_std = recent_var.sqrt();
let cv = if recent_mean.abs() > 1e-12 {
recent_std / recent_mean.abs()
} else {
0.0
};
if ratio < 0.95 {
LossTrend::Decreasing
} else if ratio > 1.05 {
LossTrend::Increasing
} else if cv > 0.1 {
LossTrend::Oscillating
} else {
LossTrend::Stable
}
}
pub fn suggest_lr(&self) -> Option<f32> {
let current_lr = self.lr_history.last().copied()?;
let trend = self.loss_trend();
let (_, _, max_gn) = self.grad_norm_stats();
if max_gn > self.config.grad_norm_threshold && max_gn.is_finite() {
return Some(current_lr * 0.1);
}
match trend {
LossTrend::Oscillating => Some(current_lr * 0.5),
LossTrend::Stable => {
let conv = self.convergence_score();
if conv > 0.99 {
None } else {
Some(current_lr * 2.0) }
}
LossTrend::Increasing => Some(current_lr * 0.1),
_ => None,
}
}
pub fn grad_norm_stats(&self) -> (f32, f32, f32) {
if self.grad_norm_history.is_empty() {
return (0.0, 0.0, 0.0);
}
let w = self.config.window_size.min(self.grad_norm_history.len());
let start = self.grad_norm_history.len() - w;
let window = &self.grad_norm_history[start..];
let finite: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
if finite.is_empty() {
return (0.0, 0.0, 0.0);
}
let n = finite.len() as f32;
let mean = finite.iter().sum::<f32>() / n;
let variance = finite.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n;
let std = variance.sqrt();
let max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
(mean, std, max)
}
pub fn convergence_score(&self) -> f32 {
let w = self.config.window_size;
if self.loss_history.len() < w {
return 0.0;
}
let start = self.loss_history.len() - w;
let window = &self.loss_history[start..];
let finite: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
if finite.len() < 2 {
return 0.0;
}
let max_val = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let min_val = finite.iter().copied().fold(f32::INFINITY, f32::min);
let range = max_val - min_val;
let mean = finite.iter().sum::<f32>() / finite.len() as f32;
if mean.abs() < 1e-12 {
if range < self.config.convergence_threshold {
return 1.0;
}
return 0.0;
}
let relative_range = range / mean.abs();
let score = (-relative_range * 100.0).exp();
score.clamp(0.0, 1.0)
}
pub fn summary(&self) -> String {
let report = self.check_health();
let (mean_gn, std_gn, max_gn) = self.grad_norm_stats();
let mut s = String::new();
s.push_str(&format!(
"=== Training Health Report (step {}) ===\n",
report.step
));
s.push_str(&format!(
"Status: {}\n",
if report.is_healthy {
"HEALTHY"
} else {
"UNHEALTHY"
}
));
s.push_str(&format!(
"Loss: {:.6} (trend: {:?})\n",
report.current_loss, report.loss_trend
));
s.push_str(&format!(
"Grad norms: mean={:.4}, std={:.4}, max={:.4}\n",
mean_gn, std_gn, max_gn
));
s.push_str(&format!(
"Convergence: {:.2}%\n",
report.convergence_score * 100.0
));
s.push_str(&format!("Dead neurons: {}\n", report.dead_neurons));
if !report.active_alerts.is_empty() {
s.push_str(&format!(
"Active alerts ({}):\n",
report.active_alerts.len()
));
for alert in &report.active_alerts {
s.push_str(&format!(" {}\n", alert));
}
}
if let Some(lr) = self.suggest_lr() {
s.push_str(&format!("Suggested LR: {:.6}\n", lr));
}
s
}
fn push_bounded(&self, _history: &mut Vec<f32>, _value: f32) {
}
fn emit_alert(
&mut self,
step: usize,
severity: AlertSeverity,
kind: AlertKind,
message: String,
) {
self.alerts.push(TrainingAlert {
step,
severity,
kind,
message,
});
}
}
impl Default for TrainingMonitor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_monitor_creation_defaults() {
let monitor = TrainingMonitor::new();
assert_eq!(monitor.step_count, 0);
assert!(monitor.loss_history.is_empty());
assert!(monitor.alerts.is_empty());
assert_eq!(monitor.config.window_size, 100);
assert!((monitor.config.grad_norm_threshold - 100.0).abs() < 1e-6);
assert!(monitor.config.nan_check);
}
#[test]
fn test_monitor_with_custom_config() {
let config = MonitorConfig {
window_size: 50,
grad_norm_threshold: 50.0,
loss_divergence_factor: 5.0,
dead_neuron_threshold: 20,
nan_check: false,
convergence_threshold: 1e-5,
max_history: 500,
};
let monitor = TrainingMonitor::with_config(config);
assert_eq!(monitor.config.window_size, 50);
assert!((monitor.config.grad_norm_threshold - 50.0).abs() < 1e-6);
assert!(!monitor.config.nan_check);
assert_eq!(monitor.config.max_history, 500);
}
#[test]
fn test_record_step_updates_state() {
let mut monitor = TrainingMonitor::new();
monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
assert_eq!(monitor.step_count, 1);
assert_eq!(monitor.loss_history.len(), 1);
assert_eq!(monitor.grad_norm_history.len(), 1);
assert_eq!(monitor.lr_history.len(), 1);
assert!((monitor.loss_history[0] - 0.5).abs() < 1e-6);
}
#[test]
fn test_nan_detection_generates_critical_alert() {
let mut monitor = TrainingMonitor::new();
monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
assert_eq!(monitor.alerts.len(), 1);
assert_eq!(monitor.alerts[0].severity, AlertSeverity::Critical);
assert_eq!(monitor.alerts[0].kind, AlertKind::NaNDetected);
assert!(monitor.alerts[0].message.contains("NaN"));
}
#[test]
fn test_inf_detection_generates_critical_alert() {
let mut monitor = TrainingMonitor::new();
monitor.record_step(f32::INFINITY, &[("w1", 1.0)], 0.001);
assert_eq!(monitor.alerts.len(), 1);
assert_eq!(monitor.alerts[0].severity, AlertSeverity::Critical);
assert_eq!(monitor.alerts[0].kind, AlertKind::InfDetected);
assert!(monitor.alerts[0].message.contains("Infinity"));
}
#[test]
fn test_nan_in_grad_norm_detected() {
let mut monitor = TrainingMonitor::new();
monitor.record_step(0.5, &[("w1", f32::NAN)], 0.001);
let nan_alerts: Vec<_> = monitor
.alerts
.iter()
.filter(|a| a.kind == AlertKind::NaNDetected)
.collect();
assert_eq!(nan_alerts.len(), 1);
assert!(nan_alerts[0].message.contains("w1"));
}
#[test]
fn test_nan_check_disabled() {
let config = MonitorConfig {
nan_check: false,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
monitor.record_step(f32::NAN, &[("w1", f32::NAN)], 0.001);
assert!(monitor.alerts.is_empty());
}
#[test]
fn test_gradient_explosion_detection() {
let mut monitor = TrainingMonitor::new();
monitor.record_step(0.5, &[("w1", 200.0)], 0.001);
let explosion_alerts: Vec<_> = monitor
.alerts
.iter()
.filter(|a| a.kind == AlertKind::GradientExplosion)
.collect();
assert_eq!(explosion_alerts.len(), 1);
assert_eq!(explosion_alerts[0].severity, AlertSeverity::Warning);
}
#[test]
fn test_gradient_vanishing_detection() {
let config = MonitorConfig {
window_size: 5,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for _ in 0..10 {
monitor.record_step(0.5, &[("w1", 1e-10)], 0.001);
}
let vanishing_alerts: Vec<_> = monitor
.alerts
.iter()
.filter(|a| a.kind == AlertKind::GradientVanishing)
.collect();
assert!(!vanishing_alerts.is_empty());
}
#[test]
fn test_gradient_vanishing_resets_on_normal_grad() {
let mut monitor = TrainingMonitor::new();
for _ in 0..5 {
monitor.record_step(0.5, &[("w1", 1e-10)], 0.001);
}
monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
assert_eq!(monitor.vanishing_streak, 0);
}
#[test]
fn test_loss_divergence_detection() {
let config = MonitorConfig {
window_size: 10,
loss_divergence_factor: 2.0,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for _ in 0..10 {
monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
}
monitor.record_step(100.0, &[("w1", 0.5)], 0.001);
let divergence_alerts: Vec<_> = monitor
.alerts
.iter()
.filter(|a| a.kind == AlertKind::LossDivergence)
.collect();
assert!(!divergence_alerts.is_empty());
}
#[test]
fn test_dead_neuron_tracking() {
let config = MonitorConfig {
dead_neuron_threshold: 5,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for _ in 0..5 {
monitor.record_step(0.5, &[("dead_layer", 0.0), ("alive_layer", 0.5)], 0.001);
}
let dead_alerts: Vec<_> = monitor
.alerts
.iter()
.filter(|a| a.kind == AlertKind::DeadNeuron)
.collect();
assert_eq!(dead_alerts.len(), 1);
assert!(dead_alerts[0].message.contains("dead_layer"));
}
#[test]
fn test_dead_neuron_resets_on_nonzero_grad() {
let config = MonitorConfig {
dead_neuron_threshold: 10,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for _ in 0..5 {
monitor.record_step(0.5, &[("layer", 0.0)], 0.001);
}
monitor.record_step(0.5, &[("layer", 1.0)], 0.001);
assert_eq!(*monitor.zero_grad_counts.get("layer").unwrap(), 0);
}
#[test]
fn test_convergence_detection() {
let config = MonitorConfig {
window_size: 10,
convergence_threshold: 1e-4,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for _ in 0..10 {
monitor.record_step(0.001, &[("w1", 0.1)], 0.001);
}
let converged_alerts: Vec<_> = monitor
.alerts
.iter()
.filter(|a| a.kind == AlertKind::Converged)
.collect();
assert!(!converged_alerts.is_empty());
}
#[test]
fn test_loss_trend_decreasing() {
let config = MonitorConfig {
window_size: 5,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for i in 0..5 {
monitor.record_step(2.0 - i as f32 * 0.01, &[("w1", 0.5)], 0.001);
}
for i in 0..5 {
monitor.record_step(1.0 - i as f32 * 0.01, &[("w1", 0.5)], 0.001);
}
assert_eq!(monitor.loss_trend(), LossTrend::Decreasing);
}
#[test]
fn test_loss_trend_increasing() {
let config = MonitorConfig {
window_size: 5,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for i in 0..5 {
monitor.record_step(1.0 + i as f32 * 0.01, &[("w1", 0.5)], 0.001);
}
for i in 0..5 {
monitor.record_step(2.0 + i as f32 * 0.01, &[("w1", 0.5)], 0.001);
}
assert_eq!(monitor.loss_trend(), LossTrend::Increasing);
}
#[test]
fn test_loss_trend_oscillating() {
let config = MonitorConfig {
window_size: 10,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for _ in 0..10 {
monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
}
for i in 0..10 {
let loss = if i % 2 == 0 { 1.3 } else { 0.7 };
monitor.record_step(loss, &[("w1", 0.5)], 0.001);
}
assert_eq!(monitor.loss_trend(), LossTrend::Oscillating);
}
#[test]
fn test_loss_trend_stable() {
let config = MonitorConfig {
window_size: 5,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for _ in 0..10 {
monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
}
assert_eq!(monitor.loss_trend(), LossTrend::Stable);
}
#[test]
fn test_loss_trend_unknown_insufficient_data() {
let config = MonitorConfig {
window_size: 100,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
assert_eq!(monitor.loss_trend(), LossTrend::Unknown);
}
#[test]
fn test_health_report_healthy_normal_training() {
let config = MonitorConfig {
window_size: 5,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for i in 0..10 {
monitor.record_step(1.0 - i as f32 * 0.05, &[("w1", 0.5)], 0.001);
}
let report = monitor.check_health();
assert!(report.is_healthy);
assert_eq!(report.step, 10);
assert_eq!(report.dead_neurons, 0);
}
#[test]
fn test_health_report_not_healthy_with_nan() {
let mut monitor = TrainingMonitor::new();
monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
let report = monitor.check_health();
assert!(!report.is_healthy);
}
#[test]
fn test_suggest_lr_exploding_gradients() {
let config = MonitorConfig {
window_size: 5,
grad_norm_threshold: 10.0,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for _ in 0..5 {
monitor.record_step(1.0, &[("w1", 50.0)], 0.01);
}
let suggested = monitor.suggest_lr();
assert!(suggested.is_some());
assert!((suggested.unwrap() - 0.001).abs() < 1e-6); }
#[test]
fn test_suggest_lr_oscillating_loss() {
let config = MonitorConfig {
window_size: 10,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for _ in 0..10 {
monitor.record_step(1.0, &[("w1", 0.5)], 0.01);
}
for i in 0..10 {
let loss = if i % 2 == 0 { 1.3 } else { 0.7 };
monitor.record_step(loss, &[("w1", 0.5)], 0.01);
}
let suggested = monitor.suggest_lr();
assert!(suggested.is_some());
assert!((suggested.unwrap() - 0.005).abs() < 1e-6); }
#[test]
fn test_suggest_lr_converged_returns_none() {
let config = MonitorConfig {
window_size: 5,
convergence_threshold: 1e-4,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for _ in 0..10 {
monitor.record_step(0.001, &[("w1", 0.01)], 0.001);
}
let trend = monitor.loss_trend();
let conv = monitor.convergence_score();
assert_eq!(trend, LossTrend::Stable);
assert!(conv > 0.99);
assert!(monitor.suggest_lr().is_none());
}
#[test]
fn test_convergence_score_fully_converged() {
let config = MonitorConfig {
window_size: 10,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for _ in 0..10 {
monitor.record_step(0.5, &[("w1", 0.1)], 0.001);
}
let score = monitor.convergence_score();
assert!((score - 1.0).abs() < 1e-3, "Expected ~1.0, got {}", score);
}
#[test]
fn test_convergence_score_actively_changing() {
let config = MonitorConfig {
window_size: 10,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for i in 0..10 {
monitor.record_step(10.0 - i as f32 * 1.0, &[("w1", 0.5)], 0.001);
}
let score = monitor.convergence_score();
assert!(score < 0.5, "Expected low score, got {}", score);
}
#[test]
fn test_convergence_score_insufficient_data() {
let monitor = TrainingMonitor::new();
assert!((monitor.convergence_score() - 0.0).abs() < 1e-6);
}
#[test]
fn test_summary_contains_key_metrics() {
let config = MonitorConfig {
window_size: 5,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for _ in 0..5 {
monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
}
let summary = monitor.summary();
assert!(summary.contains("Training Health Report"));
assert!(summary.contains("HEALTHY"));
assert!(summary.contains("Loss:"));
assert!(summary.contains("Grad norms:"));
assert!(summary.contains("Convergence:"));
assert!(summary.contains("Dead neurons:"));
}
#[test]
fn test_clear_alerts_empties_list() {
let mut monitor = TrainingMonitor::new();
monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
assert!(!monitor.alerts().is_empty());
monitor.clear_alerts();
assert!(monitor.alerts().is_empty());
}
#[test]
fn test_max_history_bounds_memory() {
let config = MonitorConfig {
max_history: 20,
window_size: 5,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for i in 0..50 {
monitor.record_step(i as f32, &[("w1", 0.5)], 0.001);
}
assert!(monitor.loss_history.len() <= 20);
assert!(monitor.grad_norm_history.len() <= 20);
assert!(monitor.lr_history.len() <= 20);
}
#[test]
fn test_grad_norm_stats_computation() {
let config = MonitorConfig {
window_size: 4,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
monitor.record_step(1.0, &[("w1", 1.0)], 0.001);
monitor.record_step(1.0, &[("w1", 2.0)], 0.001);
monitor.record_step(1.0, &[("w1", 3.0)], 0.001);
monitor.record_step(1.0, &[("w1", 4.0)], 0.001);
let (mean, std, max) = monitor.grad_norm_stats();
assert!(
(mean - 2.5).abs() < 1e-4,
"Expected mean ~2.5, got {}",
mean
);
assert!((max - 4.0).abs() < 1e-4, "Expected max 4.0, got {}", max);
assert!(
(std - 1.118).abs() < 0.01,
"Expected std ~1.118, got {}",
std
);
}
#[test]
fn test_grad_norm_stats_empty() {
let monitor = TrainingMonitor::new();
let (mean, std, max) = monitor.grad_norm_stats();
assert!((mean - 0.0).abs() < 1e-6);
assert!((std - 0.0).abs() < 1e-6);
assert!((max - 0.0).abs() < 1e-6);
}
#[test]
fn test_integration_100_step_improving_training() {
let config = MonitorConfig {
window_size: 20,
..MonitorConfig::default()
};
let mut monitor = TrainingMonitor::with_config(config);
for i in 0..100 {
let loss = 2.0 * (-0.03 * i as f32).exp(); let grad_norm = 1.0 * (-0.01 * i as f32).exp();
let lr = 0.001;
monitor.record_step(
loss,
&[
("layer1.weight", grad_norm),
("layer2.weight", grad_norm * 0.5),
],
lr,
);
}
assert_eq!(monitor.step_count, 100);
assert!(monitor.is_healthy());
let report = monitor.check_health();
assert!(report.is_healthy);
assert_eq!(report.step, 100);
assert!(report.current_loss < 0.2); assert_eq!(report.dead_neurons, 0);
let trend = monitor.loss_trend();
assert_eq!(trend, LossTrend::Decreasing);
let critical_count = monitor
.alerts
.iter()
.filter(|a| a.severity == AlertSeverity::Critical)
.count();
assert_eq!(critical_count, 0);
let summary = monitor.summary();
assert!(summary.contains("HEALTHY"));
assert!(summary.contains("step 100"));
}
#[test]
fn test_default_trait() {
let monitor = TrainingMonitor::default();
assert_eq!(monitor.step_count, 0);
assert_eq!(monitor.config.window_size, 100);
}
#[test]
fn test_alert_display() {
let alert = TrainingAlert {
step: 42,
severity: AlertSeverity::Critical,
kind: AlertKind::NaNDetected,
message: "NaN detected".to_string(),
};
let display = format!("{}", alert);
assert!(display.contains("42"));
assert!(display.contains("Critical"));
assert!(display.contains("NaN"));
}
#[test]
fn test_multiple_parameters_grad_norms() {
let mut monitor = TrainingMonitor::new();
monitor.record_step(1.0, &[("w1", 5.0), ("w2", 10.0), ("w3", 3.0)], 0.001);
assert_eq!(monitor.grad_norm_history.len(), 1);
assert!((monitor.grad_norm_history[0] - 10.0).abs() < 1e-6);
}
}