Skip to main content

burn_train/metric/rl/
ep_len.rs

1use std::sync::Arc;
2
3use super::super::{
4    MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
5    state::{FormatOptions, NumericMetricState},
6};
7use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
8
9/// Metric for the length of the last completed episode.
10#[derive(Clone)]
11pub struct EpisodeLengthMetric {
12    name: MetricName,
13    state: NumericMetricState,
14}
15
16impl EpisodeLengthMetric {
17    /// Creates a new episode length metric.
18    pub fn new() -> Self {
19        Self {
20            name: Arc::new("Episode length".to_string()),
21            state: NumericMetricState::new(),
22        }
23    }
24}
25
26impl Default for EpisodeLengthMetric {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32/// The [EpisodeLengthMetric](EpisodeLengthMetric) input type.
33#[derive(new)]
34pub struct EpisodeLengthInput {
35    ep_len: f64,
36}
37
38impl Metric for EpisodeLengthMetric {
39    type Input = EpisodeLengthInput;
40
41    fn update(&mut self, item: &EpisodeLengthInput, _metadata: &MetricMetadata) -> SerializedEntry {
42        self.state
43            .update(item.ep_len, 1, FormatOptions::new(self.name()).precision(0))
44    }
45
46    fn clear(&mut self) {
47        self.state.reset()
48    }
49
50    fn name(&self) -> MetricName {
51        self.name.clone()
52    }
53
54    fn attributes(&self) -> MetricAttributes {
55        NumericAttributes {
56            unit: Some(String::from("steps")),
57            higher_is_better: true,
58        }
59        .into()
60    }
61}
62
63impl Numeric for EpisodeLengthMetric {
64    fn value(&self) -> NumericEntry {
65        self.state.current_value()
66    }
67
68    fn running_value(&self) -> NumericEntry {
69        self.state.running_value()
70    }
71}