burn_train/metric/base.rs
1use std::sync::Arc;
2
3use burn_core::data::dataloader::Progress;
4use burn_optim::LearningRate;
5
6/// Metric metadata that can be used when computing metrics.
7pub struct MetricMetadata {
8 /// The current progress.
9 pub progress: Progress,
10
11 /// The current epoch.
12 pub epoch: usize,
13
14 /// The total number of epochs.
15 pub epoch_total: usize,
16
17 /// The current iteration.
18 pub iteration: usize,
19
20 /// The current learning rate.
21 pub lr: Option<LearningRate>,
22}
23
24impl MetricMetadata {
25 /// Fake metric metadata
26 #[cfg(test)]
27 pub fn fake() -> Self {
28 Self {
29 progress: Progress {
30 items_processed: 1,
31 items_total: 1,
32 },
33 epoch: 0,
34 epoch_total: 1,
35 iteration: 0,
36 lr: None,
37 }
38 }
39}
40
41/// Metric id that can be used to compare metrics and retrieve entries of the same metric.
42/// For now we take the name as id to make sure that the same metric has the same id across different runs.
43#[derive(Debug, Clone, new, PartialEq, Eq, Hash)]
44pub struct MetricId {
45 /// The metric id.
46 id: Arc<String>,
47}
48
49/// Metric attributes define the properties intrinsic to different types of metric.
50#[derive(Clone, Debug)]
51pub enum MetricAttributes {
52 /// Numeric attributes.
53 Numeric(NumericAttributes),
54 /// No attributes.
55 None,
56}
57
58/// Definition of a metric.
59#[derive(Clone, Debug)]
60pub struct MetricDefinition {
61 /// The metric's id.
62 pub metric_id: MetricId,
63 /// The name of the metric.
64 pub name: String,
65 /// The description of the metric.
66 pub description: Option<String>,
67 /// The attributes of the metric.
68 pub attributes: MetricAttributes,
69}
70
71impl MetricDefinition {
72 /// Create a new metric definition given the metric and a unique id.
73 pub fn new<Me: Metric>(metric_id: MetricId, metric: &Me) -> Self {
74 Self {
75 metric_id,
76 name: metric.name().to_string(),
77 description: metric.description(),
78 attributes: metric.attributes(),
79 }
80 }
81}
82
83/// Metric trait.
84///
85/// # Notes
86///
87/// Implementations should define their own input type only used by the metric.
88/// This is important since some conflict may happen when the model output is adapted for each
89/// metric's input type.
90pub trait Metric: Send + Sync + Clone {
91 /// The input type of the metric.
92 type Input;
93
94 /// The parameterized name of the metric.
95 ///
96 /// This should be unique, so avoid using short generic names, prefer using the long name.
97 ///
98 /// For a metric that can exist at different parameters (e.g., top-k accuracy for different
99 /// values of k), the name should be unique for each instance.
100 fn name(&self) -> MetricName;
101
102 /// A short description of the metric.
103 fn description(&self) -> Option<String> {
104 None
105 }
106
107 /// Attributes of the metric.
108 ///
109 /// By default, metrics have no attributes.
110 fn attributes(&self) -> MetricAttributes {
111 MetricAttributes::None
112 }
113
114 /// Update the metric state and returns the current metric entry.
115 fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry;
116
117 /// Clear the metric state.
118 fn clear(&mut self);
119}
120
121/// Type used to store metric names efficiently.
122pub type MetricName = Arc<String>;
123
124/// Adaptor are used to transform types so that they can be used by metrics.
125///
126/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
127/// registered with the specific learning paradigm (i.e. [SupervisedTraining](crate::SupervisedTraining)).
128pub trait Adaptor<T> {
129 /// Adapt the type to be passed to a [metric](Metric).
130 fn adapt(&self) -> T;
131}
132
133impl<T> Adaptor<()> for T {
134 fn adapt(&self) {}
135}
136
137/// Attributes that describe intrinsic properties of a numeric metric.
138#[derive(Clone, Debug)]
139pub struct NumericAttributes {
140 /// Optional unit (e.g. "%", "ms", "pixels")
141 pub unit: Option<String>,
142 /// Whether larger values are better (true) or smaller are better (false).
143 pub higher_is_better: bool,
144}
145
146impl From<NumericAttributes> for MetricAttributes {
147 fn from(attr: NumericAttributes) -> Self {
148 MetricAttributes::Numeric(attr)
149 }
150}
151
152impl Default for NumericAttributes {
153 fn default() -> Self {
154 Self {
155 unit: None,
156 higher_is_better: true,
157 }
158 }
159}
160
161/// Declare a metric to be numeric.
162///
163/// This is useful to plot the values of a metric during training.
164pub trait Numeric {
165 /// Returns the numeric value of the metric.
166 fn value(&self) -> NumericEntry;
167 /// Returns the current aggregated value of the metric over the global step (epoch).
168 fn running_value(&self) -> NumericEntry;
169}
170
171/// Serialized form of a metric entry.
172#[derive(Debug, Clone, new)]
173pub struct SerializedEntry {
174 /// The string to be displayed.
175 pub formatted: String,
176 /// The string to be saved.
177 pub serialized: String,
178}
179
180/// Data type that contains the current state of a metric at a given time.
181#[derive(Debug, Clone)]
182pub struct MetricEntry {
183 /// Id of the entry's metric.
184 pub metric_id: MetricId,
185 /// The serialized form of the entry.
186 pub serialized_entry: SerializedEntry,
187}
188
189impl MetricEntry {
190 /// Create a new metric.
191 pub fn new(metric_id: MetricId, serialized_entry: SerializedEntry) -> Self {
192 Self {
193 metric_id,
194 serialized_entry,
195 }
196 }
197}
198
199/// Numeric metric entry.
200#[derive(Debug, Clone)]
201pub enum NumericEntry {
202 /// Single numeric value.
203 Value(f64),
204 /// Aggregated numeric (value, number of elements).
205 Aggregated {
206 /// The aggregated value of all entries.
207 aggregated_value: f64,
208 /// The number of entries present in the aggregated value.
209 count: usize,
210 },
211}
212
213impl NumericEntry {
214 /// Gets the current aggregated value of the metric.
215 pub fn current(&self) -> f64 {
216 match self {
217 NumericEntry::Value(val) => *val,
218 NumericEntry::Aggregated {
219 aggregated_value, ..
220 } => *aggregated_value,
221 }
222 }
223
224 /// Returns a String representing the NumericEntry
225 pub fn serialize(&self) -> String {
226 match self {
227 Self::Value(v) => v.to_string(),
228 Self::Aggregated {
229 aggregated_value,
230 count,
231 } => format!("{aggregated_value},{count}"),
232 }
233 }
234
235 /// De-serializes a string representing a NumericEntry and returns a Result containing the corresponding NumericEntry.
236 pub fn deserialize(entry: &str) -> Result<Self, String> {
237 // Check for comma separated values
238 let values = entry.split(',').collect::<Vec<_>>();
239 let num_values = values.len();
240
241 if num_values == 1 {
242 // Numeric value
243 match values[0].parse::<f64>() {
244 Ok(value) => Ok(NumericEntry::Value(value)),
245 Err(err) => Err(err.to_string()),
246 }
247 } else if num_values == 2 {
248 // Aggregated numeric (value, number of elements)
249 let (value, numel) = (values[0], values[1]);
250 match value.parse::<f64>() {
251 Ok(value) => match numel.parse::<usize>() {
252 Ok(numel) => Ok(NumericEntry::Aggregated {
253 aggregated_value: value,
254 count: numel,
255 }),
256 Err(err) => Err(err.to_string()),
257 },
258 Err(err) => Err(err.to_string()),
259 }
260 } else {
261 Err("Invalid number of values for numeric entry".to_string())
262 }
263 }
264
265 /// Compare this numeric metric's value with another one using the specified direction.
266 pub fn better_than(&self, other: &NumericEntry, higher_is_better: bool) -> bool {
267 (self.current() > other.current()) == higher_is_better
268 }
269}
270
271/// Format a float with the given precision. Will use scientific notation if necessary.
272pub fn format_float(float: f64, precision: usize) -> String {
273 let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
274
275 match scientific_notation_threshold >= float {
276 true => format!("{float:.precision$e}"),
277 false => format!("{float:.precision$}"),
278 }
279}