Skip to main content

oxicuda_rl/buffer/
prioritized.rs

1//! # Prioritized Experience Replay (PER)
2//!
3//! Implements the proportional variant from Schaul et al. (2015), "Prioritized
4//! Experience Replay", ICLR 2016.
5//!
6//! ## Algorithm
7//!
8//! Transitions are sampled with probability proportional to priority raised to
9//! exponent `α`:
10//! ```text
11//! P(i) = p_i^α / Σ_j p_j^α
12//! ```
13//!
14//! To correct for the sampling bias introduced by prioritised sampling, each
15//! sampled transition receives an importance-sampling (IS) weight:
16//! ```text
17//! w_i = (1 / (N * P(i)))^β  (normalised by max_i w_i)
18//! ```
19//!
20//! β is annealed from `beta_start` to 1 over the course of training.
21//!
22//! ## Implementation
23//!
24//! Priorities are stored in a **binary segment tree** of size `capacity` so
25//! that:
26//! * `O(log N)` priority update
27//! * `O(log N)` stratified sampling
28//! * `O(1)` total-priority query
29//!
30//! The tree is stored as a flat array of size `2 * capacity` (1-indexed, root
31//! at index 1); leaves occupy indices `[capacity, 2*capacity)`.
32
33use crate::buffer::replay::Transition;
34use crate::error::{RlError, RlResult};
35use crate::handle::RlHandle;
36
37// ─── Segment tree ────────────────────────────────────────────────────────────
38
39/// Min-heap-indexed sum segment tree for O(log N) priority queries.
40#[derive(Debug, Clone)]
41struct SumTree {
42    n: usize,           // capacity (leaf count)
43    tree: Vec<f64>,     // 2*n nodes, root at index 1, leaves at [n, 2n)
44    min_tree: Vec<f64>, // parallel min tree for fast min-priority query
45}
46
47impl SumTree {
48    fn new(n: usize) -> Self {
49        Self {
50            n,
51            tree: vec![0.0; 2 * n],
52            min_tree: vec![f64::MAX; 2 * n],
53        }
54    }
55
56    /// Update priority at leaf `i` (0-indexed).
57    fn update(&mut self, i: usize, priority: f64) {
58        let pos = i + self.n; // leaf position in tree
59        self.tree[pos] = priority;
60        self.min_tree[pos] = priority;
61        let mut p = pos >> 1;
62        while p >= 1 {
63            self.tree[p] = self.tree[2 * p] + self.tree[2 * p + 1];
64            self.min_tree[p] = self.min_tree[2 * p].min(self.min_tree[2 * p + 1]);
65            p >>= 1;
66        }
67    }
68
69    /// Total sum of all priorities.
70    #[inline]
71    fn total(&self) -> f64 {
72        self.tree[1]
73    }
74
75    /// Minimum priority among all leaves.
76    #[inline]
77    fn min_priority(&self) -> f64 {
78        self.min_tree[1]
79    }
80
81    /// Find the leaf index (0-based) whose cumulative sum is just >= `value`.
82    fn find(&self, value: f64) -> usize {
83        let mut node = 1_usize;
84        let mut v = value;
85        while node < self.n {
86            let left = 2 * node;
87            if self.tree[left] >= v {
88                node = left;
89            } else {
90                v -= self.tree[left];
91                node = left + 1;
92            }
93        }
94        node - self.n // convert back to 0-based leaf index
95    }
96
97    /// Priority at leaf `i` (0-indexed).
98    fn priority_at(&self, i: usize) -> f64 {
99        self.tree[i + self.n]
100    }
101}
102
103// ─── PER buffer ──────────────────────────────────────────────────────────────
104
105/// A sample returned by [`PrioritizedReplayBuffer::sample`].
106#[derive(Debug, Clone)]
107pub struct PrioritySample {
108    /// The sampled transition.
109    pub transition: Transition,
110    /// Buffer index of this transition (needed for priority update).
111    pub index: usize,
112    /// Importance-sampling weight for this transition.
113    pub weight: f32,
114}
115
116/// Prioritized Experience Replay buffer with proportional priority and IS
117/// weights.
118#[derive(Debug, Clone)]
119pub struct PrioritizedReplayBuffer {
120    capacity: usize,
121    obs_dim: usize,
122    act_dim: usize,
123    // Transition storage (struct-of-arrays)
124    obs: Vec<f32>,
125    actions: Vec<f32>,
126    rewards: Vec<f32>,
127    next_obs: Vec<f32>,
128    dones: Vec<f32>,
129    // Segment tree
130    tree: SumTree,
131    /// Priority exponent α ∈ [0, 1].  α=0 → uniform; α=1 → full priority.
132    alpha: f64,
133    /// IS-weight exponent β ∈ [0, 1].  Annealed from `beta_start` to 1.
134    beta: f64,
135    /// Maximum priority seen so far (used for new-experience priority).
136    max_priority: f64,
137    /// Write cursor.
138    head: usize,
139    /// Current size.
140    size: usize,
141}
142
143impl PrioritizedReplayBuffer {
144    /// Create a PER buffer.
145    ///
146    /// * `capacity` — maximum number of transitions.
147    /// * `obs_dim`, `act_dim` — observation / action dimensions.
148    /// * `alpha` — priority exponent (default 0.6).
149    /// * `beta_start` — initial IS exponent (default 0.4, annealed to 1).
150    pub fn new(
151        capacity: usize,
152        obs_dim: usize,
153        act_dim: usize,
154        alpha: f64,
155        beta_start: f64,
156    ) -> Self {
157        assert!(capacity > 0, "capacity must be > 0");
158        // Tree size must be a power of 2 for the segment tree to work correctly.
159        let cap2 = capacity.next_power_of_two();
160        Self {
161            capacity,
162            obs_dim,
163            act_dim,
164            obs: vec![0.0; capacity * obs_dim],
165            actions: vec![0.0; capacity * act_dim],
166            rewards: vec![0.0; capacity],
167            next_obs: vec![0.0; capacity * obs_dim],
168            dones: vec![0.0; capacity],
169            tree: SumTree::new(cap2),
170            alpha,
171            beta: beta_start,
172            max_priority: 1.0,
173            head: 0,
174            size: 0,
175        }
176    }
177
178    /// Push a transition with maximum priority (so it will be sampled at least
179    /// once).
180    pub fn push(
181        &mut self,
182        obs: impl AsRef<[f32]>,
183        action: impl AsRef<[f32]>,
184        reward: f32,
185        next_obs: impl AsRef<[f32]>,
186        done: bool,
187    ) {
188        let obs = obs.as_ref();
189        let action = action.as_ref();
190        let next_obs = next_obs.as_ref();
191
192        let i = self.head;
193        self.obs[i * self.obs_dim..(i + 1) * self.obs_dim].copy_from_slice(obs);
194        self.actions[i * self.act_dim..(i + 1) * self.act_dim].copy_from_slice(action);
195        self.rewards[i] = reward;
196        self.next_obs[i * self.obs_dim..(i + 1) * self.obs_dim].copy_from_slice(next_obs);
197        self.dones[i] = if done { 1.0 } else { 0.0 };
198        // Assign maximum priority to new transition
199        self.tree.update(i, self.max_priority.powf(self.alpha));
200        self.head = (self.head + 1) % self.capacity;
201        if self.size < self.capacity {
202            self.size += 1;
203        }
204    }
205
206    /// Update priority for a transition that was previously sampled.
207    ///
208    /// # Arguments
209    ///
210    /// * `index` — the buffer index returned in [`PrioritySample::index`].
211    /// * `priority` — new absolute priority (typically `|TD error| + ε`).
212    pub fn update_priority(&mut self, index: usize, priority: f64) {
213        let p = priority.max(1e-6);
214        if p > self.max_priority {
215            self.max_priority = p;
216        }
217        self.tree.update(index, p.powf(self.alpha));
218    }
219
220    /// Set the current β for IS weight computation.
221    pub fn set_beta(&mut self, beta: f64) {
222        self.beta = beta.clamp(0.0, 1.0);
223    }
224
225    /// Anneal β toward 1 by `step` (additive).
226    pub fn anneal_beta(&mut self, step: f64) {
227        self.beta = (self.beta + step).min(1.0);
228    }
229
230    /// Number of stored transitions.
231    #[must_use]
232    #[inline]
233    pub fn len(&self) -> usize {
234        self.size
235    }
236
237    /// Returns `true` if empty.
238    #[must_use]
239    #[inline]
240    pub fn is_empty(&self) -> bool {
241        self.size == 0
242    }
243
244    /// Sample `batch_size` transitions using stratified proportional sampling.
245    ///
246    /// Stratified: the [0, total_priority] interval is divided into
247    /// `batch_size` equal strata and one sample is drawn uniformly from each
248    /// stratum.  This reduces variance compared to fully random sampling.
249    ///
250    /// # Errors
251    ///
252    /// * [`RlError::InsufficientTransitions`] if `size < batch_size`.
253    /// * [`RlError::ZeroPrioritySum`] if all priorities are zero.
254    pub fn sample(
255        &self,
256        batch_size: usize,
257        handle: &mut RlHandle,
258    ) -> RlResult<Vec<PrioritySample>> {
259        if self.size < batch_size {
260            return Err(RlError::InsufficientTransitions {
261                have: self.size,
262                need: batch_size,
263            });
264        }
265        let total = self.tree.total();
266        if total <= 0.0 {
267            return Err(RlError::ZeroPrioritySum);
268        }
269        let rng = handle.rng_mut();
270        let segment = total / batch_size as f64;
271        let min_p = self.tree.min_priority() / total;
272        let max_w = (1.0 / (self.size as f64 * min_p)).powf(self.beta) as f32;
273
274        let mut out = Vec::with_capacity(batch_size);
275        for k in 0..batch_size {
276            let lo = k as f64 * segment;
277            let hi = lo + segment;
278            let v = lo + rng.next_f32() as f64 * (hi - lo);
279            let idx = self.tree.find(v.min(total - 1e-9)).min(self.size - 1);
280
281            let p = self.tree.priority_at(idx) / total;
282            let w = ((1.0 / (self.size as f64 * p)).powf(self.beta) as f32 / max_w).min(1.0);
283
284            let obs = self.obs[idx * self.obs_dim..(idx + 1) * self.obs_dim].to_vec();
285            let action = self.actions[idx * self.act_dim..(idx + 1) * self.act_dim].to_vec();
286            let reward = self.rewards[idx];
287            let next_obs = self.next_obs[idx * self.obs_dim..(idx + 1) * self.obs_dim].to_vec();
288            let done = self.dones[idx] > 0.5;
289            out.push(PrioritySample {
290                transition: Transition {
291                    obs,
292                    action,
293                    reward,
294                    next_obs,
295                    done,
296                },
297                index: idx,
298                weight: w,
299            });
300        }
301        Ok(out)
302    }
303}
304
305// ─── Tests ───────────────────────────────────────────────────────────────────
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    fn make_per(cap: usize) -> PrioritizedReplayBuffer {
312        PrioritizedReplayBuffer::new(cap, 2, 1, 0.6, 0.4)
313    }
314
315    fn fill_per(buf: &mut PrioritizedReplayBuffer, n: usize) {
316        for i in 0..n {
317            buf.push(
318                [i as f32, i as f32 + 1.0],
319                [0.5_f32],
320                i as f32 * 0.1,
321                [i as f32 + 1.0, i as f32 + 2.0],
322                false,
323            );
324        }
325    }
326
327    #[test]
328    fn sum_tree_basic() {
329        let mut t = SumTree::new(4);
330        t.update(0, 1.0);
331        t.update(1, 2.0);
332        t.update(2, 3.0);
333        t.update(3, 4.0);
334        assert!((t.total() - 10.0).abs() < 1e-9, "total={}", t.total());
335    }
336
337    #[test]
338    fn sum_tree_find() {
339        let mut t = SumTree::new(4);
340        t.update(0, 1.0);
341        t.update(1, 2.0);
342        t.update(2, 3.0);
343        t.update(3, 4.0);
344        // cumsum: [1, 3, 6, 10]  → value 2.5 should land in index 1 (cumsum ≤ 3)
345        let idx = t.find(2.5);
346        assert_eq!(idx, 1, "find(2.5) should return idx=1, got {idx}");
347    }
348
349    #[test]
350    fn per_push_and_len() {
351        let mut buf = make_per(32);
352        fill_per(&mut buf, 20);
353        assert_eq!(buf.len(), 20);
354    }
355
356    #[test]
357    fn per_sample_size() {
358        let mut buf = make_per(64);
359        fill_per(&mut buf, 64);
360        let mut handle = RlHandle::default_handle();
361        let batch = buf.sample(16, &mut handle).unwrap();
362        assert_eq!(batch.len(), 16);
363    }
364
365    #[test]
366    fn per_weights_in_range() {
367        let mut buf = make_per(64);
368        fill_per(&mut buf, 64);
369        let mut handle = RlHandle::default_handle();
370        let batch = buf.sample(32, &mut handle).unwrap();
371        for s in &batch {
372            assert!(s.weight > 0.0 && s.weight <= 1.0, "weight={}", s.weight);
373        }
374    }
375
376    #[test]
377    fn per_update_priority() {
378        let mut buf = make_per(16);
379        fill_per(&mut buf, 16);
380        // Update index 0 to a very high priority
381        buf.update_priority(0, 100.0);
382        // After many samples, index 0 should appear frequently
383        let mut handle = RlHandle::default_handle();
384        let mut counts = [0_usize; 16];
385        for _ in 0..200 {
386            let batch = buf.sample(1, &mut handle).unwrap();
387            counts[batch[0].index] += 1;
388        }
389        // Index 0 should be sampled much more than average
390        assert!(
391            counts[0] > 200 / 16,
392            "high-priority index should be over-sampled"
393        );
394    }
395
396    #[test]
397    fn per_insufficient_error() {
398        let buf = make_per(16);
399        let mut handle = RlHandle::default_handle();
400        assert!(buf.sample(5, &mut handle).is_err());
401    }
402
403    #[test]
404    fn per_anneal_beta() {
405        let mut buf = make_per(16);
406        buf.set_beta(0.4);
407        buf.anneal_beta(0.3);
408        assert!((buf.beta - 0.7).abs() < 1e-9);
409        buf.anneal_beta(1.0);
410        assert!((buf.beta - 1.0).abs() < 1e-9);
411    }
412
413    #[test]
414    fn sum_tree_min_priority() {
415        let mut t = SumTree::new(4);
416        t.update(0, 5.0);
417        t.update(1, 2.0);
418        t.update(2, 8.0);
419        t.update(3, 3.0);
420        assert!((t.min_priority() - 2.0).abs() < 1e-9);
421    }
422}