Skip to main content

oxicuda_rl/buffer/
n_step.rs

1//! # N-step Return Buffer
2//!
3//! Accumulates `n` successive transitions and computes the n-step return:
4//!
5//! ```text
6//! R_t^(n) = r_t + γ r_{t+1} + γ² r_{t+2} + … + γ^{n-1} r_{t+n-1}
7//!           + γ^n V(s_{t+n}) * (1 - done)
8//! ```
9//!
10//! When the trajectory terminates before `n` steps the return is truncated at
11//! the terminal transition.
12//!
13//! This is used to improve the bias/variance tradeoff for off-policy algorithms
14//! such as DQN and SAC.
15
16use crate::error::{RlError, RlResult};
17
18// ─── NStepTransition ─────────────────────────────────────────────────────────
19
20/// A transition with n-step return pre-computed.
21#[derive(Debug, Clone)]
22pub struct NStepTransition {
23    /// Observation at the start of the n-step window.
24    pub obs: Vec<f32>,
25    /// Action taken at the start of the window.
26    pub action: Vec<f32>,
27    /// Discounted n-step return `R_t^(n)`.
28    pub n_step_return: f32,
29    /// Observation at `t + n` (or at episode end).
30    pub bootstrap_obs: Vec<f32>,
31    /// Whether the episode ended within the n-step window.
32    pub done: bool,
33    /// Actual number of steps accumulated (≤ n, < n at episode end).
34    pub actual_n: usize,
35    /// `γ^actual_n` for bootstrapping.
36    pub gamma_n: f32,
37}
38
39// ─── Internal single step ─────────────────────────────────────────────────────
40
41#[derive(Debug, Clone)]
42struct Step {
43    obs: Vec<f32>,
44    action: Vec<f32>,
45    reward: f32,
46    next_obs: Vec<f32>,
47    done: bool,
48}
49
50// ─── NStepBuffer ─────────────────────────────────────────────────────────────
51
52/// Circular accumulator for n-step returns.
53#[derive(Debug, Clone)]
54pub struct NStepBuffer {
55    n: usize,
56    gamma: f32,
57    steps: Vec<Option<Step>>,
58    head: usize,
59    count: usize,
60}
61
62impl NStepBuffer {
63    /// Create an n-step buffer.
64    ///
65    /// * `n` — number of steps to accumulate.
66    /// * `gamma` — discount factor γ ∈ (0, 1].
67    ///
68    /// # Panics
69    ///
70    /// Panics if `n == 0`.
71    #[must_use]
72    pub fn new(n: usize, gamma: f32) -> Self {
73        assert!(n > 0, "n must be > 0");
74        Self {
75            n,
76            gamma,
77            steps: vec![None; n],
78            head: 0,
79            count: 0,
80        }
81    }
82
83    /// Number of steps n.
84    #[must_use]
85    #[inline]
86    pub fn n(&self) -> usize {
87        self.n
88    }
89
90    /// Discount factor γ.
91    #[must_use]
92    #[inline]
93    pub fn gamma(&self) -> f32 {
94        self.gamma
95    }
96
97    /// Current number of accumulated steps.
98    #[must_use]
99    #[inline]
100    pub fn count(&self) -> usize {
101        self.count
102    }
103
104    /// Push a `(obs, action, reward, next_obs, done)` step.
105    ///
106    /// Returns `Some(NStepTransition)` once `n` steps have been accumulated
107    /// (or immediately on terminal).  Returns `None` while the buffer is still
108    /// filling up.
109    pub fn push(
110        &mut self,
111        obs: impl Into<Vec<f32>>,
112        action: impl Into<Vec<f32>>,
113        reward: f32,
114        next_obs: impl Into<Vec<f32>>,
115        done: bool,
116    ) -> Option<NStepTransition> {
117        let step = Step {
118            obs: obs.into(),
119            action: action.into(),
120            reward,
121            next_obs: next_obs.into(),
122            done,
123        };
124        self.steps[self.head] = Some(step);
125        self.head = (self.head + 1) % self.n;
126        if self.count < self.n {
127            self.count += 1;
128        }
129
130        if self.count == self.n {
131            Some(self.compute_return())
132        } else if done {
133            Some(self.compute_partial_return())
134        } else {
135            None
136        }
137    }
138
139    /// Flush any remaining steps in the buffer (for end-of-episode cleanup).
140    ///
141    /// Returns all remaining n-step transitions.
142    pub fn flush(&mut self) -> Vec<NStepTransition> {
143        let mut out = Vec::new();
144        while self.count > 0 {
145            out.push(self.compute_partial_return());
146            // Advance the oldest step pointer
147            let oldest = (self.head + self.n - self.count) % self.n;
148            self.steps[oldest] = None;
149            self.count -= 1;
150        }
151        out
152    }
153
154    /// Build the full n-step transition (assumes `count == n`).
155    fn compute_return(&self) -> NStepTransition {
156        self.compute_n_step_return(self.n)
157    }
158
159    /// Build a partial n-step transition when episode ended early.
160    fn compute_partial_return(&self) -> NStepTransition {
161        self.compute_n_step_return(self.count)
162    }
163
164    fn compute_n_step_return(&self, steps: usize) -> NStepTransition {
165        // Oldest step index in the circular buffer
166        let oldest = (self.head + self.n - self.count) % self.n;
167        let first = self.steps[oldest]
168            .as_ref()
169            .expect("oldest step must be Some");
170
171        let mut cumulative = 0.0_f32;
172        let mut gamma_k = 1.0_f32;
173        let mut last_next_obs = first.next_obs.clone();
174        let mut terminated = false;
175
176        for k in 0..steps {
177            let idx = (oldest + k) % self.n;
178            let step = self.steps[idx].as_ref().expect("step must be Some");
179            cumulative += gamma_k * step.reward;
180            gamma_k *= self.gamma;
181            last_next_obs = step.next_obs.clone();
182            if step.done {
183                terminated = true;
184                break;
185            }
186        }
187
188        NStepTransition {
189            obs: first.obs.clone(),
190            action: first.action.clone(),
191            n_step_return: cumulative,
192            bootstrap_obs: last_next_obs,
193            done: terminated,
194            actual_n: steps,
195            gamma_n: gamma_k,
196        }
197    }
198
199    /// Clear the buffer (e.g. at episode reset).
200    pub fn reset(&mut self) {
201        for s in self.steps.iter_mut() {
202            *s = None;
203        }
204        self.head = 0;
205        self.count = 0;
206    }
207
208    /// Attempt to produce an n-step transition from the current state.
209    ///
210    /// # Errors
211    ///
212    /// * [`RlError::NStepIncomplete`] if fewer than `n` steps have been
213    ///   accumulated.
214    pub fn try_get(&self) -> RlResult<NStepTransition> {
215        if self.count < self.n {
216            return Err(RlError::NStepIncomplete {
217                have: self.count,
218                need: self.n,
219            });
220        }
221        Ok(self.compute_return())
222    }
223}
224
225// ─── Tests ───────────────────────────────────────────────────────────────────
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    fn make_buf(n: usize, gamma: f32) -> NStepBuffer {
232        NStepBuffer::new(n, gamma)
233    }
234
235    // ── basic accumulation ───────────────────────────────────────────────────
236
237    #[test]
238    fn none_before_n_steps() {
239        let mut buf = make_buf(3, 0.99);
240        let r1 = buf.push([0.0], [0.0], 1.0, [1.0], false);
241        let r2 = buf.push([1.0], [0.0], 1.0, [2.0], false);
242        assert!(r1.is_none(), "should be None after 1 step");
243        assert!(r2.is_none(), "should be None after 2 steps");
244    }
245
246    #[test]
247    fn returns_transition_at_n() {
248        let mut buf = make_buf(3, 0.99);
249        buf.push([0.0], [0.0], 1.0, [1.0], false);
250        buf.push([1.0], [0.0], 1.0, [2.0], false);
251        let t = buf.push([2.0], [0.0], 1.0, [3.0], false);
252        assert!(t.is_some(), "should return transition at n=3");
253        let t = t.unwrap();
254        // R = 1 + 0.99*1 + 0.99²*1 = 1 + 0.99 + 0.9801 = 2.9701
255        assert!(
256            (t.n_step_return - (1.0 + 0.99 + 0.99_f32 * 0.99)).abs() < 1e-4,
257            "n_step_return={}",
258            t.n_step_return
259        );
260        assert_eq!(t.actual_n, 3);
261    }
262
263    #[test]
264    fn discount_applied_correctly() {
265        let mut buf = make_buf(2, 0.5);
266        buf.push([0.0], [0.0], 2.0, [1.0], false);
267        let t = buf.push([1.0], [0.0], 4.0, [2.0], false);
268        let t = t.unwrap();
269        // R = 2 + 0.5 * 4 = 4.0
270        assert!(
271            (t.n_step_return - 4.0).abs() < 1e-5,
272            "n_step_return={}",
273            t.n_step_return
274        );
275        assert!((t.gamma_n - 0.25).abs() < 1e-5, "gamma_n={}", t.gamma_n);
276    }
277
278    // ── done flag ────────────────────────────────────────────────────────────
279
280    #[test]
281    fn terminal_truncates_return() {
282        let mut buf = make_buf(5, 0.99);
283        // Episode ends at step 2 (n < 5)
284        buf.push([0.0], [0.0], 1.0, [1.0], false);
285        let t = buf.push([1.0], [0.0], 2.0, [2.0], true); // done
286        assert!(t.is_some(), "terminal step should emit transition early");
287        let t = t.unwrap();
288        assert!(t.done, "done flag should be set");
289        // R = 1 + 0.99 * 2 = 2.98
290        assert!(
291            (t.n_step_return - (1.0 + 0.99 * 2.0)).abs() < 1e-4,
292            "n_step_return={}",
293            t.n_step_return
294        );
295    }
296
297    // ── flush ────────────────────────────────────────────────────────────────
298
299    #[test]
300    fn flush_returns_remaining() {
301        let mut buf = make_buf(3, 0.99);
302        buf.push([0.0], [0.0], 1.0, [1.0], false);
303        buf.push([1.0], [0.0], 2.0, [2.0], false);
304        let flushed = buf.flush();
305        // 2 partial returns: one 2-step, one 1-step
306        assert!(
307            !flushed.is_empty(),
308            "flush should return partial transitions"
309        );
310    }
311
312    #[test]
313    fn flush_clears_buffer() {
314        let mut buf = make_buf(3, 0.99);
315        buf.push([0.0], [0.0], 1.0, [1.0], false);
316        buf.flush();
317        assert_eq!(buf.count(), 0);
318    }
319
320    // ── try_get ──────────────────────────────────────────────────────────────
321
322    #[test]
323    fn try_get_before_n_error() {
324        let mut buf = make_buf(3, 0.99);
325        buf.push([0.0], [0.0], 1.0, [1.0], false);
326        assert!(buf.try_get().is_err());
327    }
328
329    #[test]
330    fn try_get_after_n_ok() {
331        let mut buf = make_buf(2, 0.99);
332        buf.push([0.0], [0.0], 1.0, [1.0], false);
333        buf.push([1.0], [0.0], 2.0, [2.0], false);
334        assert!(buf.try_get().is_ok());
335    }
336
337    // ── reset ────────────────────────────────────────────────────────────────
338
339    #[test]
340    fn reset_clears() {
341        let mut buf = make_buf(3, 0.99);
342        buf.push([0.0], [0.0], 1.0, [1.0], false);
343        buf.push([1.0], [0.0], 2.0, [2.0], false);
344        buf.reset();
345        assert_eq!(buf.count(), 0);
346        assert!(buf.try_get().is_err());
347    }
348
349    // ── obs / action preservation ────────────────────────────────────────────
350
351    #[test]
352    fn obs_preserved_correctly() {
353        let mut buf = make_buf(1, 0.9);
354        let t = buf.push([7.0, 8.0], [3.0], 5.0, [9.0, 10.0], false);
355        let t = t.unwrap();
356        assert_eq!(t.obs, vec![7.0, 8.0]);
357        assert_eq!(t.action, vec![3.0]);
358        assert_eq!(t.bootstrap_obs, vec![9.0, 10.0]);
359    }
360}