burn_train/metric/
iteration.rs1use std::sync::Arc;
2
3use super::MetricEntry;
4use super::MetricMetadata;
5use super::state::FormatOptions;
6use super::state::NumericMetricState;
7use crate::metric::MetricName;
8use crate::metric::{Metric, Numeric};
9
10#[derive(Clone)]
12pub struct IterationSpeedMetric {
13 name: MetricName,
14 state: NumericMetricState,
15 instant: Option<std::time::Instant>,
16}
17
18impl Default for IterationSpeedMetric {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl IterationSpeedMetric {
25 pub fn new() -> Self {
27 Self {
28 name: Arc::new("Iteration Speed".to_string()),
29 state: Default::default(),
30 instant: Default::default(),
31 }
32 }
33}
34
35impl Metric for IterationSpeedMetric {
36 type Input = ();
37
38 fn update(&mut self, _: &Self::Input, metadata: &MetricMetadata) -> MetricEntry {
39 let raw = match self.instant {
40 Some(val) => metadata.iteration as f64 / val.elapsed().as_secs_f64(),
41 None => {
42 self.instant = Some(std::time::Instant::now());
43 0.0
44 }
45 };
46
47 self.state.update(
48 raw,
49 1,
50 FormatOptions::new(self.name())
51 .unit("iter/sec")
52 .precision(2),
53 )
54 }
55
56 fn clear(&mut self) {
57 self.instant = None;
58 }
59
60 fn name(&self) -> MetricName {
61 self.name.clone()
62 }
63}
64
65impl Numeric for IterationSpeedMetric {
66 fn value(&self) -> super::NumericEntry {
67 self.state.value()
68 }
69}