Skip to main content

entrenar/monitor/report/
analyzer.rs

1//! Hansei (反省) Post-Training Report Generator
2//!
3//! Toyota Way principle: Reflection and continuous improvement through
4//! systematic analysis of training outcomes.
5//!
6//! Reference: Liker, J.K. (2004). The Toyota Way: 14 Management Principles.
7
8use super::output::PostTrainingReport;
9use super::types::{IssueSeverity, MetricSummary, TrainingIssue, Trend};
10use crate::monitor::{Metric, MetricStats, MetricsCollector};
11use std::collections::HashMap;
12use std::fmt::Write as FmtWrite;
13
14/// Hansei report generator
15pub struct HanseiAnalyzer {
16    /// Threshold for loss increase to trigger warning
17    pub loss_increase_threshold: f64,
18    /// Threshold for gradient norm to indicate explosion
19    pub gradient_explosion_threshold: f64,
20    /// Threshold for gradient norm to indicate vanishing
21    pub gradient_vanishing_threshold: f64,
22    /// Minimum expected accuracy improvement
23    pub min_accuracy_improvement: f64,
24}
25
26impl Default for HanseiAnalyzer {
27    fn default() -> Self {
28        Self {
29            loss_increase_threshold: 0.1, // 10% increase
30            gradient_explosion_threshold: 100.0,
31            gradient_vanishing_threshold: 1e-7,
32            min_accuracy_improvement: 0.01, // 1% improvement
33        }
34    }
35}
36
37impl HanseiAnalyzer {
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Analyze a completed training run and generate a report
43    pub fn analyze(
44        &self,
45        training_id: &str,
46        collector: &MetricsCollector,
47        duration_secs: f64,
48    ) -> PostTrainingReport {
49        let mut issues = Vec::new();
50        let mut recommendations = Vec::new();
51        let mut metric_summaries = HashMap::new();
52        let mut final_metrics = HashMap::new();
53
54        let summary = collector.summary();
55        let total_steps = summary.values().map(|s| s.count).sum::<usize>() as u64;
56
57        // Analyze each metric
58        for (metric, stats) in &summary {
59            let metric_summary = self.analyze_metric(metric, stats);
60            metric_summaries.insert(metric.clone(), metric_summary.clone());
61            final_metrics.insert(metric.clone(), stats.mean);
62
63            // Check for issues based on metric type
64            self.check_metric_issues(metric, &metric_summary, stats, &mut issues);
65        }
66
67        // Generate recommendations based on issues
68        self.generate_recommendations(&issues, &mut recommendations);
69
70        // Check for missing expected metrics
71        self.check_missing_metrics(&summary, &mut issues);
72
73        // Sort issues by severity (critical first)
74        issues.sort_by(|a, b| b.severity.cmp(&a.severity));
75
76        PostTrainingReport {
77            training_id: training_id.to_string(),
78            duration_secs,
79            total_steps,
80            final_metrics,
81            metric_summaries,
82            issues,
83            recommendations,
84        }
85    }
86
87    fn analyze_metric(&self, metric: &Metric, stats: &MetricStats) -> MetricSummary {
88        // Determine trend based on metric type and statistics
89        let trend = self.determine_trend(metric, stats);
90
91        MetricSummary {
92            initial: stats.min,      // Approximation - would need history for actual initial
93            final_value: stats.mean, // Approximation - would need last value
94            min: stats.min,
95            max: stats.max,
96            mean: stats.mean,
97            std_dev: stats.std,
98            trend,
99        }
100    }
101
102    fn determine_trend(&self, metric: &Metric, stats: &MetricStats) -> Trend {
103        let cv = coeff_of_variation(stats);
104        if cv > 0.5 {
105            return Trend::Oscillating;
106        }
107        match metric {
108            Metric::Loss => range_trend(stats, true),
109            Metric::Accuracy => range_trend(stats, false),
110            Metric::GradientNorm => {
111                if cv < 0.2 {
112                    Trend::Stable
113                } else {
114                    Trend::Oscillating
115                }
116            }
117            Metric::LearningRate | Metric::Epoch | Metric::Batch | Metric::Custom(_) => {
118                Trend::Stable
119            }
120        }
121    }
122}
123
124fn coeff_of_variation(stats: &MetricStats) -> f64 {
125    if stats.mean.abs() > 1e-10 {
126        stats.std / stats.mean.abs()
127    } else {
128        0.0
129    }
130}
131
132/// Determine trend based on whether mean is above or below midpoint.
133/// `lower_is_better` = true for loss, false for accuracy.
134fn range_trend(stats: &MetricStats, lower_is_better: bool) -> Trend {
135    if stats.max - stats.min < stats.std * 0.5 {
136        return Trend::Stable;
137    }
138    let mid = f64::midpoint(stats.min, stats.max);
139    let improving = if lower_is_better { stats.mean < mid } else { stats.mean > mid };
140    if improving {
141        Trend::Improving
142    } else {
143        Trend::Degrading
144    }
145}
146
147impl HanseiAnalyzer {
148    fn check_metric_issues(
149        &self,
150        metric: &Metric,
151        summary: &MetricSummary,
152        stats: &MetricStats,
153        issues: &mut Vec<TrainingIssue>,
154    ) {
155        match metric {
156            Metric::Loss => self.check_loss_issues(summary, stats, issues),
157            Metric::Accuracy => self.check_accuracy_issues(summary, stats, issues),
158            Metric::GradientNorm => self.check_gradient_issues(stats, issues),
159            Metric::LearningRate => self.check_lr_issues(summary, issues),
160            Metric::Epoch | Metric::Batch | Metric::Custom(_) => {}
161        }
162    }
163
164    /// Check loss metric for NaN/Inf, degrading trend, and oscillation.
165    fn check_loss_issues(
166        &self,
167        summary: &MetricSummary,
168        stats: &MetricStats,
169        issues: &mut Vec<TrainingIssue>,
170    ) {
171        if stats.has_nan {
172            issues.push(TrainingIssue {
173                severity: IssueSeverity::Critical,
174                category: "Numerical Stability".to_string(),
175                description: "NaN values detected in loss".to_string(),
176                recommendation:
177                    "Reduce learning rate, add gradient clipping, or check data preprocessing"
178                        .to_string(),
179            });
180        }
181        if stats.has_inf {
182            issues.push(TrainingIssue {
183                severity: IssueSeverity::Critical,
184                category: "Numerical Stability".to_string(),
185                description: "Infinity values detected in loss".to_string(),
186                recommendation: "Check for division by zero, reduce learning rate".to_string(),
187            });
188        }
189        if summary.trend == Trend::Degrading {
190            issues.push(TrainingIssue {
191                severity: IssueSeverity::Warning,
192                category: "Convergence".to_string(),
193                description: "Loss appears to be increasing over training".to_string(),
194                recommendation: "Consider reducing learning rate or checking data quality"
195                    .to_string(),
196            });
197        }
198        if summary.trend == Trend::Oscillating {
199            issues.push(TrainingIssue {
200                severity: IssueSeverity::Warning,
201                category: "Stability".to_string(),
202                description: "Loss is oscillating significantly".to_string(),
203                recommendation: "Reduce learning rate or increase batch size".to_string(),
204            });
205        }
206    }
207
208    /// Check accuracy metric for low values and stagnation.
209    fn check_accuracy_issues(
210        &self,
211        summary: &MetricSummary,
212        stats: &MetricStats,
213        issues: &mut Vec<TrainingIssue>,
214    ) {
215        if summary.final_value < 0.5 && stats.count > 100 {
216            issues.push(TrainingIssue {
217                severity: IssueSeverity::Warning,
218                category: "Performance".to_string(),
219                description: format!("Final accuracy is low: {:.2}%", summary.final_value * 100.0),
220                recommendation: "Consider model architecture changes or hyperparameter tuning"
221                    .to_string(),
222            });
223        }
224        if summary.trend == Trend::Stable
225            && summary.max - summary.min < self.min_accuracy_improvement
226        {
227            issues.push(TrainingIssue {
228                severity: IssueSeverity::Info,
229                category: "Convergence".to_string(),
230                description: "Accuracy shows minimal improvement".to_string(),
231                recommendation: "Model may have converged or may be stuck in local minimum"
232                    .to_string(),
233            });
234        }
235    }
236
237    /// Check gradient norms for explosion and vanishing.
238    fn check_gradient_issues(&self, stats: &MetricStats, issues: &mut Vec<TrainingIssue>) {
239        if stats.max > self.gradient_explosion_threshold {
240            issues.push(TrainingIssue {
241                severity: IssueSeverity::Error,
242                category: "Gradient Health".to_string(),
243                description: format!("Gradient explosion detected: max norm = {:.2e}", stats.max),
244                recommendation: "Enable gradient clipping (e.g., max_norm=1.0)".to_string(),
245            });
246        }
247        if stats.mean < self.gradient_vanishing_threshold && stats.count > 10 {
248            issues.push(TrainingIssue {
249                severity: IssueSeverity::Warning,
250                category: "Gradient Health".to_string(),
251                description: format!(
252                    "Possible vanishing gradients: mean norm = {:.2e}",
253                    stats.mean
254                ),
255                recommendation:
256                    "Consider using residual connections or different activation functions"
257                        .to_string(),
258            });
259        }
260    }
261
262    /// Check learning rate schedule for high variance.
263    fn check_lr_issues(&self, summary: &MetricSummary, issues: &mut Vec<TrainingIssue>) {
264        if summary.std_dev > summary.mean * 0.5 {
265            issues.push(TrainingIssue {
266                severity: IssueSeverity::Info,
267                category: "Hyperparameters".to_string(),
268                description: "Learning rate schedule shows high variance".to_string(),
269                recommendation: "Review learning rate schedule configuration".to_string(),
270            });
271        }
272    }
273
274    fn check_missing_metrics(
275        &self,
276        metrics: &HashMap<Metric, MetricStats>,
277        issues: &mut Vec<TrainingIssue>,
278    ) {
279        // Check for essential metrics
280        if !metrics.contains_key(&Metric::Loss) {
281            issues.push(TrainingIssue {
282                severity: IssueSeverity::Warning,
283                category: "Observability".to_string(),
284                description: "No loss metric recorded".to_string(),
285                recommendation: "Ensure loss is being tracked for proper monitoring".to_string(),
286            });
287        }
288    }
289
290    fn generate_recommendations(
291        &self,
292        issues: &[TrainingIssue],
293        recommendations: &mut Vec<String>,
294    ) {
295        let has_numerical_issues = issues.iter().any(|i| i.category == "Numerical Stability");
296        let has_gradient_issues = issues.iter().any(|i| i.category == "Gradient Health");
297        let has_convergence_issues = issues.iter().any(|i| i.category == "Convergence");
298
299        if has_numerical_issues {
300            recommendations.push(
301                "Priority 1: Address numerical stability before continuing training".to_string(),
302            );
303        }
304
305        if has_gradient_issues {
306            recommendations.push("Enable gradient clipping in optimizer configuration".to_string());
307        }
308
309        if has_convergence_issues {
310            recommendations.push(
311                "Consider hyperparameter search for learning rate and batch size".to_string(),
312            );
313        }
314
315        if issues.is_empty() {
316            recommendations.push(
317                "Training completed without detected issues. Consider running validation tests."
318                    .to_string(),
319            );
320        }
321    }
322
323    /// Generate a human-readable report
324    pub fn format_report(&self, report: &PostTrainingReport) -> String {
325        let mut output = String::new();
326
327        // Writing to String never fails, so we ignore the Result
328        let _ = writeln!(output, "═══════════════════════════════════════════════════════════════");
329        let _ =
330            writeln!(output, "                    HANSEI POST-TRAINING REPORT                 ");
331        let _ = writeln!(output, "═══════════════════════════════════════════════════════════════");
332        let _ = writeln!(output);
333        let _ = writeln!(output, "Training ID: {}", report.training_id);
334        let _ = writeln!(output, "Duration: {:.2}s", report.duration_secs);
335        let _ = writeln!(output, "Total Steps: {}", report.total_steps);
336        let _ = writeln!(output);
337
338        // Metric summaries
339        let _ =
340            writeln!(output, "─── Metric Summaries ───────────────────────────────────────────");
341        for (metric_type, summary) in &report.metric_summaries {
342            let _ = writeln!(output, "\n{metric_type:?}:");
343            let _ = writeln!(output, "  Mean: {:.6}  Std: {:.6}", summary.mean, summary.std_dev);
344            let _ = writeln!(output, "  Min: {:.6}   Max: {:.6}", summary.min, summary.max);
345            let _ = writeln!(output, "  Trend: {}", summary.trend);
346        }
347        let _ = writeln!(output);
348
349        // Issues
350        if !report.issues.is_empty() {
351            let _ = writeln!(
352                output,
353                "─── Issues Detected ────────────────────────────────────────────"
354            );
355            for issue in &report.issues {
356                let _ = writeln!(output, "\n[{}] {}", issue.severity, issue.category);
357                let _ = writeln!(output, "  {}", issue.description);
358                let _ = writeln!(output, "  → {}", issue.recommendation);
359            }
360            let _ = writeln!(output);
361        }
362
363        // Recommendations
364        let _ =
365            writeln!(output, "─── Recommendations ────────────────────────────────────────────");
366        for (i, rec) in report.recommendations.iter().enumerate() {
367            let _ = writeln!(output, "{}. {}", i + 1, rec);
368        }
369        let _ = writeln!(output);
370
371        let _ = writeln!(output, "═══════════════════════════════════════════════════════════════");
372
373        output
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn test_determine_trend_all_metric_variants() {
383        let analyzer = HanseiAnalyzer::default();
384
385        // Stable stats (low CV, narrow range)
386        let stable_stats = MetricStats {
387            count: 100,
388            mean: 1.0,
389            std: 0.01,
390            min: 0.99,
391            max: 1.01,
392            sum: 100.0,
393            has_nan: false,
394            has_inf: false,
395        };
396
397        // Syntactic match covering all arms from determine_trend
398        let metrics = [
399            Metric::Loss,
400            Metric::Accuracy,
401            Metric::GradientNorm,
402            Metric::LearningRate,
403            Metric::Epoch,
404            Metric::Batch,
405            Metric::Custom("custom_metric".to_string()),
406        ];
407
408        for metric in &metrics {
409            let trend = analyzer.determine_trend(metric, &stable_stats);
410            match metric {
411                Metric::Loss => {
412                    assert!(matches!(
413                        trend,
414                        Trend::Stable | Trend::Improving | Trend::Degrading | Trend::Oscillating
415                    ));
416                }
417                Metric::Accuracy => {
418                    assert!(matches!(
419                        trend,
420                        Trend::Stable | Trend::Improving | Trend::Degrading | Trend::Oscillating
421                    ));
422                }
423                Metric::GradientNorm => {
424                    assert!(matches!(trend, Trend::Stable | Trend::Oscillating));
425                }
426                Metric::LearningRate | Metric::Epoch | Metric::Batch | Metric::Custom(_) => {
427                    assert_eq!(trend, Trend::Stable);
428                }
429            }
430        }
431    }
432
433    #[test]
434    fn test_check_metric_issues_all_metric_variants() {
435        let analyzer = HanseiAnalyzer::default();
436
437        let stats = MetricStats {
438            count: 200,
439            mean: 0.5,
440            std: 0.1,
441            min: 0.3,
442            max: 0.7,
443            sum: 100.0,
444            has_nan: false,
445            has_inf: false,
446        };
447
448        let summary = MetricSummary {
449            initial: 0.3,
450            final_value: 0.5,
451            min: 0.3,
452            max: 0.7,
453            mean: 0.5,
454            std_dev: 0.1,
455            trend: Trend::Stable,
456        };
457
458        let metrics = [
459            Metric::Loss,
460            Metric::Accuracy,
461            Metric::GradientNorm,
462            Metric::LearningRate,
463            Metric::Epoch,
464            Metric::Batch,
465            Metric::Custom("test".to_string()),
466        ];
467
468        for metric in &metrics {
469            let mut issues = Vec::new();
470            analyzer.check_metric_issues(metric, &summary, &stats, &mut issues);
471
472            // Syntactic match covering all arms from check_metric_issues
473            match metric {
474                Metric::Loss => {
475                    // Loss branch checks NaN, Inf, trend
476                }
477                Metric::Accuracy => {
478                    // Accuracy branch checks low accuracy, no improvement
479                }
480                Metric::GradientNorm => {
481                    // GradientNorm branch checks explosion, vanishing
482                }
483                Metric::LearningRate => {
484                    // LearningRate branch checks variance
485                }
486                Metric::Epoch | Metric::Batch | Metric::Custom(_) => {
487                    // No-op branch
488                    assert!(issues.is_empty(), "Epoch/Batch/Custom should produce no issues");
489                }
490            }
491        }
492    }
493
494    #[test]
495    fn test_analyzer_default() {
496        let analyzer = HanseiAnalyzer::default();
497        assert!((analyzer.loss_increase_threshold - 0.1).abs() < 1e-10);
498        assert!((analyzer.gradient_explosion_threshold - 100.0).abs() < 1e-10);
499        assert!((analyzer.gradient_vanishing_threshold - 1e-7).abs() < 1e-15);
500        assert!((analyzer.min_accuracy_improvement - 0.01).abs() < 1e-10);
501    }
502
503    #[test]
504    fn test_analyzer_new() {
505        let analyzer = HanseiAnalyzer::new();
506        assert!((analyzer.loss_increase_threshold - 0.1).abs() < 1e-10);
507    }
508}