Skip to main content

ember_rl/algorithms/dqn/
replay.rs

1use rand::Rng;
2use rl_traits::{Experience, ReplayBuffer};
3
4/// A fixed-capacity circular replay buffer.
5///
6/// Stores experience tuples and overwrites the oldest when full.
7/// Implements `rl_traits::ReplayBuffer` so it's usable by any algorithm
8/// that depends on that trait, not just DQN.
9///
10/// # Implementation notes
11///
12/// Uses a `Vec` pre-allocated to `capacity` with a write cursor that wraps
13/// around. This avoids any allocation after construction and gives O(1) push
14/// and O(batch_size) sample.
15pub struct CircularBuffer<O, A> {
16    storage: Vec<Experience<O, A>>,
17    capacity: usize,
18    cursor: usize,
19    len: usize,
20}
21
22impl<O: Clone + Send + Sync, A: Clone + Send + Sync> CircularBuffer<O, A> {
23    pub fn new(capacity: usize) -> Self {
24        assert!(capacity > 0, "buffer capacity must be > 0");
25        Self {
26            storage: Vec::with_capacity(capacity),
27            capacity,
28            cursor: 0,
29            len: 0,
30        }
31    }
32}
33
34impl<O, A> ReplayBuffer<O, A> for CircularBuffer<O, A>
35where
36    O: Clone + Send + Sync,
37    A: Clone + Send + Sync,
38{
39    fn push(&mut self, experience: Experience<O, A>) {
40        if self.storage.len() < self.capacity {
41            self.storage.push(experience);
42        } else {
43            self.storage[self.cursor] = experience;
44        }
45        self.cursor = (self.cursor + 1) % self.capacity;
46        self.len = (self.len + 1).min(self.capacity);
47    }
48
49    fn sample(&self, batch_size: usize, rng: &mut impl Rng) -> Vec<Experience<O, A>> {
50        assert!(
51            batch_size <= self.len,
52            "cannot sample {batch_size} from buffer of size {}",
53            self.len
54        );
55        (0..batch_size)
56            .map(|_| {
57                let idx = rng.gen_range(0..self.len);
58                self.storage[idx].clone()
59            })
60            .collect()
61    }
62
63    fn len(&self) -> usize {
64        self.len
65    }
66
67    fn capacity(&self) -> Option<usize> {
68        Some(self.capacity)
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75    use rl_traits::EpisodeStatus;
76    use rand::SeedableRng;
77    use rand::rngs::SmallRng;
78
79    fn make_exp(reward: f64) -> Experience<f32, usize> {
80        Experience::new(0.0f32, 0usize, reward, 0.0f32, EpisodeStatus::Continuing)
81    }
82
83    #[test]
84    fn push_and_len() {
85        let mut buf: CircularBuffer<f32, usize> = CircularBuffer::new(4);
86        assert!(buf.is_empty());
87        buf.push(make_exp(1.0));
88        buf.push(make_exp(2.0));
89        assert_eq!(buf.len(), 2);
90    }
91
92    #[test]
93    fn overwrites_when_full() {
94        let mut buf: CircularBuffer<f32, usize> = CircularBuffer::new(3);
95        buf.push(make_exp(1.0));
96        buf.push(make_exp(2.0));
97        buf.push(make_exp(3.0));
98        buf.push(make_exp(4.0)); // overwrites slot 0
99        assert_eq!(buf.len(), 3);
100        assert!(buf.is_full());
101    }
102
103    #[test]
104    fn sample_returns_correct_batch_size() {
105        let mut buf: CircularBuffer<f32, usize> = CircularBuffer::new(10);
106        for i in 0..10 {
107            buf.push(make_exp(i as f64));
108        }
109        let mut rng = SmallRng::seed_from_u64(42);
110        let batch = buf.sample(4, &mut rng);
111        assert_eq!(batch.len(), 4);
112    }
113
114    #[test]
115    fn ready_for_respects_warmup() {
116        let mut buf: CircularBuffer<f32, usize> = CircularBuffer::new(100);
117        assert!(!buf.ready_for(64));
118        for i in 0..64 {
119            buf.push(make_exp(i as f64));
120        }
121        assert!(buf.ready_for(64));
122    }
123}