#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct EpochMetrics {
pub epoch: usize,
pub train_loss: f32,
pub val_loss: Option<f32>,
pub train_accuracy: f32,
pub val_accuracy: Option<f32>,
pub learning_rate: f32,
pub duration_secs: f32,
pub examples_seen: usize,
}
impl EpochMetrics {
pub fn new(epoch: usize) -> Self {
Self {
epoch,
train_loss: 0.0,
val_loss: None,
train_accuracy: 0.0,
val_accuracy: None,
learning_rate: 0.0,
duration_secs: 0.0,
examples_seen: 0,
}
}
pub fn has_validation(&self) -> bool {
self.val_loss.is_some()
}
pub fn get_metric(&self, name: &str) -> Option<f32> {
match name {
"train_loss" => Some(self.train_loss),
"val_loss" => self.val_loss,
"train_accuracy" => Some(self.train_accuracy),
"val_accuracy" => self.val_accuracy,
_ => None,
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct TrainingMetrics {
pub history: Vec<EpochMetrics>,
pub best_val_loss: Option<f32>,
pub best_epoch: Option<usize>,
pub final_train_loss: f32,
pub final_val_loss: Option<f32>,
pub total_time_secs: f32,
pub early_stopped: bool,
pub epochs_completed: usize,
}
impl Default for TrainingMetrics {
fn default() -> Self {
Self {
history: Vec::new(),
best_val_loss: None,
best_epoch: None,
final_train_loss: 0.0,
final_val_loss: None,
total_time_secs: 0.0,
early_stopped: false,
epochs_completed: 0,
}
}
}
impl TrainingMetrics {
pub fn add_epoch(&mut self, metrics: EpochMetrics) {
if let Some(val_loss) = metrics.val_loss {
if self.best_val_loss.is_none_or(|best| val_loss < best) {
self.best_val_loss = Some(val_loss);
self.best_epoch = Some(metrics.epoch);
}
}
self.final_train_loss = metrics.train_loss;
self.final_val_loss = metrics.val_loss;
self.total_time_secs += metrics.duration_secs;
self.epochs_completed = metrics.epoch + 1;
self.history.push(metrics);
}
pub fn train_loss_history(&self) -> Vec<f32> {
self.history.iter().map(|m| m.train_loss).collect()
}
pub fn val_loss_history(&self) -> Vec<f32> {
self.history.iter().filter_map(|m| m.val_loss).collect()
}
pub fn is_overfitting(&self, window: usize) -> bool {
if self.history.len() < window * 2 {
return false;
}
let recent = &self.history[self.history.len() - window..];
let earlier = &self.history[self.history.len() - window * 2..self.history.len() - window];
let recent_train: f32 = recent.iter().map(|m| m.train_loss).sum::<f32>() / window as f32;
let earlier_train: f32 = earlier.iter().map(|m| m.train_loss).sum::<f32>() / window as f32;
let recent_val: f32 = recent.iter().filter_map(|m| m.val_loss).sum::<f32>() / window as f32;
let earlier_val: f32 =
earlier.iter().filter_map(|m| m.val_loss).sum::<f32>() / window as f32;
recent_train < earlier_train && recent_val > earlier_val
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_epoch_metrics() {
let mut metrics = EpochMetrics::new(0);
metrics.train_loss = 0.5;
metrics.val_loss = Some(0.6);
metrics.train_accuracy = 0.8;
assert!(metrics.has_validation());
assert_eq!(metrics.get_metric("train_loss"), Some(0.5));
assert_eq!(metrics.get_metric("val_loss"), Some(0.6));
}
#[test]
fn test_training_metrics_history() {
let mut metrics = TrainingMetrics::default();
for i in 0..5 {
let mut epoch_metrics = EpochMetrics::new(i);
epoch_metrics.train_loss = 1.0 - 0.1 * i as f32;
epoch_metrics.val_loss = Some(1.1 - 0.1 * i as f32);
metrics.add_epoch(epoch_metrics);
}
assert_eq!(metrics.history.len(), 5);
assert_eq!(metrics.best_epoch, Some(4));
assert_eq!(metrics.epochs_completed, 5);
}
}