burn_train/metric/
learning_rate.rs1use super::{
2 state::{FormatOptions, NumericMetricState},
3 MetricMetadata, Numeric,
4};
5use crate::metric::{Metric, MetricEntry};
6
7pub struct LearningRateMetric {
9 state: NumericMetricState,
10}
11
12impl LearningRateMetric {
13 pub fn new() -> Self {
15 Self {
16 state: NumericMetricState::new(),
17 }
18 }
19}
20
21impl Default for LearningRateMetric {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl Metric for LearningRateMetric {
28 const NAME: &'static str = "Learning Rate";
29
30 type Input = ();
31
32 fn update(&mut self, _item: &(), metadata: &MetricMetadata) -> MetricEntry {
33 let lr = metadata.lr.unwrap_or(0.0);
34
35 self.state
36 .update(lr, 1, FormatOptions::new("Learning Rate").precision(2))
37 }
38
39 fn clear(&mut self) {
40 self.state.reset()
41 }
42}
43
44impl Numeric for LearningRateMetric {
45 fn value(&self) -> f64 {
46 self.state.value()
47 }
48}