burn_train/metric/base.rs
1use burn_core::{data::dataloader::Progress, LearningRate};
2
3/// Metric metadata that can be used when computing metrics.
4pub struct MetricMetadata {
5 /// The current progress.
6 pub progress: Progress,
7
8 /// The current epoch.
9 pub epoch: usize,
10
11 /// The total number of epochs.
12 pub epoch_total: usize,
13
14 /// The current iteration.
15 pub iteration: usize,
16
17 /// The current learning rate.
18 pub lr: Option<LearningRate>,
19}
20
21impl MetricMetadata {
22 /// Fake metric metadata
23 #[cfg(test)]
24 pub fn fake() -> Self {
25 Self {
26 progress: Progress {
27 items_processed: 1,
28 items_total: 1,
29 },
30 epoch: 0,
31 epoch_total: 1,
32 iteration: 0,
33 lr: None,
34 }
35 }
36}
37
38/// Metric trait.
39///
40/// # Notes
41///
42/// Implementations should define their own input type only used by the metric.
43/// This is important since some conflict may happen when the model output is adapted for each
44/// metric's input type.
45pub trait Metric: Send + Sync {
46 /// The name of the metric.
47 ///
48 /// This should be unique, so avoid using short generic names, prefer using the long name.
49 const NAME: &'static str;
50
51 /// The input type of the metric.
52 type Input;
53
54 /// Update the metric state and returns the current metric entry.
55 fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> MetricEntry;
56 /// Clear the metric state.
57 fn clear(&mut self);
58}
59
60/// Adaptor are used to transform types so that they can be used by metrics.
61///
62/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
63/// registered with the [leaner buidler](crate::learner::LearnerBuilder) .
64pub trait Adaptor<T> {
65 /// Adapt the type to be passed to a [metric](Metric).
66 fn adapt(&self) -> T;
67}
68
69/// Declare a metric to be numeric.
70///
71/// This is useful to plot the values of a metric during training.
72pub trait Numeric {
73 /// Returns the numeric value of the metric.
74 fn value(&self) -> f64;
75}
76
77/// Data type that contains the current state of a metric at a given time.
78#[derive(new, Debug, Clone)]
79pub struct MetricEntry {
80 /// The name of the metric.
81 pub name: String,
82 /// The string to be displayed.
83 pub formatted: String,
84 /// The string to be saved.
85 pub serialize: String,
86}
87
88/// Numeric metric entry.
89pub enum NumericEntry {
90 /// Single numeric value.
91 Value(f64),
92 /// Aggregated numeric (value, number of elements).
93 Aggregated(f64, usize),
94}
95
96impl NumericEntry {
97 pub(crate) fn serialize(&self) -> String {
98 match self {
99 Self::Value(v) => v.to_string(),
100 Self::Aggregated(v, n) => format!("{v},{n}"),
101 }
102 }
103
104 pub(crate) fn deserialize(entry: &str) -> Result<Self, String> {
105 // Check for comma separated values
106 let values = entry.split(',').collect::<Vec<_>>();
107 let num_values = values.len();
108
109 if num_values == 1 {
110 // Numeric value
111 match values[0].parse::<f64>() {
112 Ok(value) => Ok(NumericEntry::Value(value)),
113 Err(err) => Err(err.to_string()),
114 }
115 } else if num_values == 2 {
116 // Aggregated numeric (value, number of elements)
117 let (value, numel) = (values[0], values[1]);
118 match value.parse::<f64>() {
119 Ok(value) => match numel.parse::<usize>() {
120 Ok(numel) => Ok(NumericEntry::Aggregated(value, numel)),
121 Err(err) => Err(err.to_string()),
122 },
123 Err(err) => Err(err.to_string()),
124 }
125 } else {
126 Err("Invalid number of values for numeric entry".to_string())
127 }
128 }
129}
130
131/// Format a float with the given precision. Will use scientific notation if necessary.
132pub fn format_float(float: f64, precision: usize) -> String {
133 let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
134
135 match scientific_notation_threshold >= float {
136 true => format!("{float:.precision$e}"),
137 false => format!("{float:.precision$}"),
138 }
139}