1use crate::error::{RlError, RlResult};
17
18#[derive(Debug, Clone)]
22pub struct NStepTransition {
23 pub obs: Vec<f32>,
25 pub action: Vec<f32>,
27 pub n_step_return: f32,
29 pub bootstrap_obs: Vec<f32>,
31 pub done: bool,
33 pub actual_n: usize,
35 pub gamma_n: f32,
37}
38
39#[derive(Debug, Clone)]
42struct Step {
43 obs: Vec<f32>,
44 action: Vec<f32>,
45 reward: f32,
46 next_obs: Vec<f32>,
47 done: bool,
48}
49
50#[derive(Debug, Clone)]
54pub struct NStepBuffer {
55 n: usize,
56 gamma: f32,
57 steps: Vec<Option<Step>>,
58 head: usize,
59 count: usize,
60}
61
62impl NStepBuffer {
63 #[must_use]
72 pub fn new(n: usize, gamma: f32) -> Self {
73 assert!(n > 0, "n must be > 0");
74 Self {
75 n,
76 gamma,
77 steps: vec![None; n],
78 head: 0,
79 count: 0,
80 }
81 }
82
83 #[must_use]
85 #[inline]
86 pub fn n(&self) -> usize {
87 self.n
88 }
89
90 #[must_use]
92 #[inline]
93 pub fn gamma(&self) -> f32 {
94 self.gamma
95 }
96
97 #[must_use]
99 #[inline]
100 pub fn count(&self) -> usize {
101 self.count
102 }
103
104 pub fn push(
110 &mut self,
111 obs: impl Into<Vec<f32>>,
112 action: impl Into<Vec<f32>>,
113 reward: f32,
114 next_obs: impl Into<Vec<f32>>,
115 done: bool,
116 ) -> Option<NStepTransition> {
117 let step = Step {
118 obs: obs.into(),
119 action: action.into(),
120 reward,
121 next_obs: next_obs.into(),
122 done,
123 };
124 self.steps[self.head] = Some(step);
125 self.head = (self.head + 1) % self.n;
126 if self.count < self.n {
127 self.count += 1;
128 }
129
130 if self.count == self.n {
131 Some(self.compute_return())
132 } else if done {
133 Some(self.compute_partial_return())
134 } else {
135 None
136 }
137 }
138
139 pub fn flush(&mut self) -> Vec<NStepTransition> {
143 let mut out = Vec::new();
144 while self.count > 0 {
145 out.push(self.compute_partial_return());
146 let oldest = (self.head + self.n - self.count) % self.n;
148 self.steps[oldest] = None;
149 self.count -= 1;
150 }
151 out
152 }
153
154 fn compute_return(&self) -> NStepTransition {
156 self.compute_n_step_return(self.n)
157 }
158
159 fn compute_partial_return(&self) -> NStepTransition {
161 self.compute_n_step_return(self.count)
162 }
163
164 fn compute_n_step_return(&self, steps: usize) -> NStepTransition {
165 let oldest = (self.head + self.n - self.count) % self.n;
167 let first = self.steps[oldest]
168 .as_ref()
169 .expect("oldest step must be Some");
170
171 let mut cumulative = 0.0_f32;
172 let mut gamma_k = 1.0_f32;
173 let mut last_next_obs = first.next_obs.clone();
174 let mut terminated = false;
175
176 for k in 0..steps {
177 let idx = (oldest + k) % self.n;
178 let step = self.steps[idx].as_ref().expect("step must be Some");
179 cumulative += gamma_k * step.reward;
180 gamma_k *= self.gamma;
181 last_next_obs = step.next_obs.clone();
182 if step.done {
183 terminated = true;
184 break;
185 }
186 }
187
188 NStepTransition {
189 obs: first.obs.clone(),
190 action: first.action.clone(),
191 n_step_return: cumulative,
192 bootstrap_obs: last_next_obs,
193 done: terminated,
194 actual_n: steps,
195 gamma_n: gamma_k,
196 }
197 }
198
199 pub fn reset(&mut self) {
201 for s in self.steps.iter_mut() {
202 *s = None;
203 }
204 self.head = 0;
205 self.count = 0;
206 }
207
208 pub fn try_get(&self) -> RlResult<NStepTransition> {
215 if self.count < self.n {
216 return Err(RlError::NStepIncomplete {
217 have: self.count,
218 need: self.n,
219 });
220 }
221 Ok(self.compute_return())
222 }
223}
224
225#[cfg(test)]
228mod tests {
229 use super::*;
230
231 fn make_buf(n: usize, gamma: f32) -> NStepBuffer {
232 NStepBuffer::new(n, gamma)
233 }
234
235 #[test]
238 fn none_before_n_steps() {
239 let mut buf = make_buf(3, 0.99);
240 let r1 = buf.push([0.0], [0.0], 1.0, [1.0], false);
241 let r2 = buf.push([1.0], [0.0], 1.0, [2.0], false);
242 assert!(r1.is_none(), "should be None after 1 step");
243 assert!(r2.is_none(), "should be None after 2 steps");
244 }
245
246 #[test]
247 fn returns_transition_at_n() {
248 let mut buf = make_buf(3, 0.99);
249 buf.push([0.0], [0.0], 1.0, [1.0], false);
250 buf.push([1.0], [0.0], 1.0, [2.0], false);
251 let t = buf.push([2.0], [0.0], 1.0, [3.0], false);
252 assert!(t.is_some(), "should return transition at n=3");
253 let t = t.unwrap();
254 assert!(
256 (t.n_step_return - (1.0 + 0.99 + 0.99_f32 * 0.99)).abs() < 1e-4,
257 "n_step_return={}",
258 t.n_step_return
259 );
260 assert_eq!(t.actual_n, 3);
261 }
262
263 #[test]
264 fn discount_applied_correctly() {
265 let mut buf = make_buf(2, 0.5);
266 buf.push([0.0], [0.0], 2.0, [1.0], false);
267 let t = buf.push([1.0], [0.0], 4.0, [2.0], false);
268 let t = t.unwrap();
269 assert!(
271 (t.n_step_return - 4.0).abs() < 1e-5,
272 "n_step_return={}",
273 t.n_step_return
274 );
275 assert!((t.gamma_n - 0.25).abs() < 1e-5, "gamma_n={}", t.gamma_n);
276 }
277
278 #[test]
281 fn terminal_truncates_return() {
282 let mut buf = make_buf(5, 0.99);
283 buf.push([0.0], [0.0], 1.0, [1.0], false);
285 let t = buf.push([1.0], [0.0], 2.0, [2.0], true); assert!(t.is_some(), "terminal step should emit transition early");
287 let t = t.unwrap();
288 assert!(t.done, "done flag should be set");
289 assert!(
291 (t.n_step_return - (1.0 + 0.99 * 2.0)).abs() < 1e-4,
292 "n_step_return={}",
293 t.n_step_return
294 );
295 }
296
297 #[test]
300 fn flush_returns_remaining() {
301 let mut buf = make_buf(3, 0.99);
302 buf.push([0.0], [0.0], 1.0, [1.0], false);
303 buf.push([1.0], [0.0], 2.0, [2.0], false);
304 let flushed = buf.flush();
305 assert!(
307 !flushed.is_empty(),
308 "flush should return partial transitions"
309 );
310 }
311
312 #[test]
313 fn flush_clears_buffer() {
314 let mut buf = make_buf(3, 0.99);
315 buf.push([0.0], [0.0], 1.0, [1.0], false);
316 buf.flush();
317 assert_eq!(buf.count(), 0);
318 }
319
320 #[test]
323 fn try_get_before_n_error() {
324 let mut buf = make_buf(3, 0.99);
325 buf.push([0.0], [0.0], 1.0, [1.0], false);
326 assert!(buf.try_get().is_err());
327 }
328
329 #[test]
330 fn try_get_after_n_ok() {
331 let mut buf = make_buf(2, 0.99);
332 buf.push([0.0], [0.0], 1.0, [1.0], false);
333 buf.push([1.0], [0.0], 2.0, [2.0], false);
334 assert!(buf.try_get().is_ok());
335 }
336
337 #[test]
340 fn reset_clears() {
341 let mut buf = make_buf(3, 0.99);
342 buf.push([0.0], [0.0], 1.0, [1.0], false);
343 buf.push([1.0], [0.0], 2.0, [2.0], false);
344 buf.reset();
345 assert_eq!(buf.count(), 0);
346 assert!(buf.try_get().is_err());
347 }
348
349 #[test]
352 fn obs_preserved_correctly() {
353 let mut buf = make_buf(1, 0.9);
354 let t = buf.push([7.0, 8.0], [3.0], 5.0, [9.0, 10.0], false);
355 let t = t.unwrap();
356 assert_eq!(t.obs, vec![7.0, 8.0]);
357 assert_eq!(t.action, vec![3.0]);
358 assert_eq!(t.bootstrap_obs, vec![9.0, 10.0]);
359 }
360}