use crate::env::{Env, ResetResult, StepResult};
use crate::error::Result;
use crate::macros::delegate_env;
#[derive(Debug)]
pub struct TimeLimit<E: Env> {
env: E,
max_episode_steps: u64,
elapsed_steps: Option<u64>,
}
impl<E: Env> TimeLimit<E> {
#[must_use]
pub fn new(env: E, max_episode_steps: u64) -> Self {
assert!(max_episode_steps > 0, "max_episode_steps must be positive");
Self {
env,
max_episode_steps,
elapsed_steps: None,
}
}
#[must_use]
pub const fn max_episode_steps(&self) -> u64 {
self.max_episode_steps
}
#[must_use]
pub const fn elapsed_steps(&self) -> Option<u64> {
self.elapsed_steps
}
#[must_use]
pub const fn inner(&self) -> &E {
&self.env
}
#[must_use]
pub const fn inner_mut(&mut self) -> &mut E {
&mut self.env
}
#[must_use]
pub fn into_inner(self) -> E {
self.env
}
}
impl<E: Env> Env for TimeLimit<E> {
type Obs = E::Obs;
type Act = E::Act;
type ObsSpace = E::ObsSpace;
type ActSpace = E::ActSpace;
fn step(&mut self, action: &Self::Act) -> Result<StepResult<Self::Obs>> {
let mut result = self.env.step(action)?;
let elapsed = self
.elapsed_steps
.as_mut()
.expect("environment must be reset before step");
*elapsed += 1;
if *elapsed >= self.max_episode_steps {
result.truncated = true;
}
Ok(result)
}
fn reset(&mut self, seed: Option<u64>) -> Result<ResetResult<Self::Obs>> {
self.elapsed_steps = Some(0);
self.env.reset(seed)
}
delegate_env!(env);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::envs::classic_control::{CartPoleConfig, CartPoleEnv};
use crate::space::Space;
#[test]
fn truncates_at_limit() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut env = TimeLimit::new(env, 3);
env.reset(Some(42)).unwrap();
let r1 = env.step(&0).unwrap();
assert!(!r1.truncated);
assert_eq!(env.elapsed_steps(), Some(1));
let r2 = env.step(&1).unwrap();
assert!(!r2.truncated);
let r3 = env.step(&0).unwrap();
assert!(r3.truncated, "should truncate at step 3");
assert_eq!(env.elapsed_steps(), Some(3));
}
#[test]
fn reset_clears_counter() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut env = TimeLimit::new(env, 10);
env.reset(Some(0)).unwrap();
env.step(&0).unwrap();
assert_eq!(env.elapsed_steps(), Some(1));
env.reset(Some(1)).unwrap();
assert_eq!(env.elapsed_steps(), Some(0));
}
#[test]
fn delegates_spaces() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let wrapped = TimeLimit::new(env, 500);
assert_eq!(wrapped.observation_space().shape(), &[4]);
assert_eq!(wrapped.action_space().n, 2);
}
}