rlevo-environments 0.2.0

RL benchmark environments and landscapes for rlevo (internal crate — use `rlevo` for the full API)
use rlevo_core::{
    base::{Observation, Reward},
    environment::{ConstructableEnv, Environment, EnvironmentError, EpisodeStatus, SnapshotBase},
    render::{AsciiRenderable, StyledFrame},
};

/// Wraps an environment and truncates episodes after `max_steps` steps.
///
/// The inner environment's physics and termination logic are unchanged.
/// When the step counter reaches `max_steps` and the inner environment has
/// not already terminated, the snapshot status is upgraded from `Running`
/// to `Truncated`. This matches the Gymnasium `TimeLimit` wrapper
/// semantics.
///
/// Construct a `TimeLimit` with [`TimeLimit::new`], passing an already-built
/// inner environment and the step budget. Call [`reset`](TimeLimit::reset)
/// before the first [`step`](TimeLimit::step); the step counter resets to
/// zero on every `reset` call.
///
/// `TimeLimit` implements [`Environment`] for any inner env whose
/// `SnapshotType` is [`SnapshotBase`], [`AsciiRenderable`] by forwarding
/// to the wrapped env, and `Classic2DPayloadSource` for structured
/// post-run playback.
pub struct TimeLimit<E> {
    inner: E,
    max_steps: usize,
    steps: usize,
}

impl<E> TimeLimit<E> {
    /// Wrap `env` with a hard step cap of `max_steps`.
    pub fn new(env: E, max_steps: usize) -> Self {
        Self {
            inner: env,
            max_steps,
            steps: 0,
        }
    }

    /// Access the inner environment.
    pub fn inner(&self) -> &E {
        &self.inner
    }

    /// Mutably access the inner environment.
    pub fn inner_mut(&mut self) -> &mut E {
        &mut self.inner
    }

    /// Number of steps taken since the last `reset`.
    pub fn steps(&self) -> usize {
        self.steps
    }
}

impl<E> std::fmt::Debug for TimeLimit<E>
where
    E: std::fmt::Debug,
{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("TimeLimit")
            .field("steps", &self.steps)
            .field("max_steps", &self.max_steps)
            .field("inner", &self.inner)
            .finish()
    }
}

/// A `TimeLimit` over a constructable inner env is itself constructable: it
/// builds the inner via [`ConstructableEnv`] and wraps it with no limit
/// (`usize::MAX`). Prefer the inherent `TimeLimit::new(env, max)` for real
/// use — this exists so generic `E: ConstructableEnv` code composes.
impl<E: ConstructableEnv> ConstructableEnv for TimeLimit<E> {
    fn new(render: bool) -> Self {
        Self::new(E::new(render), usize::MAX)
    }
}

/// `TimeLimit` implements `Environment` for any inner env whose `SnapshotType`
/// is `SnapshotBase<D, Obs, Rew>`. This constraint lets `step` directly
/// set `snap.status = Truncated` without trait acrobatics.
impl<const D: usize, const SD: usize, const AD: usize, E, Obs, Rew> Environment<D, SD, AD>
    for TimeLimit<E>
where
    E: Environment<
            D,
            SD,
            AD,
            ObservationType = Obs,
            RewardType = Rew,
            SnapshotType = SnapshotBase<D, Obs, Rew>,
        >,
    Obs: Observation<D>,
    Rew: Reward,
{
    type StateType = E::StateType;
    type ObservationType = Obs;
    type ActionType = E::ActionType;
    type RewardType = Rew;
    type SnapshotType = SnapshotBase<D, Obs, Rew>;

    fn reset(&mut self) -> Result<Self::SnapshotType, EnvironmentError> {
        self.steps = 0;
        self.inner.reset()
    }

    fn step(&mut self, action: Self::ActionType) -> Result<Self::SnapshotType, EnvironmentError> {
        let mut snap = self.inner.step(action)?;
        self.steps += 1;
        if snap.status == EpisodeStatus::Running && self.steps >= self.max_steps {
            snap.status = EpisodeStatus::Truncated;
        }
        Ok(snap)
    }
}

/// Forward [`AsciiRenderable`] through to the wrapped env so wrappers that
/// require it (e.g. `rlevo_benchmarks::env_wrappers::TuiEnvTap`) can compose
/// with `TimeLimit<E>` whenever `E` is itself renderable. Mirrors the
/// forwarding impl on `BenchAdapter`.
impl<E> AsciiRenderable for TimeLimit<E>
where
    E: AsciiRenderable,
{
    fn render_ascii(&self) -> String {
        self.inner.render_ascii()
    }

    fn render_styled(&self) -> StyledFrame {
        self.inner.render_styled()
    }
}

