burn_train/metric/
learning_rate.rs1use std::sync::Arc;
2
3use super::{
4 MetricMetadata, Numeric,
5 state::{FormatOptions, NumericMetricState},
6};
7use crate::metric::{Metric, MetricEntry, MetricName};
8
9#[derive(Clone)]
11pub struct LearningRateMetric {
12 name: MetricName,
13 state: NumericMetricState,
14}
15
16impl LearningRateMetric {
17 pub fn new() -> Self {
19 Self {
20 name: Arc::new("Learning Rate".to_string()),
21 state: NumericMetricState::new(),
22 }
23 }
24}
25
26impl Default for LearningRateMetric {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl Metric for LearningRateMetric {
33 type Input = ();
34
35 fn update(&mut self, _item: &(), metadata: &MetricMetadata) -> MetricEntry {
36 let lr = metadata.lr.unwrap_or(0.0);
37
38 self.state
39 .update(lr, 1, FormatOptions::new(self.name()).precision(2))
40 }
41
42 fn clear(&mut self) {
43 self.state.reset()
44 }
45
46 fn name(&self) -> MetricName {
47 self.name.clone()
48 }
49}
50
51impl Numeric for LearningRateMetric {
52 fn value(&self) -> super::NumericEntry {
53 self.state.value()
54 }
55}