burn_train/metric/rl/
exploration_rate.rs1use std::sync::Arc;
2
3use super::super::{
4 MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
5 state::{FormatOptions, NumericMetricState},
6};
7use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
8
9#[derive(Clone)]
11pub struct ExplorationRateMetric {
12 name: MetricName,
13 state: NumericMetricState,
14}
15
16impl ExplorationRateMetric {
17 pub fn new() -> Self {
19 Self {
20 name: Arc::new("Exploration rate".to_string()),
21 state: NumericMetricState::new(),
22 }
23 }
24}
25
26impl Default for ExplorationRateMetric {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32#[derive(new)]
34pub struct ExplorationRateInput {
35 exploration_rate: f64,
36}
37
38impl Metric for ExplorationRateMetric {
39 type Input = ExplorationRateInput;
40
41 fn update(
42 &mut self,
43 item: &ExplorationRateInput,
44 _metadata: &MetricMetadata,
45 ) -> SerializedEntry {
46 self.state.update(
47 item.exploration_rate,
48 1,
49 FormatOptions::new(self.name()).precision(3),
50 )
51 }
52
53 fn clear(&mut self) {
54 self.state.reset()
55 }
56
57 fn name(&self) -> MetricName {
58 self.name.clone()
59 }
60
61 fn attributes(&self) -> MetricAttributes {
62 NumericAttributes {
63 unit: Some(String::from("%")),
64 higher_is_better: false,
65 }
66 .into()
67 }
68}
69
70impl Numeric for ExplorationRateMetric {
71 fn value(&self) -> NumericEntry {
72 self.state.current_value()
73 }
74
75 fn running_value(&self) -> NumericEntry {
76 self.state.running_value()
77 }
78}