/// Forward the optional `Classic2DPayloadSource` through to the wrapped env,
/// so a `TimeLimit` over a classic-control env stays structurally renderable
/// (ADR-0013) — e.g. when a `RecordingTap` records a `TimeLimit`-wrapped env.
impl<E> rlevo_core::render::payload::Classic2DPayloadSource for TimeLimit<E>
where
    E: rlevo_core::render::payload::Classic2DPayloadSource,
{
    fn classic2d_snapshot(&self) -> rlevo_core::render::payload::Classic2DSnapshot {
        self.inner.classic2d_snapshot()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use rlevo_core::{
        base::{Action, Observation, State},
        environment::{Environment, EnvironmentError, EpisodeStatus, Snapshot, SnapshotBase},
        reward::ScalarReward,
    };
    use serde::{Deserialize, Serialize};

    // Minimal stub environment: terminates when position reaches GOAL.
    #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
    struct StubObs {
        pos: i32,
    }

    impl Observation<1> for StubObs {
        fn shape() -> [usize; 1] {
            [1]
        }
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    struct StubState {
        pos: i32,
    }

    impl State<1> for StubState {
        type Observation = StubObs;

        fn shape() -> [usize; 1] {
            [1]
        }

        fn is_valid(&self) -> bool {
            true
        }

        fn numel(&self) -> usize {
            1
        }

        fn observe(&self) -> StubObs {
            StubObs { pos: self.pos }
        }
    }

    #[derive(Debug, Clone, Copy)]
    struct StubAction;

    impl Action<1> for StubAction {
        fn shape() -> [usize; 1] {
            [1]
        }

        fn is_valid(&self) -> bool {
            true
        }
    }

    struct StubEnv {
        pos: i32,
        goal: i32,
    }

    impl StubEnv {
        fn new_at_goal(goal: i32) -> Self {
            Self { pos: 0, goal }
        }
    }

    impl ConstructableEnv for StubEnv {
        fn new(_render: bool) -> Self {
            Self {
                pos: 0,
                goal: i32::MAX,
            }
        }
    }

    impl Environment<1, 1, 1> for StubEnv {
        type StateType = StubState;
        type ObservationType = StubObs;
        type ActionType = StubAction;
        type RewardType = ScalarReward;
        type SnapshotType = SnapshotBase<1, StubObs, ScalarReward>;


        fn reset(&mut self) -> Result<Self::SnapshotType, EnvironmentError> {
            self.pos = 0;
            Ok(SnapshotBase::running(StubObs { pos: 0 }, ScalarReward(0.0)))
        }

        fn step(&mut self, _action: StubAction) -> Result<Self::SnapshotType, EnvironmentError> {
            self.pos += 1;
            if self.pos >= self.goal {
                Ok(SnapshotBase::terminated(
                    StubObs { pos: self.pos },
                    ScalarReward(1.0),
                ))
            } else {
                Ok(SnapshotBase::running(
                    StubObs { pos: self.pos },
                    ScalarReward(0.0),
                ))
            }
        }
    }

    #[test]
    fn truncated_at_step_cap() {
        let env = StubEnv::new_at_goal(100); // goal unreachable in 3 steps
        let mut timed = TimeLimit::new(env, 3);
        timed.reset().unwrap();

        let s1 = timed.step(StubAction).unwrap();
        assert_eq!(s1.status, EpisodeStatus::Running);

        let s2 = timed.step(StubAction).unwrap();
        assert_eq!(s2.status, EpisodeStatus::Running);

        let s3 = timed.step(StubAction).unwrap();
        assert_eq!(s3.status, EpisodeStatus::Truncated);
        assert!(s3.is_truncated());
        assert!(!s3.is_terminated());
        assert!(s3.is_done());
    }

    #[test]
    fn terminated_before_cap() {
        let env = StubEnv::new_at_goal(2); // terminates at step 2
        let mut timed = TimeLimit::new(env, 10);
        timed.reset().unwrap();

        let s1 = timed.step(StubAction).unwrap();
        assert_eq!(s1.status, EpisodeStatus::Running);

        let s2 = timed.step(StubAction).unwrap();
        assert_eq!(s2.status, EpisodeStatus::Terminated);
        assert!(!s2.is_truncated());
    }

    #[test]
    fn reset_clears_step_count() {
        let env = StubEnv::new_at_goal(100);
        let mut timed = TimeLimit::new(env, 2);
        timed.reset().unwrap();

        timed.step(StubAction).unwrap();
        timed.step(StubAction).unwrap();

        // After reset, step count should restart
        timed.reset().unwrap();
        assert_eq!(timed.steps(), 0);

        let s1 = timed.step(StubAction).unwrap();
        assert_eq!(s1.status, EpisodeStatus::Running);
        let s2 = timed.step(StubAction).unwrap();
        assert_eq!(s2.status, EpisodeStatus::Truncated);
    }
}