burn_train/metric/base.rs
1use std::sync::Arc;
2
3use burn_core::data::dataloader::Progress;
4use burn_optim::LearningRate;
5
6/// Metric metadata that can be used when computing metrics.
7pub struct MetricMetadata {
8 /// The current progress.
9 pub progress: Progress,
10
11 /// The current epoch.
12 pub epoch: usize,
13
14 /// The total number of epochs.
15 pub epoch_total: usize,
16
17 /// The current iteration.
18 pub iteration: usize,
19
20 /// The current learning rate.
21 pub lr: Option<LearningRate>,
22}
23
24impl MetricMetadata {
25 /// Fake metric metadata
26 #[cfg(test)]
27 pub fn fake() -> Self {
28 Self {
29 progress: Progress {
30 items_processed: 1,
31 items_total: 1,
32 },
33 epoch: 0,
34 epoch_total: 1,
35 iteration: 0,
36 lr: None,
37 }
38 }
39}
40
41/// Metric attributes define the properties intrinsic to different types of metric.
42#[derive(Clone, Debug)]
43pub enum MetricAttributes {
44 /// Numeric attributes.
45 Numeric(NumericAttributes),
46 /// No attributes.
47 None,
48}
49
50/// Definition of a metric.
51///
52/// This is used to register a metric with the [learner builder](crate::learner::LearnerBuilder).
53#[derive(Clone, Debug)]
54pub struct MetricDefinition {
55 /// The name of the metric.
56 pub name: String,
57 /// The description of the metric.
58 pub description: Option<String>,
59 /// The attributes of the metric.
60 pub attributes: MetricAttributes,
61}
62
63impl<Me: Metric> From<&Me> for MetricDefinition {
64 fn from(metric: &Me) -> Self {
65 Self {
66 name: metric.name().to_string(),
67 description: metric.description(),
68 attributes: metric.attributes(),
69 }
70 }
71}
72
73/// Metric trait.
74///
75/// # Notes
76///
77/// Implementations should define their own input type only used by the metric.
78/// This is important since some conflict may happen when the model output is adapted for each
79/// metric's input type.
80pub trait Metric: Send + Sync + Clone {
81 /// The input type of the metric.
82 type Input;
83
84 /// The parameterized name of the metric.
85 ///
86 /// This should be unique, so avoid using short generic names, prefer using the long name.
87 ///
88 /// For a metric that can exist at different parameters (e.g., top-k accuracy for different
89 /// values of k), the name should be unique for each instance.
90 fn name(&self) -> MetricName;
91
92 /// A short description of the metric.
93 fn description(&self) -> Option<String> {
94 None
95 }
96
97 /// Attributes of the metric.
98 ///
99 /// By default, metrics have no attributes.
100 fn attributes(&self) -> MetricAttributes {
101 MetricAttributes::None
102 }
103
104 /// Update the metric state and returns the current metric entry.
105 fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> MetricEntry;
106
107 /// Clear the metric state.
108 fn clear(&mut self);
109}
110
111/// Type used to store metric names efficiently.
112pub type MetricName = Arc<String>;
113
114/// Adaptor are used to transform types so that they can be used by metrics.
115///
116/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
117/// registered with the [learner builder](crate::learner::LearnerBuilder) .
118pub trait Adaptor<T> {
119 /// Adapt the type to be passed to a [metric](Metric).
120 fn adapt(&self) -> T;
121}
122
123impl<T> Adaptor<()> for T {
124 fn adapt(&self) {}
125}
126
127/// Attributes that describe intrinsic properties of a numeric metric.
128#[derive(Clone, Debug)]
129pub struct NumericAttributes {
130 /// Optional unit (e.g. "%", "ms", "pixels")
131 pub unit: Option<String>,
132 /// Whether larger values are better (true) or smaller are better (false).
133 pub higher_is_better: bool,
134}
135
136impl From<NumericAttributes> for MetricAttributes {
137 fn from(attr: NumericAttributes) -> Self {
138 MetricAttributes::Numeric(attr)
139 }
140}
141
142impl Default for NumericAttributes {
143 fn default() -> Self {
144 Self {
145 unit: None,
146 higher_is_better: true,
147 }
148 }
149}
150
151/// Declare a metric to be numeric.
152///
153/// This is useful to plot the values of a metric during training.
154pub trait Numeric {
155 /// Returns the numeric value of the metric.
156 fn value(&self) -> NumericEntry;
157}
158
159/// Data type that contains the current state of a metric at a given time.
160#[derive(Debug, Clone)]
161pub struct MetricEntry {
162 /// The name of the metric.
163 pub name: Arc<String>,
164 /// The string to be displayed.
165 pub formatted: String,
166 /// The string to be saved.
167 pub serialize: String,
168 /// Tags linked to the metric.
169 pub tags: Vec<Arc<String>>,
170}
171
172impl MetricEntry {
173 /// Create a new metric.
174 pub fn new(name: Arc<String>, formatted: String, serialize: String) -> Self {
175 Self {
176 name,
177 formatted,
178 serialize,
179 tags: Vec::new(),
180 }
181 }
182}
183
184/// Numeric metric entry.
185#[derive(Debug, Clone)]
186pub enum NumericEntry {
187 /// Single numeric value.
188 Value(f64),
189 /// Aggregated numeric (value, number of elements).
190 Aggregated {
191 /// The aggregated value of all entries.
192 aggregated_value: f64,
193 /// The number of entries present in the aggregated value.
194 count: usize,
195 },
196}
197
198impl NumericEntry {
199 /// Gets the current aggregated value of the metric.
200 pub fn current(&self) -> f64 {
201 match self {
202 NumericEntry::Value(val) => *val,
203 NumericEntry::Aggregated {
204 aggregated_value, ..
205 } => *aggregated_value,
206 }
207 }
208
209 /// Returns a String representing the NumericEntry
210 pub fn serialize(&self) -> String {
211 match self {
212 Self::Value(v) => v.to_string(),
213 Self::Aggregated {
214 aggregated_value,
215 count,
216 } => format!("{aggregated_value},{count}"),
217 }
218 }
219
220 /// De-serializes a string representing a NumericEntry and returns a Result containing the corresponding NumericEntry.
221 pub fn deserialize(entry: &str) -> Result<Self, String> {
222 // Check for comma separated values
223 let values = entry.split(',').collect::<Vec<_>>();
224 let num_values = values.len();
225
226 if num_values == 1 {
227 // Numeric value
228 match values[0].parse::<f64>() {
229 Ok(value) => Ok(NumericEntry::Value(value)),
230 Err(err) => Err(err.to_string()),
231 }
232 } else if num_values == 2 {
233 // Aggregated numeric (value, number of elements)
234 let (value, numel) = (values[0], values[1]);
235 match value.parse::<f64>() {
236 Ok(value) => match numel.parse::<usize>() {
237 Ok(numel) => Ok(NumericEntry::Aggregated {
238 aggregated_value: value,
239 count: numel,
240 }),
241 Err(err) => Err(err.to_string()),
242 },
243 Err(err) => Err(err.to_string()),
244 }
245 } else {
246 Err("Invalid number of values for numeric entry".to_string())
247 }
248 }
249
250 /// Compare this numeric metric's value with another one using the specified direction.
251 pub fn better_than(&self, other: &NumericEntry, higher_is_better: bool) -> bool {
252 (self.current() > other.current()) == higher_is_better
253 }
254}
255
256/// Format a float with the given precision. Will use scientific notation if necessary.
257pub fn format_float(float: f64, precision: usize) -> String {
258 let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
259
260 match scientific_notation_threshold >= float {
261 true => format!("{float:.precision$e}"),
262 false => format!("{float:.precision$}"),
263 }
264}