burn_train/metric/base.rs
1use std::sync::Arc;
2
3use burn_core::data::dataloader::Progress;
4use burn_optim::LearningRate;
5
6/// Metric metadata that can be used when computing metrics.
7pub struct MetricMetadata {
8 /// The current progress.
9 pub progress: Progress,
10
11 /// The global progress of the training (e.g. epochs).
12 pub global_progress: Progress,
13
14 /// The current iteration.
15 pub iteration: Option<usize>,
16
17 /// The current learning rate.
18 pub lr: Option<LearningRate>,
19}
20
21impl MetricMetadata {
22 /// Fake metric metadata
23 #[cfg(test)]
24 pub fn fake() -> Self {
25 Self {
26 progress: Progress {
27 items_processed: 1,
28 items_total: 1,
29 },
30 global_progress: Progress {
31 items_processed: 0,
32 items_total: 1,
33 },
34 iteration: Some(0),
35 lr: None,
36 }
37 }
38}
39
40/// Metric id that can be used to compare metrics and retrieve entries of the same metric.
41/// For now we take the name as id to make sure that the same metric has the same id across different runs.
42#[derive(Debug, Clone, new, PartialEq, Eq, Hash)]
43pub struct MetricId {
44 /// The metric id.
45 id: Arc<String>,
46}
47
48/// Metric attributes define the properties intrinsic to different types of metric.
49#[derive(Clone, Debug)]
50pub enum MetricAttributes {
51 /// Numeric attributes.
52 Numeric(NumericAttributes),
53 /// No attributes.
54 None,
55}
56
57/// Definition of a metric.
58#[derive(Clone, Debug)]
59pub struct MetricDefinition {
60 /// The metric's id.
61 pub metric_id: MetricId,
62 /// The name of the metric.
63 pub name: String,
64 /// The description of the metric.
65 pub description: Option<String>,
66 /// The attributes of the metric.
67 pub attributes: MetricAttributes,
68}
69
70impl MetricDefinition {
71 /// Create a new metric definition given the metric and a unique id.
72 pub fn new<Me: Metric>(metric_id: MetricId, metric: &Me) -> Self {
73 Self {
74 metric_id,
75 name: metric.name().to_string(),
76 description: metric.description(),
77 attributes: metric.attributes(),
78 }
79 }
80}
81
82/// Metric trait.
83///
84/// # Notes
85///
86/// Implementations should define their own input type only used by the metric.
87/// This is important since some conflict may happen when the model output is adapted for each
88/// metric's input type.
89pub trait Metric: Send + Sync + Clone {
90 /// The input type of the metric.
91 type Input;
92
93 /// The parameterized name of the metric.
94 ///
95 /// This should be unique, so avoid using short generic names, prefer using the long name.
96 ///
97 /// For a metric that can exist at different parameters (e.g., top-k accuracy for different
98 /// values of k), the name should be unique for each instance.
99 fn name(&self) -> MetricName;
100
101 /// A short description of the metric.
102 fn description(&self) -> Option<String> {
103 None
104 }
105
106 /// Attributes of the metric.
107 ///
108 /// By default, metrics have no attributes.
109 fn attributes(&self) -> MetricAttributes {
110 MetricAttributes::None
111 }
112
113 /// Update the metric state and returns the current metric entry.
114 fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry;
115
116 /// Clear the metric state.
117 fn clear(&mut self);
118}
119
120/// Type used to store metric names efficiently.
121pub type MetricName = Arc<String>;
122
123/// Adaptor are used to transform types so that they can be used by metrics.
124///
125/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
126/// registered with the specific learning paradigm (i.e. [SupervisedTraining](crate::SupervisedTraining)).
127pub trait Adaptor<T> {
128 /// Adapt the type to be passed to a [metric](Metric).
129 fn adapt(&self) -> T;
130}
131
132impl<T> Adaptor<()> for T {
133 fn adapt(&self) {}
134}
135
136/// Attributes that describe intrinsic properties of a numeric metric.
137#[derive(Clone, Debug)]
138pub struct NumericAttributes {
139 /// Optional unit (e.g. "%", "ms", "pixels")
140 pub unit: Option<String>,
141 /// Whether larger values are better (true) or smaller are better (false).
142 pub higher_is_better: bool,
143}
144
145impl From<NumericAttributes> for MetricAttributes {
146 fn from(attr: NumericAttributes) -> Self {
147 MetricAttributes::Numeric(attr)
148 }
149}
150
151impl Default for NumericAttributes {
152 fn default() -> Self {
153 Self {
154 unit: None,
155 higher_is_better: true,
156 }
157 }
158}
159
160/// Declare a metric to be numeric.
161///
162/// This is useful to plot the values of a metric during training.
163pub trait Numeric {
164 /// Returns the numeric value of the metric.
165 fn value(&self) -> NumericEntry;
166 /// Returns the current aggregated value of the metric over the global step (epoch).
167 fn running_value(&self) -> NumericEntry;
168}
169
170/// Serialized form of a metric entry.
171#[derive(Debug, Clone, new)]
172pub struct SerializedEntry {
173 /// The string to be displayed.
174 pub formatted: String,
175 /// The string to be saved.
176 pub serialized: String,
177}
178
179/// Data type that contains the current state of a metric at a given time.
180#[derive(Debug, Clone)]
181pub struct MetricEntry {
182 /// Id of the entry's metric.
183 pub metric_id: MetricId,
184 /// The serialized form of the entry.
185 pub serialized_entry: SerializedEntry,
186}
187
188impl MetricEntry {
189 /// Create a new metric.
190 pub fn new(metric_id: MetricId, serialized_entry: SerializedEntry) -> Self {
191 Self {
192 metric_id,
193 serialized_entry,
194 }
195 }
196}
197
198/// Numeric metric entry.
199#[derive(Debug, Clone)]
200pub enum NumericEntry {
201 /// Single numeric value.
202 Value(f64),
203 /// Aggregated numeric (value, number of elements).
204 Aggregated {
205 /// The aggregated value of all entries.
206 aggregated_value: f64,
207 /// The number of entries present in the aggregated value.
208 count: usize,
209 },
210}
211
212impl NumericEntry {
213 /// Gets the current aggregated value of the metric.
214 pub fn current(&self) -> f64 {
215 match self {
216 NumericEntry::Value(val) => *val,
217 NumericEntry::Aggregated {
218 aggregated_value, ..
219 } => *aggregated_value,
220 }
221 }
222
223 /// Returns a String representing the NumericEntry
224 pub fn serialize(&self) -> String {
225 match self {
226 Self::Value(v) => v.to_string(),
227 Self::Aggregated {
228 aggregated_value,
229 count,
230 } => format!("{aggregated_value},{count}"),
231 }
232 }
233
234 /// De-serializes a string representing a NumericEntry and returns a Result containing the corresponding NumericEntry.
235 pub fn deserialize(entry: &str) -> Result<Self, String> {
236 // Check for comma separated values
237 let values = entry.split(',').collect::<Vec<_>>();
238 let num_values = values.len();
239
240 if num_values == 1 {
241 // Numeric value
242 match values[0].parse::<f64>() {
243 Ok(value) => Ok(NumericEntry::Value(value)),
244 Err(err) => Err(err.to_string()),
245 }
246 } else if num_values == 2 {
247 // Aggregated numeric (value, number of elements)
248 let (value, numel) = (values[0], values[1]);
249 match value.parse::<f64>() {
250 Ok(value) => match numel.parse::<usize>() {
251 Ok(numel) => Ok(NumericEntry::Aggregated {
252 aggregated_value: value,
253 count: numel,
254 }),
255 Err(err) => Err(err.to_string()),
256 },
257 Err(err) => Err(err.to_string()),
258 }
259 } else {
260 Err("Invalid number of values for numeric entry".to_string())
261 }
262 }
263
264 /// Compare this numeric metric's value with another one using the specified direction.
265 pub fn better_than(&self, other: &NumericEntry, higher_is_better: bool) -> bool {
266 (self.current() > other.current()) == higher_is_better
267 }
268}
269
270/// Format a float with the given precision. Will use scientific notation if necessary.
271pub fn format_float(float: f64, precision: usize) -> String {
272 let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
273
274 match scientific_notation_threshold >= float {
275 true => format!("{float:.precision$e}"),
276 false => format!("{float:.precision$}"),
277 }
278}