burn_train/metric/
base.rs

1use std::sync::Arc;
2
3use burn_core::data::dataloader::Progress;
4use burn_optim::LearningRate;
5
6/// Metric metadata that can be used when computing metrics.
7pub struct MetricMetadata {
8    /// The current progress.
9    pub progress: Progress,
10
11    /// The current epoch.
12    pub epoch: usize,
13
14    /// The total number of epochs.
15    pub epoch_total: usize,
16
17    /// The current iteration.
18    pub iteration: usize,
19
20    /// The current learning rate.
21    pub lr: Option<LearningRate>,
22}
23
24impl MetricMetadata {
25    /// Fake metric metadata
26    #[cfg(test)]
27    pub fn fake() -> Self {
28        Self {
29            progress: Progress {
30                items_processed: 1,
31                items_total: 1,
32            },
33            epoch: 0,
34            epoch_total: 1,
35            iteration: 0,
36            lr: None,
37        }
38    }
39}
40
41/// Metric id that can be used to compare metrics and retrieve entries of the same metric.
42/// For now we take the name as id to make sure that the same metric has the same id across different runs.
43#[derive(Debug, Clone, new, PartialEq, Eq, Hash)]
44pub struct MetricId {
45    /// The metric id.
46    id: Arc<String>,
47}
48
49/// Metric attributes define the properties intrinsic to different types of metric.
50#[derive(Clone, Debug)]
51pub enum MetricAttributes {
52    /// Numeric attributes.
53    Numeric(NumericAttributes),
54    /// No attributes.
55    None,
56}
57
58/// Definition of a metric.
59#[derive(Clone, Debug)]
60pub struct MetricDefinition {
61    /// The metric's id.
62    pub metric_id: MetricId,
63    /// The name of the metric.
64    pub name: String,
65    /// The description of the metric.
66    pub description: Option<String>,
67    /// The attributes of the metric.
68    pub attributes: MetricAttributes,
69}
70
71impl MetricDefinition {
72    /// Create a new metric definition given the metric and a unique id.
73    pub fn new<Me: Metric>(metric_id: MetricId, metric: &Me) -> Self {
74        Self {
75            metric_id,
76            name: metric.name().to_string(),
77            description: metric.description(),
78            attributes: metric.attributes(),
79        }
80    }
81}
82
83/// Metric trait.
84///
85/// # Notes
86///
87/// Implementations should define their own input type only used by the metric.
88/// This is important since some conflict may happen when the model output is adapted for each
89/// metric's input type.
90pub trait Metric: Send + Sync + Clone {
91    /// The input type of the metric.
92    type Input;
93
94    /// The parameterized name of the metric.
95    ///
96    /// This should be unique, so avoid using short generic names, prefer using the long name.
97    ///
98    /// For a metric that can exist at different parameters (e.g., top-k accuracy for different
99    /// values of k), the name should be unique for each instance.
100    fn name(&self) -> MetricName;
101
102    /// A short description of the metric.
103    fn description(&self) -> Option<String> {
104        None
105    }
106
107    /// Attributes of the metric.
108    ///
109    /// By default, metrics have no attributes.
110    fn attributes(&self) -> MetricAttributes {
111        MetricAttributes::None
112    }
113
114    /// Update the metric state and returns the current metric entry.
115    fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry;
116
117    /// Clear the metric state.
118    fn clear(&mut self);
119}
120
121/// Type used to store metric names efficiently.
122pub type MetricName = Arc<String>;
123
124/// Adaptor are used to transform types so that they can be used by metrics.
125///
126/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
127/// registered with the specific learning paradigm (i.e. [SupervisedTraining](crate::SupervisedTraining)).
128pub trait Adaptor<T> {
129    /// Adapt the type to be passed to a [metric](Metric).
130    fn adapt(&self) -> T;
131}
132
133impl<T> Adaptor<()> for T {
134    fn adapt(&self) {}
135}
136
137/// Attributes that describe intrinsic properties of a numeric metric.
138#[derive(Clone, Debug)]
139pub struct NumericAttributes {
140    /// Optional unit (e.g. "%", "ms", "pixels")
141    pub unit: Option<String>,
142    /// Whether larger values are better (true) or smaller are better (false).
143    pub higher_is_better: bool,
144}
145
146impl From<NumericAttributes> for MetricAttributes {
147    fn from(attr: NumericAttributes) -> Self {
148        MetricAttributes::Numeric(attr)
149    }
150}
151
152impl Default for NumericAttributes {
153    fn default() -> Self {
154        Self {
155            unit: None,
156            higher_is_better: true,
157        }
158    }
159}
160
161/// Declare a metric to be numeric.
162///
163/// This is useful to plot the values of a metric during training.
164pub trait Numeric {
165    /// Returns the numeric value of the metric.
166    fn value(&self) -> NumericEntry;
167    /// Returns the current aggregated value of the metric over the global step (epoch).
168    fn running_value(&self) -> NumericEntry;
169}
170
171/// Serialized form of a metric entry.
172#[derive(Debug, Clone, new)]
173pub struct SerializedEntry {
174    /// The string to be displayed.
175    pub formatted: String,
176    /// The string to be saved.
177    pub serialized: String,
178}
179
180/// Data type that contains the current state of a metric at a given time.
181#[derive(Debug, Clone)]
182pub struct MetricEntry {
183    /// Id of the entry's metric.
184    pub metric_id: MetricId,
185    /// The serialized form of the entry.
186    pub serialized_entry: SerializedEntry,
187}
188
189impl MetricEntry {
190    /// Create a new metric.
191    pub fn new(metric_id: MetricId, serialized_entry: SerializedEntry) -> Self {
192        Self {
193            metric_id,
194            serialized_entry,
195        }
196    }
197}
198
199/// Numeric metric entry.
200#[derive(Debug, Clone)]
201pub enum NumericEntry {
202    /// Single numeric value.
203    Value(f64),
204    /// Aggregated numeric (value, number of elements).
205    Aggregated {
206        /// The aggregated value of all entries.
207        aggregated_value: f64,
208        /// The number of entries present in the aggregated value.
209        count: usize,
210    },
211}
212
213impl NumericEntry {
214    /// Gets the current aggregated value of the metric.
215    pub fn current(&self) -> f64 {
216        match self {
217            NumericEntry::Value(val) => *val,
218            NumericEntry::Aggregated {
219                aggregated_value, ..
220            } => *aggregated_value,
221        }
222    }
223
224    /// Returns a String representing the NumericEntry
225    pub fn serialize(&self) -> String {
226        match self {
227            Self::Value(v) => v.to_string(),
228            Self::Aggregated {
229                aggregated_value,
230                count,
231            } => format!("{aggregated_value},{count}"),
232        }
233    }
234
235    /// De-serializes a string representing a NumericEntry and returns a Result containing the corresponding NumericEntry.
236    pub fn deserialize(entry: &str) -> Result<Self, String> {
237        // Check for comma separated values
238        let values = entry.split(',').collect::<Vec<_>>();
239        let num_values = values.len();
240
241        if num_values == 1 {
242            // Numeric value
243            match values[0].parse::<f64>() {
244                Ok(value) => Ok(NumericEntry::Value(value)),
245                Err(err) => Err(err.to_string()),
246            }
247        } else if num_values == 2 {
248            // Aggregated numeric (value, number of elements)
249            let (value, numel) = (values[0], values[1]);
250            match value.parse::<f64>() {
251                Ok(value) => match numel.parse::<usize>() {
252                    Ok(numel) => Ok(NumericEntry::Aggregated {
253                        aggregated_value: value,
254                        count: numel,
255                    }),
256                    Err(err) => Err(err.to_string()),
257                },
258                Err(err) => Err(err.to_string()),
259            }
260        } else {
261            Err("Invalid number of values for numeric entry".to_string())
262        }
263    }
264
265    /// Compare this numeric metric's value with another one using the specified direction.
266    pub fn better_than(&self, other: &NumericEntry, higher_is_better: bool) -> bool {
267        (self.current() > other.current()) == higher_is_better
268    }
269}
270
271/// Format a float with the given precision. Will use scientific notation if necessary.
272pub fn format_float(float: f64, precision: usize) -> String {
273    let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
274
275    match scientific_notation_threshold >= float {
276        true => format!("{float:.precision$e}"),
277        false => format!("{float:.precision$}"),
278    }
279}