burn_train/metric/base.rs
1use burn_core::{LearningRate, data::dataloader::Progress};
2
3/// Metric metadata that can be used when computing metrics.
4pub struct MetricMetadata {
5 /// The current progress.
6 pub progress: Progress,
7
8 /// The current epoch.
9 pub epoch: usize,
10
11 /// The total number of epochs.
12 pub epoch_total: usize,
13
14 /// The current iteration.
15 pub iteration: usize,
16
17 /// The current learning rate.
18 pub lr: Option<LearningRate>,
19}
20
21impl MetricMetadata {
22 /// Fake metric metadata
23 #[cfg(test)]
24 pub fn fake() -> Self {
25 Self {
26 progress: Progress {
27 items_processed: 1,
28 items_total: 1,
29 },
30 epoch: 0,
31 epoch_total: 1,
32 iteration: 0,
33 lr: None,
34 }
35 }
36}
37
38/// Metric trait.
39///
40/// # Notes
41///
42/// Implementations should define their own input type only used by the metric.
43/// This is important since some conflict may happen when the model output is adapted for each
44/// metric's input type.
45pub trait Metric: Send + Sync {
46 /// The input type of the metric.
47 type Input;
48
49 /// The parameterized name of the metric.
50 ///
51 /// This should be unique, so avoid using short generic names, prefer using the long name.
52 ///
53 /// For a metric that can exist at different parameters (e.g., top-k accuracy for different
54 /// values of k), the name should be unique for each instance.
55 fn name(&self) -> String;
56
57 /// Update the metric state and returns the current metric entry.
58 fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> MetricEntry;
59 /// Clear the metric state.
60 fn clear(&mut self);
61}
62
63/// Adaptor are used to transform types so that they can be used by metrics.
64///
65/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
66/// registered with the [learner builder](crate::learner::LearnerBuilder) .
67pub trait Adaptor<T> {
68 /// Adapt the type to be passed to a [metric](Metric).
69 fn adapt(&self) -> T;
70}
71
72impl<T> Adaptor<()> for T {
73 fn adapt(&self) {}
74}
75
76/// Declare a metric to be numeric.
77///
78/// This is useful to plot the values of a metric during training.
79pub trait Numeric {
80 /// Returns the numeric value of the metric.
81 fn value(&self) -> f64;
82}
83
84/// Data type that contains the current state of a metric at a given time.
85#[derive(new, Debug, Clone)]
86pub struct MetricEntry {
87 /// The name of the metric.
88 pub name: String,
89 /// The string to be displayed.
90 pub formatted: String,
91 /// The string to be saved.
92 pub serialize: String,
93}
94
95/// Numeric metric entry.
96pub enum NumericEntry {
97 /// Single numeric value.
98 Value(f64),
99 /// Aggregated numeric (value, number of elements).
100 Aggregated(f64, usize),
101}
102
103impl NumericEntry {
104 pub(crate) fn serialize(&self) -> String {
105 match self {
106 Self::Value(v) => v.to_string(),
107 Self::Aggregated(v, n) => format!("{v},{n}"),
108 }
109 }
110
111 pub(crate) fn deserialize(entry: &str) -> Result<Self, String> {
112 // Check for comma separated values
113 let values = entry.split(',').collect::<Vec<_>>();
114 let num_values = values.len();
115
116 if num_values == 1 {
117 // Numeric value
118 match values[0].parse::<f64>() {
119 Ok(value) => Ok(NumericEntry::Value(value)),
120 Err(err) => Err(err.to_string()),
121 }
122 } else if num_values == 2 {
123 // Aggregated numeric (value, number of elements)
124 let (value, numel) = (values[0], values[1]);
125 match value.parse::<f64>() {
126 Ok(value) => match numel.parse::<usize>() {
127 Ok(numel) => Ok(NumericEntry::Aggregated(value, numel)),
128 Err(err) => Err(err.to_string()),
129 },
130 Err(err) => Err(err.to_string()),
131 }
132 } else {
133 Err("Invalid number of values for numeric entry".to_string())
134 }
135 }
136}
137
138/// Format a float with the given precision. Will use scientific notation if necessary.
139pub fn format_float(float: f64, precision: usize) -> String {
140 let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
141
142 match scientific_notation_threshold >= float {
143 true => format!("{float:.precision$e}"),
144 false => format!("{float:.precision$}"),
145 }
146}