Skip to main content

entrenar/hf_pipeline/trainer/
state.rs

1//! Training state tracking.
2
3use std::time::{Duration, Instant};
4
5/// Training state tracking
6#[derive(Debug, Clone)]
7pub struct TrainingState {
8    /// Current epoch (0-indexed)
9    pub epoch: usize,
10    /// Current global step
11    pub global_step: usize,
12    /// Steps completed in current epoch
13    pub epoch_step: usize,
14    /// Best validation loss seen
15    pub best_val_loss: f32,
16    /// Training start time
17    pub start_time: Instant,
18    /// Loss history (step, loss)
19    pub loss_history: Vec<(usize, f32)>,
20    /// Validation loss history (step, loss)
21    pub val_loss_history: Vec<(usize, f32)>,
22}
23
24impl Default for TrainingState {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30impl TrainingState {
31    /// Create new training state
32    #[must_use]
33    pub fn new() -> Self {
34        Self {
35            epoch: 0,
36            global_step: 0,
37            epoch_step: 0,
38            best_val_loss: f32::INFINITY,
39            start_time: Instant::now(),
40            loss_history: Vec::new(),
41            val_loss_history: Vec::new(),
42        }
43    }
44
45    /// Record training loss
46    pub fn record_loss(&mut self, loss: f32) {
47        self.loss_history.push((self.global_step, loss));
48    }
49
50    /// Record validation loss
51    pub fn record_val_loss(&mut self, loss: f32) -> bool {
52        self.val_loss_history.push((self.global_step, loss));
53        if loss < self.best_val_loss {
54            self.best_val_loss = loss;
55            true // New best
56        } else {
57            false
58        }
59    }
60
61    /// Advance one step
62    pub fn step(&mut self) {
63        self.global_step += 1;
64        self.epoch_step += 1;
65    }
66
67    /// Start new epoch
68    pub fn new_epoch(&mut self) {
69        self.epoch += 1;
70        self.epoch_step = 0;
71    }
72
73    /// Get elapsed time
74    #[must_use]
75    pub fn elapsed(&self) -> Duration {
76        self.start_time.elapsed()
77    }
78
79    /// Get average loss over last N steps
80    #[must_use]
81    pub fn avg_loss(&self, n: usize) -> Option<f32> {
82        if self.loss_history.is_empty() {
83            return None;
84        }
85        let start = self.loss_history.len().saturating_sub(n);
86        let sum: f32 = self.loss_history[start..].iter().map(|(_, l)| l).sum();
87        Some(sum / (self.loss_history.len() - start) as f32)
88    }
89
90    /// Get steps per second
91    #[must_use]
92    pub fn steps_per_second(&self) -> f32 {
93        let elapsed = self.elapsed().as_secs_f32();
94        if elapsed > 0.0 {
95            self.global_step as f32 / elapsed
96        } else {
97            0.0
98        }
99    }
100
101    /// Get estimated time remaining
102    #[must_use]
103    pub fn eta(&self, total_steps: usize) -> Duration {
104        let sps = self.steps_per_second();
105        if sps > 0.0 {
106            let remaining = total_steps.saturating_sub(self.global_step);
107            Duration::from_secs_f32(remaining as f32 / sps)
108        } else {
109            Duration::ZERO
110        }
111    }
112}