Skip to main content

burn_train/metric/
state.rs

1use std::sync::Arc;
2
3use crate::metric::{MetricName, NumericEntry, SerializedEntry, format_float};
4
5/// Useful utility to implement numeric metrics.
6///
7/// # Notes
8///
9/// The numeric metric store values inside floats.
10/// Even if some metric are integers, their mean are floats.
11#[derive(Clone)]
12pub struct NumericMetricState {
13    sum: f64,
14    count: usize,
15    current: f64,
16    current_count: usize,
17}
18
19/// Formatting options for the [numeric metric state](NumericMetricState).
20pub struct FormatOptions {
21    name: Arc<String>,
22    unit: Option<String>,
23    precision: Option<usize>,
24}
25
26impl FormatOptions {
27    /// Create the [formatting options](FormatOptions) with a name.
28    pub fn new(name: MetricName) -> Self {
29        Self {
30            name: name.clone(),
31            unit: None,
32            precision: None,
33        }
34    }
35
36    /// Specify the metric unit.
37    pub fn unit(mut self, unit: &str) -> Self {
38        self.unit = Some(unit.to_string());
39        self
40    }
41
42    /// Specify the floating point precision.
43    pub fn precision(mut self, precision: usize) -> Self {
44        self.precision = Some(precision);
45        self
46    }
47
48    /// Get the metric name.
49    pub fn name(&self) -> &Arc<String> {
50        &self.name
51    }
52
53    /// Get the metric unit.
54    pub fn unit_value(&self) -> &Option<String> {
55        &self.unit
56    }
57
58    /// Get the precision.
59    pub fn precision_value(&self) -> Option<usize> {
60        self.precision
61    }
62}
63
64impl NumericMetricState {
65    /// Create a new [numeric metric state](NumericMetricState).
66    pub fn new() -> Self {
67        Self {
68            sum: 0.0,
69            count: 0,
70            current: f64::NAN,
71            current_count: 0,
72        }
73    }
74
75    /// Reset the state.
76    pub fn reset(&mut self) {
77        self.sum = 0.0;
78        self.count = 0;
79        self.current = f64::NAN;
80        self.current_count = 0;
81    }
82
83    /// Update the state.
84    pub fn update(
85        &mut self,
86        value: f64,
87        batch_size: usize,
88        format: FormatOptions,
89    ) -> SerializedEntry {
90        self.sum += value * batch_size as f64;
91        self.count += batch_size;
92        self.current = value;
93        self.current_count = batch_size;
94
95        let value_current = value;
96        let value_running = self.sum / self.count as f64;
97        // Numeric metric state is an aggregated value
98        let serialized = NumericEntry::Aggregated {
99            aggregated_value: value_current,
100            count: batch_size,
101        }
102        .serialize();
103
104        let (formatted_current, formatted_running) = match format.precision {
105            Some(precision) => (
106                format_float(value_current, precision),
107                format_float(value_running, precision),
108            ),
109            None => (format!("{value_current}"), format!("{value_running}")),
110        };
111
112        let formatted = match format.unit {
113            Some(unit) => {
114                format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}")
115            }
116            None => format!("epoch {formatted_running} - batch {formatted_current}"),
117        };
118
119        SerializedEntry::new(formatted, serialized)
120    }
121
122    /// Get the numeric value.
123    pub fn current_value(&self) -> NumericEntry {
124        NumericEntry::Aggregated {
125            aggregated_value: self.current,
126            count: self.current_count,
127        }
128    }
129
130    /// Get the running aggregated value.
131    pub fn running_value(&self) -> NumericEntry {
132        NumericEntry::Aggregated {
133            aggregated_value: self.sum / self.count as f64,
134            count: self.count,
135        }
136    }
137}
138
139impl Default for NumericMetricState {
140    fn default() -> Self {
141        Self::new()
142    }
143}