burn_train/metric/rl/
ep_len.rs1use std::sync::Arc;
2
3use super::super::{
4 MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,
5 state::{FormatOptions, NumericMetricState},
6};
7use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
8
9#[derive(Clone)]
11pub struct EpisodeLengthMetric {
12 name: MetricName,
13 state: NumericMetricState,
14}
15
16impl EpisodeLengthMetric {
17 pub fn new() -> Self {
19 Self {
20 name: Arc::new("Episode length".to_string()),
21 state: NumericMetricState::new(),
22 }
23 }
24}
25
26impl Default for EpisodeLengthMetric {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32#[derive(new)]
34pub struct EpisodeLengthInput {
35 ep_len: f64,
36}
37
38impl Metric for EpisodeLengthMetric {
39 type Input = EpisodeLengthInput;
40
41 fn update(&mut self, item: &EpisodeLengthInput, _metadata: &MetricMetadata) -> SerializedEntry {
42 self.state
43 .update(item.ep_len, 1, FormatOptions::new(self.name()).precision(0))
44 }
45
46 fn clear(&mut self) {
47 self.state.reset()
48 }
49
50 fn name(&self) -> MetricName {
51 self.name.clone()
52 }
53
54 fn attributes(&self) -> MetricAttributes {
55 NumericAttributes {
56 unit: Some(String::from("steps")),
57 higher_is_better: true,
58 }
59 .into()
60 }
61}
62
63impl Numeric for EpisodeLengthMetric {
64 fn value(&self) -> NumericEntry {
65 self.state.current_value()
66 }
67
68 fn running_value(&self) -> NumericEntry {
69 self.state.running_value()
70 }
71}