burn_train/metric/
state.rs1use std::sync::Arc;
2
3use crate::metric::{MetricEntry, MetricName, Numeric, NumericEntry, format_float};
4
5#[derive(Clone)]
12pub struct NumericMetricState {
13 sum: f64,
14 count: usize,
15 current: f64,
16}
17
18pub struct FormatOptions {
20 name: Arc<String>,
21 unit: Option<String>,
22 precision: Option<usize>,
23}
24
25impl FormatOptions {
26 pub fn new(name: MetricName) -> Self {
28 Self {
29 name: name.clone(),
30 unit: None,
31 precision: None,
32 }
33 }
34
35 pub fn unit(mut self, unit: &str) -> Self {
37 self.unit = Some(unit.to_string());
38 self
39 }
40
41 pub fn precision(mut self, precision: usize) -> Self {
43 self.precision = Some(precision);
44 self
45 }
46
47 pub fn name(&self) -> &Arc<String> {
49 &self.name
50 }
51
52 pub fn unit_value(&self) -> &Option<String> {
54 &self.unit
55 }
56
57 pub fn precision_value(&self) -> Option<usize> {
59 self.precision
60 }
61}
62
63impl NumericMetricState {
64 pub fn new() -> Self {
66 Self {
67 sum: 0.0,
68 count: 0,
69 current: f64::NAN,
70 }
71 }
72
73 pub fn reset(&mut self) {
75 self.sum = 0.0;
76 self.count = 0;
77 self.current = f64::NAN;
78 }
79
80 pub fn update(&mut self, value: f64, batch_size: usize, format: FormatOptions) -> MetricEntry {
82 self.sum += value * batch_size as f64;
83 self.count += batch_size;
84 self.current = value;
85
86 let value_current = value;
87 let value_running = self.sum / self.count as f64;
88 let serialized = NumericEntry::Aggregated {
90 sum: value_current,
91 count: batch_size,
92 current: value_current,
93 }
94 .serialize();
95
96 let (formatted_current, formatted_running) = match format.precision {
97 Some(precision) => (
98 format_float(value_current, precision),
99 format_float(value_running, precision),
100 ),
101 None => (format!("{value_current}"), format!("{value_running}")),
102 };
103
104 let formatted = match format.unit {
105 Some(unit) => {
106 format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}")
107 }
108 None => format!("epoch {formatted_running} - batch {formatted_current}"),
109 };
110
111 MetricEntry::new(format.name, formatted, serialized)
112 }
113}
114
115impl Numeric for NumericMetricState {
116 fn value(&self) -> NumericEntry {
117 NumericEntry::Aggregated {
118 sum: self.sum,
119 count: self.count,
120 current: self.current,
121 }
122 }
123}
124
125impl Default for NumericMetricState {
126 fn default() -> Self {
127 Self::new()
128 }
129}