use std::collections::VecDeque;
use crate::env::{Env, ResetResult, StepResult};
use crate::error::Result;
use crate::macros::delegate_env;
#[derive(Debug)]
pub struct DelayObservation<E: Env>
where
E::Obs: Clone,
{
env: E,
delay: usize,
queue: VecDeque<E::Obs>,
zero_fn: fn() -> E::Obs,
}
impl<E: Env> DelayObservation<E>
where
E::Obs: Clone,
{
#[must_use]
pub fn new(env: E, delay: usize, zero_fn: fn() -> E::Obs) -> Self {
Self {
env,
delay,
queue: VecDeque::with_capacity(delay + 1),
zero_fn,
}
}
#[must_use]
pub const fn delay(&self) -> usize {
self.delay
}
#[must_use]
pub const fn inner(&self) -> &E {
&self.env
}
#[must_use]
pub const fn inner_mut(&mut self) -> &mut E {
&mut self.env
}
fn delayed_obs(&mut self, obs: E::Obs) -> E::Obs {
self.queue.push_back(obs);
if self.queue.len() > self.delay {
self.queue.pop_front().expect("queue non-empty")
} else {
(self.zero_fn)()
}
}
}
impl<E: Env> Env for DelayObservation<E>
where
E::Obs: Clone,
{
type Obs = E::Obs;
type Act = E::Act;
type ObsSpace = E::ObsSpace;
type ActSpace = E::ActSpace;
fn step(&mut self, action: &Self::Act) -> Result<StepResult<Self::Obs>> {
let mut r = self.env.step(action)?;
r.obs = self.delayed_obs(r.obs);
Ok(r)
}
fn reset(&mut self, seed: Option<u64>) -> Result<ResetResult<Self::Obs>> {
self.queue.clear();
let mut r = self.env.reset(seed)?;
r.obs = self.delayed_obs(r.obs);
Ok(r)
}
delegate_env!(env);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::envs::classic_control::{CartPoleConfig, CartPoleEnv};
use crate::space::Space;
#[test]
fn delay_zero_is_passthrough() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut direct = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut delayed = DelayObservation::new(env, 0, || vec![0.0; 4]);
let r1 = direct.reset(Some(42)).unwrap();
let r2 = delayed.reset(Some(42)).unwrap();
assert_eq!(r1.obs, r2.obs);
}
#[test]
fn delay_returns_zeros_initially() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut delayed = DelayObservation::new(env, 2, || vec![0.0_f32; 4]);
let r = delayed.reset(Some(42)).unwrap();
assert_eq!(r.obs, vec![0.0_f32; 4]);
let s1 = delayed.step(&0).unwrap();
assert_eq!(s1.obs, vec![0.0_f32; 4]);
}
#[test]
fn delay_returns_real_obs_after_delay() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut direct = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut delayed = DelayObservation::new(env, 1, || vec![0.0_f32; 4]);
let direct_reset = direct.reset(Some(42)).unwrap();
let _ = delayed.reset(Some(42)).unwrap();
let s = delayed.step(&0).unwrap();
assert_eq!(s.obs, direct_reset.obs);
}
#[test]
fn reset_clears_queue() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut delayed = DelayObservation::new(env, 1, || vec![0.0_f32; 4]);
delayed.reset(Some(42)).unwrap();
delayed.step(&0).unwrap();
let r = delayed.reset(Some(99)).unwrap();
assert_eq!(r.obs, vec![0.0_f32; 4]);
}
#[test]
fn delegates_spaces() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let wrapped = DelayObservation::new(env, 3, || vec![0.0_f32; 4]);
assert_eq!(wrapped.observation_space().shape(), &[4]);
assert_eq!(wrapped.action_space().n, 2);
}
}