1use rand::Rng;
2
3use crate::environment::Environment;
4use crate::episode::{EpisodeStatus, StepResult};
5
6pub trait Wrapper: Environment {
24 type Inner: Environment;
25
26 fn inner(&self) -> &Self::Inner;
27 fn inner_mut(&mut self) -> &mut Self::Inner;
28
29 fn unwrapped(&self) -> &Self::Inner {
31 self.inner()
32 }
33}
34
35pub struct TimeLimit<E: Environment> {
45 env: E,
46 max_steps: usize,
47 current_step: usize,
48}
49
50impl<E: Environment> TimeLimit<E> {
51 pub fn new(env: E, max_steps: usize) -> Self {
52 Self {
53 env,
54 max_steps,
55 current_step: 0,
56 }
57 }
58
59 pub fn elapsed_steps(&self) -> usize {
60 self.current_step
61 }
62
63 pub fn remaining_steps(&self) -> usize {
64 self.max_steps.saturating_sub(self.current_step)
65 }
66}
67
68impl<E: Environment> Environment for TimeLimit<E> {
69 type Observation = E::Observation;
70 type Action = E::Action;
71 type Info = E::Info;
72
73 fn step(&mut self, action: Self::Action) -> StepResult<Self::Observation, Self::Info> {
74 let mut result = self.env.step(action);
75 self.current_step += 1;
76
77 if self.current_step >= self.max_steps && result.status == EpisodeStatus::Continuing {
80 result.status = EpisodeStatus::Truncated;
81 }
82
83 result
84 }
85
86 fn reset(&mut self, seed: Option<u64>) -> (Self::Observation, Self::Info) {
87 self.current_step = 0;
88 self.env.reset(seed)
89 }
90
91 fn sample_action(&self, rng: &mut impl Rng) -> Self::Action {
92 self.env.sample_action(rng)
93 }
94}
95
96impl<E: Environment> Wrapper for TimeLimit<E> {
97 type Inner = E;
98 fn inner(&self) -> &E {
99 &self.env
100 }
101 fn inner_mut(&mut self) -> &mut E {
102 &mut self.env
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use crate::episode::{EpisodeStatus, StepResult};
110
111 struct MockEnv<F: FnMut(usize) -> EpisodeStatus> {
114 step_count: usize,
115 status_fn: F,
116 }
117
118 impl<F: FnMut(usize) -> EpisodeStatus> MockEnv<F> {
119 fn new(status_fn: F) -> Self {
120 Self { step_count: 0, status_fn }
121 }
122 }
123
124 impl<F: FnMut(usize) -> EpisodeStatus + Send + Sync + 'static> Environment for MockEnv<F> {
125 type Observation = ();
126 type Action = ();
127 type Info = ();
128
129 fn step(&mut self, _action: ()) -> StepResult<(), ()> {
130 self.step_count += 1;
131 let status = (self.status_fn)(self.step_count);
132 StepResult::new((), 1.0, status, ())
133 }
134
135 fn reset(&mut self, _seed: Option<u64>) -> ((), ()) {
136 self.step_count = 0;
137 ((), ())
138 }
139
140 fn sample_action(&self, _rng: &mut impl rand::Rng) -> () {}
141 }
142
143 #[test]
146 fn steps_below_limit_pass_through_status_unchanged() {
147 let mut env = TimeLimit::new(
148 MockEnv::new(|_| EpisodeStatus::Continuing),
149 5,
150 );
151 env.reset(None);
152 for _ in 0..4 {
153 let r = env.step(());
154 assert_eq!(r.status, EpisodeStatus::Continuing);
155 }
156 }
157
158 #[test]
159 fn truncates_at_limit_when_inner_is_continuing() {
160 let mut env = TimeLimit::new(
161 MockEnv::new(|_| EpisodeStatus::Continuing),
162 3,
163 );
164 env.reset(None);
165 env.step(());
166 env.step(());
167 let r = env.step(());
168 assert_eq!(r.status, EpisodeStatus::Truncated);
169 }
170
171 #[test]
172 fn natural_termination_wins_over_time_limit_on_same_step() {
173 let mut env = TimeLimit::new(
178 MockEnv::new(|n| {
179 if n >= 3 { EpisodeStatus::Terminated } else { EpisodeStatus::Continuing }
180 }),
181 3,
182 );
183 env.reset(None);
184 env.step(());
185 env.step(());
186 let r = env.step(());
187 assert_eq!(r.status, EpisodeStatus::Terminated);
188 }
189
190 #[test]
191 fn reset_restores_step_counter_so_limit_applies_again() {
192 let mut env = TimeLimit::new(
193 MockEnv::new(|_| EpisodeStatus::Continuing),
194 2,
195 );
196 env.reset(None);
197 env.step(());
198 let r = env.step(());
199 assert_eq!(r.status, EpisodeStatus::Truncated);
200
201 env.reset(None);
202 let r = env.step(());
203 assert_eq!(r.status, EpisodeStatus::Continuing, "step counter not reset");
204 }
205}