use crate::env::{Env, StepResult};
use crate::error::Result;
use crate::macros::delegate_env;
#[derive(Debug)]
pub struct MaxAndSkipObservation<E>
where
E: Env<Obs = Vec<f32>>,
{
env: E,
skip: usize,
}
impl<E> MaxAndSkipObservation<E>
where
E: Env<Obs = Vec<f32>>,
{
pub fn new(env: E, skip: usize) -> Result<Self> {
if skip == 0 {
return Err(crate::error::Error::InvalidSpace {
reason: "skip must be >= 1".to_owned(),
});
}
Ok(Self { env, skip })
}
#[must_use]
pub const fn inner(&self) -> &E {
&self.env
}
#[must_use]
pub const fn inner_mut(&mut self) -> &mut E {
&mut self.env
}
#[must_use]
pub fn into_inner(self) -> E {
self.env
}
}
impl<E> Env for MaxAndSkipObservation<E>
where
E: Env<Obs = Vec<f32>>,
{
type Obs = Vec<f32>;
type Act = E::Act;
type ObsSpace = E::ObsSpace;
type ActSpace = E::ActSpace;
fn step(&mut self, action: &Self::Act) -> Result<StepResult<Self::Obs>> {
let mut total_reward = 0.0;
let mut prev_obs: Option<Vec<f32>> = None;
let mut result = self.env.step(action)?;
for i in 0..self.skip {
if i > 0 {
result = self.env.step(action)?;
}
total_reward += result.reward;
if result.terminated || result.truncated {
break;
}
if i == self.skip.saturating_sub(2) {
prev_obs = Some(result.obs.clone());
}
}
if let Some(prev) = prev_obs {
let maxed: Vec<f32> = result
.obs
.iter()
.zip(prev.iter())
.map(|(&a, &b)| a.max(b))
.collect();
result.obs = maxed;
}
result.reward = total_reward;
Ok(result)
}
delegate_env!(
env,
reset,
render,
close,
render_mode,
observation_space,
action_space
);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::envs::classic_control::{CartPoleConfig, CartPoleEnv};
#[test]
fn skip_accumulates_reward() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut env = MaxAndSkipObservation::new(env, 4).unwrap();
env.reset(Some(42)).unwrap();
let r = env.step(&0).unwrap();
assert!(r.reward >= 1.0);
assert_eq!(r.obs.len(), 4);
}
#[test]
fn skip_one_is_passthrough() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut env = MaxAndSkipObservation::new(env, 1).unwrap();
env.reset(Some(42)).unwrap();
let r = env.step(&0).unwrap();
assert!((r.reward - 1.0).abs() < f64::EPSILON);
}
#[test]
fn rejects_zero_skip() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
assert!(MaxAndSkipObservation::new(env, 0).is_err());
}
}