1use super::output::PostTrainingReport;
9use super::types::{IssueSeverity, MetricSummary, TrainingIssue, Trend};
10use crate::monitor::{Metric, MetricStats, MetricsCollector};
11use std::collections::HashMap;
12use std::fmt::Write as FmtWrite;
13
14pub struct HanseiAnalyzer {
16 pub loss_increase_threshold: f64,
18 pub gradient_explosion_threshold: f64,
20 pub gradient_vanishing_threshold: f64,
22 pub min_accuracy_improvement: f64,
24}
25
26impl Default for HanseiAnalyzer {
27 fn default() -> Self {
28 Self {
29 loss_increase_threshold: 0.1, gradient_explosion_threshold: 100.0,
31 gradient_vanishing_threshold: 1e-7,
32 min_accuracy_improvement: 0.01, }
34 }
35}
36
37impl HanseiAnalyzer {
38 pub fn new() -> Self {
39 Self::default()
40 }
41
42 pub fn analyze(
44 &self,
45 training_id: &str,
46 collector: &MetricsCollector,
47 duration_secs: f64,
48 ) -> PostTrainingReport {
49 let mut issues = Vec::new();
50 let mut recommendations = Vec::new();
51 let mut metric_summaries = HashMap::new();
52 let mut final_metrics = HashMap::new();
53
54 let summary = collector.summary();
55 let total_steps = summary.values().map(|s| s.count).sum::<usize>() as u64;
56
57 for (metric, stats) in &summary {
59 let metric_summary = self.analyze_metric(metric, stats);
60 metric_summaries.insert(metric.clone(), metric_summary.clone());
61 final_metrics.insert(metric.clone(), stats.mean);
62
63 self.check_metric_issues(metric, &metric_summary, stats, &mut issues);
65 }
66
67 self.generate_recommendations(&issues, &mut recommendations);
69
70 self.check_missing_metrics(&summary, &mut issues);
72
73 issues.sort_by(|a, b| b.severity.cmp(&a.severity));
75
76 PostTrainingReport {
77 training_id: training_id.to_string(),
78 duration_secs,
79 total_steps,
80 final_metrics,
81 metric_summaries,
82 issues,
83 recommendations,
84 }
85 }
86
87 fn analyze_metric(&self, metric: &Metric, stats: &MetricStats) -> MetricSummary {
88 let trend = self.determine_trend(metric, stats);
90
91 MetricSummary {
92 initial: stats.min, final_value: stats.mean, min: stats.min,
95 max: stats.max,
96 mean: stats.mean,
97 std_dev: stats.std,
98 trend,
99 }
100 }
101
102 fn determine_trend(&self, metric: &Metric, stats: &MetricStats) -> Trend {
103 let cv = coeff_of_variation(stats);
104 if cv > 0.5 {
105 return Trend::Oscillating;
106 }
107 match metric {
108 Metric::Loss => range_trend(stats, true),
109 Metric::Accuracy => range_trend(stats, false),
110 Metric::GradientNorm => {
111 if cv < 0.2 {
112 Trend::Stable
113 } else {
114 Trend::Oscillating
115 }
116 }
117 Metric::LearningRate | Metric::Epoch | Metric::Batch | Metric::Custom(_) => {
118 Trend::Stable
119 }
120 }
121 }
122}
123
124fn coeff_of_variation(stats: &MetricStats) -> f64 {
125 if stats.mean.abs() > 1e-10 {
126 stats.std / stats.mean.abs()
127 } else {
128 0.0
129 }
130}
131
132fn range_trend(stats: &MetricStats, lower_is_better: bool) -> Trend {
135 if stats.max - stats.min < stats.std * 0.5 {
136 return Trend::Stable;
137 }
138 let mid = f64::midpoint(stats.min, stats.max);
139 let improving = if lower_is_better { stats.mean < mid } else { stats.mean > mid };
140 if improving {
141 Trend::Improving
142 } else {
143 Trend::Degrading
144 }
145}
146
147impl HanseiAnalyzer {
148 fn check_metric_issues(
149 &self,
150 metric: &Metric,
151 summary: &MetricSummary,
152 stats: &MetricStats,
153 issues: &mut Vec<TrainingIssue>,
154 ) {
155 match metric {
156 Metric::Loss => self.check_loss_issues(summary, stats, issues),
157 Metric::Accuracy => self.check_accuracy_issues(summary, stats, issues),
158 Metric::GradientNorm => self.check_gradient_issues(stats, issues),
159 Metric::LearningRate => self.check_lr_issues(summary, issues),
160 Metric::Epoch | Metric::Batch | Metric::Custom(_) => {}
161 }
162 }
163
164 fn check_loss_issues(
166 &self,
167 summary: &MetricSummary,
168 stats: &MetricStats,
169 issues: &mut Vec<TrainingIssue>,
170 ) {
171 if stats.has_nan {
172 issues.push(TrainingIssue {
173 severity: IssueSeverity::Critical,
174 category: "Numerical Stability".to_string(),
175 description: "NaN values detected in loss".to_string(),
176 recommendation:
177 "Reduce learning rate, add gradient clipping, or check data preprocessing"
178 .to_string(),
179 });
180 }
181 if stats.has_inf {
182 issues.push(TrainingIssue {
183 severity: IssueSeverity::Critical,
184 category: "Numerical Stability".to_string(),
185 description: "Infinity values detected in loss".to_string(),
186 recommendation: "Check for division by zero, reduce learning rate".to_string(),
187 });
188 }
189 if summary.trend == Trend::Degrading {
190 issues.push(TrainingIssue {
191 severity: IssueSeverity::Warning,
192 category: "Convergence".to_string(),
193 description: "Loss appears to be increasing over training".to_string(),
194 recommendation: "Consider reducing learning rate or checking data quality"
195 .to_string(),
196 });
197 }
198 if summary.trend == Trend::Oscillating {
199 issues.push(TrainingIssue {
200 severity: IssueSeverity::Warning,
201 category: "Stability".to_string(),
202 description: "Loss is oscillating significantly".to_string(),
203 recommendation: "Reduce learning rate or increase batch size".to_string(),
204 });
205 }
206 }
207
208 fn check_accuracy_issues(
210 &self,
211 summary: &MetricSummary,
212 stats: &MetricStats,
213 issues: &mut Vec<TrainingIssue>,
214 ) {
215 if summary.final_value < 0.5 && stats.count > 100 {
216 issues.push(TrainingIssue {
217 severity: IssueSeverity::Warning,
218 category: "Performance".to_string(),
219 description: format!("Final accuracy is low: {:.2}%", summary.final_value * 100.0),
220 recommendation: "Consider model architecture changes or hyperparameter tuning"
221 .to_string(),
222 });
223 }
224 if summary.trend == Trend::Stable
225 && summary.max - summary.min < self.min_accuracy_improvement
226 {
227 issues.push(TrainingIssue {
228 severity: IssueSeverity::Info,
229 category: "Convergence".to_string(),
230 description: "Accuracy shows minimal improvement".to_string(),
231 recommendation: "Model may have converged or may be stuck in local minimum"
232 .to_string(),
233 });
234 }
235 }
236
237 fn check_gradient_issues(&self, stats: &MetricStats, issues: &mut Vec<TrainingIssue>) {
239 if stats.max > self.gradient_explosion_threshold {
240 issues.push(TrainingIssue {
241 severity: IssueSeverity::Error,
242 category: "Gradient Health".to_string(),
243 description: format!("Gradient explosion detected: max norm = {:.2e}", stats.max),
244 recommendation: "Enable gradient clipping (e.g., max_norm=1.0)".to_string(),
245 });
246 }
247 if stats.mean < self.gradient_vanishing_threshold && stats.count > 10 {
248 issues.push(TrainingIssue {
249 severity: IssueSeverity::Warning,
250 category: "Gradient Health".to_string(),
251 description: format!(
252 "Possible vanishing gradients: mean norm = {:.2e}",
253 stats.mean
254 ),
255 recommendation:
256 "Consider using residual connections or different activation functions"
257 .to_string(),
258 });
259 }
260 }
261
262 fn check_lr_issues(&self, summary: &MetricSummary, issues: &mut Vec<TrainingIssue>) {
264 if summary.std_dev > summary.mean * 0.5 {
265 issues.push(TrainingIssue {
266 severity: IssueSeverity::Info,
267 category: "Hyperparameters".to_string(),
268 description: "Learning rate schedule shows high variance".to_string(),
269 recommendation: "Review learning rate schedule configuration".to_string(),
270 });
271 }
272 }
273
274 fn check_missing_metrics(
275 &self,
276 metrics: &HashMap<Metric, MetricStats>,
277 issues: &mut Vec<TrainingIssue>,
278 ) {
279 if !metrics.contains_key(&Metric::Loss) {
281 issues.push(TrainingIssue {
282 severity: IssueSeverity::Warning,
283 category: "Observability".to_string(),
284 description: "No loss metric recorded".to_string(),
285 recommendation: "Ensure loss is being tracked for proper monitoring".to_string(),
286 });
287 }
288 }
289
290 fn generate_recommendations(
291 &self,
292 issues: &[TrainingIssue],
293 recommendations: &mut Vec<String>,
294 ) {
295 let has_numerical_issues = issues.iter().any(|i| i.category == "Numerical Stability");
296 let has_gradient_issues = issues.iter().any(|i| i.category == "Gradient Health");
297 let has_convergence_issues = issues.iter().any(|i| i.category == "Convergence");
298
299 if has_numerical_issues {
300 recommendations.push(
301 "Priority 1: Address numerical stability before continuing training".to_string(),
302 );
303 }
304
305 if has_gradient_issues {
306 recommendations.push("Enable gradient clipping in optimizer configuration".to_string());
307 }
308
309 if has_convergence_issues {
310 recommendations.push(
311 "Consider hyperparameter search for learning rate and batch size".to_string(),
312 );
313 }
314
315 if issues.is_empty() {
316 recommendations.push(
317 "Training completed without detected issues. Consider running validation tests."
318 .to_string(),
319 );
320 }
321 }
322
323 pub fn format_report(&self, report: &PostTrainingReport) -> String {
325 let mut output = String::new();
326
327 let _ = writeln!(output, "═══════════════════════════════════════════════════════════════");
329 let _ =
330 writeln!(output, " HANSEI POST-TRAINING REPORT ");
331 let _ = writeln!(output, "═══════════════════════════════════════════════════════════════");
332 let _ = writeln!(output);
333 let _ = writeln!(output, "Training ID: {}", report.training_id);
334 let _ = writeln!(output, "Duration: {:.2}s", report.duration_secs);
335 let _ = writeln!(output, "Total Steps: {}", report.total_steps);
336 let _ = writeln!(output);
337
338 let _ =
340 writeln!(output, "─── Metric Summaries ───────────────────────────────────────────");
341 for (metric_type, summary) in &report.metric_summaries {
342 let _ = writeln!(output, "\n{metric_type:?}:");
343 let _ = writeln!(output, " Mean: {:.6} Std: {:.6}", summary.mean, summary.std_dev);
344 let _ = writeln!(output, " Min: {:.6} Max: {:.6}", summary.min, summary.max);
345 let _ = writeln!(output, " Trend: {}", summary.trend);
346 }
347 let _ = writeln!(output);
348
349 if !report.issues.is_empty() {
351 let _ = writeln!(
352 output,
353 "─── Issues Detected ────────────────────────────────────────────"
354 );
355 for issue in &report.issues {
356 let _ = writeln!(output, "\n[{}] {}", issue.severity, issue.category);
357 let _ = writeln!(output, " {}", issue.description);
358 let _ = writeln!(output, " → {}", issue.recommendation);
359 }
360 let _ = writeln!(output);
361 }
362
363 let _ =
365 writeln!(output, "─── Recommendations ────────────────────────────────────────────");
366 for (i, rec) in report.recommendations.iter().enumerate() {
367 let _ = writeln!(output, "{}. {}", i + 1, rec);
368 }
369 let _ = writeln!(output);
370
371 let _ = writeln!(output, "═══════════════════════════════════════════════════════════════");
372
373 output
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[test]
382 fn test_determine_trend_all_metric_variants() {
383 let analyzer = HanseiAnalyzer::default();
384
385 let stable_stats = MetricStats {
387 count: 100,
388 mean: 1.0,
389 std: 0.01,
390 min: 0.99,
391 max: 1.01,
392 sum: 100.0,
393 has_nan: false,
394 has_inf: false,
395 };
396
397 let metrics = [
399 Metric::Loss,
400 Metric::Accuracy,
401 Metric::GradientNorm,
402 Metric::LearningRate,
403 Metric::Epoch,
404 Metric::Batch,
405 Metric::Custom("custom_metric".to_string()),
406 ];
407
408 for metric in &metrics {
409 let trend = analyzer.determine_trend(metric, &stable_stats);
410 match metric {
411 Metric::Loss => {
412 assert!(matches!(
413 trend,
414 Trend::Stable | Trend::Improving | Trend::Degrading | Trend::Oscillating
415 ));
416 }
417 Metric::Accuracy => {
418 assert!(matches!(
419 trend,
420 Trend::Stable | Trend::Improving | Trend::Degrading | Trend::Oscillating
421 ));
422 }
423 Metric::GradientNorm => {
424 assert!(matches!(trend, Trend::Stable | Trend::Oscillating));
425 }
426 Metric::LearningRate | Metric::Epoch | Metric::Batch | Metric::Custom(_) => {
427 assert_eq!(trend, Trend::Stable);
428 }
429 }
430 }
431 }
432
433 #[test]
434 fn test_check_metric_issues_all_metric_variants() {
435 let analyzer = HanseiAnalyzer::default();
436
437 let stats = MetricStats {
438 count: 200,
439 mean: 0.5,
440 std: 0.1,
441 min: 0.3,
442 max: 0.7,
443 sum: 100.0,
444 has_nan: false,
445 has_inf: false,
446 };
447
448 let summary = MetricSummary {
449 initial: 0.3,
450 final_value: 0.5,
451 min: 0.3,
452 max: 0.7,
453 mean: 0.5,
454 std_dev: 0.1,
455 trend: Trend::Stable,
456 };
457
458 let metrics = [
459 Metric::Loss,
460 Metric::Accuracy,
461 Metric::GradientNorm,
462 Metric::LearningRate,
463 Metric::Epoch,
464 Metric::Batch,
465 Metric::Custom("test".to_string()),
466 ];
467
468 for metric in &metrics {
469 let mut issues = Vec::new();
470 analyzer.check_metric_issues(metric, &summary, &stats, &mut issues);
471
472 match metric {
474 Metric::Loss => {
475 }
477 Metric::Accuracy => {
478 }
480 Metric::GradientNorm => {
481 }
483 Metric::LearningRate => {
484 }
486 Metric::Epoch | Metric::Batch | Metric::Custom(_) => {
487 assert!(issues.is_empty(), "Epoch/Batch/Custom should produce no issues");
489 }
490 }
491 }
492 }
493
494 #[test]
495 fn test_analyzer_default() {
496 let analyzer = HanseiAnalyzer::default();
497 assert!((analyzer.loss_increase_threshold - 0.1).abs() < 1e-10);
498 assert!((analyzer.gradient_explosion_threshold - 100.0).abs() < 1e-10);
499 assert!((analyzer.gradient_vanishing_threshold - 1e-7).abs() < 1e-15);
500 assert!((analyzer.min_accuracy_improvement - 0.01).abs() < 1e-10);
501 }
502
503 #[test]
504 fn test_analyzer_new() {
505 let analyzer = HanseiAnalyzer::new();
506 assert!((analyzer.loss_increase_threshold - 0.1).abs() < 1e-10);
507 }
508}