#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct EpochMetrics {
pub epoch: usize,
pub train_loss: f64,
pub val_loss: Option<f64>,
pub train_metric: Option<f64>,
pub val_metric: Option<f64>,
pub learning_rate: f64,
pub grad_norm: f64,
pub elapsed_ms: u64,
}
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct TrainingHistory {
pub epochs: Vec<EpochMetrics>,
}
impl TrainingHistory {
pub fn new() -> Self {
Self { epochs: Vec::new() }
}
pub fn push(&mut self, metrics: EpochMetrics) {
self.epochs.push(metrics);
}
pub fn len(&self) -> usize {
self.epochs.len()
}
pub fn is_empty(&self) -> bool {
self.epochs.is_empty()
}
pub fn train_losses(&self) -> Vec<f64> {
self.epochs.iter().map(|e| e.train_loss).collect()
}
pub fn val_losses(&self) -> Vec<f64> {
self.epochs.iter().filter_map(|e| e.val_loss).collect()
}
pub fn train_metrics(&self) -> Vec<f64> {
self.epochs.iter().filter_map(|e| e.train_metric).collect()
}
pub fn val_metrics(&self) -> Vec<f64> {
self.epochs.iter().filter_map(|e| e.val_metric).collect()
}
pub fn grad_norms(&self) -> Vec<f64> {
self.epochs.iter().map(|e| e.grad_norm).collect()
}
pub fn learning_rates(&self) -> Vec<f64> {
self.epochs.iter().map(|e| e.learning_rate).collect()
}
pub fn epoch_times_ms(&self) -> Vec<u64> {
self.epochs.iter().map(|e| e.elapsed_ms).collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CallbackAction {
Continue,
Stop,
}
pub trait TrainingCallback: Send + Sync {
fn on_epoch_end(&mut self, metrics: &EpochMetrics) -> CallbackAction;
fn on_training_end(&mut self) {}
}
pub(crate) fn compute_grad_norm(grads: &[(Vec<f64>, Vec<f64>)]) -> f64 {
let mut sum_sq = 0.0;
for (dw, db) in grads {
for &g in dw {
sum_sq += g * g;
}
for &g in db {
sum_sq += g * g;
}
}
sum_sq.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn history_accumulates() {
let mut h = TrainingHistory::new();
assert!(h.is_empty());
h.push(EpochMetrics {
epoch: 0,
train_loss: 1.5,
val_loss: Some(1.8),
train_metric: Some(0.6),
val_metric: Some(0.55),
learning_rate: 0.001,
grad_norm: 2.3,
elapsed_ms: 42,
});
h.push(EpochMetrics {
epoch: 1,
train_loss: 1.2,
val_loss: Some(1.4),
train_metric: Some(0.7),
val_metric: Some(0.65),
learning_rate: 0.001,
grad_norm: 1.8,
elapsed_ms: 38,
});
assert_eq!(h.len(), 2);
assert_eq!(h.train_losses(), vec![1.5, 1.2]);
assert_eq!(h.val_losses(), vec![1.8, 1.4]);
assert_eq!(h.train_metrics(), vec![0.6, 0.7]);
assert_eq!(h.grad_norms(), vec![2.3, 1.8]);
}
#[test]
fn history_without_validation() {
let mut h = TrainingHistory::new();
h.push(EpochMetrics {
epoch: 0,
train_loss: 1.0,
val_loss: None,
train_metric: None,
val_metric: None,
learning_rate: 0.01,
grad_norm: 5.0,
elapsed_ms: 10,
});
assert!(h.val_losses().is_empty());
assert!(h.val_metrics().is_empty());
assert_eq!(h.train_losses(), vec![1.0]);
}
#[test]
fn grad_norm_basic() {
let grads = vec![
(vec![3.0, 4.0], vec![0.0]), ];
let norm = compute_grad_norm(&grads);
assert!((norm - 5.0).abs() < 1e-10);
}
#[test]
fn grad_norm_multi_layer() {
let grads = vec![(vec![1.0, 0.0], vec![0.0]), (vec![0.0, 0.0], vec![2.0])];
let norm = compute_grad_norm(&grads);
assert!((norm - 5.0_f64.sqrt()).abs() < 1e-10);
}
#[test]
fn callback_action() {
struct StopAt3;
impl TrainingCallback for StopAt3 {
fn on_epoch_end(&mut self, m: &EpochMetrics) -> CallbackAction {
if m.epoch >= 3 {
CallbackAction::Stop
} else {
CallbackAction::Continue
}
}
}
let mut cb = StopAt3;
let m = EpochMetrics {
epoch: 2,
train_loss: 0.0,
val_loss: None,
train_metric: None,
val_metric: None,
learning_rate: 0.0,
grad_norm: 0.0,
elapsed_ms: 0,
};
assert_eq!(cb.on_epoch_end(&m), CallbackAction::Continue);
let m = EpochMetrics { epoch: 3, ..m };
assert_eq!(cb.on_epoch_end(&m), CallbackAction::Stop);
}
}