burn_train/metric/
learning_rate.rs

1use super::{
2    MetricMetadata, Numeric,
3    state::{FormatOptions, NumericMetricState},
4};
5use crate::metric::{Metric, MetricEntry};
6
7/// Track the learning rate across iterations.
8pub struct LearningRateMetric {
9    state: NumericMetricState,
10}
11
12impl LearningRateMetric {
13    /// Creates a new learning rate metric.
14    pub fn new() -> Self {
15        Self {
16            state: NumericMetricState::new(),
17        }
18    }
19}
20
21impl Default for LearningRateMetric {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl Metric for LearningRateMetric {
28    type Input = ();
29
30    fn update(&mut self, _item: &(), metadata: &MetricMetadata) -> MetricEntry {
31        let lr = metadata.lr.unwrap_or(0.0);
32
33        self.state
34            .update(lr, 1, FormatOptions::new(self.name()).precision(2))
35    }
36
37    fn clear(&mut self) {
38        self.state.reset()
39    }
40
41    fn name(&self) -> String {
42        "Learning Rate".to_string()
43    }
44}
45
46impl Numeric for LearningRateMetric {
47    fn value(&self) -> f64 {
48        self.state.value()
49    }
50}