use std::collections::VecDeque;
use super::types::{
ConvergenceStatus, ModelPerformanceMetrics, OverfittingIndicator, PlateauInfo,
TrainingDynamics, TrainingStability, UnderfittingIndicator,
};
#[derive(Debug)]
pub struct TrainingDynamicsAnalyzer {
metrics_history: VecDeque<ModelPerformanceMetrics>,
config: TrainingAnalysisConfig,
current_state: TrainingState,
}
#[derive(Debug, Clone)]
pub struct TrainingAnalysisConfig {
pub convergence_window: usize,
pub min_improvement_threshold: f64,
pub max_variance_threshold: f64,
pub min_plateau_duration: usize,
pub overfitting_gap_threshold: f64,
pub min_learning_rate: f64,
}
impl Default for TrainingAnalysisConfig {
fn default() -> Self {
Self {
convergence_window: 20,
min_improvement_threshold: 0.001,
max_variance_threshold: 0.1,
min_plateau_duration: 10,
overfitting_gap_threshold: 0.05,
min_learning_rate: 1e-6,
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TrainingState {
steps_since_improvement: usize,
best_loss: f64,
current_plateau: Option<PlateauInfo>,
convergence_history: VecDeque<ConvergenceStatus>,
}
impl Default for TrainingState {
fn default() -> Self {
Self {
steps_since_improvement: 0,
best_loss: f64::INFINITY,
current_plateau: None,
convergence_history: VecDeque::new(),
}
}
}
impl TrainingDynamicsAnalyzer {
pub fn new() -> Self {
Self {
metrics_history: VecDeque::new(),
config: TrainingAnalysisConfig::default(),
current_state: TrainingState::default(),
}
}
pub fn with_config(config: TrainingAnalysisConfig) -> Self {
Self {
metrics_history: VecDeque::new(),
config,
current_state: TrainingState::default(),
}
}
pub fn add_metrics(&mut self, metrics: ModelPerformanceMetrics) {
if metrics.loss < self.current_state.best_loss {
self.current_state.best_loss = metrics.loss;
self.current_state.steps_since_improvement = 0;
} else {
self.current_state.steps_since_improvement += 1;
}
self.metrics_history.push_back(metrics);
if self.metrics_history.len() > 1000 {
self.metrics_history.pop_front();
}
let status = self.detect_convergence_status();
self.current_state.convergence_history.push_back(status);
if self.current_state.convergence_history.len() > 50 {
self.current_state.convergence_history.pop_front();
}
}
pub fn record_training_dynamics(&mut self, _dynamics: TrainingDynamics) {
}
pub fn analyze_training_dynamics(&self) -> TrainingDynamics {
let convergence_status = self.detect_convergence_status();
let training_stability = self.assess_training_stability();
let learning_efficiency = self.calculate_learning_efficiency();
let overfitting_indicators = self.detect_overfitting_indicators();
let underfitting_indicators = self.detect_underfitting_indicators();
let plateau_detection = self.detect_plateau();
TrainingDynamics {
convergence_status,
training_stability,
learning_efficiency,
overfitting_indicators,
underfitting_indicators,
plateau_detection,
}
}
pub fn detect_convergence_status(&self) -> ConvergenceStatus {
if self.metrics_history.len() < self.config.convergence_window {
return ConvergenceStatus::Unknown;
}
let recent_metrics: Vec<_> =
self.metrics_history.iter().rev().take(self.config.convergence_window).collect();
let losses: Vec<f64> = recent_metrics.iter().map(|m| m.loss).collect();
if self.is_converged(&losses) {
ConvergenceStatus::Converged
} else if self.is_diverging(&losses) {
ConvergenceStatus::Diverging
} else if self.is_oscillating(&losses) {
ConvergenceStatus::Oscillating
} else if self.is_plateau(&losses) {
ConvergenceStatus::Plateau
} else if self.is_converging(&losses) {
ConvergenceStatus::Converging
} else {
ConvergenceStatus::Unknown
}
}
pub fn assess_training_stability(&self) -> TrainingStability {
if self.metrics_history.len() < 10 {
return TrainingStability::Unknown;
}
let recent_losses: Vec<f64> =
self.metrics_history.iter().rev().take(20).map(|m| m.loss).collect();
let variance = self.calculate_variance(&recent_losses);
if variance > self.config.max_variance_threshold {
TrainingStability::Unstable
} else if variance > self.config.max_variance_threshold / 2.0 {
TrainingStability::HighVariance
} else {
TrainingStability::Stable
}
}
pub fn calculate_learning_efficiency(&self) -> f64 {
if self.metrics_history.len() < 2 {
return 0.0;
}
let initial_loss = self
.metrics_history
.front()
.expect("metrics_history has at least 2 elements")
.loss;
let current_loss = self
.metrics_history
.back()
.expect("metrics_history has at least 2 elements")
.loss;
let steps = self.metrics_history.len();
if initial_loss <= current_loss {
return 0.0;
}
let improvement = (initial_loss - current_loss) / initial_loss;
let efficiency = improvement / (steps as f64).sqrt();
efficiency.min(1.0)
}
pub fn detect_overfitting_indicators(&self) -> Vec<OverfittingIndicator> {
let mut indicators = Vec::new();
if self.metrics_history.len() > 10 {
let recent_losses: Vec<f64> =
self.metrics_history.iter().rev().take(10).map(|m| m.loss).collect();
let avg_loss = recent_losses.iter().sum::<f64>() / recent_losses.len() as f64;
if avg_loss < 0.01 {
indicators.push(OverfittingIndicator::PerfectTrainingAccuracy);
}
let variance = self.calculate_variance(&recent_losses);
if variance > 0.05 {
indicators.push(OverfittingIndicator::HighVarianceInValidation);
}
}
indicators
}
pub fn detect_underfitting_indicators(&self) -> Vec<UnderfittingIndicator> {
let mut indicators = Vec::new();
if let Some(current_metrics) = self.metrics_history.back() {
if current_metrics.loss > 1.0 {
indicators.push(UnderfittingIndicator::HighTrainingLoss {
loss: current_metrics.loss,
threshold: 1.0,
});
}
if let Some(accuracy) = current_metrics.accuracy {
if accuracy < 0.5 {
indicators.push(UnderfittingIndicator::LowTrainingAccuracy {
accuracy,
threshold: 0.5,
});
}
}
if self.current_state.steps_since_improvement > 50 {
indicators.push(UnderfittingIndicator::SlowConvergence {
steps_taken: self.metrics_history.len(),
expected: self.metrics_history.len() / 2,
});
}
if self.current_state.steps_since_improvement > 100 {
indicators.push(UnderfittingIndicator::NoLearning {
steps_without_improvement: self.current_state.steps_since_improvement,
});
}
}
indicators
}
pub fn detect_plateau(&self) -> Option<PlateauInfo> {
if self.metrics_history.len() < self.config.min_plateau_duration {
return None;
}
let recent_losses: Vec<f64> = self
.metrics_history
.iter()
.rev()
.take(self.config.min_plateau_duration)
.map(|m| m.loss)
.collect();
let variance = self.calculate_variance(&recent_losses);
let mean_loss = recent_losses.iter().sum::<f64>() / recent_losses.len() as f64;
if variance < self.config.min_improvement_threshold {
let start_step = self.metrics_history.len() - self.config.min_plateau_duration;
Some(PlateauInfo {
start_step,
duration_steps: self.config.min_plateau_duration,
plateau_value: mean_loss,
variance,
})
} else {
None
}
}
pub fn generate_training_recommendations(&self) -> Vec<TrainingRecommendation> {
let mut recommendations = Vec::new();
let dynamics = self.analyze_training_dynamics();
match dynamics.convergence_status {
ConvergenceStatus::Diverging => {
recommendations.push(TrainingRecommendation {
category: "Convergence".to_string(),
priority: TrainingRecommendationPriority::Critical,
description: "Training is diverging".to_string(),
action: "Reduce learning rate immediately".to_string(),
expected_impact: 0.8,
});
},
ConvergenceStatus::Plateau => {
recommendations.push(TrainingRecommendation {
category: "Convergence".to_string(),
priority: TrainingRecommendationPriority::High,
description: "Training has reached a plateau".to_string(),
action: "Consider learning rate scheduling or data augmentation".to_string(),
expected_impact: 0.6,
});
},
_ => {},
}
if let TrainingStability::Unstable = dynamics.training_stability {
recommendations.push(TrainingRecommendation {
category: "Stability".to_string(),
priority: TrainingRecommendationPriority::High,
description: "Training is unstable".to_string(),
action: "Reduce learning rate or add gradient clipping".to_string(),
expected_impact: 0.7,
});
}
if dynamics.learning_efficiency < 0.3 {
recommendations.push(TrainingRecommendation {
category: "Efficiency".to_string(),
priority: TrainingRecommendationPriority::Medium,
description: "Low learning efficiency detected".to_string(),
action: "Consider architecture changes or hyperparameter tuning".to_string(),
expected_impact: 0.5,
});
}
recommendations
}
fn is_converged(&self, losses: &[f64]) -> bool {
if losses.len() < 5 {
return false;
}
let recent_variance = self.calculate_variance(&losses[..5]);
recent_variance < self.config.min_improvement_threshold && losses[0] < 0.01
}
fn is_diverging(&self, losses: &[f64]) -> bool {
if losses.len() < 3 {
return false;
}
losses.windows(2).all(|w| w[1] >= w[0])
&& (losses.last().expect("losses has at least 3 elements")
/ losses.first().expect("losses has at least 3 elements"))
> 1.1
}
fn is_oscillating(&self, losses: &[f64]) -> bool {
if losses.len() < 6 {
return false;
}
let mut direction_changes = 0;
for window in losses.windows(3) {
let trend1 = window[1] - window[0];
let trend2 = window[2] - window[1];
if trend1.signum() != trend2.signum() {
direction_changes += 1;
}
}
direction_changes > losses.len() / 3
}
fn is_plateau(&self, losses: &[f64]) -> bool {
let variance = self.calculate_variance(losses);
variance < self.config.min_improvement_threshold
}
fn is_converging(&self, losses: &[f64]) -> bool {
if losses.len() < 3 {
return false;
}
let trend = self.calculate_trend(losses);
trend < -self.config.min_improvement_threshold
}
fn calculate_variance(&self, values: &[f64]) -> f64 {
if values.len() < 2 {
return 0.0;
}
let mean = values.iter().sum::<f64>() / values.len() as f64;
let variance =
values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
variance
}
fn calculate_trend(&self, values: &[f64]) -> f64 {
if values.len() < 2 {
return 0.0;
}
let n = values.len() as f64;
let x_mean = (n - 1.0) / 2.0;
let y_mean = values.iter().sum::<f64>() / n;
let mut numerator = 0.0;
let mut denominator = 0.0;
for (i, &y) in values.iter().enumerate() {
let x = i as f64;
numerator += (x - x_mean) * (y - y_mean);
denominator += (x - x_mean).powi(2);
}
if denominator == 0.0 {
0.0
} else {
numerator / denominator
}
}
pub fn clear(&mut self) {
self.metrics_history.clear();
self.current_state = TrainingState::default();
}
pub fn get_training_state(&self) -> &TrainingState {
&self.current_state
}
pub async fn generate_report(&self) -> anyhow::Result<TrainingDynamicsReport> {
let training_dynamics = self.analyze_training_dynamics();
let recommendations = self.generate_recommendations();
Ok(TrainingDynamicsReport {
training_dynamics,
recommendations,
current_state: self.current_state.clone(),
})
}
fn generate_recommendations(&self) -> Vec<TrainingRecommendation> {
let mut recommendations = Vec::new();
recommendations.push(TrainingRecommendation {
category: "General".to_string(),
description: "Continue monitoring training dynamics".to_string(),
action: "Monitor training progress and adjust parameters as needed".to_string(),
priority: TrainingRecommendationPriority::Low,
expected_impact: 0.1,
});
recommendations
}
}
impl Default for TrainingDynamicsAnalyzer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TrainingRecommendation {
pub category: String,
pub priority: TrainingRecommendationPriority,
pub description: String,
pub action: String,
pub expected_impact: f64,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum TrainingRecommendationPriority {
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TrainingDynamicsReport {
pub training_dynamics: TrainingDynamics,
pub recommendations: Vec<TrainingRecommendation>,
pub current_state: TrainingState,
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
fn create_test_metrics(step: usize, loss: f64) -> ModelPerformanceMetrics {
ModelPerformanceMetrics {
training_step: step,
loss,
accuracy: Some(0.8),
learning_rate: 0.001,
batch_size: 32,
throughput_samples_per_sec: 100.0,
memory_usage_mb: 1000.0,
gpu_utilization: Some(0.9),
timestamp: Utc::now(),
}
}
#[test]
fn test_training_dynamics_analyzer_creation() {
let analyzer = TrainingDynamicsAnalyzer::new();
assert_eq!(analyzer.metrics_history.len(), 0);
}
#[test]
fn test_add_metrics() {
let mut analyzer = TrainingDynamicsAnalyzer::new();
let metrics = create_test_metrics(1, 0.5);
analyzer.add_metrics(metrics);
assert_eq!(analyzer.metrics_history.len(), 1);
assert_eq!(analyzer.current_state.best_loss, 0.5);
}
#[test]
fn test_convergence_detection() {
let mut analyzer = TrainingDynamicsAnalyzer::new();
for i in 1..=25 {
let loss = 1.0 / (i as f64);
let metrics = create_test_metrics(i, loss);
analyzer.add_metrics(metrics);
}
let status = analyzer.detect_convergence_status();
matches!(
status,
ConvergenceStatus::Converging | ConvergenceStatus::Converged
);
}
#[test]
fn test_learning_efficiency_calculation() {
let mut analyzer = TrainingDynamicsAnalyzer::new();
analyzer.add_metrics(create_test_metrics(1, 1.0));
analyzer.add_metrics(create_test_metrics(2, 0.5));
analyzer.add_metrics(create_test_metrics(3, 0.25));
let efficiency = analyzer.calculate_learning_efficiency();
assert!(efficiency > 0.0);
}
#[test]
fn test_plateau_detection() {
let mut analyzer = TrainingDynamicsAnalyzer::new();
for i in 1..=15 {
let metrics = create_test_metrics(i, 0.1); analyzer.add_metrics(metrics);
}
let plateau = analyzer.detect_plateau();
assert!(plateau.is_some());
}
#[test]
fn test_training_stability_assessment() {
let mut analyzer = TrainingDynamicsAnalyzer::new();
for i in 1..=20 {
let loss = 0.5 + (i as f64 * 0.001); let metrics = create_test_metrics(i, loss);
analyzer.add_metrics(metrics);
}
let stability = analyzer.assess_training_stability();
matches!(stability, TrainingStability::Stable);
}
}