1use std::collections::HashMap;
36use std::fmt;
37
38#[derive(Debug, Clone)]
44pub struct MonitorConfig {
45 pub window_size: usize,
47 pub grad_norm_threshold: f32,
49 pub loss_divergence_factor: f32,
51 pub dead_neuron_threshold: usize,
53 pub nan_check: bool,
55 pub convergence_threshold: f32,
57 pub max_history: usize,
59}
60
61impl Default for MonitorConfig {
62 fn default() -> Self {
63 Self {
64 window_size: 100,
65 grad_norm_threshold: 100.0,
66 loss_divergence_factor: 10.0,
67 dead_neuron_threshold: 50,
68 nan_check: true,
69 convergence_threshold: 1e-6,
70 max_history: 1000,
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
81pub struct TrainingAlert {
82 pub step: usize,
84 pub severity: AlertSeverity,
86 pub kind: AlertKind,
88 pub message: String,
90}
91
92impl fmt::Display for TrainingAlert {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 write!(
95 f,
96 "[step {}] {:?} {:?}: {}",
97 self.step, self.severity, self.kind, self.message
98 )
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub enum AlertSeverity {
105 Info,
107 Warning,
109 Critical,
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq)]
115pub enum AlertKind {
116 NaNDetected,
118 InfDetected,
120 GradientExplosion,
122 GradientVanishing,
124 LossDivergence,
126 LossStagnation,
128 DeadNeuron,
130 LearningRateTooHigh,
132 LearningRateTooLow,
134 Converged,
136}
137
138#[derive(Debug, Clone)]
144pub struct HealthReport {
145 pub is_healthy: bool,
147 pub step: usize,
149 pub current_loss: f32,
151 pub loss_trend: LossTrend,
153 pub mean_grad_norm: f32,
155 pub max_grad_norm: f32,
157 pub convergence_score: f32,
159 pub active_alerts: Vec<TrainingAlert>,
161 pub dead_neurons: usize,
163}
164
165#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167pub enum LossTrend {
168 Decreasing,
170 Stable,
172 Increasing,
174 Oscillating,
176 Unknown,
178}
179
180pub struct TrainingMonitor {
191 step_count: usize,
193 loss_history: Vec<f32>,
195 grad_norm_history: Vec<f32>,
197 lr_history: Vec<f32>,
199 alerts: Vec<TrainingAlert>,
201 config: MonitorConfig,
203 zero_grad_counts: HashMap<String, usize>,
205 vanishing_streak: usize,
207}
208
209impl TrainingMonitor {
210 pub fn new() -> Self {
212 Self::with_config(MonitorConfig::default())
213 }
214
215 pub fn with_config(config: MonitorConfig) -> Self {
217 Self {
218 step_count: 0,
219 loss_history: Vec::new(),
220 grad_norm_history: Vec::new(),
221 lr_history: Vec::new(),
222 alerts: Vec::new(),
223 config,
224 zero_grad_counts: HashMap::new(),
225 vanishing_streak: 0,
226 }
227 }
228
229 pub fn record_step(&mut self, loss: f32, grad_norms: &[(&str, f32)], lr: f32) {
237 self.step_count += 1;
238 let step = self.step_count;
239
240 self.push_bounded(&mut self.loss_history.clone(), loss);
242 self.loss_history.push(loss);
243 if self.loss_history.len() > self.config.max_history {
244 self.loss_history.remove(0);
245 }
246
247 let max_grad_norm = grad_norms
249 .iter()
250 .map(|(_, n)| *n)
251 .fold(0.0_f32, f32::max);
252
253 self.grad_norm_history.push(max_grad_norm);
254 if self.grad_norm_history.len() > self.config.max_history {
255 self.grad_norm_history.remove(0);
256 }
257
258 self.lr_history.push(lr);
259 if self.lr_history.len() > self.config.max_history {
260 self.lr_history.remove(0);
261 }
262
263 if self.config.nan_check {
265 if loss.is_nan() {
266 self.emit_alert(step, AlertSeverity::Critical, AlertKind::NaNDetected,
267 "NaN detected in loss value".to_string());
268 } else if loss.is_infinite() {
269 self.emit_alert(step, AlertSeverity::Critical, AlertKind::InfDetected,
270 "Infinity detected in loss value".to_string());
271 }
272
273 for (name, norm) in grad_norms {
274 if norm.is_nan() {
275 self.emit_alert(step, AlertSeverity::Critical, AlertKind::NaNDetected,
276 format!("NaN detected in gradient norm for '{}'", name));
277 } else if norm.is_infinite() {
278 self.emit_alert(step, AlertSeverity::Critical, AlertKind::InfDetected,
279 format!("Infinity detected in gradient norm for '{}'", name));
280 }
281 }
282 }
283
284 if max_grad_norm > self.config.grad_norm_threshold && max_grad_norm.is_finite() {
286 self.emit_alert(step, AlertSeverity::Warning, AlertKind::GradientExplosion,
287 format!("Gradient norm {:.4} exceeds threshold {:.4}",
288 max_grad_norm, self.config.grad_norm_threshold));
289 }
290
291 if max_grad_norm < 1e-8 && max_grad_norm.is_finite() {
293 self.vanishing_streak += 1;
294 if self.vanishing_streak >= 10 {
295 self.emit_alert(step, AlertSeverity::Warning, AlertKind::GradientVanishing,
296 format!("Gradient norms near zero for {} consecutive steps",
297 self.vanishing_streak));
298 }
299 } else {
300 self.vanishing_streak = 0;
301 }
302
303 let dead_threshold = self.config.dead_neuron_threshold;
305 let mut new_dead_alerts: Vec<(String, usize)> = Vec::new();
306 for (name, norm) in grad_norms {
307 let count = self.zero_grad_counts.entry(name.to_string()).or_insert(0);
308 if *norm == 0.0 {
309 *count += 1;
310 if *count == dead_threshold {
311 new_dead_alerts.push((name.to_string(), *count));
312 }
313 } else {
314 *count = 0;
315 }
316 }
317 for (name, count) in new_dead_alerts {
318 self.emit_alert(step, AlertSeverity::Warning, AlertKind::DeadNeuron,
319 format!("Parameter '{}' has had zero gradient for {} steps (dead neuron)",
320 name, count));
321 }
322
323 if self.loss_history.len() >= self.config.window_size && loss.is_finite() {
325 let window_start = self.loss_history.len().saturating_sub(self.config.window_size);
326 let window = &self.loss_history[window_start..self.loss_history.len() - 1];
327 let finite_vals: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
328 if !finite_vals.is_empty() {
329 let avg: f32 = finite_vals.iter().sum::<f32>() / finite_vals.len() as f32;
330 if avg > 0.0 && loss > avg * self.config.loss_divergence_factor {
331 self.emit_alert(step, AlertSeverity::Warning, AlertKind::LossDivergence,
332 format!("Loss {:.6} diverged from moving average {:.6} (factor {:.1}x)",
333 loss, avg, loss / avg));
334 }
335 }
336 }
337
338 if self.loss_history.len() >= self.config.window_size {
340 let window_start = self.loss_history.len() - self.config.window_size;
341 let window = &self.loss_history[window_start..];
342 let finite_vals: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
343 if finite_vals.len() >= 2 {
344 let max_val = finite_vals.iter().copied().fold(f32::NEG_INFINITY, f32::max);
345 let min_val = finite_vals.iter().copied().fold(f32::INFINITY, f32::min);
346 let range = max_val - min_val;
347 if range < self.config.convergence_threshold {
348 self.emit_alert(step, AlertSeverity::Info, AlertKind::Converged,
349 format!("Training converged: loss range {:.2e} over last {} steps",
350 range, self.config.window_size));
351 }
352 }
353 }
354 }
355
356 pub fn check_health(&self) -> HealthReport {
358 let (mean_gn, _std_gn, max_gn) = self.grad_norm_stats();
359 let trend = self.loss_trend();
360 let conv_score = self.convergence_score();
361
362 let current_loss = self.loss_history.last().copied().unwrap_or(f32::NAN);
363
364 let dead_neurons = self.zero_grad_counts.values()
366 .filter(|c| **c >= self.config.dead_neuron_threshold)
367 .count();
368
369 let has_critical = self.alerts.iter().any(|a| a.severity == AlertSeverity::Critical);
371 let is_healthy = !has_critical
372 && trend != LossTrend::Increasing
373 && !current_loss.is_nan()
374 && !current_loss.is_infinite();
375
376 let min_step = self.step_count.saturating_sub(self.config.window_size);
378 let active_alerts: Vec<TrainingAlert> = self.alerts.iter()
379 .filter(|a| a.step > min_step)
380 .cloned()
381 .collect();
382
383 HealthReport {
384 is_healthy,
385 step: self.step_count,
386 current_loss,
387 loss_trend: trend,
388 mean_grad_norm: mean_gn,
389 max_grad_norm: max_gn,
390 convergence_score: conv_score,
391 active_alerts,
392 dead_neurons,
393 }
394 }
395
396 pub fn is_healthy(&self) -> bool {
398 self.check_health().is_healthy
399 }
400
401 pub fn alerts(&self) -> &[TrainingAlert] {
403 &self.alerts
404 }
405
406 pub fn clear_alerts(&mut self) {
408 self.alerts.clear();
409 }
410
411 pub fn loss_trend(&self) -> LossTrend {
416 let w = self.config.window_size;
417 if self.loss_history.len() < w * 2 {
418 return LossTrend::Unknown;
419 }
420
421 let len = self.loss_history.len();
422 let recent = &self.loss_history[len - w..];
423 let previous = &self.loss_history[len - 2 * w..len - w];
424
425 let recent_finite: Vec<f32> = recent.iter().copied().filter(|v| v.is_finite()).collect();
426 let prev_finite: Vec<f32> = previous.iter().copied().filter(|v| v.is_finite()).collect();
427
428 if recent_finite.is_empty() || prev_finite.is_empty() {
429 return LossTrend::Unknown;
430 }
431
432 let recent_avg = recent_finite.iter().sum::<f32>() / recent_finite.len() as f32;
433 let prev_avg = prev_finite.iter().sum::<f32>() / prev_finite.len() as f32;
434
435 if prev_avg == 0.0 {
436 return LossTrend::Unknown;
437 }
438
439 let ratio = recent_avg / prev_avg;
440
441 let recent_mean = recent_avg;
443 let recent_var = recent_finite.iter()
444 .map(|v| (v - recent_mean).powi(2))
445 .sum::<f32>() / recent_finite.len() as f32;
446 let recent_std = recent_var.sqrt();
447 let cv = if recent_mean.abs() > 1e-12 { recent_std / recent_mean.abs() } else { 0.0 };
448
449 if ratio < 0.95 {
450 LossTrend::Decreasing
451 } else if ratio > 1.05 {
452 LossTrend::Increasing
453 } else if cv > 0.1 {
454 LossTrend::Oscillating
455 } else {
456 LossTrend::Stable
457 }
458 }
459
460 pub fn suggest_lr(&self) -> Option<f32> {
464 let current_lr = self.lr_history.last().copied()?;
465 let trend = self.loss_trend();
466
467 let (_, _, max_gn) = self.grad_norm_stats();
469 if max_gn > self.config.grad_norm_threshold && max_gn.is_finite() {
470 return Some(current_lr * 0.1);
471 }
472
473 match trend {
474 LossTrend::Oscillating => Some(current_lr * 0.5),
475 LossTrend::Stable => {
476 let conv = self.convergence_score();
478 if conv > 0.99 {
479 None } else {
481 Some(current_lr * 2.0) }
483 }
484 LossTrend::Increasing => Some(current_lr * 0.1),
485 _ => None,
486 }
487 }
488
489 pub fn grad_norm_stats(&self) -> (f32, f32, f32) {
491 if self.grad_norm_history.is_empty() {
492 return (0.0, 0.0, 0.0);
493 }
494
495 let w = self.config.window_size.min(self.grad_norm_history.len());
496 let start = self.grad_norm_history.len() - w;
497 let window = &self.grad_norm_history[start..];
498
499 let finite: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
500 if finite.is_empty() {
501 return (0.0, 0.0, 0.0);
502 }
503
504 let n = finite.len() as f32;
505 let mean = finite.iter().sum::<f32>() / n;
506 let variance = finite.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n;
507 let std = variance.sqrt();
508 let max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
509
510 (mean, std, max)
511 }
512
513 pub fn convergence_score(&self) -> f32 {
518 let w = self.config.window_size;
519 if self.loss_history.len() < w {
520 return 0.0;
521 }
522
523 let start = self.loss_history.len() - w;
524 let window = &self.loss_history[start..];
525
526 let finite: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
527 if finite.len() < 2 {
528 return 0.0;
529 }
530
531 let max_val = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
532 let min_val = finite.iter().copied().fold(f32::INFINITY, f32::min);
533 let range = max_val - min_val;
534
535 let mean = finite.iter().sum::<f32>() / finite.len() as f32;
536 if mean.abs() < 1e-12 {
537 if range < self.config.convergence_threshold {
539 return 1.0;
540 }
541 return 0.0;
542 }
543
544 let relative_range = range / mean.abs();
546
547 let score = (-relative_range * 100.0).exp();
550 score.clamp(0.0, 1.0)
551 }
552
553 pub fn summary(&self) -> String {
555 let report = self.check_health();
556 let (mean_gn, std_gn, max_gn) = self.grad_norm_stats();
557
558 let mut s = String::new();
559 s.push_str(&format!("=== Training Health Report (step {}) ===\n", report.step));
560 s.push_str(&format!("Status: {}\n", if report.is_healthy { "HEALTHY" } else { "UNHEALTHY" }));
561 s.push_str(&format!("Loss: {:.6} (trend: {:?})\n", report.current_loss, report.loss_trend));
562 s.push_str(&format!("Grad norms: mean={:.4}, std={:.4}, max={:.4}\n", mean_gn, std_gn, max_gn));
563 s.push_str(&format!("Convergence: {:.2}%\n", report.convergence_score * 100.0));
564 s.push_str(&format!("Dead neurons: {}\n", report.dead_neurons));
565
566 if !report.active_alerts.is_empty() {
567 s.push_str(&format!("Active alerts ({}):\n", report.active_alerts.len()));
568 for alert in &report.active_alerts {
569 s.push_str(&format!(" {}\n", alert));
570 }
571 }
572
573 if let Some(lr) = self.suggest_lr() {
574 s.push_str(&format!("Suggested LR: {:.6}\n", lr));
575 }
576
577 s
578 }
579
580 fn push_bounded(&self, _history: &mut Vec<f32>, _value: f32) {
585 }
588
589 fn emit_alert(&mut self, step: usize, severity: AlertSeverity, kind: AlertKind, message: String) {
590 self.alerts.push(TrainingAlert {
591 step,
592 severity,
593 kind,
594 message,
595 });
596 }
597}
598
599impl Default for TrainingMonitor {
600 fn default() -> Self {
601 Self::new()
602 }
603}
604
605#[cfg(test)]
610mod tests {
611 use super::*;
612
613 #[test]
618 fn test_monitor_creation_defaults() {
619 let monitor = TrainingMonitor::new();
620 assert_eq!(monitor.step_count, 0);
621 assert!(monitor.loss_history.is_empty());
622 assert!(monitor.alerts.is_empty());
623 assert_eq!(monitor.config.window_size, 100);
624 assert!((monitor.config.grad_norm_threshold - 100.0).abs() < 1e-6);
625 assert!(monitor.config.nan_check);
626 }
627
628 #[test]
629 fn test_monitor_with_custom_config() {
630 let config = MonitorConfig {
631 window_size: 50,
632 grad_norm_threshold: 50.0,
633 loss_divergence_factor: 5.0,
634 dead_neuron_threshold: 20,
635 nan_check: false,
636 convergence_threshold: 1e-5,
637 max_history: 500,
638 };
639 let monitor = TrainingMonitor::with_config(config);
640 assert_eq!(monitor.config.window_size, 50);
641 assert!((monitor.config.grad_norm_threshold - 50.0).abs() < 1e-6);
642 assert!(!monitor.config.nan_check);
643 assert_eq!(monitor.config.max_history, 500);
644 }
645
646 #[test]
651 fn test_record_step_updates_state() {
652 let mut monitor = TrainingMonitor::new();
653 monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
654
655 assert_eq!(monitor.step_count, 1);
656 assert_eq!(monitor.loss_history.len(), 1);
657 assert_eq!(monitor.grad_norm_history.len(), 1);
658 assert_eq!(monitor.lr_history.len(), 1);
659 assert!((monitor.loss_history[0] - 0.5).abs() < 1e-6);
660 }
661
662 #[test]
667 fn test_nan_detection_generates_critical_alert() {
668 let mut monitor = TrainingMonitor::new();
669 monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
670
671 assert_eq!(monitor.alerts.len(), 1);
672 assert_eq!(monitor.alerts[0].severity, AlertSeverity::Critical);
673 assert_eq!(monitor.alerts[0].kind, AlertKind::NaNDetected);
674 assert!(monitor.alerts[0].message.contains("NaN"));
675 }
676
677 #[test]
678 fn test_inf_detection_generates_critical_alert() {
679 let mut monitor = TrainingMonitor::new();
680 monitor.record_step(f32::INFINITY, &[("w1", 1.0)], 0.001);
681
682 assert_eq!(monitor.alerts.len(), 1);
683 assert_eq!(monitor.alerts[0].severity, AlertSeverity::Critical);
684 assert_eq!(monitor.alerts[0].kind, AlertKind::InfDetected);
685 assert!(monitor.alerts[0].message.contains("Infinity"));
686 }
687
688 #[test]
689 fn test_nan_in_grad_norm_detected() {
690 let mut monitor = TrainingMonitor::new();
691 monitor.record_step(0.5, &[("w1", f32::NAN)], 0.001);
692
693 let nan_alerts: Vec<_> = monitor.alerts.iter()
694 .filter(|a| a.kind == AlertKind::NaNDetected)
695 .collect();
696 assert_eq!(nan_alerts.len(), 1);
697 assert!(nan_alerts[0].message.contains("w1"));
698 }
699
700 #[test]
701 fn test_nan_check_disabled() {
702 let config = MonitorConfig {
703 nan_check: false,
704 ..MonitorConfig::default()
705 };
706 let mut monitor = TrainingMonitor::with_config(config);
707 monitor.record_step(f32::NAN, &[("w1", f32::NAN)], 0.001);
708
709 assert!(monitor.alerts.is_empty());
710 }
711
712 #[test]
717 fn test_gradient_explosion_detection() {
718 let mut monitor = TrainingMonitor::new();
719 monitor.record_step(0.5, &[("w1", 200.0)], 0.001);
720
721 let explosion_alerts: Vec<_> = monitor.alerts.iter()
722 .filter(|a| a.kind == AlertKind::GradientExplosion)
723 .collect();
724 assert_eq!(explosion_alerts.len(), 1);
725 assert_eq!(explosion_alerts[0].severity, AlertSeverity::Warning);
726 }
727
728 #[test]
729 fn test_gradient_vanishing_detection() {
730 let config = MonitorConfig {
731 window_size: 5,
732 ..MonitorConfig::default()
733 };
734 let mut monitor = TrainingMonitor::with_config(config);
735
736 for _ in 0..10 {
738 monitor.record_step(0.5, &[("w1", 1e-10)], 0.001);
739 }
740
741 let vanishing_alerts: Vec<_> = monitor.alerts.iter()
742 .filter(|a| a.kind == AlertKind::GradientVanishing)
743 .collect();
744 assert!(!vanishing_alerts.is_empty());
745 }
746
747 #[test]
748 fn test_gradient_vanishing_resets_on_normal_grad() {
749 let mut monitor = TrainingMonitor::new();
750
751 for _ in 0..5 {
753 monitor.record_step(0.5, &[("w1", 1e-10)], 0.001);
754 }
755 monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
757 assert_eq!(monitor.vanishing_streak, 0);
758 }
759
760 #[test]
765 fn test_loss_divergence_detection() {
766 let config = MonitorConfig {
767 window_size: 10,
768 loss_divergence_factor: 2.0,
769 ..MonitorConfig::default()
770 };
771 let mut monitor = TrainingMonitor::with_config(config);
772
773 for _ in 0..10 {
775 monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
776 }
777
778 monitor.record_step(100.0, &[("w1", 0.5)], 0.001);
780
781 let divergence_alerts: Vec<_> = monitor.alerts.iter()
782 .filter(|a| a.kind == AlertKind::LossDivergence)
783 .collect();
784 assert!(!divergence_alerts.is_empty());
785 }
786
787 #[test]
792 fn test_dead_neuron_tracking() {
793 let config = MonitorConfig {
794 dead_neuron_threshold: 5,
795 ..MonitorConfig::default()
796 };
797 let mut monitor = TrainingMonitor::with_config(config);
798
799 for _ in 0..5 {
800 monitor.record_step(0.5, &[("dead_layer", 0.0), ("alive_layer", 0.5)], 0.001);
801 }
802
803 let dead_alerts: Vec<_> = monitor.alerts.iter()
804 .filter(|a| a.kind == AlertKind::DeadNeuron)
805 .collect();
806 assert_eq!(dead_alerts.len(), 1);
807 assert!(dead_alerts[0].message.contains("dead_layer"));
808 }
809
810 #[test]
811 fn test_dead_neuron_resets_on_nonzero_grad() {
812 let config = MonitorConfig {
813 dead_neuron_threshold: 10,
814 ..MonitorConfig::default()
815 };
816 let mut monitor = TrainingMonitor::with_config(config);
817
818 for _ in 0..5 {
820 monitor.record_step(0.5, &[("layer", 0.0)], 0.001);
821 }
822 monitor.record_step(0.5, &[("layer", 1.0)], 0.001);
824
825 assert_eq!(*monitor.zero_grad_counts.get("layer").unwrap(), 0);
826 }
827
828 #[test]
833 fn test_convergence_detection() {
834 let config = MonitorConfig {
835 window_size: 10,
836 convergence_threshold: 1e-4,
837 ..MonitorConfig::default()
838 };
839 let mut monitor = TrainingMonitor::with_config(config);
840
841 for _ in 0..10 {
843 monitor.record_step(0.001, &[("w1", 0.1)], 0.001);
844 }
845
846 let converged_alerts: Vec<_> = monitor.alerts.iter()
847 .filter(|a| a.kind == AlertKind::Converged)
848 .collect();
849 assert!(!converged_alerts.is_empty());
850 }
851
852 #[test]
857 fn test_loss_trend_decreasing() {
858 let config = MonitorConfig {
859 window_size: 5,
860 ..MonitorConfig::default()
861 };
862 let mut monitor = TrainingMonitor::with_config(config);
863
864 for i in 0..5 {
866 monitor.record_step(2.0 - i as f32 * 0.01, &[("w1", 0.5)], 0.001);
867 }
868 for i in 0..5 {
870 monitor.record_step(1.0 - i as f32 * 0.01, &[("w1", 0.5)], 0.001);
871 }
872
873 assert_eq!(monitor.loss_trend(), LossTrend::Decreasing);
874 }
875
876 #[test]
877 fn test_loss_trend_increasing() {
878 let config = MonitorConfig {
879 window_size: 5,
880 ..MonitorConfig::default()
881 };
882 let mut monitor = TrainingMonitor::with_config(config);
883
884 for i in 0..5 {
886 monitor.record_step(1.0 + i as f32 * 0.01, &[("w1", 0.5)], 0.001);
887 }
888 for i in 0..5 {
890 monitor.record_step(2.0 + i as f32 * 0.01, &[("w1", 0.5)], 0.001);
891 }
892
893 assert_eq!(monitor.loss_trend(), LossTrend::Increasing);
894 }
895
896 #[test]
897 fn test_loss_trend_oscillating() {
898 let config = MonitorConfig {
899 window_size: 10,
900 ..MonitorConfig::default()
901 };
902 let mut monitor = TrainingMonitor::with_config(config);
903
904 for _ in 0..10 {
906 monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
907 }
908 for i in 0..10 {
910 let loss = if i % 2 == 0 { 1.3 } else { 0.7 };
911 monitor.record_step(loss, &[("w1", 0.5)], 0.001);
912 }
913
914 assert_eq!(monitor.loss_trend(), LossTrend::Oscillating);
915 }
916
917 #[test]
918 fn test_loss_trend_stable() {
919 let config = MonitorConfig {
920 window_size: 5,
921 ..MonitorConfig::default()
922 };
923 let mut monitor = TrainingMonitor::with_config(config);
924
925 for _ in 0..10 {
927 monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
928 }
929
930 assert_eq!(monitor.loss_trend(), LossTrend::Stable);
931 }
932
933 #[test]
934 fn test_loss_trend_unknown_insufficient_data() {
935 let config = MonitorConfig {
936 window_size: 100,
937 ..MonitorConfig::default()
938 };
939 let mut monitor = TrainingMonitor::with_config(config);
940
941 monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
942
943 assert_eq!(monitor.loss_trend(), LossTrend::Unknown);
944 }
945
946 #[test]
951 fn test_health_report_healthy_normal_training() {
952 let config = MonitorConfig {
953 window_size: 5,
954 ..MonitorConfig::default()
955 };
956 let mut monitor = TrainingMonitor::with_config(config);
957
958 for i in 0..10 {
959 monitor.record_step(1.0 - i as f32 * 0.05, &[("w1", 0.5)], 0.001);
960 }
961
962 let report = monitor.check_health();
963 assert!(report.is_healthy);
964 assert_eq!(report.step, 10);
965 assert_eq!(report.dead_neurons, 0);
966 }
967
968 #[test]
969 fn test_health_report_not_healthy_with_nan() {
970 let mut monitor = TrainingMonitor::new();
971 monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
972
973 let report = monitor.check_health();
974 assert!(!report.is_healthy);
975 }
976
977 #[test]
982 fn test_suggest_lr_exploding_gradients() {
983 let config = MonitorConfig {
984 window_size: 5,
985 grad_norm_threshold: 10.0,
986 ..MonitorConfig::default()
987 };
988 let mut monitor = TrainingMonitor::with_config(config);
989
990 for _ in 0..5 {
992 monitor.record_step(1.0, &[("w1", 50.0)], 0.01);
993 }
994
995 let suggested = monitor.suggest_lr();
996 assert!(suggested.is_some());
997 assert!((suggested.unwrap() - 0.001).abs() < 1e-6); }
999
1000 #[test]
1001 fn test_suggest_lr_oscillating_loss() {
1002 let config = MonitorConfig {
1003 window_size: 10,
1004 ..MonitorConfig::default()
1005 };
1006 let mut monitor = TrainingMonitor::with_config(config);
1007
1008 for _ in 0..10 {
1010 monitor.record_step(1.0, &[("w1", 0.5)], 0.01);
1011 }
1012 for i in 0..10 {
1014 let loss = if i % 2 == 0 { 1.3 } else { 0.7 };
1015 monitor.record_step(loss, &[("w1", 0.5)], 0.01);
1016 }
1017
1018 let suggested = monitor.suggest_lr();
1019 assert!(suggested.is_some());
1020 assert!((suggested.unwrap() - 0.005).abs() < 1e-6); }
1022
1023 #[test]
1024 fn test_suggest_lr_converged_returns_none() {
1025 let config = MonitorConfig {
1026 window_size: 5,
1027 convergence_threshold: 1e-4,
1028 ..MonitorConfig::default()
1029 };
1030 let mut monitor = TrainingMonitor::with_config(config);
1031
1032 for _ in 0..10 {
1034 monitor.record_step(0.001, &[("w1", 0.01)], 0.001);
1035 }
1036
1037 let trend = monitor.loss_trend();
1039 let conv = monitor.convergence_score();
1040 assert_eq!(trend, LossTrend::Stable);
1041 assert!(conv > 0.99);
1042 assert!(monitor.suggest_lr().is_none());
1043 }
1044
1045 #[test]
1050 fn test_convergence_score_fully_converged() {
1051 let config = MonitorConfig {
1052 window_size: 10,
1053 ..MonitorConfig::default()
1054 };
1055 let mut monitor = TrainingMonitor::with_config(config);
1056
1057 for _ in 0..10 {
1058 monitor.record_step(0.5, &[("w1", 0.1)], 0.001);
1059 }
1060
1061 let score = monitor.convergence_score();
1062 assert!((score - 1.0).abs() < 1e-3, "Expected ~1.0, got {}", score);
1063 }
1064
1065 #[test]
1066 fn test_convergence_score_actively_changing() {
1067 let config = MonitorConfig {
1068 window_size: 10,
1069 ..MonitorConfig::default()
1070 };
1071 let mut monitor = TrainingMonitor::with_config(config);
1072
1073 for i in 0..10 {
1074 monitor.record_step(10.0 - i as f32 * 1.0, &[("w1", 0.5)], 0.001);
1075 }
1076
1077 let score = monitor.convergence_score();
1078 assert!(score < 0.5, "Expected low score, got {}", score);
1079 }
1080
1081 #[test]
1082 fn test_convergence_score_insufficient_data() {
1083 let monitor = TrainingMonitor::new();
1084 assert!((monitor.convergence_score() - 0.0).abs() < 1e-6);
1085 }
1086
1087 #[test]
1092 fn test_summary_contains_key_metrics() {
1093 let config = MonitorConfig {
1094 window_size: 5,
1095 ..MonitorConfig::default()
1096 };
1097 let mut monitor = TrainingMonitor::with_config(config);
1098
1099 for _ in 0..5 {
1100 monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
1101 }
1102
1103 let summary = monitor.summary();
1104 assert!(summary.contains("Training Health Report"));
1105 assert!(summary.contains("HEALTHY"));
1106 assert!(summary.contains("Loss:"));
1107 assert!(summary.contains("Grad norms:"));
1108 assert!(summary.contains("Convergence:"));
1109 assert!(summary.contains("Dead neurons:"));
1110 }
1111
1112 #[test]
1117 fn test_clear_alerts_empties_list() {
1118 let mut monitor = TrainingMonitor::new();
1119 monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
1120 assert!(!monitor.alerts().is_empty());
1121
1122 monitor.clear_alerts();
1123 assert!(monitor.alerts().is_empty());
1124 }
1125
1126 #[test]
1131 fn test_max_history_bounds_memory() {
1132 let config = MonitorConfig {
1133 max_history: 20,
1134 window_size: 5,
1135 ..MonitorConfig::default()
1136 };
1137 let mut monitor = TrainingMonitor::with_config(config);
1138
1139 for i in 0..50 {
1140 monitor.record_step(i as f32, &[("w1", 0.5)], 0.001);
1141 }
1142
1143 assert!(monitor.loss_history.len() <= 20);
1144 assert!(monitor.grad_norm_history.len() <= 20);
1145 assert!(monitor.lr_history.len() <= 20);
1146 }
1147
1148 #[test]
1153 fn test_grad_norm_stats_computation() {
1154 let config = MonitorConfig {
1155 window_size: 4,
1156 ..MonitorConfig::default()
1157 };
1158 let mut monitor = TrainingMonitor::with_config(config);
1159
1160 monitor.record_step(1.0, &[("w1", 1.0)], 0.001);
1162 monitor.record_step(1.0, &[("w1", 2.0)], 0.001);
1163 monitor.record_step(1.0, &[("w1", 3.0)], 0.001);
1164 monitor.record_step(1.0, &[("w1", 4.0)], 0.001);
1165
1166 let (mean, std, max) = monitor.grad_norm_stats();
1167 assert!((mean - 2.5).abs() < 1e-4, "Expected mean ~2.5, got {}", mean);
1168 assert!((max - 4.0).abs() < 1e-4, "Expected max 4.0, got {}", max);
1169 assert!((std - 1.118).abs() < 0.01, "Expected std ~1.118, got {}", std);
1171 }
1172
1173 #[test]
1174 fn test_grad_norm_stats_empty() {
1175 let monitor = TrainingMonitor::new();
1176 let (mean, std, max) = monitor.grad_norm_stats();
1177 assert!((mean - 0.0).abs() < 1e-6);
1178 assert!((std - 0.0).abs() < 1e-6);
1179 assert!((max - 0.0).abs() < 1e-6);
1180 }
1181
1182 #[test]
1187 fn test_integration_100_step_improving_training() {
1188 let config = MonitorConfig {
1189 window_size: 20,
1190 ..MonitorConfig::default()
1191 };
1192 let mut monitor = TrainingMonitor::with_config(config);
1193
1194 for i in 0..100 {
1196 let loss = 2.0 * (-0.03 * i as f32).exp(); let grad_norm = 1.0 * (-0.01 * i as f32).exp();
1198 let lr = 0.001;
1199 monitor.record_step(loss, &[("layer1.weight", grad_norm), ("layer2.weight", grad_norm * 0.5)], lr);
1200 }
1201
1202 assert_eq!(monitor.step_count, 100);
1203 assert!(monitor.is_healthy());
1204
1205 let report = monitor.check_health();
1206 assert!(report.is_healthy);
1207 assert_eq!(report.step, 100);
1208 assert!(report.current_loss < 0.2); assert_eq!(report.dead_neurons, 0);
1210
1211 let trend = monitor.loss_trend();
1213 assert_eq!(trend, LossTrend::Decreasing);
1214
1215 let critical_count = monitor.alerts.iter()
1217 .filter(|a| a.severity == AlertSeverity::Critical)
1218 .count();
1219 assert_eq!(critical_count, 0);
1220
1221 let summary = monitor.summary();
1223 assert!(summary.contains("HEALTHY"));
1224 assert!(summary.contains("step 100"));
1225 }
1226
1227 #[test]
1228 fn test_default_trait() {
1229 let monitor = TrainingMonitor::default();
1230 assert_eq!(monitor.step_count, 0);
1231 assert_eq!(monitor.config.window_size, 100);
1232 }
1233
1234 #[test]
1235 fn test_alert_display() {
1236 let alert = TrainingAlert {
1237 step: 42,
1238 severity: AlertSeverity::Critical,
1239 kind: AlertKind::NaNDetected,
1240 message: "NaN detected".to_string(),
1241 };
1242 let display = format!("{}", alert);
1243 assert!(display.contains("42"));
1244 assert!(display.contains("Critical"));
1245 assert!(display.contains("NaN"));
1246 }
1247
1248 #[test]
1249 fn test_multiple_parameters_grad_norms() {
1250 let mut monitor = TrainingMonitor::new();
1251
1252 monitor.record_step(1.0, &[("w1", 5.0), ("w2", 10.0), ("w3", 3.0)], 0.001);
1254
1255 assert_eq!(monitor.grad_norm_history.len(), 1);
1256 assert!((monitor.grad_norm_history[0] - 10.0).abs() < 1e-6);
1257 }
1258}