burn_train/metric/base.rs
1use std::sync::Arc;
2
3use burn_core::data::dataloader::Progress;
4use burn_optim::LearningRate;
5
6/// Metric metadata that can be used when computing metrics.
7pub struct MetricMetadata {
8 /// The current progress.
9 pub progress: Progress,
10
11 /// The current epoch.
12 pub epoch: usize,
13
14 /// The total number of epochs.
15 pub epoch_total: usize,
16
17 /// The current iteration.
18 pub iteration: usize,
19
20 /// The current learning rate.
21 pub lr: Option<LearningRate>,
22}
23
24impl MetricMetadata {
25 /// Fake metric metadata
26 #[cfg(test)]
27 pub fn fake() -> Self {
28 Self {
29 progress: Progress {
30 items_processed: 1,
31 items_total: 1,
32 },
33 epoch: 0,
34 epoch_total: 1,
35 iteration: 0,
36 lr: None,
37 }
38 }
39}
40
41/// Metric trait.
42///
43/// # Notes
44///
45/// Implementations should define their own input type only used by the metric.
46/// This is important since some conflict may happen when the model output is adapted for each
47/// metric's input type.
48pub trait Metric: Send + Sync + Clone {
49 /// The input type of the metric.
50 type Input;
51
52 /// The parameterized name of the metric.
53 ///
54 /// This should be unique, so avoid using short generic names, prefer using the long name.
55 ///
56 /// For a metric that can exist at different parameters (e.g., top-k accuracy for different
57 /// values of k), the name should be unique for each instance.
58 fn name(&self) -> MetricName;
59
60 /// Update the metric state and returns the current metric entry.
61 fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> MetricEntry;
62 /// Clear the metric state.
63 fn clear(&mut self);
64}
65
66/// Type used to store metric names efficiently.
67pub type MetricName = Arc<String>;
68
69/// Adaptor are used to transform types so that they can be used by metrics.
70///
71/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
72/// registered with the [learner builder](crate::learner::LearnerBuilder) .
73pub trait Adaptor<T> {
74 /// Adapt the type to be passed to a [metric](Metric).
75 fn adapt(&self) -> T;
76}
77
78impl<T> Adaptor<()> for T {
79 fn adapt(&self) {}
80}
81
82/// Declare a metric to be numeric.
83///
84/// This is useful to plot the values of a metric during training.
85pub trait Numeric {
86 /// Returns the numeric value of the metric.
87 fn value(&self) -> NumericEntry;
88}
89
90/// Data type that contains the current state of a metric at a given time.
91#[derive(Debug, Clone)]
92pub struct MetricEntry {
93 /// The name of the metric.
94 pub name: Arc<String>,
95 /// The string to be displayed.
96 pub formatted: String,
97 /// The string to be saved.
98 pub serialize: String,
99 /// Tags linked to the metric.
100 pub tags: Vec<Arc<String>>,
101}
102
103impl MetricEntry {
104 /// Create a new metric.
105 pub fn new(name: Arc<String>, formatted: String, serialize: String) -> Self {
106 Self {
107 name,
108 formatted,
109 serialize,
110 tags: Vec::new(),
111 }
112 }
113}
114
115/// Numeric metric entry.
116#[derive(Debug, Clone)]
117pub enum NumericEntry {
118 /// Single numeric value.
119 Value(f64),
120 /// Aggregated numeric (value, number of elements).
121 Aggregated {
122 /// The sum of all entries.
123 sum: f64,
124 /// The number of entries present in the sum.
125 count: usize,
126 /// The current aggregated value.
127 current: f64,
128 },
129}
130
131impl NumericEntry {
132 /// Gets the current aggregated value of the metric.
133 pub fn current(&self) -> f64 {
134 match self {
135 NumericEntry::Value(val) => *val,
136 NumericEntry::Aggregated { current, .. } => *current,
137 }
138 }
139}
140
141impl NumericEntry {
142 /// Returns a String representing the NumericEntry
143 pub fn serialize(&self) -> String {
144 match self {
145 Self::Value(v) => v.to_string(),
146 Self::Aggregated { sum, count, .. } => format!("{sum},{count}"),
147 }
148 }
149
150 /// De-serializes a string representing a NumericEntry and returns a Result containing the corresponding NumericEntry.
151 pub fn deserialize(entry: &str) -> Result<Self, String> {
152 // Check for comma separated values
153 let values = entry.split(',').collect::<Vec<_>>();
154 let num_values = values.len();
155
156 if num_values == 1 {
157 // Numeric value
158 match values[0].parse::<f64>() {
159 Ok(value) => Ok(NumericEntry::Value(value)),
160 Err(err) => Err(err.to_string()),
161 }
162 } else if num_values == 2 {
163 // Aggregated numeric (value, number of elements)
164 let (value, numel) = (values[0], values[1]);
165 match value.parse::<f64>() {
166 Ok(value) => match numel.parse::<usize>() {
167 Ok(numel) => Ok(NumericEntry::Aggregated {
168 sum: value,
169 count: numel,
170 current: value,
171 }),
172 Err(err) => Err(err.to_string()),
173 },
174 Err(err) => Err(err.to_string()),
175 }
176 } else {
177 Err("Invalid number of values for numeric entry".to_string())
178 }
179 }
180}
181
182/// Format a float with the given precision. Will use scientific notation if necessary.
183pub fn format_float(float: f64, precision: usize) -> String {
184 let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
185
186 match scientific_notation_threshold >= float {
187 true => format!("{float:.precision$e}"),
188 false => format!("{float:.precision$}"),
189 }
190}