use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SessionMetrics {
pub loss_history: Vec<f64>,
pub accuracy_history: Vec<f64>,
pub lr_history: Vec<f64>,
pub grad_norm_history: Vec<f64>,
pub custom: HashMap<String, Vec<f64>>,
}
impl SessionMetrics {
pub fn new() -> Self {
Self::default()
}
pub fn add_loss(&mut self, loss: f64) {
self.loss_history.push(loss);
}
pub fn add_accuracy(&mut self, accuracy: f64) {
self.accuracy_history.push(accuracy);
}
pub fn add_lr(&mut self, lr: f64) {
self.lr_history.push(lr);
}
pub fn add_grad_norm(&mut self, norm: f64) {
self.grad_norm_history.push(norm);
}
pub fn add_custom(&mut self, name: impl Into<String>, value: f64) {
self.custom.entry(name.into()).or_default().push(value);
}
pub fn final_loss(&self) -> Option<f64> {
self.loss_history.last().copied()
}
pub fn final_accuracy(&self) -> Option<f64> {
self.accuracy_history.last().copied()
}
pub fn best_loss(&self) -> Option<f64> {
self.loss_history.iter().copied().reduce(f64::min)
}
pub fn best_accuracy(&self) -> Option<f64> {
self.accuracy_history.iter().copied().reduce(f64::max)
}
pub fn total_steps(&self) -> usize {
self.loss_history.len()
}
pub fn is_empty(&self) -> bool {
self.loss_history.is_empty() && self.accuracy_history.is_empty() && self.custom.is_empty()
}
}