burn_train/metric/
learning_rate.rs

1use super::{
2    state::{FormatOptions, NumericMetricState},
3    MetricMetadata, Numeric,
4};
5use crate::metric::{Metric, MetricEntry};
6
7/// Track the learning rate across iterations.
8pub struct LearningRateMetric {
9    state: NumericMetricState,
10}
11
12impl LearningRateMetric {
13    /// Creates a new learning rate metric.
14    pub fn new() -> Self {
15        Self {
16            state: NumericMetricState::new(),
17        }
18    }
19}
20
21impl Default for LearningRateMetric {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl Metric for LearningRateMetric {
28    const NAME: &'static str = "Learning Rate";
29
30    type Input = ();
31
32    fn update(&mut self, _item: &(), metadata: &MetricMetadata) -> MetricEntry {
33        let lr = metadata.lr.unwrap_or(0.0);
34
35        self.state
36            .update(lr, 1, FormatOptions::new("Learning Rate").precision(2))
37    }
38
39    fn clear(&mut self) {
40        self.state.reset()
41    }
42}
43
44impl Numeric for LearningRateMetric {
45    fn value(&self) -> f64 {
46        self.state.value()
47    }
48}