Skip to main content

rl_traits/
wrappers.rs

1use rand::Rng;
2
3use crate::environment::Environment;
4use crate::episode::{EpisodeStatus, StepResult};
5
6/// A marker trait for environments that wrap another environment.
7///
8/// Wrappers modify environment behaviour without changing its interface —
9/// exactly like Gymnasium's wrapper system, but type-safe.
10///
11/// # Examples of wrappers (to be implemented in user code or ember-rl)
12///
13/// - `TimeLimit<E>`: truncate episodes after N steps
14/// - `NormalizeObs<E>`: normalize observations to zero mean, unit variance
15/// - `ClipReward<E>`: clip rewards to a fixed range
16/// - `FrameStack<E>`: stack the last N observations
17///
18/// # Note on associated types
19///
20/// A wrapper may change `Observation` or `Action` types (e.g. `FrameStack`
21/// changes the observation shape). When types pass through unchanged,
22/// use `type Observation = E::Observation` etc.
23pub trait Wrapper: Environment {
24    type Inner: Environment;
25
26    fn inner(&self) -> &Self::Inner;
27    fn inner_mut(&mut self) -> &mut Self::Inner;
28
29    /// Unwrap all layers and return a reference to the base environment.
30    fn unwrapped(&self) -> &Self::Inner {
31        self.inner()
32    }
33}
34
35/// Wraps an environment and truncates episodes after `max_steps` steps.
36///
37/// This is one of the most universally needed wrappers. Without it,
38/// environments without natural termination conditions (e.g. locomotion
39/// tasks) run forever.
40///
41/// Episodes truncated by this wrapper emit `EpisodeStatus::Truncated`,
42/// not `EpisodeStatus::Terminated`, so algorithms correctly bootstrap
43/// the value of the final state.
44pub struct TimeLimit<E: Environment> {
45    env: E,
46    max_steps: usize,
47    current_step: usize,
48}
49
50impl<E: Environment> TimeLimit<E> {
51    pub fn new(env: E, max_steps: usize) -> Self {
52        Self {
53            env,
54            max_steps,
55            current_step: 0,
56        }
57    }
58
59    pub fn elapsed_steps(&self) -> usize {
60        self.current_step
61    }
62
63    pub fn remaining_steps(&self) -> usize {
64        self.max_steps.saturating_sub(self.current_step)
65    }
66}
67
68impl<E: Environment> Environment for TimeLimit<E> {
69    type Observation = E::Observation;
70    type Action = E::Action;
71    type Info = E::Info;
72
73    fn step(&mut self, action: Self::Action) -> StepResult<Self::Observation, Self::Info> {
74        let mut result = self.env.step(action);
75        self.current_step += 1;
76
77        // Only truncate if the environment hasn't already terminated naturally.
78        // We don't override Terminated with Truncated — natural termination wins.
79        if self.current_step >= self.max_steps && result.status == EpisodeStatus::Continuing {
80            result.status = EpisodeStatus::Truncated;
81        }
82
83        result
84    }
85
86    fn reset(&mut self, seed: Option<u64>) -> (Self::Observation, Self::Info) {
87        self.current_step = 0;
88        self.env.reset(seed)
89    }
90
91    fn sample_action(&self, rng: &mut impl Rng) -> Self::Action {
92        self.env.sample_action(rng)
93    }
94}
95
96impl<E: Environment> Wrapper for TimeLimit<E> {
97    type Inner = E;
98    fn inner(&self) -> &E {
99        &self.env
100    }
101    fn inner_mut(&mut self) -> &mut E {
102        &mut self.env
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::episode::{EpisodeStatus, StepResult};
110
111    /// Minimal environment whose termination behaviour is controlled by a
112    /// closure, so each test can specify exactly when it terminates.
113    struct MockEnv<F: FnMut(usize) -> EpisodeStatus> {
114        step_count: usize,
115        status_fn: F,
116    }
117
118    impl<F: FnMut(usize) -> EpisodeStatus> MockEnv<F> {
119        fn new(status_fn: F) -> Self {
120            Self { step_count: 0, status_fn }
121        }
122    }
123
124    impl<F: FnMut(usize) -> EpisodeStatus + Send + Sync + 'static> Environment for MockEnv<F> {
125        type Observation = ();
126        type Action = ();
127        type Info = ();
128
129        fn step(&mut self, _action: ()) -> StepResult<(), ()> {
130            self.step_count += 1;
131            let status = (self.status_fn)(self.step_count);
132            StepResult::new((), 1.0, status, ())
133        }
134
135        fn reset(&mut self, _seed: Option<u64>) -> ((), ()) {
136            self.step_count = 0;
137            ((), ())
138        }
139
140        fn sample_action(&self, _rng: &mut impl rand::Rng) -> () {}
141    }
142
143    // ── TimeLimit ────────────────────────────────────────────────────────────
144
145    #[test]
146    fn steps_below_limit_pass_through_status_unchanged() {
147        let mut env = TimeLimit::new(
148            MockEnv::new(|_| EpisodeStatus::Continuing),
149            5,
150        );
151        env.reset(None);
152        for _ in 0..4 {
153            let r = env.step(());
154            assert_eq!(r.status, EpisodeStatus::Continuing);
155        }
156    }
157
158    #[test]
159    fn truncates_at_limit_when_inner_is_continuing() {
160        let mut env = TimeLimit::new(
161            MockEnv::new(|_| EpisodeStatus::Continuing),
162            3,
163        );
164        env.reset(None);
165        env.step(());
166        env.step(());
167        let r = env.step(());
168        assert_eq!(r.status, EpisodeStatus::Truncated);
169    }
170
171    #[test]
172    fn natural_termination_wins_over_time_limit_on_same_step() {
173        // If the environment terminates naturally on the exact step that the
174        // time limit would fire, the result must be Terminated, not Truncated.
175        // Confusing these would cause an algorithm to incorrectly bootstrap the
176        // terminal state's value.
177        let mut env = TimeLimit::new(
178            MockEnv::new(|n| {
179                if n >= 3 { EpisodeStatus::Terminated } else { EpisodeStatus::Continuing }
180            }),
181            3,
182        );
183        env.reset(None);
184        env.step(());
185        env.step(());
186        let r = env.step(());
187        assert_eq!(r.status, EpisodeStatus::Terminated);
188    }
189
190    #[test]
191    fn reset_restores_step_counter_so_limit_applies_again() {
192        let mut env = TimeLimit::new(
193            MockEnv::new(|_| EpisodeStatus::Continuing),
194            2,
195        );
196        env.reset(None);
197        env.step(());
198        let r = env.step(());
199        assert_eq!(r.status, EpisodeStatus::Truncated);
200
201        env.reset(None);
202        let r = env.step(());
203        assert_eq!(r.status, EpisodeStatus::Continuing, "step counter not reset");
204    }
205}