converge_knowledge/learning/
replay.rs1use rand::Rng;
4use rand::seq::SliceRandom;
5use std::collections::VecDeque;
6
7pub struct ReplayBuffer<T> {
12 capacity: usize,
14
15 buffer: VecDeque<T>,
17
18 total_seen: u64,
20}
21
22impl<T: Clone> ReplayBuffer<T> {
23 pub fn new(capacity: usize) -> Self {
25 Self {
26 capacity,
27 buffer: VecDeque::with_capacity(capacity),
28 total_seen: 0,
29 }
30 }
31
32 pub fn add(&mut self, experience: T) {
34 self.total_seen += 1;
35
36 if self.buffer.len() < self.capacity {
37 self.buffer.push_back(experience);
39 } else {
40 let mut rng = rand::thread_rng();
42 let replace_prob = self.capacity as f64 / self.total_seen as f64;
43
44 if rng.r#gen::<f64>() < replace_prob {
45 let idx = rng.gen_range(0..self.capacity);
46 self.buffer[idx] = experience;
47 }
48 }
49 }
50
51 pub fn sample(&self, n: usize) -> Vec<T> {
53 if self.buffer.is_empty() {
54 return Vec::new();
55 }
56
57 let n = n.min(self.buffer.len());
58 let mut rng = rand::thread_rng();
59
60 let indices: Vec<usize> = {
61 let mut all_indices: Vec<usize> = (0..self.buffer.len()).collect();
62 all_indices.shuffle(&mut rng);
63 all_indices.into_iter().take(n).collect()
64 };
65
66 indices
67 .into_iter()
68 .filter_map(|i| self.buffer.get(i).cloned())
69 .collect()
70 }
71
72 pub fn len(&self) -> usize {
74 self.buffer.len()
75 }
76
77 pub fn is_empty(&self) -> bool {
79 self.buffer.is_empty()
80 }
81
82 pub fn total_seen(&self) -> u64 {
84 self.total_seen
85 }
86
87 pub fn clear(&mut self) {
89 self.buffer.clear();
90 self.total_seen = 0;
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[test]
99 fn test_replay_buffer_add() {
100 let mut buffer = ReplayBuffer::new(100);
101
102 for i in 0..50 {
103 buffer.add(i);
104 }
105
106 assert_eq!(buffer.len(), 50);
107 assert_eq!(buffer.total_seen(), 50);
108 }
109
110 #[test]
111 fn test_replay_buffer_reservoir() {
112 let mut buffer = ReplayBuffer::new(10);
113
114 for i in 0..1000 {
115 buffer.add(i);
116 }
117
118 assert_eq!(buffer.len(), 10);
119 assert_eq!(buffer.total_seen(), 1000);
120 }
121
122 #[test]
123 fn test_replay_buffer_sample() {
124 let mut buffer = ReplayBuffer::new(100);
125
126 for i in 0..100 {
127 buffer.add(i);
128 }
129
130 let samples = buffer.sample(10);
131 assert_eq!(samples.len(), 10);
132 }
133}