Skip to main content

oxicuda_rl/buffer/
replay.rs

1//! # Uniform Replay Buffer
2//!
3//! A fixed-capacity circular buffer that stores `(s, a, r, s', done)`
4//! transitions and supports uniform random sampling.
5//!
6//! ## Design
7//!
8//! * Pre-allocated, never resizes — capacity chosen once at construction.
9//! * Circular: once full, the oldest transition is overwritten.
10//! * `sample(batch_size)` returns a contiguous clone batch; indices chosen
11//!   using the handle's LCG RNG for reproducibility.
12//!
13//! ## Usage
14//!
15//! ```rust
16//! use oxicuda_rl::buffer::UniformReplayBuffer;
17//! use oxicuda_rl::handle::RlHandle;
18//!
19//! let mut handle = RlHandle::default_handle();
20//! let mut buf = UniformReplayBuffer::new(1024, 4, 2);
21//!
22//! for i in 0..100_usize {
23//!     buf.push([i as f32; 4], [0.0_f32; 2], 1.0, [i as f32 + 1.0; 4], false);
24//! }
25//!
26//! let batch = buf.sample(32, &mut handle).unwrap();
27//! assert_eq!(batch.len(), 32);
28//! ```
29
30use crate::error::{RlError, RlResult};
31use crate::handle::RlHandle;
32
33// ─── Transition ───────────────────────────────────────────────────────────────
34
35/// A single `(s, a, r, s', done)` experience tuple.
36#[derive(Debug, Clone)]
37pub struct Transition {
38    /// Observation at time `t`.
39    pub obs: Vec<f32>,
40    /// Action taken at time `t`.
41    pub action: Vec<f32>,
42    /// Reward received at time `t`.
43    pub reward: f32,
44    /// Observation at time `t+1`.
45    pub next_obs: Vec<f32>,
46    /// Whether `t+1` is a terminal state.
47    pub done: bool,
48}
49
50impl Transition {
51    /// Create a new transition.
52    #[must_use]
53    pub fn new(
54        obs: impl Into<Vec<f32>>,
55        action: impl Into<Vec<f32>>,
56        reward: f32,
57        next_obs: impl Into<Vec<f32>>,
58        done: bool,
59    ) -> Self {
60        Self {
61            obs: obs.into(),
62            action: action.into(),
63            reward,
64            next_obs: next_obs.into(),
65            done,
66        }
67    }
68}
69
70// ─── UniformReplayBuffer ──────────────────────────────────────────────────────
71
72/// Uniform circular experience replay buffer.
73#[derive(Debug, Clone)]
74pub struct UniformReplayBuffer {
75    capacity: usize,
76    obs_dim: usize,
77    act_dim: usize,
78    // Storage in struct-of-arrays layout for cache efficiency.
79    obs: Vec<f32>,      // capacity × obs_dim
80    actions: Vec<f32>,  // capacity × act_dim
81    rewards: Vec<f32>,  // capacity
82    next_obs: Vec<f32>, // capacity × obs_dim
83    dones: Vec<f32>,    // capacity (0.0 or 1.0)
84    /// Write cursor (next insertion index, wraps).
85    head: usize,
86    /// Number of valid entries (≤ capacity).
87    size: usize,
88}
89
90impl UniformReplayBuffer {
91    /// Create a buffer with the given `capacity`, observation dimension
92    /// `obs_dim`, and action dimension `act_dim`.
93    ///
94    /// # Errors
95    ///
96    /// Returns [`RlError::ZeroCapacity`] if `capacity == 0`.
97    pub fn new(capacity: usize, obs_dim: usize, act_dim: usize) -> Self {
98        assert!(capacity > 0, "capacity must be > 0");
99        Self {
100            capacity,
101            obs_dim,
102            act_dim,
103            obs: vec![0.0; capacity * obs_dim],
104            actions: vec![0.0; capacity * act_dim],
105            rewards: vec![0.0; capacity],
106            next_obs: vec![0.0; capacity * obs_dim],
107            dones: vec![0.0; capacity],
108            head: 0,
109            size: 0,
110        }
111    }
112
113    /// Number of transitions currently stored.
114    #[must_use]
115    #[inline]
116    pub fn len(&self) -> usize {
117        self.size
118    }
119
120    /// Buffer capacity.
121    #[must_use]
122    #[inline]
123    pub fn capacity(&self) -> usize {
124        self.capacity
125    }
126
127    /// Returns `true` if no transitions are stored.
128    #[must_use]
129    #[inline]
130    pub fn is_empty(&self) -> bool {
131        self.size == 0
132    }
133
134    /// Returns `true` if the buffer has been filled to capacity at least once.
135    #[must_use]
136    #[inline]
137    pub fn is_full(&self) -> bool {
138        self.size == self.capacity
139    }
140
141    /// Observation dimension.
142    #[must_use]
143    #[inline]
144    pub fn obs_dim(&self) -> usize {
145        self.obs_dim
146    }
147
148    /// Action dimension.
149    #[must_use]
150    #[inline]
151    pub fn act_dim(&self) -> usize {
152        self.act_dim
153    }
154
155    /// Push a new transition into the buffer.
156    ///
157    /// # Panics
158    ///
159    /// Panics (debug only) if `obs.len() != obs_dim` or `action.len() !=
160    /// act_dim`.
161    pub fn push(
162        &mut self,
163        obs: impl AsRef<[f32]>,
164        action: impl AsRef<[f32]>,
165        reward: f32,
166        next_obs: impl AsRef<[f32]>,
167        done: bool,
168    ) {
169        let obs = obs.as_ref();
170        let action = action.as_ref();
171        let next_obs = next_obs.as_ref();
172        debug_assert_eq!(obs.len(), self.obs_dim);
173        debug_assert_eq!(action.len(), self.act_dim);
174        debug_assert_eq!(next_obs.len(), self.obs_dim);
175
176        let i = self.head;
177        self.obs[i * self.obs_dim..(i + 1) * self.obs_dim].copy_from_slice(obs);
178        self.actions[i * self.act_dim..(i + 1) * self.act_dim].copy_from_slice(action);
179        self.rewards[i] = reward;
180        self.next_obs[i * self.obs_dim..(i + 1) * self.obs_dim].copy_from_slice(next_obs);
181        self.dones[i] = if done { 1.0 } else { 0.0 };
182
183        self.head = (self.head + 1) % self.capacity;
184        if self.size < self.capacity {
185            self.size += 1;
186        }
187    }
188
189    /// Push a [`Transition`] struct.
190    pub fn push_transition(&mut self, t: Transition) {
191        self.push(t.obs, t.action, t.reward, t.next_obs, t.done);
192    }
193
194    /// Sample `batch_size` transitions uniformly without replacement.
195    ///
196    /// Returns a `Vec<Transition>` of length `batch_size`.
197    ///
198    /// # Errors
199    ///
200    /// * [`RlError::InsufficientTransitions`] if `size < batch_size`.
201    pub fn sample(&self, batch_size: usize, handle: &mut RlHandle) -> RlResult<Vec<Transition>> {
202        if self.size < batch_size {
203            return Err(RlError::InsufficientTransitions {
204                have: self.size,
205                need: batch_size,
206            });
207        }
208        let rng = handle.rng_mut();
209        let mut out = Vec::with_capacity(batch_size);
210        // Reservoir-like: pick batch_size indices without replacement from [0, size)
211        // For small batch/size ratios: rejection sampling is fast.
212        let mut indices: Vec<usize> = Vec::with_capacity(batch_size);
213        while indices.len() < batch_size {
214            let idx = rng.next_usize(self.size);
215            if !indices.contains(&idx) {
216                indices.push(idx);
217            }
218        }
219        for idx in indices {
220            let obs = self.obs[idx * self.obs_dim..(idx + 1) * self.obs_dim].to_vec();
221            let action = self.actions[idx * self.act_dim..(idx + 1) * self.act_dim].to_vec();
222            let reward = self.rewards[idx];
223            let next_obs = self.next_obs[idx * self.obs_dim..(idx + 1) * self.obs_dim].to_vec();
224            let done = self.dones[idx] > 0.5;
225            out.push(Transition {
226                obs,
227                action,
228                reward,
229                next_obs,
230                done,
231            });
232        }
233        Ok(out)
234    }
235
236    /// Sample into pre-allocated contiguous arrays (zero-copy path).
237    ///
238    /// Fills:
239    /// * `obs_out`      — `batch_size × obs_dim` f32 slice
240    /// * `action_out`   — `batch_size × act_dim` f32 slice
241    /// * `reward_out`   — `batch_size` f32 slice
242    /// * `next_obs_out` — `batch_size × obs_dim` f32 slice
243    /// * `done_out`     — `batch_size` f32 slice (0.0 / 1.0)
244    ///
245    /// # Errors
246    ///
247    /// [`RlError::InsufficientTransitions`] or [`RlError::DimensionMismatch`]
248    /// if any slice has the wrong length.
249    #[allow(clippy::too_many_arguments)]
250    pub fn sample_into(
251        &self,
252        batch_size: usize,
253        obs_out: &mut [f32],
254        action_out: &mut [f32],
255        reward_out: &mut [f32],
256        next_obs_out: &mut [f32],
257        done_out: &mut [f32],
258        handle: &mut RlHandle,
259    ) -> RlResult<()> {
260        if self.size < batch_size {
261            return Err(RlError::InsufficientTransitions {
262                have: self.size,
263                need: batch_size,
264            });
265        }
266        let rng = handle.rng_mut();
267        let mut indices: Vec<usize> = Vec::with_capacity(batch_size);
268        while indices.len() < batch_size {
269            let idx = rng.next_usize(self.size);
270            if !indices.contains(&idx) {
271                indices.push(idx);
272            }
273        }
274        for (b, &idx) in indices.iter().enumerate() {
275            obs_out[b * self.obs_dim..(b + 1) * self.obs_dim]
276                .copy_from_slice(&self.obs[idx * self.obs_dim..(idx + 1) * self.obs_dim]);
277            action_out[b * self.act_dim..(b + 1) * self.act_dim]
278                .copy_from_slice(&self.actions[idx * self.act_dim..(idx + 1) * self.act_dim]);
279            reward_out[b] = self.rewards[idx];
280            next_obs_out[b * self.obs_dim..(b + 1) * self.obs_dim]
281                .copy_from_slice(&self.next_obs[idx * self.obs_dim..(idx + 1) * self.obs_dim]);
282            done_out[b] = self.dones[idx];
283        }
284        Ok(())
285    }
286
287    /// Clear all stored transitions.
288    pub fn clear(&mut self) {
289        self.head = 0;
290        self.size = 0;
291    }
292}
293
294// ─── Tests ───────────────────────────────────────────────────────────────────
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    fn push_n(buf: &mut UniformReplayBuffer, n: usize) {
301        let od = buf.obs_dim();
302        let ad = buf.act_dim();
303        for i in 0..n {
304            buf.push(
305                vec![i as f32; od],
306                vec![0.0_f32; ad],
307                i as f32 * 0.1,
308                vec![i as f32 + 1.0; od],
309                i % 10 == 9,
310            );
311        }
312    }
313
314    #[test]
315    fn buffer_empty_initially() {
316        let buf = UniformReplayBuffer::new(100, 4, 2);
317        assert!(buf.is_empty());
318        assert_eq!(buf.len(), 0);
319    }
320
321    #[test]
322    fn buffer_grows_to_capacity() {
323        let mut buf = UniformReplayBuffer::new(10, 2, 1);
324        push_n(&mut buf, 10);
325        assert_eq!(buf.len(), 10);
326        assert!(buf.is_full());
327    }
328
329    #[test]
330    fn buffer_overwrites_oldest() {
331        let mut buf = UniformReplayBuffer::new(5, 1, 1);
332        for i in 0..7_usize {
333            buf.push([i as f32], [0.0], i as f32, [i as f32 + 1.0], false);
334        }
335        // size is still capped at capacity
336        assert_eq!(buf.len(), 5);
337    }
338
339    #[test]
340    fn sample_correct_size() {
341        let mut buf = UniformReplayBuffer::new(100, 4, 2);
342        push_n(&mut buf, 100);
343        let mut handle = RlHandle::default_handle();
344        let batch = buf.sample(32, &mut handle).unwrap();
345        assert_eq!(batch.len(), 32);
346    }
347
348    #[test]
349    fn sample_no_duplicates() {
350        let mut buf = UniformReplayBuffer::new(100, 1, 1);
351        push_n(&mut buf, 100);
352        let mut handle = RlHandle::default_handle();
353        let batch = buf.sample(50, &mut handle).unwrap();
354        let mut seen: std::collections::HashSet<usize> = std::collections::HashSet::new();
355        for t in &batch {
356            let idx = t.obs[0] as usize;
357            assert!(seen.insert(idx), "duplicate index {idx}");
358        }
359    }
360
361    #[test]
362    fn sample_insufficient_error() {
363        let buf = UniformReplayBuffer::new(100, 4, 2);
364        let mut handle = RlHandle::default_handle();
365        assert!(buf.sample(10, &mut handle).is_err());
366    }
367
368    #[test]
369    fn push_transition_struct() {
370        let mut buf = UniformReplayBuffer::new(10, 3, 2);
371        buf.push_transition(Transition::new(
372            [1.0, 2.0, 3.0],
373            [4.0, 5.0],
374            1.0,
375            [2.0, 3.0, 4.0],
376            false,
377        ));
378        assert_eq!(buf.len(), 1);
379    }
380
381    #[test]
382    fn sample_into_fills_slices() {
383        let mut buf = UniformReplayBuffer::new(64, 4, 2);
384        push_n(&mut buf, 64);
385        let mut handle = RlHandle::default_handle();
386        let bs = 16;
387        let mut obs = vec![0.0_f32; bs * 4];
388        let mut act = vec![0.0_f32; bs * 2];
389        let mut rew = vec![0.0_f32; bs];
390        let mut nobs = vec![0.0_f32; bs * 4];
391        let mut done = vec![0.0_f32; bs];
392        buf.sample_into(
393            bs,
394            &mut obs,
395            &mut act,
396            &mut rew,
397            &mut nobs,
398            &mut done,
399            &mut handle,
400        )
401        .unwrap();
402        // obs entries should be >= 0 (we pushed i as f32)
403        assert!(obs.iter().all(|&v| v >= 0.0));
404    }
405
406    #[test]
407    fn clear_resets_buffer() {
408        let mut buf = UniformReplayBuffer::new(10, 2, 1);
409        push_n(&mut buf, 10);
410        buf.clear();
411        assert!(buf.is_empty());
412    }
413
414    #[test]
415    fn transition_done_flag() {
416        let mut buf = UniformReplayBuffer::new(5, 1, 1);
417        buf.push([0.0], [0.0], 0.0, [1.0], true);
418        let mut handle = RlHandle::default_handle();
419        let batch = buf.sample(1, &mut handle).unwrap();
420        assert!(batch[0].done);
421    }
422}