1use std::collections::HashMap;
2use std::io::Write;
3use std::path::{Path, PathBuf};
4
5#[derive(Debug, Clone, Copy, PartialEq)]
7pub enum MonitorMode {
8 Min,
10 Max,
12}
13
14#[derive(Debug, Clone)]
16pub struct EarlyStopping {
17 patience: usize,
18 min_delta: f32,
19 mode: MonitorMode,
20 best_value: f32,
21 counter: usize,
22 stopped: bool,
23 monitor: String,
24}
25
26impl EarlyStopping {
27 pub fn new(patience: usize, min_delta: f32, mode: MonitorMode) -> Self {
33 let best_value = match mode {
34 MonitorMode::Min => f32::INFINITY,
35 MonitorMode::Max => f32::NEG_INFINITY,
36 };
37 Self {
38 patience,
39 min_delta,
40 mode,
41 best_value,
42 counter: 0,
43 stopped: false,
44 monitor: "loss".to_string(),
45 }
46 }
47
48 pub fn check(&mut self, value: f32) -> bool {
52 let improved = match self.mode {
53 MonitorMode::Min => value < self.best_value - self.min_delta,
54 MonitorMode::Max => value > self.best_value + self.min_delta,
55 };
56
57 if improved {
58 self.best_value = value;
59 self.counter = 0;
60 } else {
61 self.counter += 1;
62 }
63
64 if self.counter >= self.patience {
65 self.stopped = true;
66 }
67
68 self.stopped
69 }
70
71 pub fn stopped(&self) -> bool {
73 self.stopped
74 }
75
76 pub fn best_value(&self) -> f32 {
78 self.best_value
79 }
80
81 pub fn counter(&self) -> usize {
83 self.counter
84 }
85
86 pub fn reset(&mut self) {
88 self.best_value = match self.mode {
89 MonitorMode::Min => f32::INFINITY,
90 MonitorMode::Max => f32::NEG_INFINITY,
91 };
92 self.counter = 0;
93 self.stopped = false;
94 }
95}
96
97#[derive(Debug, Clone)]
102pub struct BestModelCheckpoint {
103 save_path: PathBuf,
104 mode: MonitorMode,
105 best_value: f32,
106 monitor: String,
107}
108
109impl BestModelCheckpoint {
110 pub fn new(save_path: PathBuf, mode: MonitorMode) -> Self {
112 let best_value = match mode {
113 MonitorMode::Min => f32::INFINITY,
114 MonitorMode::Max => f32::NEG_INFINITY,
115 };
116 Self {
117 save_path,
118 mode,
119 best_value,
120 monitor: "loss".to_string(),
121 }
122 }
123
124 pub fn check(&mut self, value: f32) -> bool {
126 let improved = match self.mode {
127 MonitorMode::Min => value < self.best_value,
128 MonitorMode::Max => value > self.best_value,
129 };
130
131 if improved {
132 self.best_value = value;
133 }
134
135 improved
136 }
137
138 pub fn save_path(&self) -> &Path {
140 &self.save_path
141 }
142
143 pub fn best_value(&self) -> f32 {
145 self.best_value
146 }
147}
148
149pub trait TrainingCallback {
151 fn on_epoch_end(&mut self, epoch: usize, metrics: &HashMap<String, f32>) -> bool;
153
154 fn on_batch_end(&mut self, _epoch: usize, _batch: usize, _loss: f32) {}
156}
157
158impl EarlyStopping {
159 pub fn with_monitor(mut self, key: impl Into<String>) -> Self {
161 self.monitor = key.into();
162 self
163 }
164}
165
166impl TrainingCallback for EarlyStopping {
167 fn on_epoch_end(&mut self, _epoch: usize, metrics: &HashMap<String, f32>) -> bool {
168 if let Some(&value) = metrics.get(&self.monitor) {
169 self.check(value)
170 } else {
171 false
172 }
173 }
174}
175
176impl BestModelCheckpoint {
177 pub fn with_monitor(mut self, key: impl Into<String>) -> Self {
179 self.monitor = key.into();
180 self
181 }
182}
183
184impl TrainingCallback for BestModelCheckpoint {
185 fn on_epoch_end(&mut self, _epoch: usize, metrics: &HashMap<String, f32>) -> bool {
186 if let Some(&value) = metrics.get(&self.monitor) {
187 self.check(value);
188 }
189 false
191 }
192}
193
194pub struct MetricsLogger {
202 path: PathBuf,
203 file: Option<std::fs::File>,
204 start_time: std::time::Instant,
205}
206
207impl MetricsLogger {
208 pub fn new(path: impl Into<PathBuf>) -> Self {
210 Self {
211 path: path.into(),
212 file: None,
213 start_time: std::time::Instant::now(),
214 }
215 }
216
217 pub fn path(&self) -> &Path {
219 &self.path
220 }
221
222 fn ensure_file(&mut self) -> Option<&mut std::fs::File> {
224 if self.file.is_none()
225 && let Ok(mut f) = std::fs::File::create(&self.path)
226 {
227 let _ = writeln!(f, "epoch,train_loss,val_loss,learning_rate,duration_ms");
228 self.file = Some(f);
229 self.start_time = std::time::Instant::now();
230 }
231 self.file.as_mut()
232 }
233}
234
235impl TrainingCallback for MetricsLogger {
236 fn on_epoch_end(&mut self, epoch: usize, metrics: &HashMap<String, f32>) -> bool {
237 let train_loss = metrics
238 .get("train_loss")
239 .or_else(|| metrics.get("loss"))
240 .copied()
241 .unwrap_or(f32::NAN);
242 let val_loss = metrics.get("val_loss").copied().unwrap_or(f32::NAN);
243 let lr = metrics
244 .get("learning_rate")
245 .or_else(|| metrics.get("lr"))
246 .copied()
247 .unwrap_or(f32::NAN);
248 let duration_ms = self.start_time.elapsed().as_millis() as u64;
249
250 if let Some(f) = self.ensure_file() {
251 let _ = writeln!(f, "{epoch},{train_loss},{val_loss},{lr},{duration_ms}");
252 let _ = f.flush();
253 }
254
255 println!(
256 "[Epoch {epoch}] train_loss={train_loss:.4} val_loss={val_loss:.4} lr={lr:.6} elapsed={duration_ms}ms"
257 );
258
259 self.start_time = std::time::Instant::now();
261
262 false
264 }
265}
266
267pub fn train_epochs_with_callbacks<F>(
272 mut train_fn: F,
273 epochs: usize,
274 callbacks: &mut [&mut dyn TrainingCallback],
275) -> usize
276where
277 F: FnMut(usize) -> HashMap<String, f32>,
278{
279 for epoch in 0..epochs {
280 let metrics = train_fn(epoch);
281 let should_stop = callbacks
282 .iter_mut()
283 .fold(false, |stop, cb| cb.on_epoch_end(epoch, &metrics) || stop);
284 if should_stop {
285 return epoch + 1;
286 }
287 }
288 epochs
289}