use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct TrainingState {
pub epoch: usize,
pub global_step: usize,
pub epoch_step: usize,
pub best_val_loss: f32,
pub start_time: Instant,
pub loss_history: Vec<(usize, f32)>,
pub val_loss_history: Vec<(usize, f32)>,
}
impl Default for TrainingState {
fn default() -> Self {
Self::new()
}
}
impl TrainingState {
#[must_use]
pub fn new() -> Self {
Self {
epoch: 0,
global_step: 0,
epoch_step: 0,
best_val_loss: f32::INFINITY,
start_time: Instant::now(),
loss_history: Vec::new(),
val_loss_history: Vec::new(),
}
}
pub fn record_loss(&mut self, loss: f32) {
self.loss_history.push((self.global_step, loss));
}
pub fn record_val_loss(&mut self, loss: f32) -> bool {
self.val_loss_history.push((self.global_step, loss));
if loss < self.best_val_loss {
self.best_val_loss = loss;
true } else {
false
}
}
pub fn step(&mut self) {
self.global_step += 1;
self.epoch_step += 1;
}
pub fn new_epoch(&mut self) {
self.epoch += 1;
self.epoch_step = 0;
}
#[must_use]
pub fn elapsed(&self) -> Duration {
self.start_time.elapsed()
}
#[must_use]
pub fn avg_loss(&self, n: usize) -> Option<f32> {
if self.loss_history.is_empty() {
return None;
}
let start = self.loss_history.len().saturating_sub(n);
let sum: f32 = self.loss_history[start..].iter().map(|(_, l)| l).sum();
Some(sum / (self.loss_history.len() - start) as f32)
}
#[must_use]
pub fn steps_per_second(&self) -> f32 {
let elapsed = self.elapsed().as_secs_f32();
if elapsed > 0.0 {
self.global_step as f32 / elapsed
} else {
0.0
}
}
#[must_use]
pub fn eta(&self, total_steps: usize) -> Duration {
let sps = self.steps_per_second();
if sps > 0.0 {
let remaining = total_steps.saturating_sub(self.global_step);
Duration::from_secs_f32(remaining as f32 / sps)
} else {
Duration::ZERO
}
}
}