1use std::path::PathBuf;
4
5#[derive(Clone, Debug)]
7pub struct TrainConfig {
8 pub max_grad_norm: Option<f32>,
10
11 pub log_interval: usize,
13
14 pub save_interval: Option<usize>,
16
17 pub checkpoint_dir: Option<PathBuf>,
19
20 pub mixed_precision: bool,
22
23 pub gradient_accumulation_steps: usize,
29}
30
31impl Default for TrainConfig {
32 fn default() -> Self {
33 Self {
34 max_grad_norm: Some(1.0),
35 log_interval: 10,
36 save_interval: None,
37 checkpoint_dir: None,
38 mixed_precision: false,
39 gradient_accumulation_steps: 1,
40 }
41 }
42}
43
44impl TrainConfig {
45 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub fn with_grad_clip(mut self, max_norm: f32) -> Self {
52 self.max_grad_norm = Some(max_norm);
53 self
54 }
55
56 pub fn without_grad_clip(mut self) -> Self {
58 self.max_grad_norm = None;
59 self
60 }
61
62 pub fn with_log_interval(mut self, interval: usize) -> Self {
64 self.log_interval = interval;
65 self
66 }
67
68 pub fn with_checkpoints(mut self, interval: usize, dir: PathBuf) -> Self {
70 self.save_interval = Some(interval);
71 self.checkpoint_dir = Some(dir);
72 self
73 }
74
75 pub fn with_gradient_accumulation(mut self, steps: usize) -> Self {
81 self.gradient_accumulation_steps = steps.max(1);
82 self
83 }
84}
85
86#[derive(Clone, Debug)]
88pub struct MetricsTracker {
89 pub losses: Vec<f32>,
91
92 pub val_losses: Vec<f32>,
94
95 pub learning_rates: Vec<f32>,
97
98 pub steps: usize,
100
101 pub epoch: usize,
103}
104
105impl MetricsTracker {
106 pub fn new() -> Self {
108 Self {
109 losses: Vec::new(),
110 val_losses: Vec::new(),
111 learning_rates: Vec::new(),
112 steps: 0,
113 epoch: 0,
114 }
115 }
116
117 pub fn record_epoch(&mut self, loss: f32, lr: f32) {
119 self.losses.push(loss);
120 self.learning_rates.push(lr);
121 self.epoch += 1;
122 }
123
124 pub fn record_val_loss(&mut self, val_loss: f32) {
126 self.val_losses.push(val_loss);
127 }
128
129 pub fn best_val_loss(&self) -> Option<f32> {
131 self.val_losses.iter().copied().min_by(f32::total_cmp)
132 }
133
134 pub fn is_val_improving(&self, patience: usize) -> bool {
136 if self.val_losses.len() < patience {
137 return true;
138 }
139
140 let recent = self.val_losses[self.val_losses.len() - patience..].to_vec();
141 let mut sorted = recent.clone();
142 sorted.sort_by(f32::total_cmp);
143
144 recent != sorted
146 }
147
148 pub fn increment_step(&mut self) {
150 self.steps += 1;
151 }
152
153 pub fn avg_loss(&self, n: usize) -> f32 {
155 if self.losses.is_empty() {
156 return 0.0;
157 }
158
159 let start = self.losses.len().saturating_sub(n);
160 let window = &self.losses[start..];
161 window.iter().sum::<f32>() / window.len() as f32
162 }
163
164 pub fn best_loss(&self) -> Option<f32> {
166 self.losses.iter().copied().min_by(f32::total_cmp)
167 }
168
169 pub fn is_improving(&self, patience: usize) -> bool {
171 if self.losses.len() < patience {
172 return true;
173 }
174
175 let recent = self.losses[self.losses.len() - patience..].to_vec();
176 let mut sorted = recent.clone();
177 sorted.sort_by(f32::total_cmp);
178
179 recent != sorted
181 }
182}
183
184impl Default for MetricsTracker {
185 fn default() -> Self {
186 Self::new()
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_train_config_default() {
196 let config = TrainConfig::default();
197 assert_eq!(config.max_grad_norm, Some(1.0));
198 assert_eq!(config.log_interval, 10);
199 assert!(config.save_interval.is_none());
200 assert_eq!(config.gradient_accumulation_steps, 1);
201 }
202
203 #[test]
204 fn test_train_config_builder() {
205 let config =
206 TrainConfig::new().with_grad_clip(0.5).with_log_interval(20).without_grad_clip();
207
208 assert_eq!(config.max_grad_norm, None);
209 assert_eq!(config.log_interval, 20);
210 }
211
212 #[test]
213 fn test_metrics_tracker() {
214 let mut tracker = MetricsTracker::new();
215
216 tracker.record_epoch(1.0, 0.001);
217 tracker.record_epoch(0.8, 0.001);
218 tracker.record_epoch(0.6, 0.001);
219
220 assert_eq!(tracker.epoch, 3);
221 assert_eq!(tracker.losses.len(), 3);
222 assert_eq!(tracker.best_loss(), Some(0.6));
223 }
224
225 #[test]
226 fn test_metrics_avg_loss() {
227 let mut tracker = MetricsTracker::new();
228
229 tracker.record_epoch(1.0, 0.001);
230 tracker.record_epoch(0.8, 0.001);
231 tracker.record_epoch(0.6, 0.001);
232
233 let avg = tracker.avg_loss(2);
234 assert!((avg - 0.7).abs() < 1e-5);
235 }
236
237 #[test]
238 fn test_metrics_is_improving() {
239 let mut tracker = MetricsTracker::new();
240
241 tracker.record_epoch(1.0, 0.001);
243 tracker.record_epoch(0.8, 0.001);
244 tracker.record_epoch(0.6, 0.001);
245
246 assert!(tracker.is_improving(2));
247 }
248
249 #[test]
250 fn test_gradient_accumulation_builder() {
251 let config = TrainConfig::new().with_gradient_accumulation(4);
252 assert_eq!(config.gradient_accumulation_steps, 4);
253 }
254
255 #[test]
256 fn test_gradient_accumulation_min_value() {
257 let config = TrainConfig::new().with_gradient_accumulation(0);
259 assert_eq!(config.gradient_accumulation_steps, 1);
260 }
261
262 #[test]
263 fn test_validation_loss_tracking() {
264 let mut tracker = MetricsTracker::new();
265
266 tracker.record_epoch(1.0, 0.001);
267 tracker.record_val_loss(0.9);
268 tracker.record_epoch(0.8, 0.001);
269 tracker.record_val_loss(0.7);
270 tracker.record_epoch(0.6, 0.001);
271 tracker.record_val_loss(0.5);
272
273 assert_eq!(tracker.val_losses.len(), 3);
274 assert_eq!(tracker.best_val_loss(), Some(0.5));
275 }
276
277 #[test]
278 fn test_validation_is_improving() {
279 let mut tracker = MetricsTracker::new();
280
281 tracker.record_val_loss(0.9);
283 tracker.record_val_loss(0.7);
284 tracker.record_val_loss(0.5);
285
286 assert!(tracker.is_val_improving(2));
287 }
288
289 #[test]
290 fn test_validation_not_improving() {
291 let mut tracker = MetricsTracker::new();
292
293 tracker.record_val_loss(0.5);
295 tracker.record_val_loss(0.6);
296 tracker.record_val_loss(0.7);
297
298 assert!(!tracker.is_val_improving(2));
299 }
300
301 #[test]
302 fn test_with_checkpoints() {
303 let config = TrainConfig::new().with_checkpoints(5, PathBuf::from("/tmp/checkpoints"));
304 assert_eq!(config.save_interval, Some(5));
305 assert_eq!(config.checkpoint_dir, Some(PathBuf::from("/tmp/checkpoints")));
306 }
307
308 #[test]
309 fn test_increment_step() {
310 let mut tracker = MetricsTracker::new();
311 assert_eq!(tracker.steps, 0);
312 tracker.increment_step();
313 assert_eq!(tracker.steps, 1);
314 tracker.increment_step();
315 assert_eq!(tracker.steps, 2);
316 }
317
318 #[test]
319 fn test_metrics_tracker_default() {
320 let tracker = MetricsTracker::default();
321 assert!(tracker.losses.is_empty());
322 assert!(tracker.val_losses.is_empty());
323 assert_eq!(tracker.steps, 0);
324 assert_eq!(tracker.epoch, 0);
325 }
326
327 #[test]
328 fn test_avg_loss_empty() {
329 let tracker = MetricsTracker::new();
330 assert_eq!(tracker.avg_loss(5), 0.0);
331 }
332
333 #[test]
334 fn test_best_loss_empty() {
335 let tracker = MetricsTracker::new();
336 assert!(tracker.best_loss().is_none());
337 }
338
339 #[test]
340 fn test_best_val_loss_empty() {
341 let tracker = MetricsTracker::new();
342 assert!(tracker.best_val_loss().is_none());
343 }
344
345 #[test]
346 fn test_is_improving_insufficient_data() {
347 let mut tracker = MetricsTracker::new();
348 tracker.record_epoch(1.0, 0.001);
349 assert!(tracker.is_improving(3));
351 }
352
353 #[test]
354 fn test_is_val_improving_insufficient_data() {
355 let mut tracker = MetricsTracker::new();
356 tracker.record_val_loss(0.5);
357 assert!(tracker.is_val_improving(3));
359 }
360
361 #[test]
362 fn test_train_config_clone() {
363 let config = TrainConfig::new().with_grad_clip(0.5);
364 let cloned = config.clone();
365 assert_eq!(config.max_grad_norm, cloned.max_grad_norm);
366 }
367
368 #[test]
369 fn test_metrics_tracker_clone() {
370 let mut tracker = MetricsTracker::new();
371 tracker.record_epoch(1.0, 0.001);
372 let cloned = tracker.clone();
373 assert_eq!(tracker.losses, cloned.losses);
374 }
375}