use crate::env::{Env, StepResult};
use crate::error::Result;
use crate::macros::delegate_env;
#[derive(Debug)]
pub struct Autoreset<E: Env> {
env: E,
}
impl<E: Env> Autoreset<E> {
#[must_use]
pub const fn new(env: E) -> Self {
Self { env }
}
#[must_use]
pub const fn inner(&self) -> &E {
&self.env
}
#[must_use]
pub const fn inner_mut(&mut self) -> &mut E {
&mut self.env
}
#[must_use]
pub fn into_inner(self) -> E {
self.env
}
}
impl<E: Env> Env for Autoreset<E> {
type Obs = E::Obs;
type Act = E::Act;
type ObsSpace = E::ObsSpace;
type ActSpace = E::ActSpace;
fn step(&mut self, action: &Self::Act) -> Result<StepResult<Self::Obs>> {
let result = self.env.step(action)?;
if result.terminated || result.truncated {
let reset = self.env.reset(None)?;
let mut info = result.info;
info.insert("autoreset".to_owned(), crate::env::InfoValue::Bool(true));
return Ok(StepResult {
obs: reset.obs,
reward: result.reward,
terminated: result.terminated,
truncated: result.truncated,
info,
});
}
Ok(result)
}
delegate_env!(
env,
reset,
render,
close,
render_mode,
observation_space,
action_space
);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::envs::classic_control::{CartPoleConfig, CartPoleEnv};
#[test]
fn autoreset_on_termination() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut env = Autoreset::new(env);
env.reset(Some(0)).unwrap();
let mut terminated = false;
for _ in 0..500 {
let r = env.step(&1).unwrap();
if r.terminated || r.truncated {
terminated = true;
assert_eq!(r.obs.len(), 4);
assert!(r.info.contains_key("autoreset"));
break;
}
}
assert!(terminated, "should terminate within 500 steps");
let r = env.step(&0).unwrap();
assert_eq!(r.obs.len(), 4);
}
}