burn_train/metric/
iteration.rs

1use std::sync::Arc;
2
3use super::MetricEntry;
4use super::MetricMetadata;
5use super::state::FormatOptions;
6use super::state::NumericMetricState;
7use crate::metric::MetricName;
8use crate::metric::{Metric, Numeric};
9
10/// The loss metric.
11#[derive(Clone)]
12pub struct IterationSpeedMetric {
13    name: MetricName,
14    state: NumericMetricState,
15    instant: Option<std::time::Instant>,
16}
17
18impl Default for IterationSpeedMetric {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl IterationSpeedMetric {
25    /// Create the metric.
26    pub fn new() -> Self {
27        Self {
28            name: Arc::new("Iteration Speed".to_string()),
29            state: Default::default(),
30            instant: Default::default(),
31        }
32    }
33}
34
35impl Metric for IterationSpeedMetric {
36    type Input = ();
37
38    fn update(&mut self, _: &Self::Input, metadata: &MetricMetadata) -> MetricEntry {
39        let raw = match self.instant {
40            Some(val) => metadata.iteration as f64 / val.elapsed().as_secs_f64(),
41            None => {
42                self.instant = Some(std::time::Instant::now());
43                0.0
44            }
45        };
46
47        self.state.update(
48            raw,
49            1,
50            FormatOptions::new(self.name())
51                .unit("iter/sec")
52                .precision(2),
53        )
54    }
55
56    fn clear(&mut self) {
57        self.instant = None;
58    }
59
60    fn name(&self) -> MetricName {
61        self.name.clone()
62    }
63}
64
65impl Numeric for IterationSpeedMetric {
66    fn value(&self) -> super::NumericEntry {
67        self.state.value()
68    }
69}