Skip to main content

burn_train/metric/rl/
cum_reward.rs

1use std::sync::Arc;
2
3use super::super::{
4    MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
5    state::{FormatOptions, NumericMetricState},
6};
7use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
8
9/// Metric for the cumulative reward of the last completed episode.
10#[derive(Clone)]
11pub struct CumulativeRewardMetric {
12    name: MetricName,
13    state: NumericMetricState,
14}
15
16impl CumulativeRewardMetric {
17    /// Creates a new episode length metric.
18    pub fn new() -> Self {
19        Self {
20            name: Arc::new("Cum. Reward".to_string()),
21            state: NumericMetricState::new(),
22        }
23    }
24}
25
26impl Default for CumulativeRewardMetric {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32/// The [CumulativeRewardMetric](CumulativeRewardMetric) input type.
33#[derive(new)]
34pub struct CumulativeRewardInput {
35    cum_reward: f64,
36}
37
38impl Metric for CumulativeRewardMetric {
39    type Input = CumulativeRewardInput;
40
41    fn update(
42        &mut self,
43        item: &CumulativeRewardInput,
44        _metadata: &MetricMetadata,
45    ) -> SerializedEntry {
46        self.state.update(
47            item.cum_reward,
48            1,
49            FormatOptions::new(self.name()).precision(2),
50        )
51    }
52
53    fn clear(&mut self) {
54        self.state.reset()
55    }
56
57    fn name(&self) -> MetricName {
58        self.name.clone()
59    }
60
61    fn attributes(&self) -> MetricAttributes {
62        NumericAttributes {
63            unit: None,
64            higher_is_better: true,
65        }
66        .into()
67    }
68}
69
70impl Numeric for CumulativeRewardMetric {
71    fn value(&self) -> NumericEntry {
72        self.state.current_value()
73    }
74
75    fn running_value(&self) -> NumericEntry {
76        self.state.running_value()
77    }
78}