burn_train/metric/
base.rs

1use burn_core::{LearningRate, data::dataloader::Progress};
2
3/// Metric metadata that can be used when computing metrics.
4pub struct MetricMetadata {
5    /// The current progress.
6    pub progress: Progress,
7
8    /// The current epoch.
9    pub epoch: usize,
10
11    /// The total number of epochs.
12    pub epoch_total: usize,
13
14    /// The current iteration.
15    pub iteration: usize,
16
17    /// The current learning rate.
18    pub lr: Option<LearningRate>,
19}
20
21impl MetricMetadata {
22    /// Fake metric metadata
23    #[cfg(test)]
24    pub fn fake() -> Self {
25        Self {
26            progress: Progress {
27                items_processed: 1,
28                items_total: 1,
29            },
30            epoch: 0,
31            epoch_total: 1,
32            iteration: 0,
33            lr: None,
34        }
35    }
36}
37
38/// Metric trait.
39///
40/// # Notes
41///
42/// Implementations should define their own input type only used by the metric.
43/// This is important since some conflict may happen when the model output is adapted for each
44/// metric's input type.
45pub trait Metric: Send + Sync {
46    /// The input type of the metric.
47    type Input;
48
49    /// The parameterized name of the metric.
50    ///
51    /// This should be unique, so avoid using short generic names, prefer using the long name.
52    ///
53    /// For a metric that can exist at different parameters (e.g., top-k accuracy for different
54    /// values of k), the name should be unique for each instance.
55    fn name(&self) -> String;
56
57    /// Update the metric state and returns the current metric entry.
58    fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> MetricEntry;
59    /// Clear the metric state.
60    fn clear(&mut self);
61}
62
63/// Adaptor are used to transform types so that they can be used by metrics.
64///
65/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
66/// registered with the [learner builder](crate::learner::LearnerBuilder) .
67pub trait Adaptor<T> {
68    /// Adapt the type to be passed to a [metric](Metric).
69    fn adapt(&self) -> T;
70}
71
72impl<T> Adaptor<()> for T {
73    fn adapt(&self) {}
74}
75
76/// Declare a metric to be numeric.
77///
78/// This is useful to plot the values of a metric during training.
79pub trait Numeric {
80    /// Returns the numeric value of the metric.
81    fn value(&self) -> f64;
82}
83
84/// Data type that contains the current state of a metric at a given time.
85#[derive(new, Debug, Clone)]
86pub struct MetricEntry {
87    /// The name of the metric.
88    pub name: String,
89    /// The string to be displayed.
90    pub formatted: String,
91    /// The string to be saved.
92    pub serialize: String,
93}
94
95/// Numeric metric entry.
96pub enum NumericEntry {
97    /// Single numeric value.
98    Value(f64),
99    /// Aggregated numeric (value, number of elements).
100    Aggregated(f64, usize),
101}
102
103impl NumericEntry {
104    pub(crate) fn serialize(&self) -> String {
105        match self {
106            Self::Value(v) => v.to_string(),
107            Self::Aggregated(v, n) => format!("{v},{n}"),
108        }
109    }
110
111    pub(crate) fn deserialize(entry: &str) -> Result<Self, String> {
112        // Check for comma separated values
113        let values = entry.split(',').collect::<Vec<_>>();
114        let num_values = values.len();
115
116        if num_values == 1 {
117            // Numeric value
118            match values[0].parse::<f64>() {
119                Ok(value) => Ok(NumericEntry::Value(value)),
120                Err(err) => Err(err.to_string()),
121            }
122        } else if num_values == 2 {
123            // Aggregated numeric (value, number of elements)
124            let (value, numel) = (values[0], values[1]);
125            match value.parse::<f64>() {
126                Ok(value) => match numel.parse::<usize>() {
127                    Ok(numel) => Ok(NumericEntry::Aggregated(value, numel)),
128                    Err(err) => Err(err.to_string()),
129                },
130                Err(err) => Err(err.to_string()),
131            }
132        } else {
133            Err("Invalid number of values for numeric entry".to_string())
134        }
135    }
136}
137
138/// Format a float with the given precision. Will use scientific notation if necessary.
139pub fn format_float(float: f64, precision: usize) -> String {
140    let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
141
142    match scientific_notation_threshold >= float {
143        true => format!("{float:.precision$e}"),
144        false => format!("{float:.precision$}"),
145    }
146}