burn_train/metric/
state.rs1use crate::metric::{format_float, MetricEntry, Numeric, NumericEntry};
2
3pub struct NumericMetricState {
10 sum: f64,
11 count: usize,
12 current: f64,
13}
14
15pub struct FormatOptions {
17 name: String,
18 unit: Option<String>,
19 precision: Option<usize>,
20}
21
22impl FormatOptions {
23 pub fn new(name: &str) -> Self {
25 Self {
26 name: name.to_string(),
27 unit: None,
28 precision: None,
29 }
30 }
31
32 pub fn unit(mut self, unit: &str) -> Self {
34 self.unit = Some(unit.to_string());
35 self
36 }
37
38 pub fn precision(mut self, precision: usize) -> Self {
40 self.precision = Some(precision);
41 self
42 }
43}
44
45impl NumericMetricState {
46 pub fn new() -> Self {
48 Self {
49 sum: 0.0,
50 count: 0,
51 current: f64::NAN,
52 }
53 }
54
55 pub fn reset(&mut self) {
57 self.sum = 0.0;
58 self.count = 0;
59 self.current = f64::NAN;
60 }
61
62 pub fn update(&mut self, value: f64, batch_size: usize, format: FormatOptions) -> MetricEntry {
64 self.sum += value * batch_size as f64;
65 self.count += batch_size;
66 self.current = value;
67
68 let value_current = value;
69 let value_running = self.sum / self.count as f64;
70 let serialized = NumericEntry::Aggregated(value_current, batch_size).serialize();
72
73 let (formatted_current, formatted_running) = match format.precision {
74 Some(precision) => (
75 format_float(value_current, precision),
76 format_float(value_running, precision),
77 ),
78 None => (format!("{value_current}"), format!("{value_running}")),
79 };
80
81 let formatted = match format.unit {
82 Some(unit) => {
83 format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}")
84 }
85 None => format!("epoch {formatted_running} - batch {formatted_current}"),
86 };
87
88 MetricEntry::new(format.name, formatted, serialized)
89 }
90}
91
92impl Numeric for NumericMetricState {
93 fn value(&self) -> f64 {
94 self.current
95 }
96}
97
98impl Default for NumericMetricState {
99 fn default() -> Self {
100 Self::new()
101 }
102}