burn_train/metric/
state.rs

1use crate::metric::{format_float, MetricEntry, Numeric, NumericEntry};
2
3/// Useful utility to implement numeric metrics.
4///
5/// # Notes
6///
7/// The numeric metric store values inside floats.
8/// Even if some metric are integers, their mean are floats.
9pub struct NumericMetricState {
10    sum: f64,
11    count: usize,
12    current: f64,
13}
14
15/// Formatting options for the [numeric metric state](NumericMetricState).
16pub struct FormatOptions {
17    name: String,
18    unit: Option<String>,
19    precision: Option<usize>,
20}
21
22impl FormatOptions {
23    /// Create the [formatting options](FormatOptions) with a name.
24    pub fn new(name: &str) -> Self {
25        Self {
26            name: name.to_string(),
27            unit: None,
28            precision: None,
29        }
30    }
31
32    /// Specify the metric unit.
33    pub fn unit(mut self, unit: &str) -> Self {
34        self.unit = Some(unit.to_string());
35        self
36    }
37
38    /// Specify the floating point precision.
39    pub fn precision(mut self, precision: usize) -> Self {
40        self.precision = Some(precision);
41        self
42    }
43}
44
45impl NumericMetricState {
46    /// Create a new [numeric metric state](NumericMetricState).
47    pub fn new() -> Self {
48        Self {
49            sum: 0.0,
50            count: 0,
51            current: f64::NAN,
52        }
53    }
54
55    /// Reset the state.
56    pub fn reset(&mut self) {
57        self.sum = 0.0;
58        self.count = 0;
59        self.current = f64::NAN;
60    }
61
62    /// Update the state.
63    pub fn update(&mut self, value: f64, batch_size: usize, format: FormatOptions) -> MetricEntry {
64        self.sum += value * batch_size as f64;
65        self.count += batch_size;
66        self.current = value;
67
68        let value_current = value;
69        let value_running = self.sum / self.count as f64;
70        // Numeric metric state is an aggregated value
71        let serialized = NumericEntry::Aggregated(value_current, batch_size).serialize();
72
73        let (formatted_current, formatted_running) = match format.precision {
74            Some(precision) => (
75                format_float(value_current, precision),
76                format_float(value_running, precision),
77            ),
78            None => (format!("{value_current}"), format!("{value_running}")),
79        };
80
81        let formatted = match format.unit {
82            Some(unit) => {
83                format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}")
84            }
85            None => format!("epoch {formatted_running} - batch {formatted_current}"),
86        };
87
88        MetricEntry::new(format.name, formatted, serialized)
89    }
90}
91
92impl Numeric for NumericMetricState {
93    fn value(&self) -> f64 {
94        self.current
95    }
96}
97
98impl Default for NumericMetricState {
99    fn default() -> Self {
100        Self::new()
101    }
102}