1use std::collections::HashMap;
26use std::fmt;
27
28#[derive(Debug, Clone)]
34pub struct MonitorConfig {
35 pub window_size: usize,
37 pub grad_norm_threshold: f32,
39 pub loss_divergence_factor: f32,
41 pub dead_neuron_threshold: usize,
43 pub nan_check: bool,
45 pub convergence_threshold: f32,
47 pub max_history: usize,
49}
50
51impl Default for MonitorConfig {
52 fn default() -> Self {
53 Self {
54 window_size: 100,
55 grad_norm_threshold: 100.0,
56 loss_divergence_factor: 10.0,
57 dead_neuron_threshold: 50,
58 nan_check: true,
59 convergence_threshold: 1e-6,
60 max_history: 1000,
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
71pub struct TrainingAlert {
72 pub step: usize,
74 pub severity: AlertSeverity,
76 pub kind: AlertKind,
78 pub message: String,
80}
81
82impl fmt::Display for TrainingAlert {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 write!(
85 f,
86 "[step {}] {:?} {:?}: {}",
87 self.step, self.severity, self.kind, self.message
88 )
89 }
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub enum AlertSeverity {
95 Info,
97 Warning,
99 Critical,
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
105pub enum AlertKind {
106 NaNDetected,
108 InfDetected,
110 GradientExplosion,
112 GradientVanishing,
114 LossDivergence,
116 LossStagnation,
118 DeadNeuron,
120 LearningRateTooHigh,
122 LearningRateTooLow,
124 Converged,
126}
127
128#[derive(Debug, Clone)]
134pub struct HealthReport {
135 pub is_healthy: bool,
137 pub step: usize,
139 pub current_loss: f32,
141 pub loss_trend: LossTrend,
143 pub mean_grad_norm: f32,
145 pub max_grad_norm: f32,
147 pub convergence_score: f32,
149 pub active_alerts: Vec<TrainingAlert>,
151 pub dead_neurons: usize,
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub enum LossTrend {
158 Decreasing,
160 Stable,
162 Increasing,
164 Oscillating,
166 Unknown,
168}
169
170pub struct TrainingMonitor {
181 step_count: usize,
183 loss_history: Vec<f32>,
185 grad_norm_history: Vec<f32>,
187 lr_history: Vec<f32>,
189 alerts: Vec<TrainingAlert>,
191 config: MonitorConfig,
193 zero_grad_counts: HashMap<String, usize>,
195 vanishing_streak: usize,
197}
198
199impl TrainingMonitor {
200 pub fn new() -> Self {
202 Self::with_config(MonitorConfig::default())
203 }
204
205 pub fn with_config(config: MonitorConfig) -> Self {
207 Self {
208 step_count: 0,
209 loss_history: Vec::new(),
210 grad_norm_history: Vec::new(),
211 lr_history: Vec::new(),
212 alerts: Vec::new(),
213 config,
214 zero_grad_counts: HashMap::new(),
215 vanishing_streak: 0,
216 }
217 }
218
219 pub fn record_step(&mut self, loss: f32, grad_norms: &[(&str, f32)], lr: f32) {
227 self.step_count += 1;
228 let step = self.step_count;
229
230 self.push_bounded(&mut self.loss_history.clone(), loss);
232 self.loss_history.push(loss);
233 if self.loss_history.len() > self.config.max_history {
234 self.loss_history.remove(0);
235 }
236
237 let max_grad_norm = grad_norms.iter().map(|(_, n)| *n).fold(0.0_f32, f32::max);
239
240 self.grad_norm_history.push(max_grad_norm);
241 if self.grad_norm_history.len() > self.config.max_history {
242 self.grad_norm_history.remove(0);
243 }
244
245 self.lr_history.push(lr);
246 if self.lr_history.len() > self.config.max_history {
247 self.lr_history.remove(0);
248 }
249
250 if self.config.nan_check {
252 if loss.is_nan() {
253 self.emit_alert(
254 step,
255 AlertSeverity::Critical,
256 AlertKind::NaNDetected,
257 "NaN detected in loss value".to_string(),
258 );
259 } else if loss.is_infinite() {
260 self.emit_alert(
261 step,
262 AlertSeverity::Critical,
263 AlertKind::InfDetected,
264 "Infinity detected in loss value".to_string(),
265 );
266 }
267
268 for (name, norm) in grad_norms {
269 if norm.is_nan() {
270 self.emit_alert(
271 step,
272 AlertSeverity::Critical,
273 AlertKind::NaNDetected,
274 format!("NaN detected in gradient norm for '{}'", name),
275 );
276 } else if norm.is_infinite() {
277 self.emit_alert(
278 step,
279 AlertSeverity::Critical,
280 AlertKind::InfDetected,
281 format!("Infinity detected in gradient norm for '{}'", name),
282 );
283 }
284 }
285 }
286
287 if max_grad_norm > self.config.grad_norm_threshold && max_grad_norm.is_finite() {
289 self.emit_alert(
290 step,
291 AlertSeverity::Warning,
292 AlertKind::GradientExplosion,
293 format!(
294 "Gradient norm {:.4} exceeds threshold {:.4}",
295 max_grad_norm, self.config.grad_norm_threshold
296 ),
297 );
298 }
299
300 if max_grad_norm < 1e-8 && max_grad_norm.is_finite() {
302 self.vanishing_streak += 1;
303 if self.vanishing_streak >= 10 {
304 self.emit_alert(
305 step,
306 AlertSeverity::Warning,
307 AlertKind::GradientVanishing,
308 format!(
309 "Gradient norms near zero for {} consecutive steps",
310 self.vanishing_streak
311 ),
312 );
313 }
314 } else {
315 self.vanishing_streak = 0;
316 }
317
318 let dead_threshold = self.config.dead_neuron_threshold;
320 let mut new_dead_alerts: Vec<(String, usize)> = Vec::new();
321 for (name, norm) in grad_norms {
322 let count = self
323 .zero_grad_counts
324 .entry((*name).to_string())
325 .or_insert(0);
326 if *norm == 0.0 {
327 *count += 1;
328 if *count == dead_threshold {
329 new_dead_alerts.push(((*name).to_string(), *count));
330 }
331 } else {
332 *count = 0;
333 }
334 }
335 for (name, count) in new_dead_alerts {
336 self.emit_alert(
337 step,
338 AlertSeverity::Warning,
339 AlertKind::DeadNeuron,
340 format!(
341 "Parameter '{}' has had zero gradient for {} steps (dead neuron)",
342 name, count
343 ),
344 );
345 }
346
347 if self.loss_history.len() >= self.config.window_size && loss.is_finite() {
349 let window_start = self
350 .loss_history
351 .len()
352 .saturating_sub(self.config.window_size);
353 let window = &self.loss_history[window_start..self.loss_history.len() - 1];
354 let finite_vals: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
355 if !finite_vals.is_empty() {
356 let avg: f32 = finite_vals.iter().sum::<f32>() / finite_vals.len() as f32;
357 if avg > 0.0 && loss > avg * self.config.loss_divergence_factor {
358 self.emit_alert(
359 step,
360 AlertSeverity::Warning,
361 AlertKind::LossDivergence,
362 format!(
363 "Loss {:.6} diverged from moving average {:.6} (factor {:.1}x)",
364 loss,
365 avg,
366 loss / avg
367 ),
368 );
369 }
370 }
371 }
372
373 if self.loss_history.len() >= self.config.window_size {
375 let window_start = self.loss_history.len() - self.config.window_size;
376 let window = &self.loss_history[window_start..];
377 let finite_vals: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
378 if finite_vals.len() >= 2 {
379 let max_val = finite_vals
380 .iter()
381 .copied()
382 .fold(f32::NEG_INFINITY, f32::max);
383 let min_val = finite_vals.iter().copied().fold(f32::INFINITY, f32::min);
384 let range = max_val - min_val;
385 if range < self.config.convergence_threshold {
386 self.emit_alert(
387 step,
388 AlertSeverity::Info,
389 AlertKind::Converged,
390 format!(
391 "Training converged: loss range {:.2e} over last {} steps",
392 range, self.config.window_size
393 ),
394 );
395 }
396 }
397 }
398 }
399
400 pub fn check_health(&self) -> HealthReport {
402 let (mean_gn, _std_gn, max_gn) = self.grad_norm_stats();
403 let trend = self.loss_trend();
404 let conv_score = self.convergence_score();
405
406 let current_loss = self.loss_history.last().copied().unwrap_or(f32::NAN);
407
408 let dead_neurons = self
410 .zero_grad_counts
411 .values()
412 .filter(|c| **c >= self.config.dead_neuron_threshold)
413 .count();
414
415 let has_critical = self
417 .alerts
418 .iter()
419 .any(|a| a.severity == AlertSeverity::Critical);
420 let is_healthy = !has_critical
421 && trend != LossTrend::Increasing
422 && !current_loss.is_nan()
423 && !current_loss.is_infinite();
424
425 let min_step = self.step_count.saturating_sub(self.config.window_size);
427 let active_alerts: Vec<TrainingAlert> = self
428 .alerts
429 .iter()
430 .filter(|a| a.step > min_step)
431 .cloned()
432 .collect();
433
434 HealthReport {
435 is_healthy,
436 step: self.step_count,
437 current_loss,
438 loss_trend: trend,
439 mean_grad_norm: mean_gn,
440 max_grad_norm: max_gn,
441 convergence_score: conv_score,
442 active_alerts,
443 dead_neurons,
444 }
445 }
446
447 pub fn is_healthy(&self) -> bool {
449 self.check_health().is_healthy
450 }
451
452 pub fn alerts(&self) -> &[TrainingAlert] {
454 &self.alerts
455 }
456
457 pub fn clear_alerts(&mut self) {
459 self.alerts.clear();
460 }
461
462 pub fn loss_trend(&self) -> LossTrend {
467 let w = self.config.window_size;
468 if self.loss_history.len() < w * 2 {
469 return LossTrend::Unknown;
470 }
471
472 let len = self.loss_history.len();
473 let recent = &self.loss_history[len - w..];
474 let previous = &self.loss_history[len - 2 * w..len - w];
475
476 let recent_finite: Vec<f32> = recent.iter().copied().filter(|v| v.is_finite()).collect();
477 let prev_finite: Vec<f32> = previous.iter().copied().filter(|v| v.is_finite()).collect();
478
479 if recent_finite.is_empty() || prev_finite.is_empty() {
480 return LossTrend::Unknown;
481 }
482
483 let recent_avg = recent_finite.iter().sum::<f32>() / recent_finite.len() as f32;
484 let prev_avg = prev_finite.iter().sum::<f32>() / prev_finite.len() as f32;
485
486 if prev_avg == 0.0 {
487 return LossTrend::Unknown;
488 }
489
490 let ratio = recent_avg / prev_avg;
491
492 let recent_mean = recent_avg;
494 let recent_var = recent_finite
495 .iter()
496 .map(|v| (v - recent_mean).powi(2))
497 .sum::<f32>()
498 / recent_finite.len() as f32;
499 let recent_std = recent_var.sqrt();
500 let cv = if recent_mean.abs() > 1e-12 {
501 recent_std / recent_mean.abs()
502 } else {
503 0.0
504 };
505
506 if ratio < 0.95 {
507 LossTrend::Decreasing
508 } else if ratio > 1.05 {
509 LossTrend::Increasing
510 } else if cv > 0.1 {
511 LossTrend::Oscillating
512 } else {
513 LossTrend::Stable
514 }
515 }
516
517 pub fn suggest_lr(&self) -> Option<f32> {
521 let current_lr = self.lr_history.last().copied()?;
522 let trend = self.loss_trend();
523
524 let (_, _, max_gn) = self.grad_norm_stats();
526 if max_gn > self.config.grad_norm_threshold && max_gn.is_finite() {
527 return Some(current_lr * 0.1);
528 }
529
530 match trend {
531 LossTrend::Oscillating => Some(current_lr * 0.5),
532 LossTrend::Stable => {
533 let conv = self.convergence_score();
535 if conv > 0.99 {
536 None } else {
538 Some(current_lr * 2.0) }
540 }
541 LossTrend::Increasing => Some(current_lr * 0.1),
542 _ => None,
543 }
544 }
545
546 pub fn grad_norm_stats(&self) -> (f32, f32, f32) {
548 if self.grad_norm_history.is_empty() {
549 return (0.0, 0.0, 0.0);
550 }
551
552 let w = self.config.window_size.min(self.grad_norm_history.len());
553 let start = self.grad_norm_history.len() - w;
554 let window = &self.grad_norm_history[start..];
555
556 let finite: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
557 if finite.is_empty() {
558 return (0.0, 0.0, 0.0);
559 }
560
561 let n = finite.len() as f32;
562 let mean = finite.iter().sum::<f32>() / n;
563 let variance = finite.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n;
564 let std = variance.sqrt();
565 let max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
566
567 (mean, std, max)
568 }
569
570 pub fn convergence_score(&self) -> f32 {
575 let w = self.config.window_size;
576 if self.loss_history.len() < w {
577 return 0.0;
578 }
579
580 let start = self.loss_history.len() - w;
581 let window = &self.loss_history[start..];
582
583 let finite: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
584 if finite.len() < 2 {
585 return 0.0;
586 }
587
588 let max_val = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
589 let min_val = finite.iter().copied().fold(f32::INFINITY, f32::min);
590 let range = max_val - min_val;
591
592 let mean = finite.iter().sum::<f32>() / finite.len() as f32;
593 if mean.abs() < 1e-12 {
594 if range < self.config.convergence_threshold {
596 return 1.0;
597 }
598 return 0.0;
599 }
600
601 let relative_range = range / mean.abs();
603
604 let score = (-relative_range * 100.0).exp();
607 score.clamp(0.0, 1.0)
608 }
609
610 pub fn summary(&self) -> String {
612 let report = self.check_health();
613 let (mean_gn, std_gn, max_gn) = self.grad_norm_stats();
614
615 let mut s = String::new();
616 s.push_str(&format!(
617 "=== Training Health Report (step {}) ===\n",
618 report.step
619 ));
620 s.push_str(&format!(
621 "Status: {}\n",
622 if report.is_healthy {
623 "HEALTHY"
624 } else {
625 "UNHEALTHY"
626 }
627 ));
628 s.push_str(&format!(
629 "Loss: {:.6} (trend: {:?})\n",
630 report.current_loss, report.loss_trend
631 ));
632 s.push_str(&format!(
633 "Grad norms: mean={:.4}, std={:.4}, max={:.4}\n",
634 mean_gn, std_gn, max_gn
635 ));
636 s.push_str(&format!(
637 "Convergence: {:.2}%\n",
638 report.convergence_score * 100.0
639 ));
640 s.push_str(&format!("Dead neurons: {}\n", report.dead_neurons));
641
642 if !report.active_alerts.is_empty() {
643 s.push_str(&format!(
644 "Active alerts ({}):\n",
645 report.active_alerts.len()
646 ));
647 for alert in &report.active_alerts {
648 s.push_str(&format!(" {}\n", alert));
649 }
650 }
651
652 if let Some(lr) = self.suggest_lr() {
653 s.push_str(&format!("Suggested LR: {:.6}\n", lr));
654 }
655
656 s
657 }
658
659 fn push_bounded(&self, _history: &mut Vec<f32>, _value: f32) {
664 }
667
668 fn emit_alert(
669 &mut self,
670 step: usize,
671 severity: AlertSeverity,
672 kind: AlertKind,
673 message: String,
674 ) {
675 self.alerts.push(TrainingAlert {
676 step,
677 severity,
678 kind,
679 message,
680 });
681 }
682}
683
684impl Default for TrainingMonitor {
685 fn default() -> Self {
686 Self::new()
687 }
688}
689
690#[cfg(test)]
695mod tests {
696 use super::*;
697
698 #[test]
703 fn test_monitor_creation_defaults() {
704 let monitor = TrainingMonitor::new();
705 assert_eq!(monitor.step_count, 0);
706 assert!(monitor.loss_history.is_empty());
707 assert!(monitor.alerts.is_empty());
708 assert_eq!(monitor.config.window_size, 100);
709 assert!((monitor.config.grad_norm_threshold - 100.0).abs() < 1e-6);
710 assert!(monitor.config.nan_check);
711 }
712
713 #[test]
714 fn test_monitor_with_custom_config() {
715 let config = MonitorConfig {
716 window_size: 50,
717 grad_norm_threshold: 50.0,
718 loss_divergence_factor: 5.0,
719 dead_neuron_threshold: 20,
720 nan_check: false,
721 convergence_threshold: 1e-5,
722 max_history: 500,
723 };
724 let monitor = TrainingMonitor::with_config(config);
725 assert_eq!(monitor.config.window_size, 50);
726 assert!((monitor.config.grad_norm_threshold - 50.0).abs() < 1e-6);
727 assert!(!monitor.config.nan_check);
728 assert_eq!(monitor.config.max_history, 500);
729 }
730
731 #[test]
736 fn test_record_step_updates_state() {
737 let mut monitor = TrainingMonitor::new();
738 monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
739
740 assert_eq!(monitor.step_count, 1);
741 assert_eq!(monitor.loss_history.len(), 1);
742 assert_eq!(monitor.grad_norm_history.len(), 1);
743 assert_eq!(monitor.lr_history.len(), 1);
744 assert!((monitor.loss_history[0] - 0.5).abs() < 1e-6);
745 }
746
747 #[test]
752 fn test_nan_detection_generates_critical_alert() {
753 let mut monitor = TrainingMonitor::new();
754 monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
755
756 assert_eq!(monitor.alerts.len(), 1);
757 assert_eq!(monitor.alerts[0].severity, AlertSeverity::Critical);
758 assert_eq!(monitor.alerts[0].kind, AlertKind::NaNDetected);
759 assert!(monitor.alerts[0].message.contains("NaN"));
760 }
761
762 #[test]
763 fn test_inf_detection_generates_critical_alert() {
764 let mut monitor = TrainingMonitor::new();
765 monitor.record_step(f32::INFINITY, &[("w1", 1.0)], 0.001);
766
767 assert_eq!(monitor.alerts.len(), 1);
768 assert_eq!(monitor.alerts[0].severity, AlertSeverity::Critical);
769 assert_eq!(monitor.alerts[0].kind, AlertKind::InfDetected);
770 assert!(monitor.alerts[0].message.contains("Infinity"));
771 }
772
773 #[test]
774 fn test_nan_in_grad_norm_detected() {
775 let mut monitor = TrainingMonitor::new();
776 monitor.record_step(0.5, &[("w1", f32::NAN)], 0.001);
777
778 let nan_alerts: Vec<_> = monitor
779 .alerts
780 .iter()
781 .filter(|a| a.kind == AlertKind::NaNDetected)
782 .collect();
783 assert_eq!(nan_alerts.len(), 1);
784 assert!(nan_alerts[0].message.contains("w1"));
785 }
786
787 #[test]
788 fn test_nan_check_disabled() {
789 let config = MonitorConfig {
790 nan_check: false,
791 ..MonitorConfig::default()
792 };
793 let mut monitor = TrainingMonitor::with_config(config);
794 monitor.record_step(f32::NAN, &[("w1", f32::NAN)], 0.001);
795
796 assert!(monitor.alerts.is_empty());
797 }
798
799 #[test]
804 fn test_gradient_explosion_detection() {
805 let mut monitor = TrainingMonitor::new();
806 monitor.record_step(0.5, &[("w1", 200.0)], 0.001);
807
808 let explosion_alerts: Vec<_> = monitor
809 .alerts
810 .iter()
811 .filter(|a| a.kind == AlertKind::GradientExplosion)
812 .collect();
813 assert_eq!(explosion_alerts.len(), 1);
814 assert_eq!(explosion_alerts[0].severity, AlertSeverity::Warning);
815 }
816
817 #[test]
818 fn test_gradient_vanishing_detection() {
819 let config = MonitorConfig {
820 window_size: 5,
821 ..MonitorConfig::default()
822 };
823 let mut monitor = TrainingMonitor::with_config(config);
824
825 for _ in 0..10 {
827 monitor.record_step(0.5, &[("w1", 1e-10)], 0.001);
828 }
829
830 let vanishing_alerts: Vec<_> = monitor
831 .alerts
832 .iter()
833 .filter(|a| a.kind == AlertKind::GradientVanishing)
834 .collect();
835 assert!(!vanishing_alerts.is_empty());
836 }
837
838 #[test]
839 fn test_gradient_vanishing_resets_on_normal_grad() {
840 let mut monitor = TrainingMonitor::new();
841
842 for _ in 0..5 {
844 monitor.record_step(0.5, &[("w1", 1e-10)], 0.001);
845 }
846 monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
848 assert_eq!(monitor.vanishing_streak, 0);
849 }
850
851 #[test]
856 fn test_loss_divergence_detection() {
857 let config = MonitorConfig {
858 window_size: 10,
859 loss_divergence_factor: 2.0,
860 ..MonitorConfig::default()
861 };
862 let mut monitor = TrainingMonitor::with_config(config);
863
864 for _ in 0..10 {
866 monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
867 }
868
869 monitor.record_step(100.0, &[("w1", 0.5)], 0.001);
871
872 let divergence_alerts: Vec<_> = monitor
873 .alerts
874 .iter()
875 .filter(|a| a.kind == AlertKind::LossDivergence)
876 .collect();
877 assert!(!divergence_alerts.is_empty());
878 }
879
880 #[test]
885 fn test_dead_neuron_tracking() {
886 let config = MonitorConfig {
887 dead_neuron_threshold: 5,
888 ..MonitorConfig::default()
889 };
890 let mut monitor = TrainingMonitor::with_config(config);
891
892 for _ in 0..5 {
893 monitor.record_step(0.5, &[("dead_layer", 0.0), ("alive_layer", 0.5)], 0.001);
894 }
895
896 let dead_alerts: Vec<_> = monitor
897 .alerts
898 .iter()
899 .filter(|a| a.kind == AlertKind::DeadNeuron)
900 .collect();
901 assert_eq!(dead_alerts.len(), 1);
902 assert!(dead_alerts[0].message.contains("dead_layer"));
903 }
904
905 #[test]
906 fn test_dead_neuron_resets_on_nonzero_grad() {
907 let config = MonitorConfig {
908 dead_neuron_threshold: 10,
909 ..MonitorConfig::default()
910 };
911 let mut monitor = TrainingMonitor::with_config(config);
912
913 for _ in 0..5 {
915 monitor.record_step(0.5, &[("layer", 0.0)], 0.001);
916 }
917 monitor.record_step(0.5, &[("layer", 1.0)], 0.001);
919
920 assert_eq!(*monitor.zero_grad_counts.get("layer").unwrap(), 0);
921 }
922
923 #[test]
928 fn test_convergence_detection() {
929 let config = MonitorConfig {
930 window_size: 10,
931 convergence_threshold: 1e-4,
932 ..MonitorConfig::default()
933 };
934 let mut monitor = TrainingMonitor::with_config(config);
935
936 for _ in 0..10 {
938 monitor.record_step(0.001, &[("w1", 0.1)], 0.001);
939 }
940
941 let converged_alerts: Vec<_> = monitor
942 .alerts
943 .iter()
944 .filter(|a| a.kind == AlertKind::Converged)
945 .collect();
946 assert!(!converged_alerts.is_empty());
947 }
948
949 #[test]
954 fn test_loss_trend_decreasing() {
955 let config = MonitorConfig {
956 window_size: 5,
957 ..MonitorConfig::default()
958 };
959 let mut monitor = TrainingMonitor::with_config(config);
960
961 for i in 0..5 {
963 monitor.record_step(2.0 - i as f32 * 0.01, &[("w1", 0.5)], 0.001);
964 }
965 for i in 0..5 {
967 monitor.record_step(1.0 - i as f32 * 0.01, &[("w1", 0.5)], 0.001);
968 }
969
970 assert_eq!(monitor.loss_trend(), LossTrend::Decreasing);
971 }
972
973 #[test]
974 fn test_loss_trend_increasing() {
975 let config = MonitorConfig {
976 window_size: 5,
977 ..MonitorConfig::default()
978 };
979 let mut monitor = TrainingMonitor::with_config(config);
980
981 for i in 0..5 {
983 monitor.record_step(1.0 + i as f32 * 0.01, &[("w1", 0.5)], 0.001);
984 }
985 for i in 0..5 {
987 monitor.record_step(2.0 + i as f32 * 0.01, &[("w1", 0.5)], 0.001);
988 }
989
990 assert_eq!(monitor.loss_trend(), LossTrend::Increasing);
991 }
992
993 #[test]
994 fn test_loss_trend_oscillating() {
995 let config = MonitorConfig {
996 window_size: 10,
997 ..MonitorConfig::default()
998 };
999 let mut monitor = TrainingMonitor::with_config(config);
1000
1001 for _ in 0..10 {
1003 monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
1004 }
1005 for i in 0..10 {
1007 let loss = if i % 2 == 0 { 1.3 } else { 0.7 };
1008 monitor.record_step(loss, &[("w1", 0.5)], 0.001);
1009 }
1010
1011 assert_eq!(monitor.loss_trend(), LossTrend::Oscillating);
1012 }
1013
1014 #[test]
1015 fn test_loss_trend_stable() {
1016 let config = MonitorConfig {
1017 window_size: 5,
1018 ..MonitorConfig::default()
1019 };
1020 let mut monitor = TrainingMonitor::with_config(config);
1021
1022 for _ in 0..10 {
1024 monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
1025 }
1026
1027 assert_eq!(monitor.loss_trend(), LossTrend::Stable);
1028 }
1029
1030 #[test]
1031 fn test_loss_trend_unknown_insufficient_data() {
1032 let config = MonitorConfig {
1033 window_size: 100,
1034 ..MonitorConfig::default()
1035 };
1036 let mut monitor = TrainingMonitor::with_config(config);
1037
1038 monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
1039
1040 assert_eq!(monitor.loss_trend(), LossTrend::Unknown);
1041 }
1042
1043 #[test]
1048 fn test_health_report_healthy_normal_training() {
1049 let config = MonitorConfig {
1050 window_size: 5,
1051 ..MonitorConfig::default()
1052 };
1053 let mut monitor = TrainingMonitor::with_config(config);
1054
1055 for i in 0..10 {
1056 monitor.record_step(1.0 - i as f32 * 0.05, &[("w1", 0.5)], 0.001);
1057 }
1058
1059 let report = monitor.check_health();
1060 assert!(report.is_healthy);
1061 assert_eq!(report.step, 10);
1062 assert_eq!(report.dead_neurons, 0);
1063 }
1064
1065 #[test]
1066 fn test_health_report_not_healthy_with_nan() {
1067 let mut monitor = TrainingMonitor::new();
1068 monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
1069
1070 let report = monitor.check_health();
1071 assert!(!report.is_healthy);
1072 }
1073
1074 #[test]
1079 fn test_suggest_lr_exploding_gradients() {
1080 let config = MonitorConfig {
1081 window_size: 5,
1082 grad_norm_threshold: 10.0,
1083 ..MonitorConfig::default()
1084 };
1085 let mut monitor = TrainingMonitor::with_config(config);
1086
1087 for _ in 0..5 {
1089 monitor.record_step(1.0, &[("w1", 50.0)], 0.01);
1090 }
1091
1092 let suggested = monitor.suggest_lr();
1093 assert!(suggested.is_some());
1094 assert!((suggested.unwrap() - 0.001).abs() < 1e-6); }
1096
1097 #[test]
1098 fn test_suggest_lr_oscillating_loss() {
1099 let config = MonitorConfig {
1100 window_size: 10,
1101 ..MonitorConfig::default()
1102 };
1103 let mut monitor = TrainingMonitor::with_config(config);
1104
1105 for _ in 0..10 {
1107 monitor.record_step(1.0, &[("w1", 0.5)], 0.01);
1108 }
1109 for i in 0..10 {
1111 let loss = if i % 2 == 0 { 1.3 } else { 0.7 };
1112 monitor.record_step(loss, &[("w1", 0.5)], 0.01);
1113 }
1114
1115 let suggested = monitor.suggest_lr();
1116 assert!(suggested.is_some());
1117 assert!((suggested.unwrap() - 0.005).abs() < 1e-6); }
1119
1120 #[test]
1121 fn test_suggest_lr_converged_returns_none() {
1122 let config = MonitorConfig {
1123 window_size: 5,
1124 convergence_threshold: 1e-4,
1125 ..MonitorConfig::default()
1126 };
1127 let mut monitor = TrainingMonitor::with_config(config);
1128
1129 for _ in 0..10 {
1131 monitor.record_step(0.001, &[("w1", 0.01)], 0.001);
1132 }
1133
1134 let trend = monitor.loss_trend();
1136 let conv = monitor.convergence_score();
1137 assert_eq!(trend, LossTrend::Stable);
1138 assert!(conv > 0.99);
1139 assert!(monitor.suggest_lr().is_none());
1140 }
1141
1142 #[test]
1147 fn test_convergence_score_fully_converged() {
1148 let config = MonitorConfig {
1149 window_size: 10,
1150 ..MonitorConfig::default()
1151 };
1152 let mut monitor = TrainingMonitor::with_config(config);
1153
1154 for _ in 0..10 {
1155 monitor.record_step(0.5, &[("w1", 0.1)], 0.001);
1156 }
1157
1158 let score = monitor.convergence_score();
1159 assert!((score - 1.0).abs() < 1e-3, "Expected ~1.0, got {}", score);
1160 }
1161
1162 #[test]
1163 fn test_convergence_score_actively_changing() {
1164 let config = MonitorConfig {
1165 window_size: 10,
1166 ..MonitorConfig::default()
1167 };
1168 let mut monitor = TrainingMonitor::with_config(config);
1169
1170 for i in 0..10 {
1171 monitor.record_step(10.0 - i as f32 * 1.0, &[("w1", 0.5)], 0.001);
1172 }
1173
1174 let score = monitor.convergence_score();
1175 assert!(score < 0.5, "Expected low score, got {}", score);
1176 }
1177
1178 #[test]
1179 fn test_convergence_score_insufficient_data() {
1180 let monitor = TrainingMonitor::new();
1181 assert!((monitor.convergence_score() - 0.0).abs() < 1e-6);
1182 }
1183
1184 #[test]
1189 fn test_summary_contains_key_metrics() {
1190 let config = MonitorConfig {
1191 window_size: 5,
1192 ..MonitorConfig::default()
1193 };
1194 let mut monitor = TrainingMonitor::with_config(config);
1195
1196 for _ in 0..5 {
1197 monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
1198 }
1199
1200 let summary = monitor.summary();
1201 assert!(summary.contains("Training Health Report"));
1202 assert!(summary.contains("HEALTHY"));
1203 assert!(summary.contains("Loss:"));
1204 assert!(summary.contains("Grad norms:"));
1205 assert!(summary.contains("Convergence:"));
1206 assert!(summary.contains("Dead neurons:"));
1207 }
1208
1209 #[test]
1214 fn test_clear_alerts_empties_list() {
1215 let mut monitor = TrainingMonitor::new();
1216 monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
1217 assert!(!monitor.alerts().is_empty());
1218
1219 monitor.clear_alerts();
1220 assert!(monitor.alerts().is_empty());
1221 }
1222
1223 #[test]
1228 fn test_max_history_bounds_memory() {
1229 let config = MonitorConfig {
1230 max_history: 20,
1231 window_size: 5,
1232 ..MonitorConfig::default()
1233 };
1234 let mut monitor = TrainingMonitor::with_config(config);
1235
1236 for i in 0..50 {
1237 monitor.record_step(i as f32, &[("w1", 0.5)], 0.001);
1238 }
1239
1240 assert!(monitor.loss_history.len() <= 20);
1241 assert!(monitor.grad_norm_history.len() <= 20);
1242 assert!(monitor.lr_history.len() <= 20);
1243 }
1244
1245 #[test]
1250 fn test_grad_norm_stats_computation() {
1251 let config = MonitorConfig {
1252 window_size: 4,
1253 ..MonitorConfig::default()
1254 };
1255 let mut monitor = TrainingMonitor::with_config(config);
1256
1257 monitor.record_step(1.0, &[("w1", 1.0)], 0.001);
1259 monitor.record_step(1.0, &[("w1", 2.0)], 0.001);
1260 monitor.record_step(1.0, &[("w1", 3.0)], 0.001);
1261 monitor.record_step(1.0, &[("w1", 4.0)], 0.001);
1262
1263 let (mean, std, max) = monitor.grad_norm_stats();
1264 assert!(
1265 (mean - 2.5).abs() < 1e-4,
1266 "Expected mean ~2.5, got {}",
1267 mean
1268 );
1269 assert!((max - 4.0).abs() < 1e-4, "Expected max 4.0, got {}", max);
1270 assert!(
1272 (std - 1.118).abs() < 0.01,
1273 "Expected std ~1.118, got {}",
1274 std
1275 );
1276 }
1277
1278 #[test]
1279 fn test_grad_norm_stats_empty() {
1280 let monitor = TrainingMonitor::new();
1281 let (mean, std, max) = monitor.grad_norm_stats();
1282 assert!((mean - 0.0).abs() < 1e-6);
1283 assert!((std - 0.0).abs() < 1e-6);
1284 assert!((max - 0.0).abs() < 1e-6);
1285 }
1286
1287 #[test]
1292 fn test_integration_100_step_improving_training() {
1293 let config = MonitorConfig {
1294 window_size: 20,
1295 ..MonitorConfig::default()
1296 };
1297 let mut monitor = TrainingMonitor::with_config(config);
1298
1299 for i in 0..100 {
1301 let loss = 2.0 * (-0.03 * i as f32).exp(); let grad_norm = 1.0 * (-0.01 * i as f32).exp();
1303 let lr = 0.001;
1304 monitor.record_step(
1305 loss,
1306 &[
1307 ("layer1.weight", grad_norm),
1308 ("layer2.weight", grad_norm * 0.5),
1309 ],
1310 lr,
1311 );
1312 }
1313
1314 assert_eq!(monitor.step_count, 100);
1315 assert!(monitor.is_healthy());
1316
1317 let report = monitor.check_health();
1318 assert!(report.is_healthy);
1319 assert_eq!(report.step, 100);
1320 assert!(report.current_loss < 0.2); assert_eq!(report.dead_neurons, 0);
1322
1323 let trend = monitor.loss_trend();
1325 assert_eq!(trend, LossTrend::Decreasing);
1326
1327 let critical_count = monitor
1329 .alerts
1330 .iter()
1331 .filter(|a| a.severity == AlertSeverity::Critical)
1332 .count();
1333 assert_eq!(critical_count, 0);
1334
1335 let summary = monitor.summary();
1337 assert!(summary.contains("HEALTHY"));
1338 assert!(summary.contains("step 100"));
1339 }
1340
1341 #[test]
1342 fn test_default_trait() {
1343 let monitor = TrainingMonitor::default();
1344 assert_eq!(monitor.step_count, 0);
1345 assert_eq!(monitor.config.window_size, 100);
1346 }
1347
1348 #[test]
1349 fn test_alert_display() {
1350 let alert = TrainingAlert {
1351 step: 42,
1352 severity: AlertSeverity::Critical,
1353 kind: AlertKind::NaNDetected,
1354 message: "NaN detected".to_string(),
1355 };
1356 let display = format!("{}", alert);
1357 assert!(display.contains("42"));
1358 assert!(display.contains("Critical"));
1359 assert!(display.contains("NaN"));
1360 }
1361
1362 #[test]
1363 fn test_multiple_parameters_grad_norms() {
1364 let mut monitor = TrainingMonitor::new();
1365
1366 monitor.record_step(1.0, &[("w1", 5.0), ("w2", 10.0), ("w3", 3.0)], 0.001);
1368
1369 assert_eq!(monitor.grad_norm_history.len(), 1);
1370 assert!((monitor.grad_norm_history[0] - 10.0).abs() < 1e-6);
1371 }
1372}