use super::*;
use crate::monitor::{Metric, MetricsCollector};
#[test]
fn test_hansei_analyzer_default() {
let analyzer = HanseiAnalyzer::default();
assert_eq!(analyzer.loss_increase_threshold, 0.1);
assert_eq!(analyzer.gradient_explosion_threshold, 100.0);
}
#[test]
fn test_analyze_healthy_training() {
let analyzer = HanseiAnalyzer::new();
let mut collector = MetricsCollector::new();
for i in 0..100 {
let loss = 1.0 - (f64::from(i) * 0.008); let accuracy = 0.5 + (f64::from(i) * 0.004); collector.record(Metric::Loss, loss);
collector.record(Metric::Accuracy, accuracy);
}
let report = analyzer.analyze("test-run-1", &collector, 120.0);
assert_eq!(report.training_id, "test-run-1");
assert_eq!(report.total_steps, 200); assert!(report.duration_secs == 120.0);
}
#[test]
fn test_detect_nan_loss() {
let analyzer = HanseiAnalyzer::new();
let mut collector = MetricsCollector::new();
collector.record(Metric::Loss, 1.0);
collector.record(Metric::Loss, f64::NAN);
let report = analyzer.analyze("nan-test", &collector, 10.0);
let critical_issues: Vec<_> =
report.issues.iter().filter(|i| i.severity == IssueSeverity::Critical).collect();
assert!(!critical_issues.is_empty());
assert!(critical_issues[0].description.contains("NaN"));
}
#[test]
fn test_detect_inf_loss() {
let analyzer = HanseiAnalyzer::new();
let mut collector = MetricsCollector::new();
collector.record(Metric::Loss, 1.0);
collector.record(Metric::Loss, f64::INFINITY);
let report = analyzer.analyze("inf-test", &collector, 10.0);
let critical_issues: Vec<_> =
report.issues.iter().filter(|i| i.severity == IssueSeverity::Critical).collect();
assert!(!critical_issues.is_empty());
assert!(critical_issues[0].description.contains("Infinity"));
}
#[test]
fn test_detect_gradient_explosion() {
let analyzer = HanseiAnalyzer::new();
let mut collector = MetricsCollector::new();
collector.record(Metric::GradientNorm, 1.0);
collector.record(Metric::GradientNorm, 500.0);
let report = analyzer.analyze("grad-explosion", &collector, 10.0);
let gradient_issues: Vec<_> =
report.issues.iter().filter(|i| i.category == "Gradient Health").collect();
assert!(!gradient_issues.is_empty());
assert!(gradient_issues[0].description.contains("explosion"));
}
#[test]
fn test_detect_vanishing_gradients() {
let analyzer = HanseiAnalyzer::new();
let mut collector = MetricsCollector::new();
for _ in 0..20 {
collector.record(Metric::GradientNorm, 1e-10);
}
let report = analyzer.analyze("vanishing-grad", &collector, 10.0);
let gradient_issues: Vec<_> =
report.issues.iter().filter(|i| i.description.contains("vanishing")).collect();
assert!(!gradient_issues.is_empty());
}
#[test]
fn test_missing_loss_metric() {
let analyzer = HanseiAnalyzer::new();
let mut collector = MetricsCollector::new();
collector.record(Metric::Accuracy, 0.5);
let report = analyzer.analyze("no-loss", &collector, 10.0);
let observability_issues: Vec<_> =
report.issues.iter().filter(|i| i.category == "Observability").collect();
assert!(!observability_issues.is_empty());
}
#[test]
fn test_format_report() {
let analyzer = HanseiAnalyzer::new();
let mut collector = MetricsCollector::new();
collector.record(Metric::Loss, 1.0);
collector.record(Metric::Loss, 0.5);
collector.record(Metric::Accuracy, 0.8);
let report = analyzer.analyze("format-test", &collector, 60.0);
let formatted = analyzer.format_report(&report);
assert!(formatted.contains("HANSEI POST-TRAINING REPORT"));
assert!(formatted.contains("format-test"));
assert!(formatted.contains("Duration: 60.00s"));
}
#[test]
fn test_trend_detection_improving_loss() {
let analyzer = HanseiAnalyzer::new();
let mut collector = MetricsCollector::new();
for _ in 0..40 {
collector.record(Metric::Loss, 1.0);
}
for _ in 0..10 {
collector.record(Metric::Loss, 2.0);
}
let report = analyzer.analyze("improving", &collector, 10.0);
let loss_summary = report.metric_summaries.get(&Metric::Loss).expect("key should exist");
assert!(
loss_summary.trend == Trend::Improving,
"Expected Improving, got {:?} (mean={:.2}, mid={:.2})",
loss_summary.trend,
loss_summary.mean,
f64::midpoint(loss_summary.min, loss_summary.max)
);
}
#[test]
fn test_trend_detection_oscillating() {
let analyzer = HanseiAnalyzer::new();
let mut collector = MetricsCollector::new();
for i in 0..50 {
let value = if i % 2 == 0 { 10.0 } else { 1.0 };
collector.record(Metric::Loss, value);
}
let report = analyzer.analyze("oscillating", &collector, 10.0);
let loss_summary = report.metric_summaries.get(&Metric::Loss).expect("key should exist");
assert_eq!(loss_summary.trend, Trend::Oscillating);
}
#[test]
fn test_issue_severity_ordering() {
assert!(IssueSeverity::Critical > IssueSeverity::Error);
assert!(IssueSeverity::Error > IssueSeverity::Warning);
assert!(IssueSeverity::Warning > IssueSeverity::Info);
}
#[test]
fn test_recommendations_generated() {
let analyzer = HanseiAnalyzer::new();
let mut collector = MetricsCollector::new();
collector.record(Metric::Loss, f64::NAN);
let report = analyzer.analyze("rec-test", &collector, 10.0);
assert!(!report.recommendations.is_empty());
assert!(report.recommendations[0].contains("numerical stability"));
}
#[test]
fn test_low_accuracy_warning() {
let analyzer = HanseiAnalyzer::new();
let mut collector = MetricsCollector::new();
for _ in 0..150 {
collector.record(Metric::Accuracy, 0.3);
}
let report = analyzer.analyze("low-acc", &collector, 100.0);
let perf_issues: Vec<_> =
report.issues.iter().filter(|i| i.category == "Performance").collect();
assert!(!perf_issues.is_empty());
}
#[test]
fn test_empty_collector() {
let analyzer = HanseiAnalyzer::new();
let collector = MetricsCollector::new();
let report = analyzer.analyze("empty", &collector, 0.0);
assert_eq!(report.total_steps, 0);
assert!(report.metric_summaries.is_empty());
assert!(report.issues.iter().any(|i| i.category == "Observability"));
}
#[test]
fn test_issue_severity_display() {
assert_eq!(format!("{}", IssueSeverity::Info), "INFO");
assert_eq!(format!("{}", IssueSeverity::Warning), "WARNING");
assert_eq!(format!("{}", IssueSeverity::Error), "ERROR");
assert_eq!(format!("{}", IssueSeverity::Critical), "CRITICAL");
}
#[test]
fn test_trend_display() {
assert_eq!(format!("{}", Trend::Improving), "↑ Improving");
assert_eq!(format!("{}", Trend::Degrading), "↓ Degrading");
assert_eq!(format!("{}", Trend::Stable), "→ Stable");
assert_eq!(format!("{}", Trend::Oscillating), "~ Oscillating");
}
#[test]
fn test_trend_detection_degrading() {
let analyzer = HanseiAnalyzer::new();
let mut collector = MetricsCollector::new();
for _ in 0..10 {
collector.record(Metric::Loss, 1.0);
}
for _ in 0..40 {
collector.record(Metric::Loss, 2.0);
}
let report = analyzer.analyze("degrading", &collector, 10.0);
let loss_summary = report.metric_summaries.get(&Metric::Loss).expect("key should exist");
assert!(
loss_summary.trend == Trend::Degrading,
"Expected Degrading, got {:?}",
loss_summary.trend
);
}
#[test]
fn test_trend_detection_stable() {
let analyzer = HanseiAnalyzer::new();
let mut collector = MetricsCollector::new();
for i in 0..50 {
collector.record(Metric::GradientNorm, 1.0 + (f64::from(i) % 10.0 - 5.0) * 0.01);
}
let report = analyzer.analyze("stable", &collector, 10.0);
let grad_summary =
report.metric_summaries.get(&Metric::GradientNorm).expect("key should exist");
assert!(grad_summary.trend == Trend::Stable, "Expected Stable, got {:?}", grad_summary.trend);
}
#[test]
fn test_custom_thresholds() {
let analyzer = HanseiAnalyzer {
loss_increase_threshold: 0.5,
gradient_explosion_threshold: 50.0,
gradient_vanishing_threshold: 1e-8,
min_accuracy_improvement: 0.2,
};
assert_eq!(analyzer.loss_increase_threshold, 0.5);
assert_eq!(analyzer.gradient_explosion_threshold, 50.0);
}
#[test]
fn test_training_issue_clone() {
let issue = TrainingIssue {
severity: IssueSeverity::Warning,
category: "Test".to_string(),
description: "Test description".to_string(),
recommendation: "Test recommendation".to_string(),
};
let cloned = issue.clone();
assert_eq!(issue.severity, cloned.severity);
assert_eq!(issue.category, cloned.category);
}
#[test]
fn test_metric_summary_clone() {
let summary = MetricSummary {
initial: 1.0,
final_value: 0.5,
min: 0.3,
max: 1.2,
mean: 0.6,
std_dev: 0.2,
trend: Trend::Improving,
};
let cloned = summary.clone();
assert_eq!(summary.initial, cloned.initial);
assert_eq!(summary.trend, cloned.trend);
}
#[test]
fn test_post_training_report_clone() {
use std::collections::HashMap;
let report = PostTrainingReport {
training_id: "test".to_string(),
duration_secs: 10.0,
total_steps: 100,
final_metrics: HashMap::new(),
metric_summaries: HashMap::new(),
issues: vec![],
recommendations: vec![],
};
let cloned = report.clone();
assert_eq!(report.training_id, cloned.training_id);
assert_eq!(report.total_steps, cloned.total_steps);
}