Skip to main content

rl_traits/
buffer.rs

1use rand::Rng;
2
3use crate::experience::Experience;
4
5/// A buffer that stores past experience for agent training.
6///
7/// Used primarily by off-policy algorithms (DQN, SAC, TD3) to break
8/// temporal correlations between training samples. On-policy algorithms
9/// (PPO, A2C) typically collect fixed-length trajectories instead and
10/// don't need this trait — they can use a plain `Vec<Experience<O, A>>`.
11///
12/// # Implementing this trait
13///
14/// The most common implementation is a circular buffer with a fixed
15/// capacity that overwrites the oldest experience when full. Concrete
16/// implementations live in ember-rl, not here.
17///
18/// # Bounds
19///
20/// `O: Clone + Send + Sync` and `A: Clone + Send + Sync` are required
21/// because sampling returns owned `Experience` values (not references),
22/// and buffers may be accessed across threads during async training.
23pub trait ReplayBuffer<O, A>
24where
25    O: Clone + Send + Sync,
26    A: Clone + Send + Sync,
27{
28    /// Add a new experience to the buffer.
29    ///
30    /// If the buffer is at capacity, implementations should overwrite the
31    /// oldest experience (FIFO eviction).
32    fn push(&mut self, experience: Experience<O, A>);
33
34    /// Sample a random batch of `batch_size` experiences.
35    ///
36    /// Sampling is done with replacement. The caller supplies the RNG so
37    /// sampling randomness can be seeded and controlled independently.
38    ///
39    /// # Panics
40    ///
41    /// Implementations may panic if `batch_size > self.len()`.
42    fn sample(&self, batch_size: usize, rng: &mut impl Rng) -> Vec<Experience<O, A>>;
43
44    /// The number of experiences currently stored.
45    fn len(&self) -> usize;
46
47    /// Returns `true` if the buffer contains no experiences.
48    fn is_empty(&self) -> bool {
49        self.len() == 0
50    }
51
52    /// The maximum number of experiences the buffer can hold, if bounded.
53    ///
54    /// Returns `None` for unbounded buffers (e.g. trajectory collectors).
55    fn capacity(&self) -> Option<usize>;
56
57    /// Returns `true` if the buffer is at capacity.
58    ///
59    /// When full, the next `push()` will overwrite the oldest experience.
60    fn is_full(&self) -> bool {
61        self.capacity().is_some_and(|cap| self.len() >= cap)
62    }
63
64    /// Returns `true` if the buffer has enough experience to sample a
65    /// batch of the given size.
66    ///
67    /// Useful for deciding when to start training in the warm-up phase.
68    fn ready_for(&self, batch_size: usize) -> bool {
69        self.len() >= batch_size
70    }
71}