#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct TrainingState {
pub epoch: usize,
pub step: usize,
pub global_step: usize,
pub learning_rate: f32,
pub running_loss: f32,
pub running_correct: usize,
pub running_total: usize,
pub best_metric: f32,
pub patience_counter: usize,
}
impl TrainingState {
pub fn new(base_lr: f32) -> Self {
Self {
epoch: 0,
step: 0,
global_step: 0,
learning_rate: base_lr,
running_loss: 0.0,
running_correct: 0,
running_total: 0,
best_metric: f32::INFINITY,
patience_counter: 0,
}
}
pub fn reset_epoch(&mut self) {
self.step = 0;
self.running_loss = 0.0;
self.running_correct = 0;
self.running_total = 0;
}
pub fn epoch_accuracy(&self) -> f32 {
if self.running_total == 0 {
return 0.0;
}
self.running_correct as f32 / self.running_total as f32
}
pub fn epoch_loss(&self) -> f32 {
if self.step == 0 {
return 0.0;
}
self.running_loss / self.step as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_training_state() {
let mut state = TrainingState::new(0.001);
state.running_correct = 80;
state.running_total = 100;
state.running_loss = 0.5;
state.step = 10;
assert_eq!(state.epoch_accuracy(), 0.8);
assert_eq!(state.epoch_loss(), 0.05);
state.reset_epoch();
assert_eq!(state.running_total, 0);
}
}