1use std::collections::HashMap;
18use std::fmt;
19
20#[derive(Debug, Clone)]
26pub struct MonitorConfig {
27 pub window_size: usize,
29 pub grad_norm_threshold: f32,
31 pub loss_divergence_factor: f32,
33 pub dead_neuron_threshold: usize,
35 pub nan_check: bool,
37 pub convergence_threshold: f32,
39 pub max_history: usize,
41}
42
43impl Default for MonitorConfig {
44 fn default() -> Self {
45 Self {
46 window_size: 100,
47 grad_norm_threshold: 100.0,
48 loss_divergence_factor: 10.0,
49 dead_neuron_threshold: 50,
50 nan_check: true,
51 convergence_threshold: 1e-6,
52 max_history: 1000,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
63pub struct TrainingAlert {
64 pub step: usize,
66 pub severity: AlertSeverity,
68 pub kind: AlertKind,
70 pub message: String,
72}
73
74impl fmt::Display for TrainingAlert {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 write!(
77 f,
78 "[step {}] {:?} {:?}: {}",
79 self.step, self.severity, self.kind, self.message
80 )
81 }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum AlertSeverity {
87 Info,
89 Warning,
91 Critical,
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub enum AlertKind {
98 NaNDetected,
100 InfDetected,
102 GradientExplosion,
104 GradientVanishing,
106 LossDivergence,
108 LossStagnation,
110 DeadNeuron,
112 LearningRateTooHigh,
114 LearningRateTooLow,
116 Converged,
118}
119
120#[derive(Debug, Clone)]
126pub struct HealthReport {
127 pub is_healthy: bool,
129 pub step: usize,
131 pub current_loss: f32,
133 pub loss_trend: LossTrend,
135 pub mean_grad_norm: f32,
137 pub max_grad_norm: f32,
139 pub convergence_score: f32,
141 pub active_alerts: Vec<TrainingAlert>,
143 pub dead_neurons: usize,
145}
146
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
149pub enum LossTrend {
150 Decreasing,
152 Stable,
154 Increasing,
156 Oscillating,
158 Unknown,
160}
161
162pub struct TrainingMonitor {
173 step_count: usize,
175 loss_history: Vec<f32>,
177 grad_norm_history: Vec<f32>,
179 lr_history: Vec<f32>,
181 alerts: Vec<TrainingAlert>,
183 config: MonitorConfig,
185 zero_grad_counts: HashMap<String, usize>,
187 vanishing_streak: usize,
189}
190
191impl TrainingMonitor {
192 pub fn new() -> Self {
194 Self::with_config(MonitorConfig::default())
195 }
196
197 pub fn with_config(config: MonitorConfig) -> Self {
199 Self {
200 step_count: 0,
201 loss_history: Vec::new(),
202 grad_norm_history: Vec::new(),
203 lr_history: Vec::new(),
204 alerts: Vec::new(),
205 config,
206 zero_grad_counts: HashMap::new(),
207 vanishing_streak: 0,
208 }
209 }
210
211 pub fn record_step(&mut self, loss: f32, grad_norms: &[(&str, f32)], lr: f32) {
219 self.step_count += 1;
220 let step = self.step_count;
221
222 self.push_bounded(&mut self.loss_history.clone(), loss);
224 self.loss_history.push(loss);
225 if self.loss_history.len() > self.config.max_history {
226 self.loss_history.remove(0);
227 }
228
229 let max_grad_norm = grad_norms.iter().map(|(_, n)| *n).fold(0.0_f32, f32::max);
231
232 self.grad_norm_history.push(max_grad_norm);
233 if self.grad_norm_history.len() > self.config.max_history {
234 self.grad_norm_history.remove(0);
235 }
236
237 self.lr_history.push(lr);
238 if self.lr_history.len() > self.config.max_history {
239 self.lr_history.remove(0);
240 }
241
242 if self.config.nan_check {
244 if loss.is_nan() {
245 self.emit_alert(
246 step,
247 AlertSeverity::Critical,
248 AlertKind::NaNDetected,
249 "NaN detected in loss value".to_string(),
250 );
251 } else if loss.is_infinite() {
252 self.emit_alert(
253 step,
254 AlertSeverity::Critical,
255 AlertKind::InfDetected,
256 "Infinity detected in loss value".to_string(),
257 );
258 }
259
260 for (name, norm) in grad_norms {
261 if norm.is_nan() {
262 self.emit_alert(
263 step,
264 AlertSeverity::Critical,
265 AlertKind::NaNDetected,
266 format!("NaN detected in gradient norm for '{}'", name),
267 );
268 } else if norm.is_infinite() {
269 self.emit_alert(
270 step,
271 AlertSeverity::Critical,
272 AlertKind::InfDetected,
273 format!("Infinity detected in gradient norm for '{}'", name),
274 );
275 }
276 }
277 }
278
279 if max_grad_norm > self.config.grad_norm_threshold && max_grad_norm.is_finite() {
281 self.emit_alert(
282 step,
283 AlertSeverity::Warning,
284 AlertKind::GradientExplosion,
285 format!(
286 "Gradient norm {:.4} exceeds threshold {:.4}",
287 max_grad_norm, self.config.grad_norm_threshold
288 ),
289 );
290 }
291
292 if max_grad_norm < 1e-8 && max_grad_norm.is_finite() {
294 self.vanishing_streak += 1;
295 if self.vanishing_streak >= 10 {
296 self.emit_alert(
297 step,
298 AlertSeverity::Warning,
299 AlertKind::GradientVanishing,
300 format!(
301 "Gradient norms near zero for {} consecutive steps",
302 self.vanishing_streak
303 ),
304 );
305 }
306 } else {
307 self.vanishing_streak = 0;
308 }
309
310 let dead_threshold = self.config.dead_neuron_threshold;
312 let mut new_dead_alerts: Vec<(String, usize)> = Vec::new();
313 for (name, norm) in grad_norms {
314 let count = self
315 .zero_grad_counts
316 .entry((*name).to_string())
317 .or_insert(0);
318 if *norm == 0.0 {
319 *count += 1;
320 if *count == dead_threshold {
321 new_dead_alerts.push(((*name).to_string(), *count));
322 }
323 } else {
324 *count = 0;
325 }
326 }
327 for (name, count) in new_dead_alerts {
328 self.emit_alert(
329 step,
330 AlertSeverity::Warning,
331 AlertKind::DeadNeuron,
332 format!(
333 "Parameter '{}' has had zero gradient for {} steps (dead neuron)",
334 name, count
335 ),
336 );
337 }
338
339 if self.loss_history.len() >= self.config.window_size && loss.is_finite() {
341 let window_start = self
342 .loss_history
343 .len()
344 .saturating_sub(self.config.window_size);
345 let window = &self.loss_history[window_start..self.loss_history.len() - 1];
346 let finite_vals: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
347 if !finite_vals.is_empty() {
348 let avg: f32 = finite_vals.iter().sum::<f32>() / finite_vals.len() as f32;
349 if avg > 0.0 && loss > avg * self.config.loss_divergence_factor {
350 self.emit_alert(
351 step,
352 AlertSeverity::Warning,
353 AlertKind::LossDivergence,
354 format!(
355 "Loss {:.6} diverged from moving average {:.6} (factor {:.1}x)",
356 loss,
357 avg,
358 loss / avg
359 ),
360 );
361 }
362 }
363 }
364
365 if self.loss_history.len() >= self.config.window_size {
367 let window_start = self.loss_history.len() - self.config.window_size;
368 let window = &self.loss_history[window_start..];
369 let finite_vals: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
370 if finite_vals.len() >= 2 {
371 let max_val = finite_vals
372 .iter()
373 .copied()
374 .fold(f32::NEG_INFINITY, f32::max);
375 let min_val = finite_vals.iter().copied().fold(f32::INFINITY, f32::min);
376 let range = max_val - min_val;
377 if range < self.config.convergence_threshold {
378 self.emit_alert(
379 step,
380 AlertSeverity::Info,
381 AlertKind::Converged,
382 format!(
383 "Training converged: loss range {:.2e} over last {} steps",
384 range, self.config.window_size
385 ),
386 );
387 }
388 }
389 }
390 }
391
392 pub fn check_health(&self) -> HealthReport {
394 let (mean_gn, _std_gn, max_gn) = self.grad_norm_stats();
395 let trend = self.loss_trend();
396 let conv_score = self.convergence_score();
397
398 let current_loss = self.loss_history.last().copied().unwrap_or(f32::NAN);
399
400 let dead_neurons = self
402 .zero_grad_counts
403 .values()
404 .filter(|c| **c >= self.config.dead_neuron_threshold)
405 .count();
406
407 let has_critical = self
409 .alerts
410 .iter()
411 .any(|a| a.severity == AlertSeverity::Critical);
412 let is_healthy = !has_critical
413 && trend != LossTrend::Increasing
414 && !current_loss.is_nan()
415 && !current_loss.is_infinite();
416
417 let min_step = self.step_count.saturating_sub(self.config.window_size);
419 let active_alerts: Vec<TrainingAlert> = self
420 .alerts
421 .iter()
422 .filter(|a| a.step > min_step)
423 .cloned()
424 .collect();
425
426 HealthReport {
427 is_healthy,
428 step: self.step_count,
429 current_loss,
430 loss_trend: trend,
431 mean_grad_norm: mean_gn,
432 max_grad_norm: max_gn,
433 convergence_score: conv_score,
434 active_alerts,
435 dead_neurons,
436 }
437 }
438
439 pub fn is_healthy(&self) -> bool {
441 self.check_health().is_healthy
442 }
443
444 pub fn alerts(&self) -> &[TrainingAlert] {
446 &self.alerts
447 }
448
449 pub fn clear_alerts(&mut self) {
451 self.alerts.clear();
452 }
453
454 pub fn loss_trend(&self) -> LossTrend {
459 let w = self.config.window_size;
460 if self.loss_history.len() < w * 2 {
461 return LossTrend::Unknown;
462 }
463
464 let len = self.loss_history.len();
465 let recent = &self.loss_history[len - w..];
466 let previous = &self.loss_history[len - 2 * w..len - w];
467
468 let recent_finite: Vec<f32> = recent.iter().copied().filter(|v| v.is_finite()).collect();
469 let prev_finite: Vec<f32> = previous.iter().copied().filter(|v| v.is_finite()).collect();
470
471 if recent_finite.is_empty() || prev_finite.is_empty() {
472 return LossTrend::Unknown;
473 }
474
475 let recent_avg = recent_finite.iter().sum::<f32>() / recent_finite.len() as f32;
476 let prev_avg = prev_finite.iter().sum::<f32>() / prev_finite.len() as f32;
477
478 if prev_avg == 0.0 {
479 return LossTrend::Unknown;
480 }
481
482 let ratio = recent_avg / prev_avg;
483
484 let recent_mean = recent_avg;
486 let recent_var = recent_finite
487 .iter()
488 .map(|v| (v - recent_mean).powi(2))
489 .sum::<f32>()
490 / recent_finite.len() as f32;
491 let recent_std = recent_var.sqrt();
492 let cv = if recent_mean.abs() > 1e-12 {
493 recent_std / recent_mean.abs()
494 } else {
495 0.0
496 };
497
498 if ratio < 0.95 {
499 LossTrend::Decreasing
500 } else if ratio > 1.05 {
501 LossTrend::Increasing
502 } else if cv > 0.1 {
503 LossTrend::Oscillating
504 } else {
505 LossTrend::Stable
506 }
507 }
508
509 pub fn suggest_lr(&self) -> Option<f32> {
513 let current_lr = self.lr_history.last().copied()?;
514 let trend = self.loss_trend();
515
516 let (_, _, max_gn) = self.grad_norm_stats();
518 if max_gn > self.config.grad_norm_threshold && max_gn.is_finite() {
519 return Some(current_lr * 0.1);
520 }
521
522 match trend {
523 LossTrend::Oscillating => Some(current_lr * 0.5),
524 LossTrend::Stable => {
525 let conv = self.convergence_score();
527 if conv > 0.99 {
528 None } else {
530 Some(current_lr * 2.0) }
532 }
533 LossTrend::Increasing => Some(current_lr * 0.1),
534 _ => None,
535 }
536 }
537
538 pub fn grad_norm_stats(&self) -> (f32, f32, f32) {
540 if self.grad_norm_history.is_empty() {
541 return (0.0, 0.0, 0.0);
542 }
543
544 let w = self.config.window_size.min(self.grad_norm_history.len());
545 let start = self.grad_norm_history.len() - w;
546 let window = &self.grad_norm_history[start..];
547
548 let finite: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
549 if finite.is_empty() {
550 return (0.0, 0.0, 0.0);
551 }
552
553 let n = finite.len() as f32;
554 let mean = finite.iter().sum::<f32>() / n;
555 let variance = finite.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n;
556 let std = variance.sqrt();
557 let max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
558
559 (mean, std, max)
560 }
561
562 pub fn convergence_score(&self) -> f32 {
567 let w = self.config.window_size;
568 if self.loss_history.len() < w {
569 return 0.0;
570 }
571
572 let start = self.loss_history.len() - w;
573 let window = &self.loss_history[start..];
574
575 let finite: Vec<f32> = window.iter().copied().filter(|v| v.is_finite()).collect();
576 if finite.len() < 2 {
577 return 0.0;
578 }
579
580 let max_val = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
581 let min_val = finite.iter().copied().fold(f32::INFINITY, f32::min);
582 let range = max_val - min_val;
583
584 let mean = finite.iter().sum::<f32>() / finite.len() as f32;
585 if mean.abs() < 1e-12 {
586 if range < self.config.convergence_threshold {
588 return 1.0;
589 }
590 return 0.0;
591 }
592
593 let relative_range = range / mean.abs();
595
596 let score = (-relative_range * 100.0).exp();
599 score.clamp(0.0, 1.0)
600 }
601
602 pub fn summary(&self) -> String {
604 let report = self.check_health();
605 let (mean_gn, std_gn, max_gn) = self.grad_norm_stats();
606
607 let mut s = String::new();
608 s.push_str(&format!(
609 "=== Training Health Report (step {}) ===\n",
610 report.step
611 ));
612 s.push_str(&format!(
613 "Status: {}\n",
614 if report.is_healthy {
615 "HEALTHY"
616 } else {
617 "UNHEALTHY"
618 }
619 ));
620 s.push_str(&format!(
621 "Loss: {:.6} (trend: {:?})\n",
622 report.current_loss, report.loss_trend
623 ));
624 s.push_str(&format!(
625 "Grad norms: mean={:.4}, std={:.4}, max={:.4}\n",
626 mean_gn, std_gn, max_gn
627 ));
628 s.push_str(&format!(
629 "Convergence: {:.2}%\n",
630 report.convergence_score * 100.0
631 ));
632 s.push_str(&format!("Dead neurons: {}\n", report.dead_neurons));
633
634 if !report.active_alerts.is_empty() {
635 s.push_str(&format!(
636 "Active alerts ({}):\n",
637 report.active_alerts.len()
638 ));
639 for alert in &report.active_alerts {
640 s.push_str(&format!(" {}\n", alert));
641 }
642 }
643
644 if let Some(lr) = self.suggest_lr() {
645 s.push_str(&format!("Suggested LR: {:.6}\n", lr));
646 }
647
648 s
649 }
650
651 fn push_bounded(&self, _history: &mut Vec<f32>, _value: f32) {
656 }
659
660 fn emit_alert(
661 &mut self,
662 step: usize,
663 severity: AlertSeverity,
664 kind: AlertKind,
665 message: String,
666 ) {
667 self.alerts.push(TrainingAlert {
668 step,
669 severity,
670 kind,
671 message,
672 });
673 }
674}
675
676impl Default for TrainingMonitor {
677 fn default() -> Self {
678 Self::new()
679 }
680}
681
682#[cfg(test)]
687mod tests {
688 use super::*;
689
690 #[test]
695 fn test_monitor_creation_defaults() {
696 let monitor = TrainingMonitor::new();
697 assert_eq!(monitor.step_count, 0);
698 assert!(monitor.loss_history.is_empty());
699 assert!(monitor.alerts.is_empty());
700 assert_eq!(monitor.config.window_size, 100);
701 assert!((monitor.config.grad_norm_threshold - 100.0).abs() < 1e-6);
702 assert!(monitor.config.nan_check);
703 }
704
705 #[test]
706 fn test_monitor_with_custom_config() {
707 let config = MonitorConfig {
708 window_size: 50,
709 grad_norm_threshold: 50.0,
710 loss_divergence_factor: 5.0,
711 dead_neuron_threshold: 20,
712 nan_check: false,
713 convergence_threshold: 1e-5,
714 max_history: 500,
715 };
716 let monitor = TrainingMonitor::with_config(config);
717 assert_eq!(monitor.config.window_size, 50);
718 assert!((monitor.config.grad_norm_threshold - 50.0).abs() < 1e-6);
719 assert!(!monitor.config.nan_check);
720 assert_eq!(monitor.config.max_history, 500);
721 }
722
723 #[test]
728 fn test_record_step_updates_state() {
729 let mut monitor = TrainingMonitor::new();
730 monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
731
732 assert_eq!(monitor.step_count, 1);
733 assert_eq!(monitor.loss_history.len(), 1);
734 assert_eq!(monitor.grad_norm_history.len(), 1);
735 assert_eq!(monitor.lr_history.len(), 1);
736 assert!((monitor.loss_history[0] - 0.5).abs() < 1e-6);
737 }
738
739 #[test]
744 fn test_nan_detection_generates_critical_alert() {
745 let mut monitor = TrainingMonitor::new();
746 monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
747
748 assert_eq!(monitor.alerts.len(), 1);
749 assert_eq!(monitor.alerts[0].severity, AlertSeverity::Critical);
750 assert_eq!(monitor.alerts[0].kind, AlertKind::NaNDetected);
751 assert!(monitor.alerts[0].message.contains("NaN"));
752 }
753
754 #[test]
755 fn test_inf_detection_generates_critical_alert() {
756 let mut monitor = TrainingMonitor::new();
757 monitor.record_step(f32::INFINITY, &[("w1", 1.0)], 0.001);
758
759 assert_eq!(monitor.alerts.len(), 1);
760 assert_eq!(monitor.alerts[0].severity, AlertSeverity::Critical);
761 assert_eq!(monitor.alerts[0].kind, AlertKind::InfDetected);
762 assert!(monitor.alerts[0].message.contains("Infinity"));
763 }
764
765 #[test]
766 fn test_nan_in_grad_norm_detected() {
767 let mut monitor = TrainingMonitor::new();
768 monitor.record_step(0.5, &[("w1", f32::NAN)], 0.001);
769
770 let nan_alerts: Vec<_> = monitor
771 .alerts
772 .iter()
773 .filter(|a| a.kind == AlertKind::NaNDetected)
774 .collect();
775 assert_eq!(nan_alerts.len(), 1);
776 assert!(nan_alerts[0].message.contains("w1"));
777 }
778
779 #[test]
780 fn test_nan_check_disabled() {
781 let config = MonitorConfig {
782 nan_check: false,
783 ..MonitorConfig::default()
784 };
785 let mut monitor = TrainingMonitor::with_config(config);
786 monitor.record_step(f32::NAN, &[("w1", f32::NAN)], 0.001);
787
788 assert!(monitor.alerts.is_empty());
789 }
790
791 #[test]
796 fn test_gradient_explosion_detection() {
797 let mut monitor = TrainingMonitor::new();
798 monitor.record_step(0.5, &[("w1", 200.0)], 0.001);
799
800 let explosion_alerts: Vec<_> = monitor
801 .alerts
802 .iter()
803 .filter(|a| a.kind == AlertKind::GradientExplosion)
804 .collect();
805 assert_eq!(explosion_alerts.len(), 1);
806 assert_eq!(explosion_alerts[0].severity, AlertSeverity::Warning);
807 }
808
809 #[test]
810 fn test_gradient_vanishing_detection() {
811 let config = MonitorConfig {
812 window_size: 5,
813 ..MonitorConfig::default()
814 };
815 let mut monitor = TrainingMonitor::with_config(config);
816
817 for _ in 0..10 {
819 monitor.record_step(0.5, &[("w1", 1e-10)], 0.001);
820 }
821
822 let vanishing_alerts: Vec<_> = monitor
823 .alerts
824 .iter()
825 .filter(|a| a.kind == AlertKind::GradientVanishing)
826 .collect();
827 assert!(!vanishing_alerts.is_empty());
828 }
829
830 #[test]
831 fn test_gradient_vanishing_resets_on_normal_grad() {
832 let mut monitor = TrainingMonitor::new();
833
834 for _ in 0..5 {
836 monitor.record_step(0.5, &[("w1", 1e-10)], 0.001);
837 }
838 monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
840 assert_eq!(monitor.vanishing_streak, 0);
841 }
842
843 #[test]
848 fn test_loss_divergence_detection() {
849 let config = MonitorConfig {
850 window_size: 10,
851 loss_divergence_factor: 2.0,
852 ..MonitorConfig::default()
853 };
854 let mut monitor = TrainingMonitor::with_config(config);
855
856 for _ in 0..10 {
858 monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
859 }
860
861 monitor.record_step(100.0, &[("w1", 0.5)], 0.001);
863
864 let divergence_alerts: Vec<_> = monitor
865 .alerts
866 .iter()
867 .filter(|a| a.kind == AlertKind::LossDivergence)
868 .collect();
869 assert!(!divergence_alerts.is_empty());
870 }
871
872 #[test]
877 fn test_dead_neuron_tracking() {
878 let config = MonitorConfig {
879 dead_neuron_threshold: 5,
880 ..MonitorConfig::default()
881 };
882 let mut monitor = TrainingMonitor::with_config(config);
883
884 for _ in 0..5 {
885 monitor.record_step(0.5, &[("dead_layer", 0.0), ("alive_layer", 0.5)], 0.001);
886 }
887
888 let dead_alerts: Vec<_> = monitor
889 .alerts
890 .iter()
891 .filter(|a| a.kind == AlertKind::DeadNeuron)
892 .collect();
893 assert_eq!(dead_alerts.len(), 1);
894 assert!(dead_alerts[0].message.contains("dead_layer"));
895 }
896
897 #[test]
898 fn test_dead_neuron_resets_on_nonzero_grad() {
899 let config = MonitorConfig {
900 dead_neuron_threshold: 10,
901 ..MonitorConfig::default()
902 };
903 let mut monitor = TrainingMonitor::with_config(config);
904
905 for _ in 0..5 {
907 monitor.record_step(0.5, &[("layer", 0.0)], 0.001);
908 }
909 monitor.record_step(0.5, &[("layer", 1.0)], 0.001);
911
912 assert_eq!(*monitor.zero_grad_counts.get("layer").unwrap(), 0);
913 }
914
915 #[test]
920 fn test_convergence_detection() {
921 let config = MonitorConfig {
922 window_size: 10,
923 convergence_threshold: 1e-4,
924 ..MonitorConfig::default()
925 };
926 let mut monitor = TrainingMonitor::with_config(config);
927
928 for _ in 0..10 {
930 monitor.record_step(0.001, &[("w1", 0.1)], 0.001);
931 }
932
933 let converged_alerts: Vec<_> = monitor
934 .alerts
935 .iter()
936 .filter(|a| a.kind == AlertKind::Converged)
937 .collect();
938 assert!(!converged_alerts.is_empty());
939 }
940
941 #[test]
946 fn test_loss_trend_decreasing() {
947 let config = MonitorConfig {
948 window_size: 5,
949 ..MonitorConfig::default()
950 };
951 let mut monitor = TrainingMonitor::with_config(config);
952
953 for i in 0..5 {
955 monitor.record_step(2.0 - i as f32 * 0.01, &[("w1", 0.5)], 0.001);
956 }
957 for i in 0..5 {
959 monitor.record_step(1.0 - i as f32 * 0.01, &[("w1", 0.5)], 0.001);
960 }
961
962 assert_eq!(monitor.loss_trend(), LossTrend::Decreasing);
963 }
964
965 #[test]
966 fn test_loss_trend_increasing() {
967 let config = MonitorConfig {
968 window_size: 5,
969 ..MonitorConfig::default()
970 };
971 let mut monitor = TrainingMonitor::with_config(config);
972
973 for i in 0..5 {
975 monitor.record_step(1.0 + i as f32 * 0.01, &[("w1", 0.5)], 0.001);
976 }
977 for i in 0..5 {
979 monitor.record_step(2.0 + i as f32 * 0.01, &[("w1", 0.5)], 0.001);
980 }
981
982 assert_eq!(monitor.loss_trend(), LossTrend::Increasing);
983 }
984
985 #[test]
986 fn test_loss_trend_oscillating() {
987 let config = MonitorConfig {
988 window_size: 10,
989 ..MonitorConfig::default()
990 };
991 let mut monitor = TrainingMonitor::with_config(config);
992
993 for _ in 0..10 {
995 monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
996 }
997 for i in 0..10 {
999 let loss = if i % 2 == 0 { 1.3 } else { 0.7 };
1000 monitor.record_step(loss, &[("w1", 0.5)], 0.001);
1001 }
1002
1003 assert_eq!(monitor.loss_trend(), LossTrend::Oscillating);
1004 }
1005
1006 #[test]
1007 fn test_loss_trend_stable() {
1008 let config = MonitorConfig {
1009 window_size: 5,
1010 ..MonitorConfig::default()
1011 };
1012 let mut monitor = TrainingMonitor::with_config(config);
1013
1014 for _ in 0..10 {
1016 monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
1017 }
1018
1019 assert_eq!(monitor.loss_trend(), LossTrend::Stable);
1020 }
1021
1022 #[test]
1023 fn test_loss_trend_unknown_insufficient_data() {
1024 let config = MonitorConfig {
1025 window_size: 100,
1026 ..MonitorConfig::default()
1027 };
1028 let mut monitor = TrainingMonitor::with_config(config);
1029
1030 monitor.record_step(1.0, &[("w1", 0.5)], 0.001);
1031
1032 assert_eq!(monitor.loss_trend(), LossTrend::Unknown);
1033 }
1034
1035 #[test]
1040 fn test_health_report_healthy_normal_training() {
1041 let config = MonitorConfig {
1042 window_size: 5,
1043 ..MonitorConfig::default()
1044 };
1045 let mut monitor = TrainingMonitor::with_config(config);
1046
1047 for i in 0..10 {
1048 monitor.record_step(1.0 - i as f32 * 0.05, &[("w1", 0.5)], 0.001);
1049 }
1050
1051 let report = monitor.check_health();
1052 assert!(report.is_healthy);
1053 assert_eq!(report.step, 10);
1054 assert_eq!(report.dead_neurons, 0);
1055 }
1056
1057 #[test]
1058 fn test_health_report_not_healthy_with_nan() {
1059 let mut monitor = TrainingMonitor::new();
1060 monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
1061
1062 let report = monitor.check_health();
1063 assert!(!report.is_healthy);
1064 }
1065
1066 #[test]
1071 fn test_suggest_lr_exploding_gradients() {
1072 let config = MonitorConfig {
1073 window_size: 5,
1074 grad_norm_threshold: 10.0,
1075 ..MonitorConfig::default()
1076 };
1077 let mut monitor = TrainingMonitor::with_config(config);
1078
1079 for _ in 0..5 {
1081 monitor.record_step(1.0, &[("w1", 50.0)], 0.01);
1082 }
1083
1084 let suggested = monitor.suggest_lr();
1085 assert!(suggested.is_some());
1086 assert!((suggested.unwrap() - 0.001).abs() < 1e-6); }
1088
1089 #[test]
1090 fn test_suggest_lr_oscillating_loss() {
1091 let config = MonitorConfig {
1092 window_size: 10,
1093 ..MonitorConfig::default()
1094 };
1095 let mut monitor = TrainingMonitor::with_config(config);
1096
1097 for _ in 0..10 {
1099 monitor.record_step(1.0, &[("w1", 0.5)], 0.01);
1100 }
1101 for i in 0..10 {
1103 let loss = if i % 2 == 0 { 1.3 } else { 0.7 };
1104 monitor.record_step(loss, &[("w1", 0.5)], 0.01);
1105 }
1106
1107 let suggested = monitor.suggest_lr();
1108 assert!(suggested.is_some());
1109 assert!((suggested.unwrap() - 0.005).abs() < 1e-6); }
1111
1112 #[test]
1113 fn test_suggest_lr_converged_returns_none() {
1114 let config = MonitorConfig {
1115 window_size: 5,
1116 convergence_threshold: 1e-4,
1117 ..MonitorConfig::default()
1118 };
1119 let mut monitor = TrainingMonitor::with_config(config);
1120
1121 for _ in 0..10 {
1123 monitor.record_step(0.001, &[("w1", 0.01)], 0.001);
1124 }
1125
1126 let trend = monitor.loss_trend();
1128 let conv = monitor.convergence_score();
1129 assert_eq!(trend, LossTrend::Stable);
1130 assert!(conv > 0.99);
1131 assert!(monitor.suggest_lr().is_none());
1132 }
1133
1134 #[test]
1139 fn test_convergence_score_fully_converged() {
1140 let config = MonitorConfig {
1141 window_size: 10,
1142 ..MonitorConfig::default()
1143 };
1144 let mut monitor = TrainingMonitor::with_config(config);
1145
1146 for _ in 0..10 {
1147 monitor.record_step(0.5, &[("w1", 0.1)], 0.001);
1148 }
1149
1150 let score = monitor.convergence_score();
1151 assert!((score - 1.0).abs() < 1e-3, "Expected ~1.0, got {}", score);
1152 }
1153
1154 #[test]
1155 fn test_convergence_score_actively_changing() {
1156 let config = MonitorConfig {
1157 window_size: 10,
1158 ..MonitorConfig::default()
1159 };
1160 let mut monitor = TrainingMonitor::with_config(config);
1161
1162 for i in 0..10 {
1163 monitor.record_step(10.0 - i as f32 * 1.0, &[("w1", 0.5)], 0.001);
1164 }
1165
1166 let score = monitor.convergence_score();
1167 assert!(score < 0.5, "Expected low score, got {}", score);
1168 }
1169
1170 #[test]
1171 fn test_convergence_score_insufficient_data() {
1172 let monitor = TrainingMonitor::new();
1173 assert!((monitor.convergence_score() - 0.0).abs() < 1e-6);
1174 }
1175
1176 #[test]
1181 fn test_summary_contains_key_metrics() {
1182 let config = MonitorConfig {
1183 window_size: 5,
1184 ..MonitorConfig::default()
1185 };
1186 let mut monitor = TrainingMonitor::with_config(config);
1187
1188 for _ in 0..5 {
1189 monitor.record_step(0.5, &[("w1", 1.0)], 0.001);
1190 }
1191
1192 let summary = monitor.summary();
1193 assert!(summary.contains("Training Health Report"));
1194 assert!(summary.contains("HEALTHY"));
1195 assert!(summary.contains("Loss:"));
1196 assert!(summary.contains("Grad norms:"));
1197 assert!(summary.contains("Convergence:"));
1198 assert!(summary.contains("Dead neurons:"));
1199 }
1200
1201 #[test]
1206 fn test_clear_alerts_empties_list() {
1207 let mut monitor = TrainingMonitor::new();
1208 monitor.record_step(f32::NAN, &[("w1", 1.0)], 0.001);
1209 assert!(!monitor.alerts().is_empty());
1210
1211 monitor.clear_alerts();
1212 assert!(monitor.alerts().is_empty());
1213 }
1214
1215 #[test]
1220 fn test_max_history_bounds_memory() {
1221 let config = MonitorConfig {
1222 max_history: 20,
1223 window_size: 5,
1224 ..MonitorConfig::default()
1225 };
1226 let mut monitor = TrainingMonitor::with_config(config);
1227
1228 for i in 0..50 {
1229 monitor.record_step(i as f32, &[("w1", 0.5)], 0.001);
1230 }
1231
1232 assert!(monitor.loss_history.len() <= 20);
1233 assert!(monitor.grad_norm_history.len() <= 20);
1234 assert!(monitor.lr_history.len() <= 20);
1235 }
1236
1237 #[test]
1242 fn test_grad_norm_stats_computation() {
1243 let config = MonitorConfig {
1244 window_size: 4,
1245 ..MonitorConfig::default()
1246 };
1247 let mut monitor = TrainingMonitor::with_config(config);
1248
1249 monitor.record_step(1.0, &[("w1", 1.0)], 0.001);
1251 monitor.record_step(1.0, &[("w1", 2.0)], 0.001);
1252 monitor.record_step(1.0, &[("w1", 3.0)], 0.001);
1253 monitor.record_step(1.0, &[("w1", 4.0)], 0.001);
1254
1255 let (mean, std, max) = monitor.grad_norm_stats();
1256 assert!(
1257 (mean - 2.5).abs() < 1e-4,
1258 "Expected mean ~2.5, got {}",
1259 mean
1260 );
1261 assert!((max - 4.0).abs() < 1e-4, "Expected max 4.0, got {}", max);
1262 assert!(
1264 (std - 1.118).abs() < 0.01,
1265 "Expected std ~1.118, got {}",
1266 std
1267 );
1268 }
1269
1270 #[test]
1271 fn test_grad_norm_stats_empty() {
1272 let monitor = TrainingMonitor::new();
1273 let (mean, std, max) = monitor.grad_norm_stats();
1274 assert!((mean - 0.0).abs() < 1e-6);
1275 assert!((std - 0.0).abs() < 1e-6);
1276 assert!((max - 0.0).abs() < 1e-6);
1277 }
1278
1279 #[test]
1284 fn test_integration_100_step_improving_training() {
1285 let config = MonitorConfig {
1286 window_size: 20,
1287 ..MonitorConfig::default()
1288 };
1289 let mut monitor = TrainingMonitor::with_config(config);
1290
1291 for i in 0..100 {
1293 let loss = 2.0 * (-0.03 * i as f32).exp(); let grad_norm = 1.0 * (-0.01 * i as f32).exp();
1295 let lr = 0.001;
1296 monitor.record_step(
1297 loss,
1298 &[
1299 ("layer1.weight", grad_norm),
1300 ("layer2.weight", grad_norm * 0.5),
1301 ],
1302 lr,
1303 );
1304 }
1305
1306 assert_eq!(monitor.step_count, 100);
1307 assert!(monitor.is_healthy());
1308
1309 let report = monitor.check_health();
1310 assert!(report.is_healthy);
1311 assert_eq!(report.step, 100);
1312 assert!(report.current_loss < 0.2); assert_eq!(report.dead_neurons, 0);
1314
1315 let trend = monitor.loss_trend();
1317 assert_eq!(trend, LossTrend::Decreasing);
1318
1319 let critical_count = monitor
1321 .alerts
1322 .iter()
1323 .filter(|a| a.severity == AlertSeverity::Critical)
1324 .count();
1325 assert_eq!(critical_count, 0);
1326
1327 let summary = monitor.summary();
1329 assert!(summary.contains("HEALTHY"));
1330 assert!(summary.contains("step 100"));
1331 }
1332
1333 #[test]
1334 fn test_default_trait() {
1335 let monitor = TrainingMonitor::default();
1336 assert_eq!(monitor.step_count, 0);
1337 assert_eq!(monitor.config.window_size, 100);
1338 }
1339
1340 #[test]
1341 fn test_alert_display() {
1342 let alert = TrainingAlert {
1343 step: 42,
1344 severity: AlertSeverity::Critical,
1345 kind: AlertKind::NaNDetected,
1346 message: "NaN detected".to_string(),
1347 };
1348 let display = format!("{}", alert);
1349 assert!(display.contains("42"));
1350 assert!(display.contains("Critical"));
1351 assert!(display.contains("NaN"));
1352 }
1353
1354 #[test]
1355 fn test_multiple_parameters_grad_norms() {
1356 let mut monitor = TrainingMonitor::new();
1357
1358 monitor.record_step(1.0, &[("w1", 5.0), ("w2", 10.0), ("w3", 3.0)], 0.001);
1360
1361 assert_eq!(monitor.grad_norm_history.len(), 1);
1362 assert!((monitor.grad_norm_history[0] - 10.0).abs() < 1e-6);
1363 }
1364}