burn_train/metric/
learning_rate.rs1use super::{
2 MetricMetadata, Numeric,
3 state::{FormatOptions, NumericMetricState},
4};
5use crate::metric::{Metric, MetricEntry};
6
7pub struct LearningRateMetric {
9 state: NumericMetricState,
10}
11
12impl LearningRateMetric {
13 pub fn new() -> Self {
15 Self {
16 state: NumericMetricState::new(),
17 }
18 }
19}
20
21impl Default for LearningRateMetric {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl Metric for LearningRateMetric {
28 type Input = ();
29
30 fn update(&mut self, _item: &(), metadata: &MetricMetadata) -> MetricEntry {
31 let lr = metadata.lr.unwrap_or(0.0);
32
33 self.state
34 .update(lr, 1, FormatOptions::new(self.name()).precision(2))
35 }
36
37 fn clear(&mut self) {
38 self.state.reset()
39 }
40
41 fn name(&self) -> String {
42 "Learning Rate".to_string()
43 }
44}
45
46impl Numeric for LearningRateMetric {
47 fn value(&self) -> f64 {
48 self.state.value()
49 }
50}