entrenar/monitor/tui/render/
epoch.rs1use super::super::state::TrainingSnapshot;
4
5#[derive(Debug, Clone)]
6pub struct EpochSummary {
7 pub epoch: usize,
8 pub avg_loss: f32,
9 pub min_loss: f32,
10 pub max_loss: f32,
11 pub end_loss: f32,
12 pub avg_grad: f32,
13 pub lr: f32,
14 pub tokens_per_sec: f32,
15}
16
17pub fn compute_epoch_summaries(snapshot: &TrainingSnapshot) -> Vec<EpochSummary> {
18 if snapshot.steps_per_epoch == 0 || snapshot.loss_history.is_empty() {
19 return Vec::new();
20 }
21
22 let steps = snapshot.steps_per_epoch;
23 let mut summaries = Vec::new();
24
25 for (epoch_idx, chunk) in snapshot.loss_history.chunks(steps).enumerate() {
26 let valid: Vec<f32> = chunk.iter().copied().filter(|v| v.is_finite()).collect();
27 if valid.is_empty() {
28 continue;
29 }
30
31 let valid_count = valid.len().min(usize::from(u16::MAX)) as f32;
33 let avg_loss = valid.iter().sum::<f32>() / valid_count;
34 let min_loss = valid.iter().copied().fold(f32::INFINITY, f32::min);
35 let max_loss = valid.iter().copied().fold(f32::NEG_INFINITY, f32::max);
36 let end_loss = *valid.last().unwrap_or(&0.0);
37
38 let lr = if snapshot.lr_history.is_empty() {
39 snapshot.learning_rate
40 } else {
41 let lr_start = epoch_idx * steps;
42 let lr_end = (lr_start + steps).min(snapshot.lr_history.len());
43 if lr_start < snapshot.lr_history.len() {
44 let lr_span = (lr_end - lr_start).max(1).min(usize::from(u16::MAX)) as f32;
45 snapshot.lr_history[lr_start..lr_end].iter().sum::<f32>() / lr_span
46 } else {
47 snapshot.learning_rate
48 }
49 };
50
51 summaries.push(EpochSummary {
52 epoch: epoch_idx + 1,
53 avg_loss,
54 min_loss,
55 max_loss,
56 end_loss,
57 avg_grad: snapshot.gradient_norm.max(0.0),
58 lr,
59 tokens_per_sec: snapshot.tokens_per_second.max(0.0),
60 });
61 }
62 summaries
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68
69 #[test]
70 fn test_epoch_summaries() {
71 let snapshot = TrainingSnapshot {
72 steps_per_epoch: 4,
73 loss_history: vec![10.0, 9.5, 9.0, 8.5, 8.0, 7.5, 7.0, 6.5, 6.0, 5.5, 5.0, 4.5],
74 ..Default::default()
75 };
76
77 let summaries = compute_epoch_summaries(&snapshot);
78 assert_eq!(summaries.len(), 3);
79 assert!((summaries[0].avg_loss - 9.25).abs() < 0.01);
80 assert!((summaries[0].min_loss - 8.5).abs() < 0.01);
81 assert!((summaries[0].max_loss - 10.0).abs() < 0.01);
82 }
83}