burn_train/metric/
base.rs

1use std::sync::Arc;
2
3use burn_core::data::dataloader::Progress;
4use burn_optim::LearningRate;
5
6/// Metric metadata that can be used when computing metrics.
7pub struct MetricMetadata {
8    /// The current progress.
9    pub progress: Progress,
10
11    /// The current epoch.
12    pub epoch: usize,
13
14    /// The total number of epochs.
15    pub epoch_total: usize,
16
17    /// The current iteration.
18    pub iteration: usize,
19
20    /// The current learning rate.
21    pub lr: Option<LearningRate>,
22}
23
24impl MetricMetadata {
25    /// Fake metric metadata
26    #[cfg(test)]
27    pub fn fake() -> Self {
28        Self {
29            progress: Progress {
30                items_processed: 1,
31                items_total: 1,
32            },
33            epoch: 0,
34            epoch_total: 1,
35            iteration: 0,
36            lr: None,
37        }
38    }
39}
40
41/// Metric trait.
42///
43/// # Notes
44///
45/// Implementations should define their own input type only used by the metric.
46/// This is important since some conflict may happen when the model output is adapted for each
47/// metric's input type.
48pub trait Metric: Send + Sync + Clone {
49    /// The input type of the metric.
50    type Input;
51
52    /// The parameterized name of the metric.
53    ///
54    /// This should be unique, so avoid using short generic names, prefer using the long name.
55    ///
56    /// For a metric that can exist at different parameters (e.g., top-k accuracy for different
57    /// values of k), the name should be unique for each instance.
58    fn name(&self) -> MetricName;
59
60    /// Update the metric state and returns the current metric entry.
61    fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> MetricEntry;
62    /// Clear the metric state.
63    fn clear(&mut self);
64}
65
66/// Type used to store metric names efficiently.
67pub type MetricName = Arc<String>;
68
69/// Adaptor are used to transform types so that they can be used by metrics.
70///
71/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
72/// registered with the [learner builder](crate::learner::LearnerBuilder) .
73pub trait Adaptor<T> {
74    /// Adapt the type to be passed to a [metric](Metric).
75    fn adapt(&self) -> T;
76}
77
78impl<T> Adaptor<()> for T {
79    fn adapt(&self) {}
80}
81
82/// Declare a metric to be numeric.
83///
84/// This is useful to plot the values of a metric during training.
85pub trait Numeric {
86    /// Returns the numeric value of the metric.
87    fn value(&self) -> NumericEntry;
88}
89
90/// Data type that contains the current state of a metric at a given time.
91#[derive(Debug, Clone)]
92pub struct MetricEntry {
93    /// The name of the metric.
94    pub name: Arc<String>,
95    /// The string to be displayed.
96    pub formatted: String,
97    /// The string to be saved.
98    pub serialize: String,
99    /// Tags linked to the metric.
100    pub tags: Vec<Arc<String>>,
101}
102
103impl MetricEntry {
104    /// Create a new metric.
105    pub fn new(name: Arc<String>, formatted: String, serialize: String) -> Self {
106        Self {
107            name,
108            formatted,
109            serialize,
110            tags: Vec::new(),
111        }
112    }
113}
114
115/// Numeric metric entry.
116#[derive(Debug, Clone)]
117pub enum NumericEntry {
118    /// Single numeric value.
119    Value(f64),
120    /// Aggregated numeric (value, number of elements).
121    Aggregated {
122        /// The sum of all entries.
123        sum: f64,
124        /// The number of entries present in the sum.
125        count: usize,
126        /// The current aggregated value.
127        current: f64,
128    },
129}
130
131impl NumericEntry {
132    /// Gets the current aggregated value of the metric.
133    pub fn current(&self) -> f64 {
134        match self {
135            NumericEntry::Value(val) => *val,
136            NumericEntry::Aggregated { current, .. } => *current,
137        }
138    }
139}
140
141impl NumericEntry {
142    /// Returns a String representing the NumericEntry
143    pub fn serialize(&self) -> String {
144        match self {
145            Self::Value(v) => v.to_string(),
146            Self::Aggregated { sum, count, .. } => format!("{sum},{count}"),
147        }
148    }
149
150    /// De-serializes a string representing a NumericEntry and returns a Result containing the corresponding NumericEntry.
151    pub fn deserialize(entry: &str) -> Result<Self, String> {
152        // Check for comma separated values
153        let values = entry.split(',').collect::<Vec<_>>();
154        let num_values = values.len();
155
156        if num_values == 1 {
157            // Numeric value
158            match values[0].parse::<f64>() {
159                Ok(value) => Ok(NumericEntry::Value(value)),
160                Err(err) => Err(err.to_string()),
161            }
162        } else if num_values == 2 {
163            // Aggregated numeric (value, number of elements)
164            let (value, numel) = (values[0], values[1]);
165            match value.parse::<f64>() {
166                Ok(value) => match numel.parse::<usize>() {
167                    Ok(numel) => Ok(NumericEntry::Aggregated {
168                        sum: value,
169                        count: numel,
170                        current: value,
171                    }),
172                    Err(err) => Err(err.to_string()),
173                },
174                Err(err) => Err(err.to_string()),
175            }
176        } else {
177            Err("Invalid number of values for numeric entry".to_string())
178        }
179    }
180}
181
182/// Format a float with the given precision. Will use scientific notation if necessary.
183pub fn format_float(float: f64, precision: usize) -> String {
184    let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
185
186    match scientific_notation_threshold >= float {
187        true => format!("{float:.precision$e}"),
188        false => format!("{float:.precision$}"),
189    }
190}