burn_train/metric/
base.rs

1use std::sync::Arc;
2
3use burn_core::data::dataloader::Progress;
4use burn_optim::LearningRate;
5
6/// Metric metadata that can be used when computing metrics.
7pub struct MetricMetadata {
8    /// The current progress.
9    pub progress: Progress,
10
11    /// The current epoch.
12    pub epoch: usize,
13
14    /// The total number of epochs.
15    pub epoch_total: usize,
16
17    /// The current iteration.
18    pub iteration: usize,
19
20    /// The current learning rate.
21    pub lr: Option<LearningRate>,
22}
23
24impl MetricMetadata {
25    /// Fake metric metadata
26    #[cfg(test)]
27    pub fn fake() -> Self {
28        Self {
29            progress: Progress {
30                items_processed: 1,
31                items_total: 1,
32            },
33            epoch: 0,
34            epoch_total: 1,
35            iteration: 0,
36            lr: None,
37        }
38    }
39}
40
41/// Metric id that can be used to compare metrics and retrieve entries of the same metric.
42/// For now we take the name as id to make sure that the same metric has the same id across different runs.
43#[derive(Debug, Clone, new, PartialEq, Eq, Hash)]
44pub struct MetricId {
45    /// The metric id.
46    id: Arc<String>,
47}
48
49/// Metric attributes define the properties intrinsic to different types of metric.
50#[derive(Clone, Debug)]
51pub enum MetricAttributes {
52    /// Numeric attributes.
53    Numeric(NumericAttributes),
54    /// No attributes.
55    None,
56}
57
58/// Definition of a metric.
59///
60/// This is used to register a metric with the [learner builder](crate::learner::LearnerBuilder).
61#[derive(Clone, Debug)]
62pub struct MetricDefinition {
63    /// The metric's id.
64    pub metric_id: MetricId,
65    /// The name of the metric.
66    pub name: String,
67    /// The description of the metric.
68    pub description: Option<String>,
69    /// The attributes of the metric.
70    pub attributes: MetricAttributes,
71}
72
73impl MetricDefinition {
74    /// Create a new metric definition given the metric and a unique id.
75    pub fn new<Me: Metric>(metric_id: MetricId, metric: &Me) -> Self {
76        Self {
77            metric_id,
78            name: metric.name().to_string(),
79            description: metric.description(),
80            attributes: metric.attributes(),
81        }
82    }
83}
84
85/// Metric trait.
86///
87/// # Notes
88///
89/// Implementations should define their own input type only used by the metric.
90/// This is important since some conflict may happen when the model output is adapted for each
91/// metric's input type.
92pub trait Metric: Send + Sync + Clone {
93    /// The input type of the metric.
94    type Input;
95
96    /// The parameterized name of the metric.
97    ///
98    /// This should be unique, so avoid using short generic names, prefer using the long name.
99    ///
100    /// For a metric that can exist at different parameters (e.g., top-k accuracy for different
101    /// values of k), the name should be unique for each instance.
102    fn name(&self) -> MetricName;
103
104    /// A short description of the metric.
105    fn description(&self) -> Option<String> {
106        None
107    }
108
109    /// Attributes of the metric.
110    ///
111    /// By default, metrics have no attributes.
112    fn attributes(&self) -> MetricAttributes {
113        MetricAttributes::None
114    }
115
116    /// Update the metric state and returns the current metric entry.
117    fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry;
118
119    /// Clear the metric state.
120    fn clear(&mut self);
121}
122
123/// Type used to store metric names efficiently.
124pub type MetricName = Arc<String>;
125
126/// Adaptor are used to transform types so that they can be used by metrics.
127///
128/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
129/// registered with the [learner builder](crate::learner::LearnerBuilder) .
130pub trait Adaptor<T> {
131    /// Adapt the type to be passed to a [metric](Metric).
132    fn adapt(&self) -> T;
133}
134
135impl<T> Adaptor<()> for T {
136    fn adapt(&self) {}
137}
138
139/// Attributes that describe intrinsic properties of a numeric metric.
140#[derive(Clone, Debug)]
141pub struct NumericAttributes {
142    /// Optional unit (e.g. "%", "ms", "pixels")
143    pub unit: Option<String>,
144    /// Whether larger values are better (true) or smaller are better (false).
145    pub higher_is_better: bool,
146}
147
148impl From<NumericAttributes> for MetricAttributes {
149    fn from(attr: NumericAttributes) -> Self {
150        MetricAttributes::Numeric(attr)
151    }
152}
153
154impl Default for NumericAttributes {
155    fn default() -> Self {
156        Self {
157            unit: None,
158            higher_is_better: true,
159        }
160    }
161}
162
163/// Declare a metric to be numeric.
164///
165/// This is useful to plot the values of a metric during training.
166pub trait Numeric {
167    /// Returns the numeric value of the metric.
168    fn value(&self) -> NumericEntry;
169}
170
171/// Serialized form of a metric entry.
172#[derive(Debug, Clone, new)]
173pub struct SerializedEntry {
174    /// The string to be displayed.
175    pub formatted: String,
176    /// The string to be saved.
177    pub serialized: String,
178}
179
180/// Data type that contains the current state of a metric at a given time.
181#[derive(Debug, Clone)]
182pub struct MetricEntry {
183    /// Id of the entry's metric.
184    pub metric_id: MetricId,
185    /// The serialized form of the entry.
186    pub serialized_entry: SerializedEntry,
187}
188
189impl MetricEntry {
190    /// Create a new metric.
191    pub fn new(metric_id: MetricId, serialized_entry: SerializedEntry) -> Self {
192        Self {
193            metric_id,
194            serialized_entry,
195        }
196    }
197}
198
199/// Numeric metric entry.
200#[derive(Debug, Clone)]
201pub enum NumericEntry {
202    /// Single numeric value.
203    Value(f64),
204    /// Aggregated numeric (value, number of elements).
205    Aggregated {
206        /// The aggregated value of all entries.
207        aggregated_value: f64,
208        /// The number of entries present in the aggregated value.
209        count: usize,
210    },
211}
212
213impl NumericEntry {
214    /// Gets the current aggregated value of the metric.
215    pub fn current(&self) -> f64 {
216        match self {
217            NumericEntry::Value(val) => *val,
218            NumericEntry::Aggregated {
219                aggregated_value, ..
220            } => *aggregated_value,
221        }
222    }
223
224    /// Returns a String representing the NumericEntry
225    pub fn serialize(&self) -> String {
226        match self {
227            Self::Value(v) => v.to_string(),
228            Self::Aggregated {
229                aggregated_value,
230                count,
231            } => format!("{aggregated_value},{count}"),
232        }
233    }
234
235    /// De-serializes a string representing a NumericEntry and returns a Result containing the corresponding NumericEntry.
236    pub fn deserialize(entry: &str) -> Result<Self, String> {
237        // Check for comma separated values
238        let values = entry.split(',').collect::<Vec<_>>();
239        let num_values = values.len();
240
241        if num_values == 1 {
242            // Numeric value
243            match values[0].parse::<f64>() {
244                Ok(value) => Ok(NumericEntry::Value(value)),
245                Err(err) => Err(err.to_string()),
246            }
247        } else if num_values == 2 {
248            // Aggregated numeric (value, number of elements)
249            let (value, numel) = (values[0], values[1]);
250            match value.parse::<f64>() {
251                Ok(value) => match numel.parse::<usize>() {
252                    Ok(numel) => Ok(NumericEntry::Aggregated {
253                        aggregated_value: value,
254                        count: numel,
255                    }),
256                    Err(err) => Err(err.to_string()),
257                },
258                Err(err) => Err(err.to_string()),
259            }
260        } else {
261            Err("Invalid number of values for numeric entry".to_string())
262        }
263    }
264
265    /// Compare this numeric metric's value with another one using the specified direction.
266    pub fn better_than(&self, other: &NumericEntry, higher_is_better: bool) -> bool {
267        (self.current() > other.current()) == higher_is_better
268    }
269}
270
271/// Format a float with the given precision. Will use scientific notation if necessary.
272pub fn format_float(float: f64, precision: usize) -> String {
273    let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
274
275    match scientific_notation_threshold >= float {
276        true => format!("{float:.precision$e}"),
277        false => format!("{float:.precision$}"),
278    }
279}