border_core/generic_replay_buffer/base.rs
1//! Generic implementation of replay buffers for reinforcement learning.
2//!
3//! This module provides a generic implementation of replay buffers that can store
4//! and sample transitions of arbitrary observation and action types. It supports:
5//! - Standard experience replay
6//! - Prioritized experience replay (PER)
7//! - Importance sampling weights for off-policy learning
8
9mod iw_scheduler;
10mod sum_tree;
11use super::{config::PerConfig, BatchBase, GenericTransitionBatch, SimpleReplayBufferConfig};
12use crate::{ExperienceBufferBase, ReplayBufferBase, TransitionBatch};
13use anyhow::Result;
14pub use iw_scheduler::IwScheduler;
15use rand::{rngs::StdRng, RngCore, SeedableRng};
16use sum_tree::SumTree;
17pub use sum_tree::WeightNormalizer;
18
19/// State management for Prioritized Experience Replay (PER).
20///
21/// This struct maintains the necessary state for PER, including:
22/// - A sum tree for efficient priority sampling
23/// - An importance weight scheduler for adjusting sample weights
24struct PerState {
25 /// A sum tree data structure for efficient priority sampling.
26 sum_tree: SumTree,
27
28 /// Scheduler for importance sampling weights.
29 iw_scheduler: IwScheduler,
30}
31
32impl PerState {
33 /// Creates a new PER state with the given configuration.
34 ///
35 /// # Arguments
36 ///
37 /// * `capacity` - Maximum number of transitions to store
38 /// * `per_config` - Configuration for prioritized experience replay
39 fn new(capacity: usize, per_config: &PerConfig) -> Self {
40 Self {
41 sum_tree: SumTree::new(capacity, per_config.alpha, per_config.normalize),
42 iw_scheduler: IwScheduler::new(
43 per_config.beta_0,
44 per_config.beta_final,
45 per_config.n_opts_final,
46 ),
47 }
48 }
49}
50
51/// A generic implementation of a replay buffer for reinforcement learning.
52///
53/// This buffer can store transitions of arbitrary observation and action types,
54/// making it suitable for a wide range of reinforcement learning tasks. It supports:
55/// - Standard experience replay
56/// - Prioritized experience replay (optional)
57/// - Efficient sampling and storage
58///
59/// # Type Parameters
60///
61/// * `O` - The type of observations, must implement [`BatchBase`]
62/// * `A` - The type of actions, must implement [`BatchBase`]
63///
64/// # Examples
65///
66/// ```ignore
67/// let config = SimpleReplayBufferConfig {
68/// capacity: 10000,
69/// per_config: Some(PerConfig {
70/// alpha: 0.6,
71/// beta_0: 0.4,
72/// beta_final: 1.0,
73/// n_opts_final: 100000,
74/// normalize: true,
75/// }),
76/// };
77///
78/// let mut buffer = SimpleReplayBuffer::<Tensor, Tensor>::build(&config);
79///
80/// // Add transitions
81/// buffer.push(transition)?;
82///
83/// // Sample a batch
84/// let batch = buffer.batch(32)?;
85/// ```
86pub struct SimpleReplayBuffer<O, A>
87where
88 O: BatchBase,
89 A: BatchBase,
90{
91 /// Maximum number of transitions that can be stored.
92 capacity: usize,
93
94 /// Current insertion index.
95 i: usize,
96
97 /// Current number of stored transitions.
98 size: usize,
99
100 /// Storage for observations.
101 obs: O,
102
103 /// Storage for actions.
104 act: A,
105
106 /// Storage for next observations.
107 next_obs: O,
108
109 /// Storage for rewards.
110 reward: Vec<f32>,
111
112 /// Storage for termination flags.
113 is_terminated: Vec<i8>,
114
115 /// Storage for truncation flags.
116 is_truncated: Vec<i8>,
117
118 /// Random number generator for sampling.
119 rng: StdRng,
120
121 /// State for prioritized experience replay, if enabled.
122 per_state: Option<PerState>,
123}
124
125impl<O, A> SimpleReplayBuffer<O, A>
126where
127 O: BatchBase,
128 A: BatchBase,
129{
130 /// Pushes rewards into the buffer at the specified index.
131 ///
132 /// # Arguments
133 ///
134 /// * `i` - Starting index for insertion
135 /// * `b` - Vector of rewards to insert
136 #[inline]
137 fn push_reward(&mut self, i: usize, b: &Vec<f32>) {
138 let mut j = i;
139 for r in b.iter() {
140 self.reward[j] = *r;
141 j += 1;
142 if j == self.capacity {
143 j = 0;
144 }
145 }
146 }
147
148 /// Pushes termination flags into the buffer at the specified index.
149 ///
150 /// # Arguments
151 ///
152 /// * `i` - Starting index for insertion
153 /// * `b` - Vector of termination flags to insert
154 #[inline]
155 fn push_is_terminated(&mut self, i: usize, b: &Vec<i8>) {
156 let mut j = i;
157 for d in b.iter() {
158 self.is_terminated[j] = *d;
159 j += 1;
160 if j == self.capacity {
161 j = 0;
162 }
163 }
164 }
165
166 /// Pushes truncation flags into the buffer at the specified index.
167 ///
168 /// # Arguments
169 ///
170 /// * `i` - Starting index for insertion
171 /// * `b` - Vector of truncation flags to insert
172 fn push_is_truncated(&mut self, i: usize, b: &Vec<i8>) {
173 let mut j = i;
174 for d in b.iter() {
175 self.is_truncated[j] = *d;
176 j += 1;
177 if j == self.capacity {
178 j = 0;
179 }
180 }
181 }
182
183 /// Samples rewards for the given indices.
184 ///
185 /// # Arguments
186 ///
187 /// * `ixs` - Indices to sample from
188 ///
189 /// # Returns
190 ///
191 /// Vector of sampled rewards
192 fn sample_reward(&self, ixs: &Vec<usize>) -> Vec<f32> {
193 ixs.iter().map(|ix| self.reward[*ix]).collect()
194 }
195
196 /// Samples termination flags for the given indices.
197 ///
198 /// # Arguments
199 ///
200 /// * `ixs` - Indices to sample from
201 ///
202 /// # Returns
203 ///
204 /// Vector of sampled termination flags
205 fn sample_is_terminated(&self, ixs: &Vec<usize>) -> Vec<i8> {
206 ixs.iter().map(|ix| self.is_terminated[*ix]).collect()
207 }
208
209 /// Samples truncation flags for the given indices.
210 ///
211 /// # Arguments
212 ///
213 /// * `ixs` - Indices to sample from
214 ///
215 /// # Returns
216 ///
217 /// Vector of sampled truncation flags
218 fn sample_is_truncated(&self, ixs: &Vec<usize>) -> Vec<i8> {
219 ixs.iter().map(|ix| self.is_truncated[*ix]).collect()
220 }
221
222 /// Sets priorities for newly added samples in prioritized experience replay.
223 ///
224 /// # Arguments
225 ///
226 /// * `batch_size` - Number of new samples to prioritize
227 fn set_priority(&mut self, batch_size: usize) {
228 let sum_tree = &mut self.per_state.as_mut().unwrap().sum_tree;
229 let max_p = sum_tree.max();
230
231 for j in 0..batch_size {
232 let i = (self.i + j) % self.capacity;
233 sum_tree.add(i, max_p);
234 }
235 }
236
237 /// Returns a batch containing all actions in the buffer.
238 ///
239 /// # Warning
240 ///
241 /// This method should be used with caution on large replay buffers
242 /// as it may consume significant memory.
243 pub fn whole_actions(&self) -> A {
244 let ixs = (0..self.size).collect::<Vec<_>>();
245 self.act.sample(&ixs)
246 }
247
248 /// Returns the number of terminated episodes in the buffer.
249 pub fn num_terminated_flags(&self) -> usize {
250 self.is_terminated
251 .iter()
252 .map(|is_terminated| *is_terminated as usize)
253 .sum()
254 }
255
256 /// Returns the number of truncated episodes in the buffer.
257 pub fn num_truncated_flags(&self) -> usize {
258 self.is_truncated
259 .iter()
260 .map(|is_truncated| *is_truncated as usize)
261 .sum()
262 }
263
264 /// Returns the sum of all rewards in the buffer.
265 pub fn sum_rewards(&self) -> f32 {
266 self.reward.iter().sum()
267 }
268}
269
270impl<O, A> ExperienceBufferBase for SimpleReplayBuffer<O, A>
271where
272 O: BatchBase,
273 A: BatchBase,
274{
275 type Item = GenericTransitionBatch<O, A>;
276
277 /// Returns the current number of transitions in the buffer.
278 fn len(&self) -> usize {
279 self.size
280 }
281
282 /// Adds a new transition to the buffer.
283 ///
284 /// # Arguments
285 ///
286 /// * `tr` - The transition to add
287 ///
288 /// # Returns
289 ///
290 /// `Ok(())` if the transition was added successfully
291 ///
292 /// # Errors
293 ///
294 /// Returns an error if the buffer is full and cannot accept more transitions
295 fn push(&mut self, tr: Self::Item) -> Result<()> {
296 let len = tr.len(); // batch size
297 let (obs, act, next_obs, reward, is_terminated, is_truncated, _, _) = tr.unpack();
298 self.obs.push(self.i, obs);
299 self.act.push(self.i, act);
300 self.next_obs.push(self.i, next_obs);
301 self.push_reward(self.i, &reward);
302 self.push_is_terminated(self.i, &is_terminated);
303 self.push_is_truncated(self.i, &is_truncated);
304
305 if self.per_state.is_some() {
306 self.set_priority(len)
307 };
308
309 self.i = (self.i + len) % self.capacity;
310 self.size += len;
311 if self.size >= self.capacity {
312 self.size = self.capacity;
313 }
314
315 Ok(())
316 }
317}
318
319impl<O, A> ReplayBufferBase for SimpleReplayBuffer<O, A>
320where
321 O: BatchBase,
322 A: BatchBase,
323{
324 type Config = SimpleReplayBufferConfig;
325 type Batch = GenericTransitionBatch<O, A>;
326
327 /// Creates a new replay buffer with the given configuration.
328 ///
329 /// # Arguments
330 ///
331 /// * `config` - Configuration for the replay buffer
332 ///
333 /// # Returns
334 ///
335 /// A new instance of the replay buffer
336 fn build(config: &Self::Config) -> Self {
337 let capacity = config.capacity;
338 let per_state = match &config.per_config {
339 Some(per_config) => Some(PerState::new(capacity, per_config)),
340 None => None,
341 };
342
343 Self {
344 capacity,
345 i: 0,
346 size: 0,
347 obs: O::new(capacity),
348 act: A::new(capacity),
349 next_obs: O::new(capacity),
350 reward: vec![0.; capacity],
351 is_terminated: vec![0; capacity],
352 is_truncated: vec![0; capacity],
353 rng: StdRng::seed_from_u64(config.seed as _),
354 per_state,
355 }
356 }
357
358 /// Samples a batch of transitions from the buffer.
359 ///
360 /// If prioritized experience replay is enabled, samples are selected
361 /// according to their priorities. Otherwise, uniform random sampling is used.
362 ///
363 /// # Arguments
364 ///
365 /// * `size` - Number of transitions to sample
366 ///
367 /// # Returns
368 ///
369 /// A batch of sampled transitions
370 ///
371 /// # Errors
372 ///
373 /// Returns an error if:
374 /// - The buffer is empty
375 /// - The requested batch size is larger than the buffer size
376 fn batch(&mut self, size: usize) -> Result<Self::Batch> {
377 let (ixs, weight) = if let Some(per_state) = &self.per_state {
378 let sum_tree = &per_state.sum_tree;
379 let beta = per_state.iw_scheduler.beta();
380 let (ixs, weight) = sum_tree.sample(size, beta);
381 let ixs = ixs.iter().map(|&ix| ix as usize).collect();
382 (ixs, Some(weight))
383 } else {
384 let ixs = (0..size)
385 // .map(|_| self.rng.usize(..self.size))
386 .map(|_| (self.rng.next_u32() as usize) % self.size)
387 .collect::<Vec<_>>();
388 let weight = None;
389 (ixs, weight)
390 };
391
392 Ok(Self::Batch {
393 obs: self.obs.sample(&ixs),
394 act: self.act.sample(&ixs),
395 next_obs: self.next_obs.sample(&ixs),
396 reward: self.sample_reward(&ixs),
397 is_terminated: self.sample_is_terminated(&ixs),
398 is_truncated: self.sample_is_truncated(&ixs),
399 ix_sample: Some(ixs),
400 weight,
401 })
402 }
403
404 /// Updates the priorities of transitions in the buffer.
405 ///
406 /// This method is used in prioritized experience replay to adjust
407 /// the sampling probabilities based on TD errors.
408 ///
409 /// # Arguments
410 ///
411 /// * `ixs` - Optional indices of transitions to update
412 /// * `td_errs` - Optional TD errors for the transitions
413 fn update_priority(&mut self, ixs: &Option<Vec<usize>>, td_errs: &Option<Vec<f32>>) {
414 if let Some(per_state) = &mut self.per_state {
415 let ixs = ixs
416 .as_ref()
417 .expect("ixs should be Some(_) in update_priority().");
418 let td_errs = td_errs
419 .as_ref()
420 .expect("td_errs should be Some(_) in update_priority().");
421 for (&ix, &td_err) in ixs.iter().zip(td_errs.iter()) {
422 per_state.sum_tree.update(ix, td_err);
423 }
424 per_state.iw_scheduler.add_n_opts();
425 }
426 }
427}