burn_train/metric/
state.rs

1use std::sync::Arc;
2
3use crate::metric::{MetricEntry, MetricName, Numeric, NumericEntry, format_float};
4
5/// Useful utility to implement numeric metrics.
6///
7/// # Notes
8///
9/// The numeric metric store values inside floats.
10/// Even if some metric are integers, their mean are floats.
11#[derive(Clone)]
12pub struct NumericMetricState {
13    sum: f64,
14    count: usize,
15    current: f64,
16}
17
18/// Formatting options for the [numeric metric state](NumericMetricState).
19pub struct FormatOptions {
20    name: Arc<String>,
21    unit: Option<String>,
22    precision: Option<usize>,
23}
24
25impl FormatOptions {
26    /// Create the [formatting options](FormatOptions) with a name.
27    pub fn new(name: MetricName) -> Self {
28        Self {
29            name: name.clone(),
30            unit: None,
31            precision: None,
32        }
33    }
34
35    /// Specify the metric unit.
36    pub fn unit(mut self, unit: &str) -> Self {
37        self.unit = Some(unit.to_string());
38        self
39    }
40
41    /// Specify the floating point precision.
42    pub fn precision(mut self, precision: usize) -> Self {
43        self.precision = Some(precision);
44        self
45    }
46
47    /// Get the metric name.
48    pub fn name(&self) -> &Arc<String> {
49        &self.name
50    }
51
52    /// Get the metric unit.
53    pub fn unit_value(&self) -> &Option<String> {
54        &self.unit
55    }
56
57    /// Get the precision.
58    pub fn precision_value(&self) -> Option<usize> {
59        self.precision
60    }
61}
62
63impl NumericMetricState {
64    /// Create a new [numeric metric state](NumericMetricState).
65    pub fn new() -> Self {
66        Self {
67            sum: 0.0,
68            count: 0,
69            current: f64::NAN,
70        }
71    }
72
73    /// Reset the state.
74    pub fn reset(&mut self) {
75        self.sum = 0.0;
76        self.count = 0;
77        self.current = f64::NAN;
78    }
79
80    /// Update the state.
81    pub fn update(&mut self, value: f64, batch_size: usize, format: FormatOptions) -> MetricEntry {
82        self.sum += value * batch_size as f64;
83        self.count += batch_size;
84        self.current = value;
85
86        let value_current = value;
87        let value_running = self.sum / self.count as f64;
88        // Numeric metric state is an aggregated value
89        let serialized = NumericEntry::Aggregated {
90            sum: value_current,
91            count: batch_size,
92            current: value_current,
93        }
94        .serialize();
95
96        let (formatted_current, formatted_running) = match format.precision {
97            Some(precision) => (
98                format_float(value_current, precision),
99                format_float(value_running, precision),
100            ),
101            None => (format!("{value_current}"), format!("{value_running}")),
102        };
103
104        let formatted = match format.unit {
105            Some(unit) => {
106                format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}")
107            }
108            None => format!("epoch {formatted_running} - batch {formatted_current}"),
109        };
110
111        MetricEntry::new(format.name, formatted, serialized)
112    }
113}
114
115impl Numeric for NumericMetricState {
116    fn value(&self) -> NumericEntry {
117        NumericEntry::Aggregated {
118            sum: self.sum,
119            count: self.count,
120            current: self.current,
121        }
122    }
123}
124
125impl Default for NumericMetricState {
126    fn default() -> Self {
127        Self::new()
128    }
129}