Skip to main content

entrenar/ecosystem/ruchy/
metrics.rs

1//! Session metrics for training history tracking.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6/// Training metrics from a session.
7#[derive(Debug, Clone, Default, Serialize, Deserialize)]
8pub struct SessionMetrics {
9    /// Loss values over time
10    pub loss_history: Vec<f64>,
11    /// Accuracy values over time (optional)
12    pub accuracy_history: Vec<f64>,
13    /// Learning rate schedule
14    pub lr_history: Vec<f64>,
15    /// Gradient norms (for debugging)
16    pub grad_norm_history: Vec<f64>,
17    /// Custom metrics
18    pub custom: HashMap<String, Vec<f64>>,
19}
20
21impl SessionMetrics {
22    /// Create empty metrics.
23    pub fn new() -> Self {
24        Self::default()
25    }
26
27    /// Add a loss value.
28    pub fn add_loss(&mut self, loss: f64) {
29        self.loss_history.push(loss);
30    }
31
32    /// Add an accuracy value.
33    pub fn add_accuracy(&mut self, accuracy: f64) {
34        self.accuracy_history.push(accuracy);
35    }
36
37    /// Add a learning rate value.
38    pub fn add_lr(&mut self, lr: f64) {
39        self.lr_history.push(lr);
40    }
41
42    /// Add a gradient norm value.
43    pub fn add_grad_norm(&mut self, norm: f64) {
44        self.grad_norm_history.push(norm);
45    }
46
47    /// Add a custom metric value.
48    pub fn add_custom(&mut self, name: impl Into<String>, value: f64) {
49        self.custom.entry(name.into()).or_default().push(value);
50    }
51
52    /// Get final loss (last value).
53    pub fn final_loss(&self) -> Option<f64> {
54        self.loss_history.last().copied()
55    }
56
57    /// Get final accuracy (last value).
58    pub fn final_accuracy(&self) -> Option<f64> {
59        self.accuracy_history.last().copied()
60    }
61
62    /// Get best loss (minimum).
63    pub fn best_loss(&self) -> Option<f64> {
64        self.loss_history.iter().copied().reduce(f64::min)
65    }
66
67    /// Get best accuracy (maximum).
68    pub fn best_accuracy(&self) -> Option<f64> {
69        self.accuracy_history.iter().copied().reduce(f64::max)
70    }
71
72    /// Get total training steps.
73    pub fn total_steps(&self) -> usize {
74        self.loss_history.len()
75    }
76
77    /// Check if metrics are empty.
78    pub fn is_empty(&self) -> bool {
79        self.loss_history.is_empty() && self.accuracy_history.is_empty() && self.custom.is_empty()
80    }
81}