border_core/generic_replay_buffer/
base.rs

1//! Generic implementation of replay buffers for reinforcement learning.
2//!
3//! This module provides a generic implementation of replay buffers that can store
4//! and sample transitions of arbitrary observation and action types. It supports:
5//! - Standard experience replay
6//! - Prioritized experience replay (PER)
7//! - Importance sampling weights for off-policy learning
8
9mod iw_scheduler;
10mod sum_tree;
11use super::{config::PerConfig, BatchBase, GenericTransitionBatch, SimpleReplayBufferConfig};
12use crate::{ExperienceBufferBase, ReplayBufferBase, TransitionBatch};
13use anyhow::Result;
14pub use iw_scheduler::IwScheduler;
15use rand::{rngs::StdRng, RngCore, SeedableRng};
16use sum_tree::SumTree;
17pub use sum_tree::WeightNormalizer;
18
19/// State management for Prioritized Experience Replay (PER).
20///
21/// This struct maintains the necessary state for PER, including:
22/// - A sum tree for efficient priority sampling
23/// - An importance weight scheduler for adjusting sample weights
24struct PerState {
25    /// A sum tree data structure for efficient priority sampling.
26    sum_tree: SumTree,
27
28    /// Scheduler for importance sampling weights.
29    iw_scheduler: IwScheduler,
30}
31
32impl PerState {
33    /// Creates a new PER state with the given configuration.
34    ///
35    /// # Arguments
36    ///
37    /// * `capacity` - Maximum number of transitions to store
38    /// * `per_config` - Configuration for prioritized experience replay
39    fn new(capacity: usize, per_config: &PerConfig) -> Self {
40        Self {
41            sum_tree: SumTree::new(capacity, per_config.alpha, per_config.normalize),
42            iw_scheduler: IwScheduler::new(
43                per_config.beta_0,
44                per_config.beta_final,
45                per_config.n_opts_final,
46            ),
47        }
48    }
49}
50
51/// A generic implementation of a replay buffer for reinforcement learning.
52///
53/// This buffer can store transitions of arbitrary observation and action types,
54/// making it suitable for a wide range of reinforcement learning tasks. It supports:
55/// - Standard experience replay
56/// - Prioritized experience replay (optional)
57/// - Efficient sampling and storage
58///
59/// # Type Parameters
60///
61/// * `O` - The type of observations, must implement [`BatchBase`]
62/// * `A` - The type of actions, must implement [`BatchBase`]
63///
64/// # Examples
65///
66/// ```ignore
67/// let config = SimpleReplayBufferConfig {
68///     capacity: 10000,
69///     per_config: Some(PerConfig {
70///         alpha: 0.6,
71///         beta_0: 0.4,
72///         beta_final: 1.0,
73///         n_opts_final: 100000,
74///         normalize: true,
75///     }),
76/// };
77///
78/// let mut buffer = SimpleReplayBuffer::<Tensor, Tensor>::build(&config);
79///
80/// // Add transitions
81/// buffer.push(transition)?;
82///
83/// // Sample a batch
84/// let batch = buffer.batch(32)?;
85/// ```
86pub struct SimpleReplayBuffer<O, A>
87where
88    O: BatchBase,
89    A: BatchBase,
90{
91    /// Maximum number of transitions that can be stored.
92    capacity: usize,
93
94    /// Current insertion index.
95    i: usize,
96
97    /// Current number of stored transitions.
98    size: usize,
99
100    /// Storage for observations.
101    obs: O,
102
103    /// Storage for actions.
104    act: A,
105
106    /// Storage for next observations.
107    next_obs: O,
108
109    /// Storage for rewards.
110    reward: Vec<f32>,
111
112    /// Storage for termination flags.
113    is_terminated: Vec<i8>,
114
115    /// Storage for truncation flags.
116    is_truncated: Vec<i8>,
117
118    /// Random number generator for sampling.
119    rng: StdRng,
120
121    /// State for prioritized experience replay, if enabled.
122    per_state: Option<PerState>,
123}
124
125impl<O, A> SimpleReplayBuffer<O, A>
126where
127    O: BatchBase,
128    A: BatchBase,
129{
130    /// Pushes rewards into the buffer at the specified index.
131    ///
132    /// # Arguments
133    ///
134    /// * `i` - Starting index for insertion
135    /// * `b` - Vector of rewards to insert
136    #[inline]
137    fn push_reward(&mut self, i: usize, b: &Vec<f32>) {
138        let mut j = i;
139        for r in b.iter() {
140            self.reward[j] = *r;
141            j += 1;
142            if j == self.capacity {
143                j = 0;
144            }
145        }
146    }
147
148    /// Pushes termination flags into the buffer at the specified index.
149    ///
150    /// # Arguments
151    ///
152    /// * `i` - Starting index for insertion
153    /// * `b` - Vector of termination flags to insert
154    #[inline]
155    fn push_is_terminated(&mut self, i: usize, b: &Vec<i8>) {
156        let mut j = i;
157        for d in b.iter() {
158            self.is_terminated[j] = *d;
159            j += 1;
160            if j == self.capacity {
161                j = 0;
162            }
163        }
164    }
165
166    /// Pushes truncation flags into the buffer at the specified index.
167    ///
168    /// # Arguments
169    ///
170    /// * `i` - Starting index for insertion
171    /// * `b` - Vector of truncation flags to insert
172    fn push_is_truncated(&mut self, i: usize, b: &Vec<i8>) {
173        let mut j = i;
174        for d in b.iter() {
175            self.is_truncated[j] = *d;
176            j += 1;
177            if j == self.capacity {
178                j = 0;
179            }
180        }
181    }
182
183    /// Samples rewards for the given indices.
184    ///
185    /// # Arguments
186    ///
187    /// * `ixs` - Indices to sample from
188    ///
189    /// # Returns
190    ///
191    /// Vector of sampled rewards
192    fn sample_reward(&self, ixs: &Vec<usize>) -> Vec<f32> {
193        ixs.iter().map(|ix| self.reward[*ix]).collect()
194    }
195
196    /// Samples termination flags for the given indices.
197    ///
198    /// # Arguments
199    ///
200    /// * `ixs` - Indices to sample from
201    ///
202    /// # Returns
203    ///
204    /// Vector of sampled termination flags
205    fn sample_is_terminated(&self, ixs: &Vec<usize>) -> Vec<i8> {
206        ixs.iter().map(|ix| self.is_terminated[*ix]).collect()
207    }
208
209    /// Samples truncation flags for the given indices.
210    ///
211    /// # Arguments
212    ///
213    /// * `ixs` - Indices to sample from
214    ///
215    /// # Returns
216    ///
217    /// Vector of sampled truncation flags
218    fn sample_is_truncated(&self, ixs: &Vec<usize>) -> Vec<i8> {
219        ixs.iter().map(|ix| self.is_truncated[*ix]).collect()
220    }
221
222    /// Sets priorities for newly added samples in prioritized experience replay.
223    ///
224    /// # Arguments
225    ///
226    /// * `batch_size` - Number of new samples to prioritize
227    fn set_priority(&mut self, batch_size: usize) {
228        let sum_tree = &mut self.per_state.as_mut().unwrap().sum_tree;
229        let max_p = sum_tree.max();
230
231        for j in 0..batch_size {
232            let i = (self.i + j) % self.capacity;
233            sum_tree.add(i, max_p);
234        }
235    }
236
237    /// Returns a batch containing all actions in the buffer.
238    ///
239    /// # Warning
240    ///
241    /// This method should be used with caution on large replay buffers
242    /// as it may consume significant memory.
243    pub fn whole_actions(&self) -> A {
244        let ixs = (0..self.size).collect::<Vec<_>>();
245        self.act.sample(&ixs)
246    }
247
248    /// Returns the number of terminated episodes in the buffer.
249    pub fn num_terminated_flags(&self) -> usize {
250        self.is_terminated
251            .iter()
252            .map(|is_terminated| *is_terminated as usize)
253            .sum()
254    }
255
256    /// Returns the number of truncated episodes in the buffer.
257    pub fn num_truncated_flags(&self) -> usize {
258        self.is_truncated
259            .iter()
260            .map(|is_truncated| *is_truncated as usize)
261            .sum()
262    }
263
264    /// Returns the sum of all rewards in the buffer.
265    pub fn sum_rewards(&self) -> f32 {
266        self.reward.iter().sum()
267    }
268}
269
270impl<O, A> ExperienceBufferBase for SimpleReplayBuffer<O, A>
271where
272    O: BatchBase,
273    A: BatchBase,
274{
275    type Item = GenericTransitionBatch<O, A>;
276
277    /// Returns the current number of transitions in the buffer.
278    fn len(&self) -> usize {
279        self.size
280    }
281
282    /// Adds a new transition to the buffer.
283    ///
284    /// # Arguments
285    ///
286    /// * `tr` - The transition to add
287    ///
288    /// # Returns
289    ///
290    /// `Ok(())` if the transition was added successfully
291    ///
292    /// # Errors
293    ///
294    /// Returns an error if the buffer is full and cannot accept more transitions
295    fn push(&mut self, tr: Self::Item) -> Result<()> {
296        let len = tr.len(); // batch size
297        let (obs, act, next_obs, reward, is_terminated, is_truncated, _, _) = tr.unpack();
298        self.obs.push(self.i, obs);
299        self.act.push(self.i, act);
300        self.next_obs.push(self.i, next_obs);
301        self.push_reward(self.i, &reward);
302        self.push_is_terminated(self.i, &is_terminated);
303        self.push_is_truncated(self.i, &is_truncated);
304
305        if self.per_state.is_some() {
306            self.set_priority(len)
307        };
308
309        self.i = (self.i + len) % self.capacity;
310        self.size += len;
311        if self.size >= self.capacity {
312            self.size = self.capacity;
313        }
314
315        Ok(())
316    }
317}
318
319impl<O, A> ReplayBufferBase for SimpleReplayBuffer<O, A>
320where
321    O: BatchBase,
322    A: BatchBase,
323{
324    type Config = SimpleReplayBufferConfig;
325    type Batch = GenericTransitionBatch<O, A>;
326
327    /// Creates a new replay buffer with the given configuration.
328    ///
329    /// # Arguments
330    ///
331    /// * `config` - Configuration for the replay buffer
332    ///
333    /// # Returns
334    ///
335    /// A new instance of the replay buffer
336    fn build(config: &Self::Config) -> Self {
337        let capacity = config.capacity;
338        let per_state = match &config.per_config {
339            Some(per_config) => Some(PerState::new(capacity, per_config)),
340            None => None,
341        };
342
343        Self {
344            capacity,
345            i: 0,
346            size: 0,
347            obs: O::new(capacity),
348            act: A::new(capacity),
349            next_obs: O::new(capacity),
350            reward: vec![0.; capacity],
351            is_terminated: vec![0; capacity],
352            is_truncated: vec![0; capacity],
353            rng: StdRng::seed_from_u64(config.seed as _),
354            per_state,
355        }
356    }
357
358    /// Samples a batch of transitions from the buffer.
359    ///
360    /// If prioritized experience replay is enabled, samples are selected
361    /// according to their priorities. Otherwise, uniform random sampling is used.
362    ///
363    /// # Arguments
364    ///
365    /// * `size` - Number of transitions to sample
366    ///
367    /// # Returns
368    ///
369    /// A batch of sampled transitions
370    ///
371    /// # Errors
372    ///
373    /// Returns an error if:
374    /// - The buffer is empty
375    /// - The requested batch size is larger than the buffer size
376    fn batch(&mut self, size: usize) -> Result<Self::Batch> {
377        let (ixs, weight) = if let Some(per_state) = &self.per_state {
378            let sum_tree = &per_state.sum_tree;
379            let beta = per_state.iw_scheduler.beta();
380            let (ixs, weight) = sum_tree.sample(size, beta);
381            let ixs = ixs.iter().map(|&ix| ix as usize).collect();
382            (ixs, Some(weight))
383        } else {
384            let ixs = (0..size)
385                // .map(|_| self.rng.usize(..self.size))
386                .map(|_| (self.rng.next_u32() as usize) % self.size)
387                .collect::<Vec<_>>();
388            let weight = None;
389            (ixs, weight)
390        };
391
392        Ok(Self::Batch {
393            obs: self.obs.sample(&ixs),
394            act: self.act.sample(&ixs),
395            next_obs: self.next_obs.sample(&ixs),
396            reward: self.sample_reward(&ixs),
397            is_terminated: self.sample_is_terminated(&ixs),
398            is_truncated: self.sample_is_truncated(&ixs),
399            ix_sample: Some(ixs),
400            weight,
401        })
402    }
403
404    /// Updates the priorities of transitions in the buffer.
405    ///
406    /// This method is used in prioritized experience replay to adjust
407    /// the sampling probabilities based on TD errors.
408    ///
409    /// # Arguments
410    ///
411    /// * `ixs` - Optional indices of transitions to update
412    /// * `td_errs` - Optional TD errors for the transitions
413    fn update_priority(&mut self, ixs: &Option<Vec<usize>>, td_errs: &Option<Vec<f32>>) {
414        if let Some(per_state) = &mut self.per_state {
415            let ixs = ixs
416                .as_ref()
417                .expect("ixs should be Some(_) in update_priority().");
418            let td_errs = td_errs
419                .as_ref()
420                .expect("td_errs should be Some(_) in update_priority().");
421            for (&ix, &td_err) in ixs.iter().zip(td_errs.iter()) {
422                per_state.sum_tree.update(ix, td_err);
423            }
424            per_state.iw_scheduler.add_n_opts();
425        }
426    }
427}