use std::collections::VecDeque;
use crate::core::{Env, StepResult};
use super::Wrapper;
pub struct RecordEpisodeStatistics<E: Env> {
env: E,
episode_return: f64,
episode_length: usize,
episode_count: usize,
return_history: VecDeque<f64>,
length_history: VecDeque<usize>,
buffer_length: usize,
}
impl<E: Env> RecordEpisodeStatistics<E> {
pub fn new(env: E) -> Self {
Self::with_buffer(env, 100)
}
pub fn with_buffer(env: E, buffer_length: usize) -> Self {
Self {
env,
episode_return: 0.0,
episode_length: 0,
episode_count: 0,
return_history: VecDeque::with_capacity(buffer_length),
length_history: VecDeque::with_capacity(buffer_length),
buffer_length,
}
}
pub fn episode_return(&self) -> f64 {
self.episode_return
}
pub fn episode_length(&self) -> usize {
self.episode_length
}
pub fn episode_count(&self) -> usize {
self.episode_count
}
pub fn return_history(&self) -> &VecDeque<f64> {
&self.return_history
}
pub fn length_history(&self) -> &VecDeque<usize> {
&self.length_history
}
fn record_episode(&mut self) {
self.episode_count += 1;
if self.return_history.len() >= self.buffer_length {
self.return_history.pop_front();
self.length_history.pop_front();
}
self.return_history.push_back(self.episode_return);
self.length_history.push_back(self.episode_length);
}
}
impl<E: Env> Env for RecordEpisodeStatistics<E> {
type Action = E::Action;
type Observation = E::Observation;
type ActionSpace = E::ActionSpace;
type ObservationSpace = E::ObservationSpace;
type ResetOptions = E::ResetOptions;
fn step(&mut self, action: Self::Action) -> StepResult<Self::Observation> {
let result = self.env.step(action);
self.episode_return += result.reward;
self.episode_length += 1;
if result.terminated || result.truncated {
self.record_episode();
}
result
}
fn reset(&mut self, seed: Option<u64>, options: Self::ResetOptions) -> Self::Observation {
self.episode_return = 0.0;
self.episode_length = 0;
self.env.reset(seed, options)
}
fn action_space(&self) -> &Self::ActionSpace {
self.env.action_space()
}
fn observation_space(&self) -> &Self::ObservationSpace {
self.env.observation_space()
}
fn close(&mut self) {
self.env.close();
}
}
impl<E: Env> Wrapper for RecordEpisodeStatistics<E> {
type Inner = E;
fn inner(&self) -> &E {
&self.env
}
fn inner_mut(&mut self) -> &mut E {
&mut self.env
}
fn into_inner(self) -> E {
self.env
}
}