use crate::DebugConfig;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use crate::model_diagnostics::*;
#[derive(Debug)]
pub struct ModelDiagnostics {
#[allow(dead_code)]
config: DebugConfig,
performance_analyzer: PerformanceAnalyzer,
architecture_analyzer: ArchitectureAnalyzer,
training_analyzer: TrainingDynamicsAnalyzer,
layer_analyzer: LayerAnalyzer,
alert_manager: AlertManager,
auto_debugger: AutoDebugger,
analytics_engine: AdvancedAnalytics,
current_step: usize,
}
impl ModelDiagnostics {
pub fn new(config: &DebugConfig) -> Self {
Self {
config: config.clone(),
performance_analyzer: PerformanceAnalyzer::new(),
architecture_analyzer: ArchitectureAnalyzer::new(),
training_analyzer: TrainingDynamicsAnalyzer::new(),
layer_analyzer: LayerAnalyzer::new(),
alert_manager: AlertManager::new(),
auto_debugger: AutoDebugger::new(),
analytics_engine: AdvancedAnalytics::new(),
current_step: 0,
}
}
pub fn record_performance(&mut self, metrics: ModelPerformanceMetrics) -> Result<()> {
self.performance_analyzer.record_metrics(metrics.clone());
self.auto_debugger.record_performance_metrics(metrics.clone());
self.analytics_engine.record_performance_metrics(&metrics);
self.alert_manager.process_performance_metrics(&metrics)?;
Ok(())
}
pub fn record_architecture(&mut self, arch_info: ModelArchitectureInfo) {
self.architecture_analyzer.record_architecture(arch_info);
}
pub fn record_layer_stats(&mut self, stats: LayerActivationStats) -> Result<()> {
self.layer_analyzer.record_layer_stats(stats.clone());
self.auto_debugger.record_layer_stats(stats.clone());
self.alert_manager.process_layer_stats(&stats)?;
Ok(())
}
pub fn record_training_dynamics(&mut self, dynamics: TrainingDynamics) -> Result<()> {
self.training_analyzer.record_training_dynamics(dynamics.clone());
self.auto_debugger.record_training_dynamics(dynamics.clone());
self.alert_manager.process_training_dynamics(&dynamics)?;
Ok(())
}
fn calculate_health_score(&self) -> f64 {
let performance_score = 0.8; let architecture_score = 0.7; let training_score = 0.9;
(performance_score + architecture_score + training_score) / 3.0
}
fn aggregate_recommendations(&self) -> Vec<String> {
let mut recommendations = Vec::new();
if let Ok(arch_analysis) = self.architecture_analyzer.analyze_architecture() {
for recommendation in arch_analysis.recommendations {
recommendations.push(format!("[Architecture] {}", recommendation));
}
}
let perf_summary = self.performance_analyzer.generate_performance_summary();
if perf_summary.current_loss > perf_summary.best_loss * 1.5 {
recommendations.push(
"[Performance] Current loss significantly higher than best - check for training instability"
.to_string(),
);
}
if perf_summary.peak_memory_mb > 16384.0 {
recommendations.push(
"[Performance] High memory usage detected - consider gradient checkpointing or smaller batch size"
.to_string(),
);
}
let training_dynamics = self.training_analyzer.analyze_training_dynamics();
match training_dynamics.training_stability {
TrainingStability::Unstable => {
recommendations.push(
"[Training] Training stability issues detected - consider reducing learning rate or applying gradient clipping"
.to_string(),
);
},
TrainingStability::Unknown => {
recommendations.push(
"[Training] Collect more training metrics for better stability assessment"
.to_string(),
);
},
_ => {},
}
if let Some(plateau) = &training_dynamics.plateau_detection {
if plateau.duration_steps > 100 {
recommendations.push(
"[Training] Training plateau detected - consider learning rate adjustment or early stopping"
.to_string(),
);
}
}
match training_dynamics.convergence_status {
ConvergenceStatus::Diverging => {
recommendations.push(
"[Training] Model is diverging - reduce learning rate immediately".to_string(),
);
},
ConvergenceStatus::Plateau => {
recommendations.push(
"[Training] Training has reached a plateau - consider changing optimization strategy or early stopping"
.to_string(),
);
},
ConvergenceStatus::Oscillating => {
recommendations.push(
"[Training] Training is oscillating - reduce learning rate or increase batch size"
.to_string(),
);
},
_ => {},
}
if !training_dynamics.overfitting_indicators.is_empty() {
recommendations.push(
"[Training] Overfitting detected - consider regularization, dropout, or early stopping"
.to_string(),
);
}
if !training_dynamics.underfitting_indicators.is_empty() {
recommendations.push(
"[Training] Underfitting detected - consider increasing model capacity or training longer"
.to_string(),
);
}
if let Ok(analytics_report) = self.analytics_engine.generate_analytics_report() {
for recommendation in analytics_report.recommendations {
recommendations.push(format!("[Analytics] {}", recommendation));
}
}
let mut seen = std::collections::HashSet::new();
recommendations.retain(|r| seen.insert(r.clone()));
recommendations
}
pub fn current_step(&self) -> usize {
self.current_step
}
pub fn analyze_training_dynamics(&self) -> TrainingDynamics {
self.training_analyzer.analyze_training_dynamics()
}
pub fn increment_step(&mut self) {
self.current_step += 1;
}
pub async fn start(&mut self) -> Result<()> {
Ok(())
}
pub async fn generate_report(&self) -> Result<ModelDiagnosticsReport> {
self.generate_report_sync()
}
pub fn generate_report_sync(&self) -> Result<ModelDiagnosticsReport> {
let performance_summary = self.performance_analyzer.generate_performance_summary();
let architectural_analysis = self.architecture_analyzer.analyze_architecture().ok();
let training_dynamics = self.training_analyzer.analyze_training_dynamics();
let alerts = self.alert_manager.get_active_alerts().to_vec();
let auto_debugging_results = None;
let analytics_report = self.analytics_engine.generate_analytics_report().ok();
Ok(ModelDiagnosticsReport {
current_step: self.current_step,
training_dynamics,
performance_summary,
architectural_analysis,
alerts: alerts.into_iter().map(|a| a.alert).collect(),
recommendations: self.aggregate_recommendations(),
health_score: self.calculate_health_score(),
auto_debugging_results,
analytics_report,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelDiagnosticsReport {
pub current_step: usize,
pub training_dynamics: TrainingDynamics,
pub performance_summary: PerformanceSummary,
pub architectural_analysis: Option<ArchitecturalAnalysis>,
pub alerts: Vec<ModelDiagnosticAlert>,
pub recommendations: Vec<String>,
pub health_score: f64,
pub auto_debugging_results: Option<DebuggingReport>,
pub analytics_report: Option<AnalyticsReport>,
}
impl Default for ModelDiagnosticsReport {
fn default() -> Self {
Self {
current_step: 0,
training_dynamics: TrainingDynamics {
convergence_status: ConvergenceStatus::Unknown,
training_stability: TrainingStability::Unknown,
learning_efficiency: 0.0,
overfitting_indicators: Vec::new(),
underfitting_indicators: Vec::new(),
plateau_detection: None,
},
performance_summary: PerformanceSummary::default(),
architectural_analysis: None,
alerts: Vec::new(),
recommendations: Vec::new(),
health_score: 0.0,
auto_debugging_results: None,
analytics_report: None,
}
}
}