use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::fmt;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
pub epochs: usize,
pub batch_size: usize,
pub verbose: bool,
pub early_stopping_patience: usize,
pub early_stopping_threshold: f32,
pub lr_schedule: Option<LearningRateSchedule>,
pub validation_split: f32,
pub shuffle: bool,
pub random_seed: Option<u64>,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
epochs: 100,
batch_size: 32,
verbose: true,
early_stopping_patience: 0,
early_stopping_threshold: 1e-4,
lr_schedule: None,
validation_split: 0.0,
shuffle: true,
random_seed: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LearningRateSchedule {
StepLR {
step_size: usize,
gamma: f32,
},
ExponentialLR {
gamma: f32,
},
ReduceOnPlateau {
factor: f32,
patience: usize,
threshold: f32,
min_lr: f32,
},
CosineAnnealingLR {
t_max: usize,
eta_min: f32,
},
PolynomialLR {
total_epochs: usize,
power: f32,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingMetrics {
pub loss: f32,
pub accuracy: f32,
pub learning_rate: f32,
pub epoch_time_ms: f32,
pub val_loss: Option<f32>,
pub val_accuracy: Option<f32>,
pub custom_metrics: std::collections::HashMap<String, f32>,
}
impl TrainingMetrics {
pub fn new(loss: f32, accuracy: f32, learning_rate: f32, epoch_time_ms: f32) -> Self {
Self {
loss,
accuracy,
learning_rate,
epoch_time_ms,
val_loss: None,
val_accuracy: None,
custom_metrics: std::collections::HashMap::new(),
}
}
pub fn add_metric<S: Into<String>>(&mut self, name: S, value: f32) {
self.custom_metrics.insert(name.into(), value);
}
pub fn get_metric(&self, name: &str) -> Option<f32> {
self.custom_metrics.get(name).copied()
}
pub fn set_validation(&mut self, val_loss: f32, val_accuracy: f32) {
self.val_loss = Some(val_loss);
self.val_accuracy = Some(val_accuracy);
}
}
impl fmt::Display for TrainingMetrics {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Loss: {:.6}, Accuracy: {:.4}, LR: {:.6}, Time: {:.2}ms",
self.loss, self.accuracy, self.learning_rate, self.epoch_time_ms
)?;
if let (Some(val_loss), Some(val_acc)) = (self.val_loss, self.val_accuracy) {
write!(f, ", Val Loss: {:.6}, Val Acc: {:.4}", val_loss, val_acc)?;
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingHistory {
epochs: Vec<TrainingMetrics>,
best_loss: f32,
best_accuracy: f32,
best_loss_epoch: usize,
best_accuracy_epoch: usize,
early_stopping_counter: usize,
}
impl TrainingHistory {
pub fn new() -> Self {
Self {
epochs: Vec::new(),
best_loss: f32::INFINITY,
best_accuracy: 0.0,
best_loss_epoch: 0,
best_accuracy_epoch: 0,
early_stopping_counter: 0,
}
}
pub fn add_epoch(&mut self, metrics: TrainingMetrics) {
let epoch = self.epochs.len();
if metrics.loss < self.best_loss {
self.best_loss = metrics.loss;
self.best_loss_epoch = epoch;
self.early_stopping_counter = 0; } else {
self.early_stopping_counter += 1;
}
if metrics.accuracy > self.best_accuracy {
self.best_accuracy = metrics.accuracy;
self.best_accuracy_epoch = epoch;
}
self.epochs.push(metrics);
}
pub fn epochs(&self) -> usize {
self.epochs.len()
}
pub fn get_epoch(&self, epoch: usize) -> Option<&TrainingMetrics> {
self.epochs.get(epoch)
}
pub fn all_epochs(&self) -> &[TrainingMetrics] {
&self.epochs
}
pub fn latest(&self) -> Option<&TrainingMetrics> {
self.epochs.last()
}
pub fn final_loss(&self) -> f32 {
self.epochs.last().map(|m| m.loss).unwrap_or(f32::INFINITY)
}
pub fn final_accuracy(&self) -> f32 {
self.epochs.last().map(|m| m.accuracy).unwrap_or(0.0)
}
pub fn best_loss(&self) -> f32 {
self.best_loss
}
pub fn best_accuracy(&self) -> f32 {
self.best_accuracy
}
pub fn best_loss_epoch(&self) -> usize {
self.best_loss_epoch
}
pub fn best_accuracy_epoch(&self) -> usize {
self.best_accuracy_epoch
}
pub fn should_early_stop(&self, patience: usize, threshold: f32) -> bool {
if patience == 0 {
return false;
}
self.early_stopping_counter >= patience && self.best_loss > threshold
}
pub fn loss_history(&self) -> Vec<f32> {
self.epochs.iter().map(|m| m.loss).collect()
}
pub fn accuracy_history(&self) -> Vec<f32> {
self.epochs.iter().map(|m| m.accuracy).collect()
}
pub fn lr_history(&self) -> Vec<f32> {
self.epochs.iter().map(|m| m.learning_rate).collect()
}
pub fn val_loss_history(&self) -> Vec<f32> {
self.epochs.iter().filter_map(|m| m.val_loss).collect()
}
pub fn val_accuracy_history(&self) -> Vec<f32> {
self.epochs.iter().filter_map(|m| m.val_accuracy).collect()
}
pub fn average_loss(&self, n: usize) -> f32 {
if self.epochs.is_empty() {
return f32::INFINITY;
}
let start = self.epochs.len().saturating_sub(n);
let losses: Vec<f32> = self.epochs[start..].iter().map(|m| m.loss).collect();
losses.iter().sum::<f32>() / losses.len() as f32
}
pub fn average_accuracy(&self, n: usize) -> f32 {
if self.epochs.is_empty() {
return 0.0;
}
let start = self.epochs.len().saturating_sub(n);
let accuracies: Vec<f32> = self.epochs[start..].iter().map(|m| m.accuracy).collect();
accuracies.iter().sum::<f32>() / accuracies.len() as f32
}
pub fn is_improving(&self, window: usize) -> bool {
if self.epochs.len() < window * 2 {
return true; }
let recent_avg = self.average_loss(window);
let older_avg = self.average_loss(window * 2) - recent_avg;
recent_avg < older_avg
}
pub fn summary(&self) -> TrainingSummary {
TrainingSummary {
total_epochs: self.epochs(),
best_loss: self.best_loss,
best_accuracy: self.best_accuracy,
final_loss: self.final_loss(),
final_accuracy: self.final_accuracy(),
best_loss_epoch: self.best_loss_epoch,
best_accuracy_epoch: self.best_accuracy_epoch,
total_time_ms: self.epochs.iter().map(|m| m.epoch_time_ms).sum(),
average_epoch_time_ms: if self.epochs.is_empty() {
0.0
} else {
self.epochs.iter().map(|m| m.epoch_time_ms).sum::<f32>() / self.epochs.len() as f32
},
}
}
}
impl Default for TrainingHistory {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingSummary {
pub total_epochs: usize,
pub best_loss: f32,
pub best_accuracy: f32,
pub final_loss: f32,
pub final_accuracy: f32,
pub best_loss_epoch: usize,
pub best_accuracy_epoch: usize,
pub total_time_ms: f32,
pub average_epoch_time_ms: f32,
}
impl fmt::Display for TrainingSummary {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Training Summary")?;
writeln!(f, "===============")?;
writeln!(f, "Total Epochs: {}", self.total_epochs)?;
writeln!(
f,
"Best Loss: {:.6} (epoch {})",
self.best_loss, self.best_loss_epoch
)?;
writeln!(
f,
"Best Accuracy: {:.4} (epoch {})",
self.best_accuracy, self.best_accuracy_epoch
)?;
writeln!(f, "Final Loss: {:.6}", self.final_loss)?;
writeln!(f, "Final Accuracy: {:.4}", self.final_accuracy)?;
writeln!(f, "Total Time: {:.2}s", self.total_time_ms / 1000.0)?;
writeln!(f, "Average Epoch Time: {:.2}ms", self.average_epoch_time_ms)?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MovingAverage {
window_size: usize,
values: VecDeque<f32>,
sum: f32,
}
impl MovingAverage {
pub fn new(window_size: usize) -> Self {
Self {
window_size,
values: VecDeque::new(),
sum: 0.0,
}
}
pub fn update(&mut self, value: f32) -> f32 {
self.values.push_back(value);
self.sum += value;
if self.values.len() > self.window_size {
if let Some(old_value) = self.values.pop_front() {
self.sum -= old_value;
}
}
self.average()
}
pub fn average(&self) -> f32 {
if self.values.is_empty() {
0.0
} else {
self.sum / self.values.len() as f32
}
}
pub fn reset(&mut self) {
self.values.clear();
self.sum = 0.0;
}
}
pub mod utils {
pub fn ema(current: f32, previous: f32, alpha: f32) -> f32 {
alpha * current + (1.0 - alpha) * previous
}
pub fn cosine_annealing_lr(
initial_lr: f32,
current_epoch: usize,
total_epochs: usize,
eta_min: f32,
) -> f32 {
if total_epochs == 0 {
return initial_lr;
}
let progress = current_epoch as f32 / total_epochs as f32;
eta_min + (initial_lr - eta_min) * (1.0 + (std::f32::consts::PI * progress).cos()) / 2.0
}
pub fn polynomial_lr(
initial_lr: f32,
current_epoch: usize,
total_epochs: usize,
power: f32,
) -> f32 {
if total_epochs == 0 {
return initial_lr;
}
let progress = (current_epoch as f32 / total_epochs as f32).min(1.0);
initial_lr * (1.0 - progress).powf(power)
}
pub fn warmup_lr(target_lr: f32, current_step: usize, warmup_steps: usize) -> f32 {
if warmup_steps == 0 {
return target_lr;
}
let progress = (current_step as f32 / warmup_steps as f32).min(1.0);
target_lr * progress
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_training_config_default() {
let config = TrainingConfig::default();
assert_eq!(config.epochs, 100);
assert_eq!(config.batch_size, 32);
assert!(config.verbose);
assert_eq!(config.early_stopping_patience, 0);
}
#[test]
fn test_training_metrics() {
let mut metrics = TrainingMetrics::new(0.5, 0.8, 0.001, 100.0);
assert_eq!(metrics.loss, 0.5);
assert_eq!(metrics.accuracy, 0.8);
assert_eq!(metrics.learning_rate, 0.001);
metrics.add_metric("precision", 0.85);
assert_eq!(metrics.get_metric("precision"), Some(0.85));
metrics.set_validation(0.6, 0.75);
assert_eq!(metrics.val_loss, Some(0.6));
assert_eq!(metrics.val_accuracy, Some(0.75));
}
#[test]
fn test_training_history() {
let mut history = TrainingHistory::new();
let metrics1 = TrainingMetrics::new(1.0, 0.5, 0.01, 100.0);
history.add_epoch(metrics1);
assert_eq!(history.epochs(), 1);
assert_eq!(history.best_loss(), 1.0);
assert_eq!(history.best_accuracy(), 0.5);
let metrics2 = TrainingMetrics::new(0.8, 0.7, 0.01, 100.0);
history.add_epoch(metrics2);
assert_eq!(history.epochs(), 2);
assert_eq!(history.best_loss(), 0.8);
assert_eq!(history.best_accuracy(), 0.7);
assert_eq!(history.best_loss_epoch(), 1);
assert_eq!(history.best_accuracy_epoch(), 1);
}
#[test]
fn test_early_stopping() {
let mut history = TrainingHistory::new();
for i in 0..5 {
let loss = 1.0 + i as f32 * 0.1; let metrics = TrainingMetrics::new(loss, 0.5, 0.01, 100.0);
history.add_epoch(metrics);
}
assert!(history.should_early_stop(3, 0.5));
assert!(!history.should_early_stop(10, 0.5)); }
#[test]
fn test_moving_average() {
let mut ma = MovingAverage::new(3);
assert_eq!(ma.update(1.0), 1.0);
assert_eq!(ma.update(2.0), 1.5);
assert_eq!(ma.update(3.0), 2.0);
assert_eq!(ma.update(4.0), 3.0);
ma.reset();
assert_eq!(ma.average(), 0.0);
}
#[test]
fn test_cosine_annealing_lr() {
let initial_lr = 0.1;
let total_epochs = 100;
let eta_min = 0.001;
let lr_0 = utils::cosine_annealing_lr(initial_lr, 0, total_epochs, eta_min);
assert!((lr_0 - initial_lr).abs() < 1e-6);
let lr_50 = utils::cosine_annealing_lr(initial_lr, 50, total_epochs, eta_min);
assert!(lr_50 > eta_min && lr_50 < initial_lr);
let lr_100 = utils::cosine_annealing_lr(initial_lr, 100, total_epochs, eta_min);
assert!((lr_100 - eta_min).abs() < 1e-6);
}
#[test]
fn test_polynomial_lr() {
let initial_lr = 0.1;
let total_epochs = 100;
let power = 2.0;
let lr_0 = utils::polynomial_lr(initial_lr, 0, total_epochs, power);
assert!((lr_0 - initial_lr).abs() < 1e-6);
let lr_100 = utils::polynomial_lr(initial_lr, 100, total_epochs, power);
assert!(lr_100 < 1e-6);
}
#[test]
fn test_warmup_lr() {
let target_lr = 0.01;
let warmup_steps = 1000;
let lr_0 = utils::warmup_lr(target_lr, 0, warmup_steps);
assert!(lr_0 < 1e-6);
let lr_500 = utils::warmup_lr(target_lr, 500, warmup_steps);
assert!((lr_500 - target_lr / 2.0).abs() < 1e-6);
let lr_1000 = utils::warmup_lr(target_lr, 1000, warmup_steps);
assert!((lr_1000 - target_lr).abs() < 1e-6);
}
#[test]
fn test_training_summary() {
let mut history = TrainingHistory::new();
for i in 0..5 {
let loss = 1.0 - i as f32 * 0.1; let accuracy = 0.5 + i as f32 * 0.1; let metrics = TrainingMetrics::new(loss, accuracy, 0.01, 100.0);
history.add_epoch(metrics);
}
let summary = history.summary();
assert_eq!(summary.total_epochs, 5);
assert_eq!(summary.best_loss, 0.6); assert_eq!(summary.best_accuracy, 0.9); assert_eq!(summary.total_time_ms, 500.0); assert_eq!(summary.average_epoch_time_ms, 100.0);
}
}