entrenar/hf_pipeline/trainer/
state.rs1use std::time::{Duration, Instant};
4
5#[derive(Debug, Clone)]
7pub struct TrainingState {
8 pub epoch: usize,
10 pub global_step: usize,
12 pub epoch_step: usize,
14 pub best_val_loss: f32,
16 pub start_time: Instant,
18 pub loss_history: Vec<(usize, f32)>,
20 pub val_loss_history: Vec<(usize, f32)>,
22}
23
24impl Default for TrainingState {
25 fn default() -> Self {
26 Self::new()
27 }
28}
29
30impl TrainingState {
31 #[must_use]
33 pub fn new() -> Self {
34 Self {
35 epoch: 0,
36 global_step: 0,
37 epoch_step: 0,
38 best_val_loss: f32::INFINITY,
39 start_time: Instant::now(),
40 loss_history: Vec::new(),
41 val_loss_history: Vec::new(),
42 }
43 }
44
45 pub fn record_loss(&mut self, loss: f32) {
47 self.loss_history.push((self.global_step, loss));
48 }
49
50 pub fn record_val_loss(&mut self, loss: f32) -> bool {
52 self.val_loss_history.push((self.global_step, loss));
53 if loss < self.best_val_loss {
54 self.best_val_loss = loss;
55 true } else {
57 false
58 }
59 }
60
61 pub fn step(&mut self) {
63 self.global_step += 1;
64 self.epoch_step += 1;
65 }
66
67 pub fn new_epoch(&mut self) {
69 self.epoch += 1;
70 self.epoch_step = 0;
71 }
72
73 #[must_use]
75 pub fn elapsed(&self) -> Duration {
76 self.start_time.elapsed()
77 }
78
79 #[must_use]
81 pub fn avg_loss(&self, n: usize) -> Option<f32> {
82 if self.loss_history.is_empty() {
83 return None;
84 }
85 let start = self.loss_history.len().saturating_sub(n);
86 let sum: f32 = self.loss_history[start..].iter().map(|(_, l)| l).sum();
87 Some(sum / (self.loss_history.len() - start) as f32)
88 }
89
90 #[must_use]
92 pub fn steps_per_second(&self) -> f32 {
93 let elapsed = self.elapsed().as_secs_f32();
94 if elapsed > 0.0 {
95 self.global_step as f32 / elapsed
96 } else {
97 0.0
98 }
99 }
100
101 #[must_use]
103 pub fn eta(&self, total_steps: usize) -> Duration {
104 let sps = self.steps_per_second();
105 if sps > 0.0 {
106 let remaining = total_steps.saturating_sub(self.global_step);
107 Duration::from_secs_f32(remaining as f32 / sps)
108 } else {
109 Duration::ZERO
110 }
111 }
112}