Skip to main content

burn_train/metric/rl/
exploration_rate.rs

1use std::sync::Arc;
2
3use super::super::{
4    MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
5    state::{FormatOptions, NumericMetricState},
6};
7use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
8
9/// Metric for the length of the last completed episode.
10#[derive(Clone)]
11pub struct ExplorationRateMetric {
12    name: MetricName,
13    state: NumericMetricState,
14}
15
16impl ExplorationRateMetric {
17    /// Creates a new episode length metric.
18    pub fn new() -> Self {
19        Self {
20            name: Arc::new("Exploration rate".to_string()),
21            state: NumericMetricState::new(),
22        }
23    }
24}
25
26impl Default for ExplorationRateMetric {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32/// The [ExplorationRateMetric](ExplorationRateMetric) input type.
33#[derive(new)]
34pub struct ExplorationRateInput {
35    exploration_rate: f64,
36}
37
38impl Metric for ExplorationRateMetric {
39    type Input = ExplorationRateInput;
40
41    fn update(
42        &mut self,
43        item: &ExplorationRateInput,
44        _metadata: &MetricMetadata,
45    ) -> SerializedEntry {
46        self.state.update(
47            item.exploration_rate,
48            1,
49            FormatOptions::new(self.name()).precision(3),
50        )
51    }
52
53    fn clear(&mut self) {
54        self.state.reset()
55    }
56
57    fn name(&self) -> MetricName {
58        self.name.clone()
59    }
60
61    fn attributes(&self) -> MetricAttributes {
62        NumericAttributes {
63            unit: Some(String::from("%")),
64            higher_is_better: false,
65        }
66        .into()
67    }
68}
69
70impl Numeric for ExplorationRateMetric {
71    fn value(&self) -> NumericEntry {
72        self.state.current_value()
73    }
74
75    fn running_value(&self) -> NumericEntry {
76        self.state.running_value()
77    }
78}