Skip to main content

scry_learn/neural/
callback.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Training callback system — structured per-epoch metrics and user hooks.
3//!
4//! Every iterative model (MLP, gradient boosting) populates a
5//! [`TrainingHistory`] during `fit()`. Users can also inject custom
6//! [`TrainingCallback`] implementations to log, visualize, or early-stop
7//! based on arbitrary criteria.
8//!
9//! # Example
10//!
11//! ```ignore
12//! let mut clf = MLPClassifier::new().max_iter(100);
13//! clf.fit(&train)?;
14//!
15//! let history = clf.history().unwrap();
16//! println!("Final loss: {:.4}", history.epochs.last().unwrap().train_loss);
17//! ```
18
19/// Snapshot of metrics at the end of one training epoch.
20#[derive(Debug, Clone)]
21#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22pub struct EpochMetrics {
23    /// Zero-indexed epoch number.
24    pub epoch: usize,
25    /// Mean training loss for this epoch.
26    pub train_loss: f64,
27    /// Validation loss (only when early stopping is enabled).
28    pub val_loss: Option<f64>,
29    /// Training accuracy (classification) or R² (regression).
30    pub train_metric: Option<f64>,
31    /// Validation accuracy / R² (only when early stopping is enabled).
32    pub val_metric: Option<f64>,
33    /// Current learning rate.
34    pub learning_rate: f64,
35    /// L2 norm of all parameter gradients (detects vanishing/exploding).
36    pub grad_norm: f64,
37    /// Wall-clock milliseconds for this epoch.
38    pub elapsed_ms: u64,
39}
40
41/// Accumulated history of training metrics — returned after `fit()`.
42///
43/// Access via `model.history()` on any iterative model.
44#[derive(Debug, Clone, Default)]
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46pub struct TrainingHistory {
47    /// Per-epoch snapshots.
48    pub epochs: Vec<EpochMetrics>,
49}
50
51impl TrainingHistory {
52    /// Create a new empty history.
53    pub fn new() -> Self {
54        Self { epochs: Vec::new() }
55    }
56
57    /// Push a new epoch snapshot.
58    pub fn push(&mut self, metrics: EpochMetrics) {
59        self.epochs.push(metrics);
60    }
61
62    /// Number of recorded epochs.
63    pub fn len(&self) -> usize {
64        self.epochs.len()
65    }
66
67    /// Whether the history is empty.
68    pub fn is_empty(&self) -> bool {
69        self.epochs.is_empty()
70    }
71
72    /// Training loss per epoch.
73    pub fn train_losses(&self) -> Vec<f64> {
74        self.epochs.iter().map(|e| e.train_loss).collect()
75    }
76
77    /// Validation loss per epoch (only epochs that have it).
78    pub fn val_losses(&self) -> Vec<f64> {
79        self.epochs.iter().filter_map(|e| e.val_loss).collect()
80    }
81
82    /// Training metric (accuracy / R²) per epoch.
83    pub fn train_metrics(&self) -> Vec<f64> {
84        self.epochs.iter().filter_map(|e| e.train_metric).collect()
85    }
86
87    /// Validation metric per epoch.
88    pub fn val_metrics(&self) -> Vec<f64> {
89        self.epochs.iter().filter_map(|e| e.val_metric).collect()
90    }
91
92    /// Gradient L2 norm per epoch.
93    pub fn grad_norms(&self) -> Vec<f64> {
94        self.epochs.iter().map(|e| e.grad_norm).collect()
95    }
96
97    /// Learning rate per epoch.
98    pub fn learning_rates(&self) -> Vec<f64> {
99        self.epochs.iter().map(|e| e.learning_rate).collect()
100    }
101
102    /// Wall-clock milliseconds per epoch.
103    pub fn epoch_times_ms(&self) -> Vec<u64> {
104        self.epochs.iter().map(|e| e.elapsed_ms).collect()
105    }
106}
107
108/// Action returned by a [`TrainingCallback`] to control training flow.
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110pub enum CallbackAction {
111    /// Continue training normally.
112    Continue,
113    /// Stop training immediately (user-driven early stop).
114    Stop,
115}
116
117/// Trait for user-supplied training callbacks.
118///
119/// Implement this to add custom logging, checkpointing, or stopping
120/// criteria during training.
121///
122/// # Example
123///
124/// ```ignore
125/// struct PrintLogger;
126///
127/// impl TrainingCallback for PrintLogger {
128///     fn on_epoch_end(&mut self, metrics: &EpochMetrics) -> CallbackAction {
129///         println!("Epoch {}: loss={:.4}", metrics.epoch, metrics.train_loss);
130///         CallbackAction::Continue
131///     }
132/// }
133/// ```
134pub trait TrainingCallback: Send + Sync {
135    /// Called at the end of each training epoch.
136    ///
137    /// Return [`CallbackAction::Stop`] to halt training early.
138    fn on_epoch_end(&mut self, metrics: &EpochMetrics) -> CallbackAction;
139
140    /// Called when training finishes (after the last epoch).
141    ///
142    /// Default implementation does nothing.
143    fn on_training_end(&mut self) {}
144}
145
146/// Compute the L2 norm of all gradients across layers.
147///
148/// `grads` is the list of `(weight_grads, bias_grads)` per layer.
149pub(crate) fn compute_grad_norm(grads: &[(Vec<f64>, Vec<f64>)]) -> f64 {
150    let mut sum_sq = 0.0;
151    for (dw, db) in grads {
152        for &g in dw {
153            sum_sq += g * g;
154        }
155        for &g in db {
156            sum_sq += g * g;
157        }
158    }
159    sum_sq.sqrt()
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn history_accumulates() {
168        let mut h = TrainingHistory::new();
169        assert!(h.is_empty());
170
171        h.push(EpochMetrics {
172            epoch: 0,
173            train_loss: 1.5,
174            val_loss: Some(1.8),
175            train_metric: Some(0.6),
176            val_metric: Some(0.55),
177            learning_rate: 0.001,
178            grad_norm: 2.3,
179            elapsed_ms: 42,
180        });
181        h.push(EpochMetrics {
182            epoch: 1,
183            train_loss: 1.2,
184            val_loss: Some(1.4),
185            train_metric: Some(0.7),
186            val_metric: Some(0.65),
187            learning_rate: 0.001,
188            grad_norm: 1.8,
189            elapsed_ms: 38,
190        });
191
192        assert_eq!(h.len(), 2);
193        assert_eq!(h.train_losses(), vec![1.5, 1.2]);
194        assert_eq!(h.val_losses(), vec![1.8, 1.4]);
195        assert_eq!(h.train_metrics(), vec![0.6, 0.7]);
196        assert_eq!(h.grad_norms(), vec![2.3, 1.8]);
197    }
198
199    #[test]
200    fn history_without_validation() {
201        let mut h = TrainingHistory::new();
202        h.push(EpochMetrics {
203            epoch: 0,
204            train_loss: 1.0,
205            val_loss: None,
206            train_metric: None,
207            val_metric: None,
208            learning_rate: 0.01,
209            grad_norm: 5.0,
210            elapsed_ms: 10,
211        });
212
213        assert!(h.val_losses().is_empty());
214        assert!(h.val_metrics().is_empty());
215        assert_eq!(h.train_losses(), vec![1.0]);
216    }
217
218    #[test]
219    fn grad_norm_basic() {
220        let grads = vec![
221            (vec![3.0, 4.0], vec![0.0]), // sqrt(9+16+0) = 5
222        ];
223        let norm = compute_grad_norm(&grads);
224        assert!((norm - 5.0).abs() < 1e-10);
225    }
226
227    #[test]
228    fn grad_norm_multi_layer() {
229        let grads = vec![(vec![1.0, 0.0], vec![0.0]), (vec![0.0, 0.0], vec![2.0])];
230        // sqrt(1 + 0 + 0 + 0 + 0 + 4) = sqrt(5)
231        let norm = compute_grad_norm(&grads);
232        assert!((norm - 5.0_f64.sqrt()).abs() < 1e-10);
233    }
234
235    #[test]
236    fn callback_action() {
237        struct StopAt3;
238        impl TrainingCallback for StopAt3 {
239            fn on_epoch_end(&mut self, m: &EpochMetrics) -> CallbackAction {
240                if m.epoch >= 3 {
241                    CallbackAction::Stop
242                } else {
243                    CallbackAction::Continue
244                }
245            }
246        }
247
248        let mut cb = StopAt3;
249        let m = EpochMetrics {
250            epoch: 2,
251            train_loss: 0.0,
252            val_loss: None,
253            train_metric: None,
254            val_metric: None,
255            learning_rate: 0.0,
256            grad_norm: 0.0,
257            elapsed_ms: 0,
258        };
259        assert_eq!(cb.on_epoch_end(&m), CallbackAction::Continue);
260
261        let m = EpochMetrics { epoch: 3, ..m };
262        assert_eq!(cb.on_epoch_end(&m), CallbackAction::Stop);
263    }
264}