burn_train/metric/
base.rs

1use burn_core::{data::dataloader::Progress, LearningRate};
2
3/// Metric metadata that can be used when computing metrics.
4pub struct MetricMetadata {
5    /// The current progress.
6    pub progress: Progress,
7
8    /// The current epoch.
9    pub epoch: usize,
10
11    /// The total number of epochs.
12    pub epoch_total: usize,
13
14    /// The current iteration.
15    pub iteration: usize,
16
17    /// The current learning rate.
18    pub lr: Option<LearningRate>,
19}
20
21impl MetricMetadata {
22    /// Fake metric metadata
23    #[cfg(test)]
24    pub fn fake() -> Self {
25        Self {
26            progress: Progress {
27                items_processed: 1,
28                items_total: 1,
29            },
30            epoch: 0,
31            epoch_total: 1,
32            iteration: 0,
33            lr: None,
34        }
35    }
36}
37
38/// Metric trait.
39///
40/// # Notes
41///
42/// Implementations should define their own input type only used by the metric.
43/// This is important since some conflict may happen when the model output is adapted for each
44/// metric's input type.
45pub trait Metric: Send + Sync {
46    /// The name of the metric.
47    ///
48    /// This should be unique, so avoid using short generic names, prefer using the long name.
49    const NAME: &'static str;
50
51    /// The input type of the metric.
52    type Input;
53
54    /// Update the metric state and returns the current metric entry.
55    fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> MetricEntry;
56    /// Clear the metric state.
57    fn clear(&mut self);
58}
59
60/// Adaptor are used to transform types so that they can be used by metrics.
61///
62/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
63/// registered with the [leaner buidler](crate::learner::LearnerBuilder) .
64pub trait Adaptor<T> {
65    /// Adapt the type to be passed to a [metric](Metric).
66    fn adapt(&self) -> T;
67}
68
69/// Declare a metric to be numeric.
70///
71/// This is useful to plot the values of a metric during training.
72pub trait Numeric {
73    /// Returns the numeric value of the metric.
74    fn value(&self) -> f64;
75}
76
77/// Data type that contains the current state of a metric at a given time.
78#[derive(new, Debug, Clone)]
79pub struct MetricEntry {
80    /// The name of the metric.
81    pub name: String,
82    /// The string to be displayed.
83    pub formatted: String,
84    /// The string to be saved.
85    pub serialize: String,
86}
87
88/// Numeric metric entry.
89pub enum NumericEntry {
90    /// Single numeric value.
91    Value(f64),
92    /// Aggregated numeric (value, number of elements).
93    Aggregated(f64, usize),
94}
95
96impl NumericEntry {
97    pub(crate) fn serialize(&self) -> String {
98        match self {
99            Self::Value(v) => v.to_string(),
100            Self::Aggregated(v, n) => format!("{v},{n}"),
101        }
102    }
103
104    pub(crate) fn deserialize(entry: &str) -> Result<Self, String> {
105        // Check for comma separated values
106        let values = entry.split(',').collect::<Vec<_>>();
107        let num_values = values.len();
108
109        if num_values == 1 {
110            // Numeric value
111            match values[0].parse::<f64>() {
112                Ok(value) => Ok(NumericEntry::Value(value)),
113                Err(err) => Err(err.to_string()),
114            }
115        } else if num_values == 2 {
116            // Aggregated numeric (value, number of elements)
117            let (value, numel) = (values[0], values[1]);
118            match value.parse::<f64>() {
119                Ok(value) => match numel.parse::<usize>() {
120                    Ok(numel) => Ok(NumericEntry::Aggregated(value, numel)),
121                    Err(err) => Err(err.to_string()),
122                },
123                Err(err) => Err(err.to_string()),
124            }
125        } else {
126            Err("Invalid number of values for numeric entry".to_string())
127        }
128    }
129}
130
131/// Format a float with the given precision. Will use scientific notation if necessary.
132pub fn format_float(float: f64, precision: usize) -> String {
133    let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
134
135    match scientific_notation_threshold >= float {
136        true => format!("{float:.precision$e}"),
137        false => format!("{float:.precision$}"),
138    }
139}