burn_train/metric/store/
base.rs

1use std::sync::Arc;
2
3use crate::metric::{MetricEntry, NumericEntry};
4
5/// Event happening during the training/validation process.
6pub enum Event {
7    /// Signal that metrics have been updated.
8    MetricsUpdate(MetricsUpdate),
9    /// Signal the end of an epoch.
10    EndEpoch(usize),
11}
12
13/// Contains all metric information.
14#[derive(new, Clone, Debug)]
15pub struct MetricsUpdate {
16    /// Metrics information related to non-numeric metrics.
17    pub entries: Vec<MetricEntry>,
18    /// Metrics information related to numeric metrics.
19    pub entries_numeric: Vec<(MetricEntry, NumericEntry)>,
20}
21
22impl MetricsUpdate {
23    /// Appends a tag to the config.
24    pub fn tag(&mut self, tag: Arc<String>) {
25        self.entries.iter_mut().for_each(|entry| {
26            entry.tags.push(tag.clone());
27        });
28        self.entries_numeric.iter_mut().for_each(|(entry, _)| {
29            entry.tags.push(tag.clone());
30        });
31    }
32}
33
34/// Defines how training and validation events are collected and searched.
35///
36/// This trait also exposes methods that uses the collected data to compute useful information.
37pub trait EventStore: Send {
38    /// Collect a training/validation event.
39    fn add_event(&mut self, event: Event, split: Split);
40
41    /// Find the epoch following the given criteria from the collected data.
42    fn find_epoch(
43        &mut self,
44        name: &str,
45        aggregate: Aggregate,
46        direction: Direction,
47        split: Split,
48    ) -> Option<usize>;
49
50    /// Find the metric value for the current epoch following the given criteria.
51    fn find_metric(
52        &mut self,
53        name: &str,
54        epoch: usize,
55        aggregate: Aggregate,
56        split: Split,
57    ) -> Option<f64>;
58}
59
60#[derive(Copy, Clone, Hash, PartialEq, Eq, Debug)]
61/// How to aggregate the metric.
62pub enum Aggregate {
63    /// Compute the average.
64    Mean,
65}
66
67#[derive(Copy, Clone)]
68/// The split to use.
69pub enum Split {
70    /// The training split.
71    Train,
72    /// The validation split.
73    Valid,
74    /// The testing split.
75    Test,
76}
77
78#[derive(Copy, Clone)]
79/// The direction of the query.
80pub enum Direction {
81    /// Lower is better.
82    Lowest,
83    /// Higher is better.
84    Highest,
85}