Skip to main content

entrenar/train/
config.rs

1//! Training configuration and metrics
2
3use std::path::PathBuf;
4
5/// Training configuration
6#[derive(Clone, Debug)]
7pub struct TrainConfig {
8    /// Maximum gradient norm for clipping (None = no clipping)
9    pub max_grad_norm: Option<f32>,
10
11    /// Print training progress every N steps
12    pub log_interval: usize,
13
14    /// Save checkpoint every N epochs
15    pub save_interval: Option<usize>,
16
17    /// Directory to save checkpoints
18    pub checkpoint_dir: Option<PathBuf>,
19
20    /// Use mixed precision training
21    pub mixed_precision: bool,
22
23    /// Gradient accumulation steps (1 = no accumulation)
24    ///
25    /// Simulates larger batch sizes by accumulating gradients over
26    /// multiple mini-batches before performing an optimizer step.
27    /// Effective batch size = batch_size * gradient_accumulation_steps
28    pub gradient_accumulation_steps: usize,
29}
30
31impl Default for TrainConfig {
32    fn default() -> Self {
33        Self {
34            max_grad_norm: Some(1.0),
35            log_interval: 10,
36            save_interval: None,
37            checkpoint_dir: None,
38            mixed_precision: false,
39            gradient_accumulation_steps: 1,
40        }
41    }
42}
43
44impl TrainConfig {
45    /// Create a new training configuration
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    /// Set gradient clipping norm
51    pub fn with_grad_clip(mut self, max_norm: f32) -> Self {
52        self.max_grad_norm = Some(max_norm);
53        self
54    }
55
56    /// Disable gradient clipping
57    pub fn without_grad_clip(mut self) -> Self {
58        self.max_grad_norm = None;
59        self
60    }
61
62    /// Set logging interval
63    pub fn with_log_interval(mut self, interval: usize) -> Self {
64        self.log_interval = interval;
65        self
66    }
67
68    /// Set checkpoint saving
69    pub fn with_checkpoints(mut self, interval: usize, dir: PathBuf) -> Self {
70        self.save_interval = Some(interval);
71        self.checkpoint_dir = Some(dir);
72        self
73    }
74
75    /// Set gradient accumulation steps
76    ///
77    /// Simulates larger batch sizes by accumulating gradients over
78    /// multiple mini-batches before performing an optimizer step.
79    /// Effective batch size = batch_size * gradient_accumulation_steps
80    pub fn with_gradient_accumulation(mut self, steps: usize) -> Self {
81        self.gradient_accumulation_steps = steps.max(1);
82        self
83    }
84}
85
86/// Tracks training metrics across epochs
87#[derive(Clone, Debug)]
88pub struct MetricsTracker {
89    /// Training loss history (one per epoch)
90    pub losses: Vec<f32>,
91
92    /// Validation loss history (one per epoch, if validation is used)
93    pub val_losses: Vec<f32>,
94
95    /// Learning rates (one per epoch)
96    pub learning_rates: Vec<f32>,
97
98    /// Training step count
99    pub steps: usize,
100
101    /// Current epoch
102    pub epoch: usize,
103}
104
105impl MetricsTracker {
106    /// Create a new metrics tracker
107    pub fn new() -> Self {
108        Self {
109            losses: Vec::new(),
110            val_losses: Vec::new(),
111            learning_rates: Vec::new(),
112            steps: 0,
113            epoch: 0,
114        }
115    }
116
117    /// Record an epoch's training metrics
118    pub fn record_epoch(&mut self, loss: f32, lr: f32) {
119        self.losses.push(loss);
120        self.learning_rates.push(lr);
121        self.epoch += 1;
122    }
123
124    /// Record validation loss for the current epoch
125    pub fn record_val_loss(&mut self, val_loss: f32) {
126        self.val_losses.push(val_loss);
127    }
128
129    /// Get best (minimum) validation loss
130    pub fn best_val_loss(&self) -> Option<f32> {
131        self.val_losses.iter().copied().min_by(f32::total_cmp)
132    }
133
134    /// Check if validation loss is improving
135    pub fn is_val_improving(&self, patience: usize) -> bool {
136        if self.val_losses.len() < patience {
137            return true;
138        }
139
140        let recent = self.val_losses[self.val_losses.len() - patience..].to_vec();
141        let mut sorted = recent.clone();
142        sorted.sort_by(f32::total_cmp);
143
144        // Check if val losses are generally decreasing
145        recent != sorted
146    }
147
148    /// Increment step counter
149    pub fn increment_step(&mut self) {
150        self.steps += 1;
151    }
152
153    /// Get average loss over last N epochs
154    pub fn avg_loss(&self, n: usize) -> f32 {
155        if self.losses.is_empty() {
156            return 0.0;
157        }
158
159        let start = self.losses.len().saturating_sub(n);
160        let window = &self.losses[start..];
161        window.iter().sum::<f32>() / window.len() as f32
162    }
163
164    /// Get best (minimum) loss
165    pub fn best_loss(&self) -> Option<f32> {
166        self.losses.iter().copied().min_by(f32::total_cmp)
167    }
168
169    /// Check if training is improving (loss decreasing)
170    pub fn is_improving(&self, patience: usize) -> bool {
171        if self.losses.len() < patience {
172            return true;
173        }
174
175        let recent = self.losses[self.losses.len() - patience..].to_vec();
176        let mut sorted = recent.clone();
177        sorted.sort_by(f32::total_cmp);
178
179        // Check if losses are generally decreasing
180        recent != sorted
181    }
182}
183
184impl Default for MetricsTracker {
185    fn default() -> Self {
186        Self::new()
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[test]
195    fn test_train_config_default() {
196        let config = TrainConfig::default();
197        assert_eq!(config.max_grad_norm, Some(1.0));
198        assert_eq!(config.log_interval, 10);
199        assert!(config.save_interval.is_none());
200        assert_eq!(config.gradient_accumulation_steps, 1);
201    }
202
203    #[test]
204    fn test_train_config_builder() {
205        let config =
206            TrainConfig::new().with_grad_clip(0.5).with_log_interval(20).without_grad_clip();
207
208        assert_eq!(config.max_grad_norm, None);
209        assert_eq!(config.log_interval, 20);
210    }
211
212    #[test]
213    fn test_metrics_tracker() {
214        let mut tracker = MetricsTracker::new();
215
216        tracker.record_epoch(1.0, 0.001);
217        tracker.record_epoch(0.8, 0.001);
218        tracker.record_epoch(0.6, 0.001);
219
220        assert_eq!(tracker.epoch, 3);
221        assert_eq!(tracker.losses.len(), 3);
222        assert_eq!(tracker.best_loss(), Some(0.6));
223    }
224
225    #[test]
226    fn test_metrics_avg_loss() {
227        let mut tracker = MetricsTracker::new();
228
229        tracker.record_epoch(1.0, 0.001);
230        tracker.record_epoch(0.8, 0.001);
231        tracker.record_epoch(0.6, 0.001);
232
233        let avg = tracker.avg_loss(2);
234        assert!((avg - 0.7).abs() < 1e-5);
235    }
236
237    #[test]
238    fn test_metrics_is_improving() {
239        let mut tracker = MetricsTracker::new();
240
241        // Decreasing losses = improving
242        tracker.record_epoch(1.0, 0.001);
243        tracker.record_epoch(0.8, 0.001);
244        tracker.record_epoch(0.6, 0.001);
245
246        assert!(tracker.is_improving(2));
247    }
248
249    #[test]
250    fn test_gradient_accumulation_builder() {
251        let config = TrainConfig::new().with_gradient_accumulation(4);
252        assert_eq!(config.gradient_accumulation_steps, 4);
253    }
254
255    #[test]
256    fn test_gradient_accumulation_min_value() {
257        // Should clamp to minimum of 1
258        let config = TrainConfig::new().with_gradient_accumulation(0);
259        assert_eq!(config.gradient_accumulation_steps, 1);
260    }
261
262    #[test]
263    fn test_validation_loss_tracking() {
264        let mut tracker = MetricsTracker::new();
265
266        tracker.record_epoch(1.0, 0.001);
267        tracker.record_val_loss(0.9);
268        tracker.record_epoch(0.8, 0.001);
269        tracker.record_val_loss(0.7);
270        tracker.record_epoch(0.6, 0.001);
271        tracker.record_val_loss(0.5);
272
273        assert_eq!(tracker.val_losses.len(), 3);
274        assert_eq!(tracker.best_val_loss(), Some(0.5));
275    }
276
277    #[test]
278    fn test_validation_is_improving() {
279        let mut tracker = MetricsTracker::new();
280
281        // Decreasing val losses = improving
282        tracker.record_val_loss(0.9);
283        tracker.record_val_loss(0.7);
284        tracker.record_val_loss(0.5);
285
286        assert!(tracker.is_val_improving(2));
287    }
288
289    #[test]
290    fn test_validation_not_improving() {
291        let mut tracker = MetricsTracker::new();
292
293        // Increasing val losses = not improving
294        tracker.record_val_loss(0.5);
295        tracker.record_val_loss(0.6);
296        tracker.record_val_loss(0.7);
297
298        assert!(!tracker.is_val_improving(2));
299    }
300
301    #[test]
302    fn test_with_checkpoints() {
303        let config = TrainConfig::new().with_checkpoints(5, PathBuf::from("/tmp/checkpoints"));
304        assert_eq!(config.save_interval, Some(5));
305        assert_eq!(config.checkpoint_dir, Some(PathBuf::from("/tmp/checkpoints")));
306    }
307
308    #[test]
309    fn test_increment_step() {
310        let mut tracker = MetricsTracker::new();
311        assert_eq!(tracker.steps, 0);
312        tracker.increment_step();
313        assert_eq!(tracker.steps, 1);
314        tracker.increment_step();
315        assert_eq!(tracker.steps, 2);
316    }
317
318    #[test]
319    fn test_metrics_tracker_default() {
320        let tracker = MetricsTracker::default();
321        assert!(tracker.losses.is_empty());
322        assert!(tracker.val_losses.is_empty());
323        assert_eq!(tracker.steps, 0);
324        assert_eq!(tracker.epoch, 0);
325    }
326
327    #[test]
328    fn test_avg_loss_empty() {
329        let tracker = MetricsTracker::new();
330        assert_eq!(tracker.avg_loss(5), 0.0);
331    }
332
333    #[test]
334    fn test_best_loss_empty() {
335        let tracker = MetricsTracker::new();
336        assert!(tracker.best_loss().is_none());
337    }
338
339    #[test]
340    fn test_best_val_loss_empty() {
341        let tracker = MetricsTracker::new();
342        assert!(tracker.best_val_loss().is_none());
343    }
344
345    #[test]
346    fn test_is_improving_insufficient_data() {
347        let mut tracker = MetricsTracker::new();
348        tracker.record_epoch(1.0, 0.001);
349        // With patience=3 and only 1 data point, should return true
350        assert!(tracker.is_improving(3));
351    }
352
353    #[test]
354    fn test_is_val_improving_insufficient_data() {
355        let mut tracker = MetricsTracker::new();
356        tracker.record_val_loss(0.5);
357        // With patience=3 and only 1 data point, should return true
358        assert!(tracker.is_val_improving(3));
359    }
360
361    #[test]
362    fn test_train_config_clone() {
363        let config = TrainConfig::new().with_grad_clip(0.5);
364        let cloned = config.clone();
365        assert_eq!(config.max_grad_norm, cloned.max_grad_norm);
366    }
367
368    #[test]
369    fn test_metrics_tracker_clone() {
370        let mut tracker = MetricsTracker::new();
371        tracker.record_epoch(1.0, 0.001);
372        let cloned = tracker.clone();
373        assert_eq!(tracker.losses, cloned.losses);
374    }
375}