Skip to main content

burn_train/learner/
summary.rs

1use core::cmp::Ordering;
2use std::{
3    collections::{HashMap, hash_map::Entry},
4    fmt::Display,
5    path::{Path, PathBuf},
6};
7
8use crate::{
9    logger::FileMetricLogger,
10    metric::store::{Aggregate, EventStore, LogEventStore, Split},
11};
12
13/// Contains the metric value at a given time.
14#[derive(Debug)]
15pub struct MetricEntry {
16    /// The step at which the metric was recorded (i.e., epoch).
17    pub step: usize,
18    /// The metric value.
19    pub value: f64,
20}
21
22/// Contains the summary of recorded values for a given metric.
23#[derive(Debug)]
24pub struct MetricSummary {
25    /// The metric name.
26    pub name: String,
27    /// The metric entries.
28    pub entries: Vec<MetricEntry>,
29}
30
31impl MetricSummary {
32    fn collect<E: EventStore>(
33        event_store: &mut E,
34        metric: &str,
35        split: &Split,
36        num_epochs: usize,
37    ) -> Option<Self> {
38        let entries = (1..=num_epochs)
39            .filter_map(|epoch| {
40                event_store
41                    .find_metric(metric, epoch, Aggregate::Mean, split)
42                    .map(|value| MetricEntry { step: epoch, value })
43            })
44            .collect::<Vec<_>>();
45
46        if entries.is_empty() {
47            None
48        } else {
49            Some(Self {
50                name: metric.to_string(),
51                entries,
52            })
53        }
54    }
55}
56
57/// Contains the summary of recorded metrics for the training and validation steps.
58pub struct SummaryMetrics {
59    /// Training metrics summary.
60    pub train: Vec<MetricSummary>,
61    /// Validation metrics summary.
62    pub valid: Vec<MetricSummary>,
63    /// Test metrics summary per test split tag.
64    ///
65    /// Each key corresponds to a `Split::Test(Some(tag))`.
66    /// The empty string represents `Split::Test(None)`.
67    pub test: HashMap<String, Vec<MetricSummary>>,
68}
69
70/// Detailed training summary.
71pub struct LearnerSummary {
72    /// The number of epochs completed.
73    pub epochs: usize,
74    /// The summary of recorded metrics during training.
75    pub metrics: SummaryMetrics,
76    /// The model name (only recorded within the learner).
77    pub(crate) model: Option<String>,
78}
79
80impl LearnerSummary {
81    /// Creates a new learner summary for the specified metrics.
82    ///
83    /// # Arguments
84    ///
85    /// * `directory` - The directory containing the training artifacts (checkpoints and logs).
86    /// * `metrics` - The list of metrics to collect for the summary.
87    pub fn new<S: AsRef<str>>(directory: impl AsRef<Path>, metrics: &[S]) -> Result<Self, String> {
88        let directory = directory.as_ref();
89        if !directory.exists() {
90            return Err(format!(
91                "Artifact directory does not exist at: {}",
92                directory.display()
93            ));
94        }
95
96        let mut event_store = LogEventStore::default();
97        let train_split = Split::Train;
98        let valid_split = Split::Valid;
99
100        let logger = FileMetricLogger::new(directory);
101        let test_split_root = logger.split_dir(&Split::Test(None));
102        if !logger.split_exists(&train_split)
103            && !logger.split_exists(&valid_split)
104            && test_split_root.is_none()
105        {
106            return Err(format!(
107                "No training, validation or test artifacts found at: {}",
108                directory.display()
109            ));
110        }
111
112        // Number of recorded epochs
113        let epochs = logger.epochs();
114
115        event_store.register_logger(logger);
116
117        let train_summary = metrics
118            .iter()
119            .filter_map(|metric| {
120                MetricSummary::collect(&mut event_store, metric.as_ref(), &train_split, epochs)
121            })
122            .collect::<Vec<_>>();
123
124        let valid_summary = metrics
125            .iter()
126            .filter_map(|metric| {
127                MetricSummary::collect(&mut event_store, metric.as_ref(), &valid_split, epochs)
128            })
129            .collect::<Vec<_>>();
130
131        let test_summary = match test_split_root {
132            Some(root) => collect_test_split_metrics(root, metrics, &mut event_store, epochs),
133            None => Default::default(),
134        };
135
136        Ok(Self {
137            epochs,
138            metrics: SummaryMetrics {
139                train: train_summary,
140                valid: valid_summary,
141                test: test_summary,
142            },
143            model: None,
144        })
145    }
146
147    pub(crate) fn with_model(mut self, name: String) -> Self {
148        self.model = Some(name);
149        self
150    }
151
152    /// Merges another summary into this one, combining all metric entries.
153    pub(crate) fn merge(mut self, other: LearnerSummary) -> Self {
154        fn merge_metrics(
155            base: Vec<MetricSummary>,
156            incoming: Vec<MetricSummary>,
157        ) -> Vec<MetricSummary> {
158            let mut map: HashMap<String, MetricSummary> =
159                base.into_iter().map(|m| (m.name.clone(), m)).collect();
160
161            for metric in incoming {
162                match map.entry(metric.name.clone()) {
163                    Entry::Occupied(mut entry) => {
164                        entry.get_mut().entries.extend(metric.entries);
165                    }
166                    Entry::Vacant(entry) => {
167                        entry.insert(metric);
168                    }
169                }
170            }
171            map.into_values().collect()
172        }
173
174        self.metrics.train = merge_metrics(self.metrics.train, other.metrics.train);
175        self.metrics.valid = merge_metrics(self.metrics.valid, other.metrics.valid);
176
177        for (tag, metrics) in other.metrics.test {
178            match self.metrics.test.entry(tag) {
179                Entry::Occupied(mut entry) => {
180                    let current = std::mem::take(entry.get_mut());
181                    let merged = merge_metrics(current, metrics);
182                    *entry.get_mut() = merged;
183                }
184                Entry::Vacant(entry) => {
185                    entry.insert(metrics);
186                }
187            }
188        }
189
190        if self.model != other.model {
191            self.model = None;
192        }
193
194        self
195    }
196}
197
198fn collect_test_split_metrics<P: AsRef<Path>, S: AsRef<str>>(
199    root: P,
200    metrics: &[S],
201    event_store: &mut LogEventStore,
202    epochs: usize,
203) -> HashMap<String, Vec<MetricSummary>> {
204    // Collect immediate child directories
205    let dirs = match std::fs::read_dir(root) {
206        Ok(entries) => entries
207            .filter_map(|entry| {
208                let entry = entry.ok()?;
209                let file_type = entry.file_type().ok()?;
210                if file_type.is_dir() {
211                    Some(entry.file_name().to_string_lossy().to_string())
212                } else {
213                    None
214                }
215            })
216            .collect::<Vec<_>>(),
217        Err(_) => Vec::new(),
218    };
219
220    let mut map = HashMap::new();
221
222    if dirs.is_empty() {
223        return map;
224    }
225
226    // Detect if all directories are epoch directories
227    let all_epochs = dirs.iter().all(FileMetricLogger::is_epoch_dir);
228
229    if all_epochs {
230        // Single untagged test split
231        let split = Split::Test(None);
232
233        let summaries = metrics
234            .iter()
235            .filter_map(|metric| {
236                MetricSummary::collect(event_store, metric.as_ref(), &split, epochs)
237            })
238            .collect::<Vec<_>>();
239
240        // Untagged marked with empty string
241        map.insert("".to_string(), summaries);
242    } else {
243        // Tagged splits
244        for tag in dirs {
245            let split = Split::Test(Some(tag.clone().into()));
246
247            let summaries = metrics
248                .iter()
249                .filter_map(|metric| {
250                    MetricSummary::collect(event_store, metric.as_ref(), &split, epochs)
251                })
252                .collect::<Vec<_>>();
253
254            map.insert(tag, summaries);
255        }
256    }
257
258    map
259}
260
261impl Display for LearnerSummary {
262    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263        // Compute the max length for each column
264        let mut max_split_len = 5; // "Train"
265        let mut max_metric_len = "Metric".len();
266        for metric in self.metrics.train.iter() {
267            max_metric_len = max_metric_len.max(metric.name.len());
268        }
269        for metric in self.metrics.valid.iter() {
270            max_metric_len = max_metric_len.max(metric.name.len());
271        }
272        for (tag, metrics) in self.metrics.test.iter() {
273            let split_name = if tag.is_empty() {
274                "Test".to_string()
275            } else {
276                format!("Test ({tag})")
277            };
278
279            max_split_len = max_split_len.max(split_name.len());
280
281            for metric in metrics {
282                max_metric_len = max_metric_len.max(metric.name.len());
283            }
284        }
285
286        // Summary header
287        writeln!(
288            f,
289            "{:=>width_symbol$} Learner Summary {:=>width_symbol$}",
290            "",
291            "",
292            width_symbol = 24,
293        )?;
294
295        if let Some(model) = &self.model {
296            writeln!(f, "Model:\n{model}")?;
297        }
298        writeln!(f, "Total Epochs: {epochs}\n\n", epochs = self.epochs)?;
299
300        // Metrics table header
301        writeln!(
302            f,
303            "| {:<width_split$} | {:<width_metric$} | Min.     | Epoch    | Max.     | Epoch    |\n|{:->width_split$}--|{:->width_metric$}--|----------|----------|----------|----------|",
304            "Split",
305            "Metric",
306            "",
307            "",
308            width_split = max_split_len,
309            width_metric = max_metric_len,
310        )?;
311
312        // Table entries
313        fn cmp_f64(a: &f64, b: &f64) -> Ordering {
314            match (a.is_nan(), b.is_nan()) {
315                (true, true) => Ordering::Equal,
316                (true, false) => Ordering::Greater,
317                (false, true) => Ordering::Less,
318                _ => a.partial_cmp(b).unwrap(),
319            }
320        }
321
322        fn fmt_val(val: f64) -> String {
323            if val < 1e-2 {
324                // Use scientific notation for small values which would otherwise be truncated
325                format!("{val:<9.3e}")
326            } else {
327                format!("{val:<9.3}")
328            }
329        }
330
331        let mut write_metrics_summary =
332            |metrics: &[MetricSummary], split: String| -> std::fmt::Result {
333                for metric in metrics.iter() {
334                    if metric.entries.is_empty() {
335                        continue; // skip metrics with no recorded values
336                    }
337
338                    // Compute the min & max for each metric
339                    let metric_min = metric
340                        .entries
341                        .iter()
342                        .min_by(|a, b| cmp_f64(&a.value, &b.value))
343                        .unwrap();
344                    let metric_max = metric
345                        .entries
346                        .iter()
347                        .max_by(|a, b| cmp_f64(&a.value, &b.value))
348                        .unwrap();
349
350                    writeln!(
351                        f,
352                        "| {:<width_split$} | {:<width_metric$} | {}| {:<9?}| {}| {:<9?}|",
353                        split,
354                        metric.name,
355                        fmt_val(metric_min.value),
356                        metric_min.step,
357                        fmt_val(metric_max.value),
358                        metric_max.step,
359                        width_split = max_split_len,
360                        width_metric = max_metric_len,
361                    )?;
362                }
363
364                Ok(())
365            };
366
367        write_metrics_summary(&self.metrics.train, format!("{:?}", Split::Train))?;
368        write_metrics_summary(&self.metrics.valid, format!("{:?}", Split::Valid))?;
369
370        for (tag, metrics) in &self.metrics.test {
371            let split_name = if tag.is_empty() {
372                "Test".to_string()
373            } else {
374                format!("Test ({tag})")
375            };
376
377            write_metrics_summary(metrics, split_name)?;
378        }
379
380        Ok(())
381    }
382}
383
384// TODO: rename to `ExperimentSummary`? Used in learner + evaluator.
385
386#[derive(Clone)]
387/// Learning summary config.
388pub struct LearnerSummaryConfig {
389    pub(crate) directory: PathBuf,
390    pub(crate) metrics: Vec<String>,
391}
392
393impl LearnerSummaryConfig {
394    /// Create the learning summary.
395    pub fn init(&self) -> Result<LearnerSummary, String> {
396        LearnerSummary::new(&self.directory, &self.metrics[..])
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    #[should_panic = "Summary artifacts should exist"]
406    fn test_artifact_dir_should_exist() {
407        let dir = "/tmp/learner-summary-not-found";
408        let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
409    }
410
411    #[test]
412    #[should_panic = "Summary artifacts should exist"]
413    fn test_train_valid_artifacts_should_exist() {
414        let dir = "/tmp/test-learner-summary-empty";
415        std::fs::create_dir_all(dir).ok();
416        let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
417    }
418
419    #[test]
420    fn test_summary_should_be_empty() {
421        let dir = Path::new("/tmp/test-learner-summary-empty-metrics");
422        std::fs::create_dir_all(dir).unwrap();
423        std::fs::create_dir_all(dir.join("train/epoch-1")).unwrap();
424        std::fs::create_dir_all(dir.join("valid/epoch-1")).unwrap();
425        let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
426            .expect("Summary artifacts should exist");
427
428        assert_eq!(summary.epochs, 1);
429
430        assert_eq!(summary.metrics.train.len(), 0);
431        assert_eq!(summary.metrics.valid.len(), 0);
432
433        std::fs::remove_dir_all(dir).unwrap();
434    }
435
436    #[test]
437    fn test_summary_should_be_collected() {
438        let dir = Path::new("/tmp/test-learner-summary");
439        let train_dir = dir.join("train/epoch-1");
440        let valid_dir = dir.join("valid/epoch-1");
441        std::fs::create_dir_all(dir).unwrap();
442        std::fs::create_dir_all(&train_dir).unwrap();
443        std::fs::create_dir_all(&valid_dir).unwrap();
444
445        std::fs::write(train_dir.join("Loss.log"), "1.0\n2.0").expect("Unable to write file");
446        std::fs::write(valid_dir.join("Loss.log"), "1.0").expect("Unable to write file");
447
448        let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
449            .expect("Summary artifacts should exist");
450
451        assert_eq!(summary.epochs, 1);
452
453        // Only Loss metric
454        assert_eq!(summary.metrics.train.len(), 1);
455        assert_eq!(summary.metrics.valid.len(), 1);
456
457        // Aggregated train metric entries for 1 epoch
458        let train_metric = &summary.metrics.train[0];
459        assert_eq!(train_metric.name, "Loss");
460        assert_eq!(train_metric.entries.len(), 1);
461        let entry = &train_metric.entries[0];
462        assert_eq!(entry.step, 1); // epoch = 1
463        assert_eq!(entry.value, 1.5); // (1 + 2) / 2
464
465        // Aggregated valid metric entries for 1 epoch
466        let valid_metric = &summary.metrics.valid[0];
467        assert_eq!(valid_metric.name, "Loss");
468        assert_eq!(valid_metric.entries.len(), 1);
469        let entry = &valid_metric.entries[0];
470        assert_eq!(entry.step, 1); // epoch = 1
471        assert_eq!(entry.value, 1.0);
472
473        std::fs::remove_dir_all(dir).unwrap();
474    }
475}