burn_train/learner/
summary.rs

1use core::cmp::Ordering;
2use std::{
3    fmt::Display,
4    path::{Path, PathBuf},
5};
6
7use crate::{
8    logger::FileMetricLogger,
9    metric::store::{Aggregate, EventStore, LogEventStore, Split},
10};
11
12/// Contains the metric value at a given time.
13pub struct MetricEntry {
14    /// The step at which the metric was recorded (i.e., epoch).
15    pub step: usize,
16    /// The metric value.
17    pub value: f64,
18}
19
20/// Contains the summary of recorded values for a given metric.
21pub struct MetricSummary {
22    /// The metric name.
23    pub name: String,
24    /// The metric entries.
25    pub entries: Vec<MetricEntry>,
26}
27
28impl MetricSummary {
29    fn new<E: EventStore>(
30        event_store: &mut E,
31        metric: &str,
32        split: Split,
33        num_epochs: usize,
34    ) -> Option<Self> {
35        let entries = (1..=num_epochs)
36            .filter_map(|epoch| {
37                event_store
38                    .find_metric(metric, epoch, Aggregate::Mean, split)
39                    .map(|value| MetricEntry { step: epoch, value })
40            })
41            .collect::<Vec<_>>();
42
43        if entries.is_empty() {
44            None
45        } else {
46            Some(Self {
47                name: metric.to_string(),
48                entries,
49            })
50        }
51    }
52}
53
54/// Contains the summary of recorded metrics for the training and validation steps.
55pub struct SummaryMetrics {
56    /// Training metrics summary.
57    pub train: Vec<MetricSummary>,
58    /// Validation metrics summary.
59    pub valid: Vec<MetricSummary>,
60}
61
62/// Detailed training summary.
63pub struct LearnerSummary {
64    /// The number of epochs completed.
65    pub epochs: usize,
66    /// The summary of recorded metrics during training.
67    pub metrics: SummaryMetrics,
68    /// The model name (only recorded within the learner).
69    pub(crate) model: Option<String>,
70}
71
72impl LearnerSummary {
73    /// Creates a new learner summary for the specified metrics.
74    ///
75    /// # Arguments
76    ///
77    /// * `directory` - The directory containing the training artifacts (checkpoints and logs).
78    /// * `metrics` - The list of metrics to collect for the summary.
79    pub fn new<S: AsRef<str>>(directory: impl AsRef<Path>, metrics: &[S]) -> Result<Self, String> {
80        let directory = directory.as_ref();
81        if !directory.exists() {
82            return Err(format!(
83                "Artifact directory does not exist at: {}",
84                directory.display()
85            ));
86        }
87        let train_dir = directory.join("train");
88        let valid_dir = directory.join("valid");
89        if !train_dir.exists() & !valid_dir.exists() {
90            return Err(format!(
91                "No training or validation artifacts found at: {}",
92                directory.display()
93            ));
94        }
95
96        let mut event_store = LogEventStore::default();
97
98        let train_logger = FileMetricLogger::new_train(train_dir.to_str().unwrap());
99        let valid_logger = FileMetricLogger::new_train(valid_dir.to_str().unwrap());
100
101        // Number of recorded epochs
102        let epochs = train_logger.epochs();
103
104        event_store.register_logger_train(train_logger);
105        event_store.register_logger_valid(valid_logger);
106
107        let train_summary = metrics
108            .iter()
109            .filter_map(|metric| {
110                MetricSummary::new(&mut event_store, metric.as_ref(), Split::Train, epochs)
111            })
112            .collect::<Vec<_>>();
113
114        let valid_summary = metrics
115            .iter()
116            .filter_map(|metric| {
117                MetricSummary::new(&mut event_store, metric.as_ref(), Split::Valid, epochs)
118            })
119            .collect::<Vec<_>>();
120
121        Ok(Self {
122            epochs,
123            metrics: SummaryMetrics {
124                train: train_summary,
125                valid: valid_summary,
126            },
127            model: None,
128        })
129    }
130
131    pub(crate) fn with_model(mut self, name: String) -> Self {
132        self.model = Some(name);
133        self
134    }
135}
136
137impl Display for LearnerSummary {
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        // Compute the max length for each column
140        let split_train = "Train";
141        let split_valid = "Valid";
142        let max_split_len = "Split".len().max(split_train.len()).max(split_valid.len());
143        let mut max_metric_len = "Metric".len();
144        for metric in self.metrics.train.iter() {
145            max_metric_len = max_metric_len.max(metric.name.len());
146        }
147        for metric in self.metrics.valid.iter() {
148            max_metric_len = max_metric_len.max(metric.name.len());
149        }
150
151        // Summary header
152        writeln!(
153            f,
154            "{:=>width_symbol$} Learner Summary {:=>width_symbol$}",
155            "",
156            "",
157            width_symbol = 24,
158        )?;
159
160        if let Some(model) = &self.model {
161            writeln!(f, "Model:\n{model}")?;
162        }
163        writeln!(f, "Total Epochs: {epochs}\n\n", epochs = self.epochs)?;
164
165        // Metrics table header
166        writeln!(
167            f,
168            "| {:<width_split$} | {:<width_metric$} | Min.     | Epoch    | Max.     | Epoch    |\n|{:->width_split$}--|{:->width_metric$}--|----------|----------|----------|----------|",
169            "Split",
170            "Metric",
171            "",
172            "",
173            width_split = max_split_len,
174            width_metric = max_metric_len,
175        )?;
176
177        // Table entries
178        fn cmp_f64(a: &f64, b: &f64) -> Ordering {
179            match (a.is_nan(), b.is_nan()) {
180                (true, true) => Ordering::Equal,
181                (true, false) => Ordering::Greater,
182                (false, true) => Ordering::Less,
183                _ => a.partial_cmp(b).unwrap(),
184            }
185        }
186
187        fn fmt_val(val: f64) -> String {
188            if val < 1e-2 {
189                // Use scientific notation for small values which would otherwise be truncated
190                format!("{val:<9.3e}")
191            } else {
192                format!("{val:<9.3}")
193            }
194        }
195
196        let mut write_metrics_summary =
197            |metrics: &[MetricSummary], split: &str| -> std::fmt::Result {
198                for metric in metrics.iter() {
199                    if metric.entries.is_empty() {
200                        continue; // skip metrics with no recorded values
201                    }
202
203                    // Compute the min & max for each metric
204                    let metric_min = metric
205                        .entries
206                        .iter()
207                        .min_by(|a, b| cmp_f64(&a.value, &b.value))
208                        .unwrap();
209                    let metric_max = metric
210                        .entries
211                        .iter()
212                        .max_by(|a, b| cmp_f64(&a.value, &b.value))
213                        .unwrap();
214
215                    writeln!(
216                        f,
217                        "| {:<width_split$} | {:<width_metric$} | {}| {:<9?}| {}| {:<9?}|",
218                        split,
219                        metric.name,
220                        fmt_val(metric_min.value),
221                        metric_min.step,
222                        fmt_val(metric_max.value),
223                        metric_max.step,
224                        width_split = max_split_len,
225                        width_metric = max_metric_len,
226                    )?;
227                }
228
229                Ok(())
230            };
231
232        write_metrics_summary(&self.metrics.train, split_train)?;
233        write_metrics_summary(&self.metrics.valid, split_valid)?;
234
235        Ok(())
236    }
237}
238
239pub(crate) struct LearnerSummaryConfig {
240    pub(crate) directory: PathBuf,
241    pub(crate) metrics: Vec<String>,
242}
243
244impl LearnerSummaryConfig {
245    pub fn init(&self) -> Result<LearnerSummary, String> {
246        LearnerSummary::new(&self.directory, &self.metrics[..])
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    #[should_panic = "Summary artifacts should exist"]
256    fn test_artifact_dir_should_exist() {
257        let dir = "/tmp/learner-summary-not-found";
258        let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
259    }
260
261    #[test]
262    #[should_panic = "Summary artifacts should exist"]
263    fn test_train_valid_artifacts_should_exist() {
264        let dir = "/tmp/test-learner-summary-empty";
265        std::fs::create_dir_all(dir).ok();
266        let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
267    }
268
269    #[test]
270    fn test_summary_should_be_empty() {
271        let dir = Path::new("/tmp/test-learner-summary-empty-metrics");
272        std::fs::create_dir_all(dir).unwrap();
273        std::fs::create_dir_all(dir.join("train/epoch-1")).unwrap();
274        std::fs::create_dir_all(dir.join("valid/epoch-1")).unwrap();
275        let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
276            .expect("Summary artifacts should exist");
277
278        assert_eq!(summary.epochs, 1);
279
280        assert_eq!(summary.metrics.train.len(), 0);
281        assert_eq!(summary.metrics.valid.len(), 0);
282
283        std::fs::remove_dir_all(dir).unwrap();
284    }
285
286    #[test]
287    fn test_summary_should_be_collected() {
288        let dir = Path::new("/tmp/test-learner-summary");
289        let train_dir = dir.join("train/epoch-1");
290        let valid_dir = dir.join("valid/epoch-1");
291        std::fs::create_dir_all(dir).unwrap();
292        std::fs::create_dir_all(&train_dir).unwrap();
293        std::fs::create_dir_all(&valid_dir).unwrap();
294
295        std::fs::write(train_dir.join("Loss.log"), "1.0\n2.0").expect("Unable to write file");
296        std::fs::write(valid_dir.join("Loss.log"), "1.0").expect("Unable to write file");
297
298        let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
299            .expect("Summary artifacts should exist");
300
301        assert_eq!(summary.epochs, 1);
302
303        // Only Loss metric
304        assert_eq!(summary.metrics.train.len(), 1);
305        assert_eq!(summary.metrics.valid.len(), 1);
306
307        // Aggregated train metric entries for 1 epoch
308        let train_metric = &summary.metrics.train[0];
309        assert_eq!(train_metric.name, "Loss");
310        assert_eq!(train_metric.entries.len(), 1);
311        let entry = &train_metric.entries[0];
312        assert_eq!(entry.step, 1); // epoch = 1
313        assert_eq!(entry.value, 1.5); // (1 + 2) / 2
314
315        // Aggregated valid metric entries for 1 epoch
316        let valid_metric = &summary.metrics.valid[0];
317        assert_eq!(valid_metric.name, "Loss");
318        assert_eq!(valid_metric.entries.len(), 1);
319        let entry = &valid_metric.entries[0];
320        assert_eq!(entry.step, 1); // epoch = 1
321        assert_eq!(entry.value, 1.0);
322
323        std::fs::remove_dir_all(dir).unwrap();
324    }
325}