burn_train/metric/
learning_rate.rs

1use std::sync::Arc;
2
3use super::{
4    MetricMetadata, Numeric,
5    state::{FormatOptions, NumericMetricState},
6};
7use crate::metric::{Metric, MetricEntry, MetricName};
8
9/// Track the learning rate across iterations.
10#[derive(Clone)]
11pub struct LearningRateMetric {
12    name: MetricName,
13    state: NumericMetricState,
14}
15
16impl LearningRateMetric {
17    /// Creates a new learning rate metric.
18    pub fn new() -> Self {
19        Self {
20            name: Arc::new("Learning Rate".to_string()),
21            state: NumericMetricState::new(),
22        }
23    }
24}
25
26impl Default for LearningRateMetric {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl Metric for LearningRateMetric {
33    type Input = ();
34
35    fn update(&mut self, _item: &(), metadata: &MetricMetadata) -> MetricEntry {
36        let lr = metadata.lr.unwrap_or(0.0);
37
38        self.state
39            .update(lr, 1, FormatOptions::new(self.name()).precision(2))
40    }
41
42    fn clear(&mut self) {
43        self.state.reset()
44    }
45
46    fn name(&self) -> MetricName {
47        self.name.clone()
48    }
49}
50
51impl Numeric for LearningRateMetric {
52    fn value(&self) -> super::NumericEntry {
53        self.state.value()
54    }
55}