Skip to main content

converge_knowledge/learning/
replay.rs

1//! Experience replay buffer for stable learning.
2
3use rand::Rng;
4use rand::seq::SliceRandom;
5use std::collections::VecDeque;
6
7/// Experience replay buffer using reservoir sampling.
8///
9/// This ensures uniform coverage of past experiences for stable learning,
10/// preventing catastrophic forgetting as described in the ruvector-gnn crate.
11pub struct ReplayBuffer<T> {
12    /// Maximum buffer capacity.
13    capacity: usize,
14
15    /// Stored experiences.
16    buffer: VecDeque<T>,
17
18    /// Total experiences seen (for reservoir sampling).
19    total_seen: u64,
20}
21
22impl<T: Clone> ReplayBuffer<T> {
23    /// Create a new replay buffer with the given capacity.
24    pub fn new(capacity: usize) -> Self {
25        Self {
26            capacity,
27            buffer: VecDeque::with_capacity(capacity),
28            total_seen: 0,
29        }
30    }
31
32    /// Add an experience to the buffer using reservoir sampling.
33    pub fn add(&mut self, experience: T) {
34        self.total_seen += 1;
35
36        if self.buffer.len() < self.capacity {
37            // Buffer not full: just add
38            self.buffer.push_back(experience);
39        } else {
40            // Reservoir sampling: replace with probability capacity/total_seen
41            let mut rng = rand::thread_rng();
42            let replace_prob = self.capacity as f64 / self.total_seen as f64;
43
44            if rng.r#gen::<f64>() < replace_prob {
45                let idx = rng.gen_range(0..self.capacity);
46                self.buffer[idx] = experience;
47            }
48        }
49    }
50
51    /// Sample n experiences uniformly at random.
52    pub fn sample(&self, n: usize) -> Vec<T> {
53        if self.buffer.is_empty() {
54            return Vec::new();
55        }
56
57        let n = n.min(self.buffer.len());
58        let mut rng = rand::thread_rng();
59
60        let indices: Vec<usize> = {
61            let mut all_indices: Vec<usize> = (0..self.buffer.len()).collect();
62            all_indices.shuffle(&mut rng);
63            all_indices.into_iter().take(n).collect()
64        };
65
66        indices
67            .into_iter()
68            .filter_map(|i| self.buffer.get(i).cloned())
69            .collect()
70    }
71
72    /// Get the current buffer size.
73    pub fn len(&self) -> usize {
74        self.buffer.len()
75    }
76
77    /// Check if buffer is empty.
78    pub fn is_empty(&self) -> bool {
79        self.buffer.is_empty()
80    }
81
82    /// Get total experiences seen.
83    pub fn total_seen(&self) -> u64 {
84        self.total_seen
85    }
86
87    /// Clear the buffer.
88    pub fn clear(&mut self) {
89        self.buffer.clear();
90        self.total_seen = 0;
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn test_replay_buffer_add() {
100        let mut buffer = ReplayBuffer::new(100);
101
102        for i in 0..50 {
103            buffer.add(i);
104        }
105
106        assert_eq!(buffer.len(), 50);
107        assert_eq!(buffer.total_seen(), 50);
108    }
109
110    #[test]
111    fn test_replay_buffer_reservoir() {
112        let mut buffer = ReplayBuffer::new(10);
113
114        for i in 0..1000 {
115            buffer.add(i);
116        }
117
118        assert_eq!(buffer.len(), 10);
119        assert_eq!(buffer.total_seen(), 1000);
120    }
121
122    #[test]
123    fn test_replay_buffer_sample() {
124        let mut buffer = ReplayBuffer::new(100);
125
126        for i in 0..100 {
127            buffer.add(i);
128        }
129
130        let samples = buffer.sample(10);
131        assert_eq!(samples.len(), 10);
132    }
133}