burn_train/metric/base.rs
1use std::sync::Arc;
2
3use burn_core::data::dataloader::Progress;
4use burn_optim::LearningRate;
5
6/// Metric metadata that can be used when computing metrics.
7pub struct MetricMetadata {
8 /// The current progress.
9 pub progress: Progress,
10
11 /// The current epoch.
12 pub epoch: usize,
13
14 /// The total number of epochs.
15 pub epoch_total: usize,
16
17 /// The current iteration.
18 pub iteration: usize,
19
20 /// The current learning rate.
21 pub lr: Option<LearningRate>,
22}
23
24impl MetricMetadata {
25 /// Fake metric metadata
26 #[cfg(test)]
27 pub fn fake() -> Self {
28 Self {
29 progress: Progress {
30 items_processed: 1,
31 items_total: 1,
32 },
33 epoch: 0,
34 epoch_total: 1,
35 iteration: 0,
36 lr: None,
37 }
38 }
39}
40
41/// Metric id that can be used to compare metrics and retrieve entries of the same metric.
42/// For now we take the name as id to make sure that the same metric has the same id across different runs.
43#[derive(Debug, Clone, new, PartialEq, Eq, Hash)]
44pub struct MetricId {
45 /// The metric id.
46 id: Arc<String>,
47}
48
49/// Metric attributes define the properties intrinsic to different types of metric.
50#[derive(Clone, Debug)]
51pub enum MetricAttributes {
52 /// Numeric attributes.
53 Numeric(NumericAttributes),
54 /// No attributes.
55 None,
56}
57
58/// Definition of a metric.
59///
60/// This is used to register a metric with the [learner builder](crate::learner::LearnerBuilder).
61#[derive(Clone, Debug)]
62pub struct MetricDefinition {
63 /// The metric's id.
64 pub metric_id: MetricId,
65 /// The name of the metric.
66 pub name: String,
67 /// The description of the metric.
68 pub description: Option<String>,
69 /// The attributes of the metric.
70 pub attributes: MetricAttributes,
71}
72
73impl MetricDefinition {
74 /// Create a new metric definition given the metric and a unique id.
75 pub fn new<Me: Metric>(metric_id: MetricId, metric: &Me) -> Self {
76 Self {
77 metric_id,
78 name: metric.name().to_string(),
79 description: metric.description(),
80 attributes: metric.attributes(),
81 }
82 }
83}
84
85/// Metric trait.
86///
87/// # Notes
88///
89/// Implementations should define their own input type only used by the metric.
90/// This is important since some conflict may happen when the model output is adapted for each
91/// metric's input type.
92pub trait Metric: Send + Sync + Clone {
93 /// The input type of the metric.
94 type Input;
95
96 /// The parameterized name of the metric.
97 ///
98 /// This should be unique, so avoid using short generic names, prefer using the long name.
99 ///
100 /// For a metric that can exist at different parameters (e.g., top-k accuracy for different
101 /// values of k), the name should be unique for each instance.
102 fn name(&self) -> MetricName;
103
104 /// A short description of the metric.
105 fn description(&self) -> Option<String> {
106 None
107 }
108
109 /// Attributes of the metric.
110 ///
111 /// By default, metrics have no attributes.
112 fn attributes(&self) -> MetricAttributes {
113 MetricAttributes::None
114 }
115
116 /// Update the metric state and returns the current metric entry.
117 fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry;
118
119 /// Clear the metric state.
120 fn clear(&mut self);
121}
122
123/// Type used to store metric names efficiently.
124pub type MetricName = Arc<String>;
125
126/// Adaptor are used to transform types so that they can be used by metrics.
127///
128/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
129/// registered with the [learner builder](crate::learner::LearnerBuilder) .
130pub trait Adaptor<T> {
131 /// Adapt the type to be passed to a [metric](Metric).
132 fn adapt(&self) -> T;
133}
134
135impl<T> Adaptor<()> for T {
136 fn adapt(&self) {}
137}
138
139/// Attributes that describe intrinsic properties of a numeric metric.
140#[derive(Clone, Debug)]
141pub struct NumericAttributes {
142 /// Optional unit (e.g. "%", "ms", "pixels")
143 pub unit: Option<String>,
144 /// Whether larger values are better (true) or smaller are better (false).
145 pub higher_is_better: bool,
146}
147
148impl From<NumericAttributes> for MetricAttributes {
149 fn from(attr: NumericAttributes) -> Self {
150 MetricAttributes::Numeric(attr)
151 }
152}
153
154impl Default for NumericAttributes {
155 fn default() -> Self {
156 Self {
157 unit: None,
158 higher_is_better: true,
159 }
160 }
161}
162
163/// Declare a metric to be numeric.
164///
165/// This is useful to plot the values of a metric during training.
166pub trait Numeric {
167 /// Returns the numeric value of the metric.
168 fn value(&self) -> NumericEntry;
169}
170
171/// Serialized form of a metric entry.
172#[derive(Debug, Clone, new)]
173pub struct SerializedEntry {
174 /// The string to be displayed.
175 pub formatted: String,
176 /// The string to be saved.
177 pub serialized: String,
178}
179
180/// Data type that contains the current state of a metric at a given time.
181#[derive(Debug, Clone)]
182pub struct MetricEntry {
183 /// Id of the entry's metric.
184 pub metric_id: MetricId,
185 /// The serialized form of the entry.
186 pub serialized_entry: SerializedEntry,
187}
188
189impl MetricEntry {
190 /// Create a new metric.
191 pub fn new(metric_id: MetricId, serialized_entry: SerializedEntry) -> Self {
192 Self {
193 metric_id,
194 serialized_entry,
195 }
196 }
197}
198
199/// Numeric metric entry.
200#[derive(Debug, Clone)]
201pub enum NumericEntry {
202 /// Single numeric value.
203 Value(f64),
204 /// Aggregated numeric (value, number of elements).
205 Aggregated {
206 /// The aggregated value of all entries.
207 aggregated_value: f64,
208 /// The number of entries present in the aggregated value.
209 count: usize,
210 },
211}
212
213impl NumericEntry {
214 /// Gets the current aggregated value of the metric.
215 pub fn current(&self) -> f64 {
216 match self {
217 NumericEntry::Value(val) => *val,
218 NumericEntry::Aggregated {
219 aggregated_value, ..
220 } => *aggregated_value,
221 }
222 }
223
224 /// Returns a String representing the NumericEntry
225 pub fn serialize(&self) -> String {
226 match self {
227 Self::Value(v) => v.to_string(),
228 Self::Aggregated {
229 aggregated_value,
230 count,
231 } => format!("{aggregated_value},{count}"),
232 }
233 }
234
235 /// De-serializes a string representing a NumericEntry and returns a Result containing the corresponding NumericEntry.
236 pub fn deserialize(entry: &str) -> Result<Self, String> {
237 // Check for comma separated values
238 let values = entry.split(',').collect::<Vec<_>>();
239 let num_values = values.len();
240
241 if num_values == 1 {
242 // Numeric value
243 match values[0].parse::<f64>() {
244 Ok(value) => Ok(NumericEntry::Value(value)),
245 Err(err) => Err(err.to_string()),
246 }
247 } else if num_values == 2 {
248 // Aggregated numeric (value, number of elements)
249 let (value, numel) = (values[0], values[1]);
250 match value.parse::<f64>() {
251 Ok(value) => match numel.parse::<usize>() {
252 Ok(numel) => Ok(NumericEntry::Aggregated {
253 aggregated_value: value,
254 count: numel,
255 }),
256 Err(err) => Err(err.to_string()),
257 },
258 Err(err) => Err(err.to_string()),
259 }
260 } else {
261 Err("Invalid number of values for numeric entry".to_string())
262 }
263 }
264
265 /// Compare this numeric metric's value with another one using the specified direction.
266 pub fn better_than(&self, other: &NumericEntry, higher_is_better: bool) -> bool {
267 (self.current() > other.current()) == higher_is_better
268 }
269}
270
271/// Format a float with the given precision. Will use scientific notation if necessary.
272pub fn format_float(float: f64, precision: usize) -> String {
273 let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
274
275 match scientific_notation_threshold >= float {
276 true => format!("{float:.precision$e}"),
277 false => format!("{float:.precision$}"),
278 }
279}