burn_train/metric/
base.rs

1use std::sync::Arc;
2
3use burn_core::data::dataloader::Progress;
4use burn_optim::LearningRate;
5
6/// Metric metadata that can be used when computing metrics.
7pub struct MetricMetadata {
8    /// The current progress.
9    pub progress: Progress,
10
11    /// The current epoch.
12    pub epoch: usize,
13
14    /// The total number of epochs.
15    pub epoch_total: usize,
16
17    /// The current iteration.
18    pub iteration: usize,
19
20    /// The current learning rate.
21    pub lr: Option<LearningRate>,
22}
23
24impl MetricMetadata {
25    /// Fake metric metadata
26    #[cfg(test)]
27    pub fn fake() -> Self {
28        Self {
29            progress: Progress {
30                items_processed: 1,
31                items_total: 1,
32            },
33            epoch: 0,
34            epoch_total: 1,
35            iteration: 0,
36            lr: None,
37        }
38    }
39}
40
41/// Metric attributes define the properties intrinsic to different types of metric.
42#[derive(Clone, Debug)]
43pub enum MetricAttributes {
44    /// Numeric attributes.
45    Numeric(NumericAttributes),
46    /// No attributes.
47    None,
48}
49
50/// Definition of a metric.
51///
52/// This is used to register a metric with the [learner builder](crate::learner::LearnerBuilder).
53#[derive(Clone, Debug)]
54pub struct MetricDefinition {
55    /// The name of the metric.
56    pub name: String,
57    /// The description of the metric.
58    pub description: Option<String>,
59    /// The attributes of the metric.
60    pub attributes: MetricAttributes,
61}
62
63impl<Me: Metric> From<&Me> for MetricDefinition {
64    fn from(metric: &Me) -> Self {
65        Self {
66            name: metric.name().to_string(),
67            description: metric.description(),
68            attributes: metric.attributes(),
69        }
70    }
71}
72
73/// Metric trait.
74///
75/// # Notes
76///
77/// Implementations should define their own input type only used by the metric.
78/// This is important since some conflict may happen when the model output is adapted for each
79/// metric's input type.
80pub trait Metric: Send + Sync + Clone {
81    /// The input type of the metric.
82    type Input;
83
84    /// The parameterized name of the metric.
85    ///
86    /// This should be unique, so avoid using short generic names, prefer using the long name.
87    ///
88    /// For a metric that can exist at different parameters (e.g., top-k accuracy for different
89    /// values of k), the name should be unique for each instance.
90    fn name(&self) -> MetricName;
91
92    /// A short description of the metric.
93    fn description(&self) -> Option<String> {
94        None
95    }
96
97    /// Attributes of the metric.
98    ///
99    /// By default, metrics have no attributes.
100    fn attributes(&self) -> MetricAttributes {
101        MetricAttributes::None
102    }
103
104    /// Update the metric state and returns the current metric entry.
105    fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> MetricEntry;
106
107    /// Clear the metric state.
108    fn clear(&mut self);
109}
110
111/// Type used to store metric names efficiently.
112pub type MetricName = Arc<String>;
113
114/// Adaptor are used to transform types so that they can be used by metrics.
115///
116/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
117/// registered with the [learner builder](crate::learner::LearnerBuilder) .
118pub trait Adaptor<T> {
119    /// Adapt the type to be passed to a [metric](Metric).
120    fn adapt(&self) -> T;
121}
122
123impl<T> Adaptor<()> for T {
124    fn adapt(&self) {}
125}
126
127/// Attributes that describe intrinsic properties of a numeric metric.
128#[derive(Clone, Debug)]
129pub struct NumericAttributes {
130    /// Optional unit (e.g. "%", "ms", "pixels")
131    pub unit: Option<String>,
132    /// Whether larger values are better (true) or smaller are better (false).
133    pub higher_is_better: bool,
134}
135
136impl From<NumericAttributes> for MetricAttributes {
137    fn from(attr: NumericAttributes) -> Self {
138        MetricAttributes::Numeric(attr)
139    }
140}
141
142impl Default for NumericAttributes {
143    fn default() -> Self {
144        Self {
145            unit: None,
146            higher_is_better: true,
147        }
148    }
149}
150
151/// Declare a metric to be numeric.
152///
153/// This is useful to plot the values of a metric during training.
154pub trait Numeric {
155    /// Returns the numeric value of the metric.
156    fn value(&self) -> NumericEntry;
157}
158
159/// Data type that contains the current state of a metric at a given time.
160#[derive(Debug, Clone)]
161pub struct MetricEntry {
162    /// The name of the metric.
163    pub name: Arc<String>,
164    /// The string to be displayed.
165    pub formatted: String,
166    /// The string to be saved.
167    pub serialize: String,
168    /// Tags linked to the metric.
169    pub tags: Vec<Arc<String>>,
170}
171
172impl MetricEntry {
173    /// Create a new metric.
174    pub fn new(name: Arc<String>, formatted: String, serialize: String) -> Self {
175        Self {
176            name,
177            formatted,
178            serialize,
179            tags: Vec::new(),
180        }
181    }
182}
183
184/// Numeric metric entry.
185#[derive(Debug, Clone)]
186pub enum NumericEntry {
187    /// Single numeric value.
188    Value(f64),
189    /// Aggregated numeric (value, number of elements).
190    Aggregated {
191        /// The aggregated value of all entries.
192        aggregated_value: f64,
193        /// The number of entries present in the aggregated value.
194        count: usize,
195    },
196}
197
198impl NumericEntry {
199    /// Gets the current aggregated value of the metric.
200    pub fn current(&self) -> f64 {
201        match self {
202            NumericEntry::Value(val) => *val,
203            NumericEntry::Aggregated {
204                aggregated_value, ..
205            } => *aggregated_value,
206        }
207    }
208
209    /// Returns a String representing the NumericEntry
210    pub fn serialize(&self) -> String {
211        match self {
212            Self::Value(v) => v.to_string(),
213            Self::Aggregated {
214                aggregated_value,
215                count,
216            } => format!("{aggregated_value},{count}"),
217        }
218    }
219
220    /// De-serializes a string representing a NumericEntry and returns a Result containing the corresponding NumericEntry.
221    pub fn deserialize(entry: &str) -> Result<Self, String> {
222        // Check for comma separated values
223        let values = entry.split(',').collect::<Vec<_>>();
224        let num_values = values.len();
225
226        if num_values == 1 {
227            // Numeric value
228            match values[0].parse::<f64>() {
229                Ok(value) => Ok(NumericEntry::Value(value)),
230                Err(err) => Err(err.to_string()),
231            }
232        } else if num_values == 2 {
233            // Aggregated numeric (value, number of elements)
234            let (value, numel) = (values[0], values[1]);
235            match value.parse::<f64>() {
236                Ok(value) => match numel.parse::<usize>() {
237                    Ok(numel) => Ok(NumericEntry::Aggregated {
238                        aggregated_value: value,
239                        count: numel,
240                    }),
241                    Err(err) => Err(err.to_string()),
242                },
243                Err(err) => Err(err.to_string()),
244            }
245        } else {
246            Err("Invalid number of values for numeric entry".to_string())
247        }
248    }
249
250    /// Compare this numeric metric's value with another one using the specified direction.
251    pub fn better_than(&self, other: &NumericEntry, higher_is_better: bool) -> bool {
252        (self.current() > other.current()) == higher_is_better
253    }
254}
255
256/// Format a float with the given precision. Will use scientific notation if necessary.
257pub fn format_float(float: f64, precision: usize) -> String {
258    let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
259
260    match scientific_notation_threshold >= float {
261        true => format!("{float:.precision$e}"),
262        false => format!("{float:.precision$}"),
263    }
264}