burn_train/metric/
state.rs1use std::sync::Arc;
2
3use crate::metric::{MetricName, NumericEntry, SerializedEntry, format_float};
4
5#[derive(Clone)]
12pub struct NumericMetricState {
13 sum: f64,
14 count: usize,
15 current: f64,
16 current_count: usize,
17}
18
19pub struct FormatOptions {
21 name: Arc<String>,
22 unit: Option<String>,
23 precision: Option<usize>,
24}
25
26impl FormatOptions {
27 pub fn new(name: MetricName) -> Self {
29 Self {
30 name: name.clone(),
31 unit: None,
32 precision: None,
33 }
34 }
35
36 pub fn unit(mut self, unit: &str) -> Self {
38 self.unit = Some(unit.to_string());
39 self
40 }
41
42 pub fn precision(mut self, precision: usize) -> Self {
44 self.precision = Some(precision);
45 self
46 }
47
48 pub fn name(&self) -> &Arc<String> {
50 &self.name
51 }
52
53 pub fn unit_value(&self) -> &Option<String> {
55 &self.unit
56 }
57
58 pub fn precision_value(&self) -> Option<usize> {
60 self.precision
61 }
62}
63
64impl NumericMetricState {
65 pub fn new() -> Self {
67 Self {
68 sum: 0.0,
69 count: 0,
70 current: f64::NAN,
71 current_count: 0,
72 }
73 }
74
75 pub fn reset(&mut self) {
77 self.sum = 0.0;
78 self.count = 0;
79 self.current = f64::NAN;
80 self.current_count = 0;
81 }
82
83 pub fn update(
85 &mut self,
86 value: f64,
87 batch_size: usize,
88 format: FormatOptions,
89 ) -> SerializedEntry {
90 self.sum += value * batch_size as f64;
91 self.count += batch_size;
92 self.current = value;
93 self.current_count = batch_size;
94
95 let value_current = value;
96 let value_running = self.sum / self.count as f64;
97 let serialized = NumericEntry::Aggregated {
99 aggregated_value: value_current,
100 count: batch_size,
101 }
102 .serialize();
103
104 let (formatted_current, formatted_running) = match format.precision {
105 Some(precision) => (
106 format_float(value_current, precision),
107 format_float(value_running, precision),
108 ),
109 None => (format!("{value_current}"), format!("{value_running}")),
110 };
111
112 let formatted = match format.unit {
113 Some(unit) => {
114 format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}")
115 }
116 None => format!("epoch {formatted_running} - batch {formatted_current}"),
117 };
118
119 SerializedEntry::new(formatted, serialized)
120 }
121
122 pub fn current_value(&self) -> NumericEntry {
124 NumericEntry::Aggregated {
125 aggregated_value: self.current,
126 count: self.current_count,
127 }
128 }
129
130 pub fn running_value(&self) -> NumericEntry {
132 NumericEntry::Aggregated {
133 aggregated_value: self.sum / self.count as f64,
134 count: self.count,
135 }
136 }
137}
138
139impl Default for NumericMetricState {
140 fn default() -> Self {
141 Self::new()
142 }
143}