scry_learn/neural/
callback.rs1#[derive(Debug, Clone)]
21#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22pub struct EpochMetrics {
23 pub epoch: usize,
25 pub train_loss: f64,
27 pub val_loss: Option<f64>,
29 pub train_metric: Option<f64>,
31 pub val_metric: Option<f64>,
33 pub learning_rate: f64,
35 pub grad_norm: f64,
37 pub elapsed_ms: u64,
39}
40
41#[derive(Debug, Clone, Default)]
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46pub struct TrainingHistory {
47 pub epochs: Vec<EpochMetrics>,
49}
50
51impl TrainingHistory {
52 pub fn new() -> Self {
54 Self { epochs: Vec::new() }
55 }
56
57 pub fn push(&mut self, metrics: EpochMetrics) {
59 self.epochs.push(metrics);
60 }
61
62 pub fn len(&self) -> usize {
64 self.epochs.len()
65 }
66
67 pub fn is_empty(&self) -> bool {
69 self.epochs.is_empty()
70 }
71
72 pub fn train_losses(&self) -> Vec<f64> {
74 self.epochs.iter().map(|e| e.train_loss).collect()
75 }
76
77 pub fn val_losses(&self) -> Vec<f64> {
79 self.epochs.iter().filter_map(|e| e.val_loss).collect()
80 }
81
82 pub fn train_metrics(&self) -> Vec<f64> {
84 self.epochs.iter().filter_map(|e| e.train_metric).collect()
85 }
86
87 pub fn val_metrics(&self) -> Vec<f64> {
89 self.epochs.iter().filter_map(|e| e.val_metric).collect()
90 }
91
92 pub fn grad_norms(&self) -> Vec<f64> {
94 self.epochs.iter().map(|e| e.grad_norm).collect()
95 }
96
97 pub fn learning_rates(&self) -> Vec<f64> {
99 self.epochs.iter().map(|e| e.learning_rate).collect()
100 }
101
102 pub fn epoch_times_ms(&self) -> Vec<u64> {
104 self.epochs.iter().map(|e| e.elapsed_ms).collect()
105 }
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110pub enum CallbackAction {
111 Continue,
113 Stop,
115}
116
117pub trait TrainingCallback: Send + Sync {
135 fn on_epoch_end(&mut self, metrics: &EpochMetrics) -> CallbackAction;
139
140 fn on_training_end(&mut self) {}
144}
145
146pub(crate) fn compute_grad_norm(grads: &[(Vec<f64>, Vec<f64>)]) -> f64 {
150 let mut sum_sq = 0.0;
151 for (dw, db) in grads {
152 for &g in dw {
153 sum_sq += g * g;
154 }
155 for &g in db {
156 sum_sq += g * g;
157 }
158 }
159 sum_sq.sqrt()
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn history_accumulates() {
168 let mut h = TrainingHistory::new();
169 assert!(h.is_empty());
170
171 h.push(EpochMetrics {
172 epoch: 0,
173 train_loss: 1.5,
174 val_loss: Some(1.8),
175 train_metric: Some(0.6),
176 val_metric: Some(0.55),
177 learning_rate: 0.001,
178 grad_norm: 2.3,
179 elapsed_ms: 42,
180 });
181 h.push(EpochMetrics {
182 epoch: 1,
183 train_loss: 1.2,
184 val_loss: Some(1.4),
185 train_metric: Some(0.7),
186 val_metric: Some(0.65),
187 learning_rate: 0.001,
188 grad_norm: 1.8,
189 elapsed_ms: 38,
190 });
191
192 assert_eq!(h.len(), 2);
193 assert_eq!(h.train_losses(), vec![1.5, 1.2]);
194 assert_eq!(h.val_losses(), vec![1.8, 1.4]);
195 assert_eq!(h.train_metrics(), vec![0.6, 0.7]);
196 assert_eq!(h.grad_norms(), vec![2.3, 1.8]);
197 }
198
199 #[test]
200 fn history_without_validation() {
201 let mut h = TrainingHistory::new();
202 h.push(EpochMetrics {
203 epoch: 0,
204 train_loss: 1.0,
205 val_loss: None,
206 train_metric: None,
207 val_metric: None,
208 learning_rate: 0.01,
209 grad_norm: 5.0,
210 elapsed_ms: 10,
211 });
212
213 assert!(h.val_losses().is_empty());
214 assert!(h.val_metrics().is_empty());
215 assert_eq!(h.train_losses(), vec![1.0]);
216 }
217
218 #[test]
219 fn grad_norm_basic() {
220 let grads = vec![
221 (vec![3.0, 4.0], vec![0.0]), ];
223 let norm = compute_grad_norm(&grads);
224 assert!((norm - 5.0).abs() < 1e-10);
225 }
226
227 #[test]
228 fn grad_norm_multi_layer() {
229 let grads = vec![(vec![1.0, 0.0], vec![0.0]), (vec![0.0, 0.0], vec![2.0])];
230 let norm = compute_grad_norm(&grads);
232 assert!((norm - 5.0_f64.sqrt()).abs() < 1e-10);
233 }
234
235 #[test]
236 fn callback_action() {
237 struct StopAt3;
238 impl TrainingCallback for StopAt3 {
239 fn on_epoch_end(&mut self, m: &EpochMetrics) -> CallbackAction {
240 if m.epoch >= 3 {
241 CallbackAction::Stop
242 } else {
243 CallbackAction::Continue
244 }
245 }
246 }
247
248 let mut cb = StopAt3;
249 let m = EpochMetrics {
250 epoch: 2,
251 train_loss: 0.0,
252 val_loss: None,
253 train_metric: None,
254 val_metric: None,
255 learning_rate: 0.0,
256 grad_norm: 0.0,
257 elapsed_ms: 0,
258 };
259 assert_eq!(cb.on_epoch_end(&m), CallbackAction::Continue);
260
261 let m = EpochMetrics { epoch: 3, ..m };
262 assert_eq!(cb.on_epoch_end(&m), CallbackAction::Stop);
263 }
264}