ember_rl/algorithms/dqn/
replay.rs1use rand::Rng;
2use rl_traits::{Experience, ReplayBuffer};
3
4pub struct CircularBuffer<O, A> {
16 storage: Vec<Experience<O, A>>,
17 capacity: usize,
18 cursor: usize,
19 len: usize,
20}
21
22impl<O: Clone + Send + Sync, A: Clone + Send + Sync> CircularBuffer<O, A> {
23 pub fn new(capacity: usize) -> Self {
24 assert!(capacity > 0, "buffer capacity must be > 0");
25 Self {
26 storage: Vec::with_capacity(capacity),
27 capacity,
28 cursor: 0,
29 len: 0,
30 }
31 }
32}
33
34impl<O, A> ReplayBuffer<O, A> for CircularBuffer<O, A>
35where
36 O: Clone + Send + Sync,
37 A: Clone + Send + Sync,
38{
39 fn push(&mut self, experience: Experience<O, A>) {
40 if self.storage.len() < self.capacity {
41 self.storage.push(experience);
42 } else {
43 self.storage[self.cursor] = experience;
44 }
45 self.cursor = (self.cursor + 1) % self.capacity;
46 self.len = (self.len + 1).min(self.capacity);
47 }
48
49 fn sample(&self, batch_size: usize, rng: &mut impl Rng) -> Vec<Experience<O, A>> {
50 assert!(
51 batch_size <= self.len,
52 "cannot sample {batch_size} from buffer of size {}",
53 self.len
54 );
55 (0..batch_size)
56 .map(|_| {
57 let idx = rng.gen_range(0..self.len);
58 self.storage[idx].clone()
59 })
60 .collect()
61 }
62
63 fn len(&self) -> usize {
64 self.len
65 }
66
67 fn capacity(&self) -> Option<usize> {
68 Some(self.capacity)
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75 use rl_traits::EpisodeStatus;
76 use rand::SeedableRng;
77 use rand::rngs::SmallRng;
78
79 fn make_exp(reward: f64) -> Experience<f32, usize> {
80 Experience::new(0.0f32, 0usize, reward, 0.0f32, EpisodeStatus::Continuing)
81 }
82
83 #[test]
84 fn push_and_len() {
85 let mut buf: CircularBuffer<f32, usize> = CircularBuffer::new(4);
86 assert!(buf.is_empty());
87 buf.push(make_exp(1.0));
88 buf.push(make_exp(2.0));
89 assert_eq!(buf.len(), 2);
90 }
91
92 #[test]
93 fn overwrites_when_full() {
94 let mut buf: CircularBuffer<f32, usize> = CircularBuffer::new(3);
95 buf.push(make_exp(1.0));
96 buf.push(make_exp(2.0));
97 buf.push(make_exp(3.0));
98 buf.push(make_exp(4.0)); assert_eq!(buf.len(), 3);
100 assert!(buf.is_full());
101 }
102
103 #[test]
104 fn sample_returns_correct_batch_size() {
105 let mut buf: CircularBuffer<f32, usize> = CircularBuffer::new(10);
106 for i in 0..10 {
107 buf.push(make_exp(i as f64));
108 }
109 let mut rng = SmallRng::seed_from_u64(42);
110 let batch = buf.sample(4, &mut rng);
111 assert_eq!(batch.len(), 4);
112 }
113
114 #[test]
115 fn ready_for_respects_warmup() {
116 let mut buf: CircularBuffer<f32, usize> = CircularBuffer::new(100);
117 assert!(!buf.ready_for(64));
118 for i in 0..64 {
119 buf.push(make_exp(i as f64));
120 }
121 assert!(buf.ready_for(64));
122 }
123}