Skip to main content

yscv_model/
callbacks.rs

1use std::collections::HashMap;
2use std::io::Write;
3use std::path::{Path, PathBuf};
4
5/// Mode for metric monitoring.
6#[derive(Debug, Clone, Copy, PartialEq)]
7pub enum MonitorMode {
8    /// Lower is better (e.g., loss).
9    Min,
10    /// Higher is better (e.g., accuracy).
11    Max,
12}
13
14/// Early stopping to halt training when a metric stops improving.
15#[derive(Debug, Clone)]
16pub struct EarlyStopping {
17    patience: usize,
18    min_delta: f32,
19    mode: MonitorMode,
20    best_value: f32,
21    counter: usize,
22    stopped: bool,
23    monitor: String,
24}
25
26impl EarlyStopping {
27    /// Create a new early stopping monitor.
28    ///
29    /// - `patience`: number of epochs without improvement before stopping.
30    /// - `min_delta`: minimum change to qualify as an improvement.
31    /// - `mode`: whether lower or higher metric values are better.
32    pub fn new(patience: usize, min_delta: f32, mode: MonitorMode) -> Self {
33        let best_value = match mode {
34            MonitorMode::Min => f32::INFINITY,
35            MonitorMode::Max => f32::NEG_INFINITY,
36        };
37        Self {
38            patience,
39            min_delta,
40            mode,
41            best_value,
42            counter: 0,
43            stopped: false,
44            monitor: "loss".to_string(),
45        }
46    }
47
48    /// Check if training should stop. Call once per epoch with the monitored metric.
49    ///
50    /// Returns `true` when the patience has been exhausted (training should stop).
51    pub fn check(&mut self, value: f32) -> bool {
52        let improved = match self.mode {
53            MonitorMode::Min => value < self.best_value - self.min_delta,
54            MonitorMode::Max => value > self.best_value + self.min_delta,
55        };
56
57        if improved {
58            self.best_value = value;
59            self.counter = 0;
60        } else {
61            self.counter += 1;
62        }
63
64        if self.counter >= self.patience {
65            self.stopped = true;
66        }
67
68        self.stopped
69    }
70
71    /// Whether stop was triggered.
72    pub fn stopped(&self) -> bool {
73        self.stopped
74    }
75
76    /// Best value seen so far.
77    pub fn best_value(&self) -> f32 {
78        self.best_value
79    }
80
81    /// Number of epochs without improvement.
82    pub fn counter(&self) -> usize {
83        self.counter
84    }
85
86    /// Reset state for a new training run.
87    pub fn reset(&mut self) {
88        self.best_value = match self.mode {
89            MonitorMode::Min => f32::INFINITY,
90            MonitorMode::Max => f32::NEG_INFINITY,
91        };
92        self.counter = 0;
93        self.stopped = false;
94    }
95}
96
97/// Saves model weights when monitored metric improves.
98///
99/// This struct tracks the best metric value and signals when a new best is found.
100/// The caller is responsible for performing the actual model serialisation.
101#[derive(Debug, Clone)]
102pub struct BestModelCheckpoint {
103    save_path: PathBuf,
104    mode: MonitorMode,
105    best_value: f32,
106    monitor: String,
107}
108
109impl BestModelCheckpoint {
110    /// Create a new checkpoint tracker.
111    pub fn new(save_path: PathBuf, mode: MonitorMode) -> Self {
112        let best_value = match mode {
113            MonitorMode::Min => f32::INFINITY,
114            MonitorMode::Max => f32::NEG_INFINITY,
115        };
116        Self {
117            save_path,
118            mode,
119            best_value,
120            monitor: "loss".to_string(),
121        }
122    }
123
124    /// Check if metric improved. Returns `true` if a new best was found.
125    pub fn check(&mut self, value: f32) -> bool {
126        let improved = match self.mode {
127            MonitorMode::Min => value < self.best_value,
128            MonitorMode::Max => value > self.best_value,
129        };
130
131        if improved {
132            self.best_value = value;
133        }
134
135        improved
136    }
137
138    /// Get the save path.
139    pub fn save_path(&self) -> &Path {
140        &self.save_path
141    }
142
143    /// Best value seen.
144    pub fn best_value(&self) -> f32 {
145        self.best_value
146    }
147}
148
149/// Trait for training callbacks invoked after each epoch.
150pub trait TrainingCallback {
151    /// Called after each epoch. Returns true if training should stop.
152    fn on_epoch_end(&mut self, epoch: usize, metrics: &HashMap<String, f32>) -> bool;
153
154    /// Called after each batch within an epoch. Default implementation does nothing.
155    fn on_batch_end(&mut self, _epoch: usize, _batch: usize, _loss: f32) {}
156}
157
158impl EarlyStopping {
159    /// Set the metric key to monitor (default: `"loss"`).
160    pub fn with_monitor(mut self, key: impl Into<String>) -> Self {
161        self.monitor = key.into();
162        self
163    }
164}
165
166impl TrainingCallback for EarlyStopping {
167    fn on_epoch_end(&mut self, _epoch: usize, metrics: &HashMap<String, f32>) -> bool {
168        if let Some(&value) = metrics.get(&self.monitor) {
169            self.check(value)
170        } else {
171            false
172        }
173    }
174}
175
176impl BestModelCheckpoint {
177    /// Set the metric key to monitor (default: `"loss"`).
178    pub fn with_monitor(mut self, key: impl Into<String>) -> Self {
179        self.monitor = key.into();
180        self
181    }
182}
183
184impl TrainingCallback for BestModelCheckpoint {
185    fn on_epoch_end(&mut self, _epoch: usize, metrics: &HashMap<String, f32>) -> bool {
186        if let Some(&value) = metrics.get(&self.monitor) {
187            self.check(value);
188        }
189        // BestModelCheckpoint never requests training to stop.
190        false
191    }
192}
193
194/// Logs training metrics to a CSV file and prints a summary line to stdout.
195///
196/// Creates a CSV file at the specified path with columns:
197/// `epoch,train_loss,val_loss,learning_rate,duration_ms`
198///
199/// On each `on_epoch_end`, appends a row with the current metrics and prints a
200/// formatted summary to stdout.
201pub struct MetricsLogger {
202    path: PathBuf,
203    file: Option<std::fs::File>,
204    start_time: std::time::Instant,
205}
206
207impl MetricsLogger {
208    /// Create a new metrics logger that writes CSV rows to `path`.
209    pub fn new(path: impl Into<PathBuf>) -> Self {
210        Self {
211            path: path.into(),
212            file: None,
213            start_time: std::time::Instant::now(),
214        }
215    }
216
217    /// Returns the path of the CSV file.
218    pub fn path(&self) -> &Path {
219        &self.path
220    }
221
222    /// Ensure the CSV file is open and the header has been written.
223    fn ensure_file(&mut self) -> Option<&mut std::fs::File> {
224        if self.file.is_none()
225            && let Ok(mut f) = std::fs::File::create(&self.path)
226        {
227            let _ = writeln!(f, "epoch,train_loss,val_loss,learning_rate,duration_ms");
228            self.file = Some(f);
229            self.start_time = std::time::Instant::now();
230        }
231        self.file.as_mut()
232    }
233}
234
235impl TrainingCallback for MetricsLogger {
236    fn on_epoch_end(&mut self, epoch: usize, metrics: &HashMap<String, f32>) -> bool {
237        let train_loss = metrics
238            .get("train_loss")
239            .or_else(|| metrics.get("loss"))
240            .copied()
241            .unwrap_or(f32::NAN);
242        let val_loss = metrics.get("val_loss").copied().unwrap_or(f32::NAN);
243        let lr = metrics
244            .get("learning_rate")
245            .or_else(|| metrics.get("lr"))
246            .copied()
247            .unwrap_or(f32::NAN);
248        let duration_ms = self.start_time.elapsed().as_millis() as u64;
249
250        if let Some(f) = self.ensure_file() {
251            let _ = writeln!(f, "{epoch},{train_loss},{val_loss},{lr},{duration_ms}");
252            let _ = f.flush();
253        }
254
255        println!(
256            "[Epoch {epoch}] train_loss={train_loss:.4} val_loss={val_loss:.4} lr={lr:.6} elapsed={duration_ms}ms"
257        );
258
259        // Reset timer for next epoch measurement.
260        self.start_time = std::time::Instant::now();
261
262        // MetricsLogger never requests training to stop.
263        false
264    }
265}
266
267/// Train for multiple epochs with callbacks.
268///
269/// Training stops early if any callback returns `true` from `on_epoch_end`.
270/// Returns the number of epochs actually trained.
271pub fn train_epochs_with_callbacks<F>(
272    mut train_fn: F,
273    epochs: usize,
274    callbacks: &mut [&mut dyn TrainingCallback],
275) -> usize
276where
277    F: FnMut(usize) -> HashMap<String, f32>,
278{
279    for epoch in 0..epochs {
280        let metrics = train_fn(epoch);
281        let should_stop = callbacks
282            .iter_mut()
283            .fold(false, |stop, cb| cb.on_epoch_end(epoch, &metrics) || stop);
284        if should_stop {
285            return epoch + 1;
286        }
287    }
288    epochs
289}