Skip to main content

burn_train/metric/
base.rs

1use std::sync::Arc;
2
3use burn_core::data::dataloader::Progress;
4use burn_optim::LearningRate;
5
6/// Metric metadata that can be used when computing metrics.
7pub struct MetricMetadata {
8    /// The current progress.
9    pub progress: Progress,
10
11    /// The global progress of the training (e.g. epochs).
12    pub global_progress: Progress,
13
14    /// The current iteration.
15    pub iteration: Option<usize>,
16
17    /// The current learning rate.
18    pub lr: Option<LearningRate>,
19}
20
21impl MetricMetadata {
22    /// Fake metric metadata
23    #[cfg(test)]
24    pub fn fake() -> Self {
25        Self {
26            progress: Progress {
27                items_processed: 1,
28                items_total: 1,
29            },
30            global_progress: Progress {
31                items_processed: 0,
32                items_total: 1,
33            },
34            iteration: Some(0),
35            lr: None,
36        }
37    }
38}
39
40/// Metric id that can be used to compare metrics and retrieve entries of the same metric.
41/// For now we take the name as id to make sure that the same metric has the same id across different runs.
42#[derive(Debug, Clone, new, PartialEq, Eq, Hash)]
43pub struct MetricId {
44    /// The metric id.
45    id: Arc<String>,
46}
47
48/// Metric attributes define the properties intrinsic to different types of metric.
49#[derive(Clone, Debug)]
50pub enum MetricAttributes {
51    /// Numeric attributes.
52    Numeric(NumericAttributes),
53    /// No attributes.
54    None,
55}
56
57/// Definition of a metric.
58#[derive(Clone, Debug)]
59pub struct MetricDefinition {
60    /// The metric's id.
61    pub metric_id: MetricId,
62    /// The name of the metric.
63    pub name: String,
64    /// The description of the metric.
65    pub description: Option<String>,
66    /// The attributes of the metric.
67    pub attributes: MetricAttributes,
68}
69
70impl MetricDefinition {
71    /// Create a new metric definition given the metric and a unique id.
72    pub fn new<Me: Metric>(metric_id: MetricId, metric: &Me) -> Self {
73        Self {
74            metric_id,
75            name: metric.name().to_string(),
76            description: metric.description(),
77            attributes: metric.attributes(),
78        }
79    }
80}
81
82/// Metric trait.
83///
84/// # Notes
85///
86/// Implementations should define their own input type only used by the metric.
87/// This is important since some conflict may happen when the model output is adapted for each
88/// metric's input type.
89pub trait Metric: Send + Sync + Clone {
90    /// The input type of the metric.
91    type Input;
92
93    /// The parameterized name of the metric.
94    ///
95    /// This should be unique, so avoid using short generic names, prefer using the long name.
96    ///
97    /// For a metric that can exist at different parameters (e.g., top-k accuracy for different
98    /// values of k), the name should be unique for each instance.
99    fn name(&self) -> MetricName;
100
101    /// A short description of the metric.
102    fn description(&self) -> Option<String> {
103        None
104    }
105
106    /// Attributes of the metric.
107    ///
108    /// By default, metrics have no attributes.
109    fn attributes(&self) -> MetricAttributes {
110        MetricAttributes::None
111    }
112
113    /// Update the metric state and returns the current metric entry.
114    fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry;
115
116    /// Clear the metric state.
117    fn clear(&mut self);
118}
119
120/// Type used to store metric names efficiently.
121pub type MetricName = Arc<String>;
122
123/// Adaptor are used to transform types so that they can be used by metrics.
124///
125/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
126/// registered with the specific learning paradigm (i.e. [SupervisedTraining](crate::SupervisedTraining)).
127pub trait Adaptor<T> {
128    /// Adapt the type to be passed to a [metric](Metric).
129    fn adapt(&self) -> T;
130}
131
132impl<T> Adaptor<()> for T {
133    fn adapt(&self) {}
134}
135
136/// Attributes that describe intrinsic properties of a numeric metric.
137#[derive(Clone, Debug)]
138pub struct NumericAttributes {
139    /// Optional unit (e.g. "%", "ms", "pixels")
140    pub unit: Option<String>,
141    /// Whether larger values are better (true) or smaller are better (false).
142    pub higher_is_better: bool,
143}
144
145impl From<NumericAttributes> for MetricAttributes {
146    fn from(attr: NumericAttributes) -> Self {
147        MetricAttributes::Numeric(attr)
148    }
149}
150
151impl Default for NumericAttributes {
152    fn default() -> Self {
153        Self {
154            unit: None,
155            higher_is_better: true,
156        }
157    }
158}
159
160/// Declare a metric to be numeric.
161///
162/// This is useful to plot the values of a metric during training.
163pub trait Numeric {
164    /// Returns the numeric value of the metric.
165    fn value(&self) -> NumericEntry;
166    /// Returns the current aggregated value of the metric over the global step (epoch).
167    fn running_value(&self) -> NumericEntry;
168}
169
170/// Serialized form of a metric entry.
171#[derive(Debug, Clone, new)]
172pub struct SerializedEntry {
173    /// The string to be displayed.
174    pub formatted: String,
175    /// The string to be saved.
176    pub serialized: String,
177}
178
179/// Data type that contains the current state of a metric at a given time.
180#[derive(Debug, Clone)]
181pub struct MetricEntry {
182    /// Id of the entry's metric.
183    pub metric_id: MetricId,
184    /// The serialized form of the entry.
185    pub serialized_entry: SerializedEntry,
186}
187
188impl MetricEntry {
189    /// Create a new metric.
190    pub fn new(metric_id: MetricId, serialized_entry: SerializedEntry) -> Self {
191        Self {
192            metric_id,
193            serialized_entry,
194        }
195    }
196}
197
198/// Numeric metric entry.
199#[derive(Debug, Clone)]
200pub enum NumericEntry {
201    /// Single numeric value.
202    Value(f64),
203    /// Aggregated numeric (value, number of elements).
204    Aggregated {
205        /// The aggregated value of all entries.
206        aggregated_value: f64,
207        /// The number of entries present in the aggregated value.
208        count: usize,
209    },
210}
211
212impl NumericEntry {
213    /// Gets the current aggregated value of the metric.
214    pub fn current(&self) -> f64 {
215        match self {
216            NumericEntry::Value(val) => *val,
217            NumericEntry::Aggregated {
218                aggregated_value, ..
219            } => *aggregated_value,
220        }
221    }
222
223    /// Returns a String representing the NumericEntry
224    pub fn serialize(&self) -> String {
225        match self {
226            Self::Value(v) => v.to_string(),
227            Self::Aggregated {
228                aggregated_value,
229                count,
230            } => format!("{aggregated_value},{count}"),
231        }
232    }
233
234    /// De-serializes a string representing a NumericEntry and returns a Result containing the corresponding NumericEntry.
235    pub fn deserialize(entry: &str) -> Result<Self, String> {
236        // Check for comma separated values
237        let values = entry.split(',').collect::<Vec<_>>();
238        let num_values = values.len();
239
240        if num_values == 1 {
241            // Numeric value
242            match values[0].parse::<f64>() {
243                Ok(value) => Ok(NumericEntry::Value(value)),
244                Err(err) => Err(err.to_string()),
245            }
246        } else if num_values == 2 {
247            // Aggregated numeric (value, number of elements)
248            let (value, numel) = (values[0], values[1]);
249            match value.parse::<f64>() {
250                Ok(value) => match numel.parse::<usize>() {
251                    Ok(numel) => Ok(NumericEntry::Aggregated {
252                        aggregated_value: value,
253                        count: numel,
254                    }),
255                    Err(err) => Err(err.to_string()),
256                },
257                Err(err) => Err(err.to_string()),
258            }
259        } else {
260            Err("Invalid number of values for numeric entry".to_string())
261        }
262    }
263
264    /// Compare this numeric metric's value with another one using the specified direction.
265    pub fn better_than(&self, other: &NumericEntry, higher_is_better: bool) -> bool {
266        (self.current() > other.current()) == higher_is_better
267    }
268}
269
270/// Format a float with the given precision. Will use scientific notation if necessary.
271pub fn format_float(float: f64, precision: usize) -> String {
272    let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
273
274    match scientific_notation_threshold >= float {
275        true => format!("{float:.precision$e}"),
276        false => format!("{float:.precision$}"),
277    }
278}