Skip to main content

entrenar/monitor/tui/render/
epoch.rs

1//! Epoch summary computation from training snapshots.
2
3use super::super::state::TrainingSnapshot;
4
5#[derive(Debug, Clone)]
6pub struct EpochSummary {
7    pub epoch: usize,
8    pub avg_loss: f32,
9    pub min_loss: f32,
10    pub max_loss: f32,
11    pub end_loss: f32,
12    pub avg_grad: f32,
13    pub lr: f32,
14    pub tokens_per_sec: f32,
15}
16
17pub fn compute_epoch_summaries(snapshot: &TrainingSnapshot) -> Vec<EpochSummary> {
18    if snapshot.steps_per_epoch == 0 || snapshot.loss_history.is_empty() {
19        return Vec::new();
20    }
21
22    let steps = snapshot.steps_per_epoch;
23    let mut summaries = Vec::new();
24
25    for (epoch_idx, chunk) in snapshot.loss_history.chunks(steps).enumerate() {
26        let valid: Vec<f32> = chunk.iter().copied().filter(|v| v.is_finite()).collect();
27        if valid.is_empty() {
28            continue;
29        }
30
31        // valid.len() <= steps_per_epoch, bounded by training config, safe for f32
32        let valid_count = valid.len().min(usize::from(u16::MAX)) as f32;
33        let avg_loss = valid.iter().sum::<f32>() / valid_count;
34        let min_loss = valid.iter().copied().fold(f32::INFINITY, f32::min);
35        let max_loss = valid.iter().copied().fold(f32::NEG_INFINITY, f32::max);
36        let end_loss = *valid.last().unwrap_or(&0.0);
37
38        let lr = if snapshot.lr_history.is_empty() {
39            snapshot.learning_rate
40        } else {
41            let lr_start = epoch_idx * steps;
42            let lr_end = (lr_start + steps).min(snapshot.lr_history.len());
43            if lr_start < snapshot.lr_history.len() {
44                let lr_span = (lr_end - lr_start).max(1).min(usize::from(u16::MAX)) as f32;
45                snapshot.lr_history[lr_start..lr_end].iter().sum::<f32>() / lr_span
46            } else {
47                snapshot.learning_rate
48            }
49        };
50
51        summaries.push(EpochSummary {
52            epoch: epoch_idx + 1,
53            avg_loss,
54            min_loss,
55            max_loss,
56            end_loss,
57            avg_grad: snapshot.gradient_norm.max(0.0),
58            lr,
59            tokens_per_sec: snapshot.tokens_per_second.max(0.0),
60        });
61    }
62    summaries
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68
69    #[test]
70    fn test_epoch_summaries() {
71        let snapshot = TrainingSnapshot {
72            steps_per_epoch: 4,
73            loss_history: vec![10.0, 9.5, 9.0, 8.5, 8.0, 7.5, 7.0, 6.5, 6.0, 5.5, 5.0, 4.5],
74            ..Default::default()
75        };
76
77        let summaries = compute_epoch_summaries(&snapshot);
78        assert_eq!(summaries.len(), 3);
79        assert!((summaries[0].avg_loss - 9.25).abs() < 0.01);
80        assert!((summaries[0].min_loss - 8.5).abs() < 0.01);
81        assert!((summaries[0].max_loss - 10.0).abs() < 0.01);
82    }
83}