Skip to main content

content_extractor_rl/
replay_buffer.rs

1// ============================================================================
2// FILE: crates/content-extractor-rl/src/replay_buffer.rs
3// ============================================================================
4
5use rand::RngExt;
6use std::collections::VecDeque;
7
8/// Experience tuple
9#[derive(Clone)]
10pub struct Experience {
11    pub state: Vec<f32>,
12    pub action: (usize, Vec<f32>),
13    pub reward: f32,
14    pub next_state: Vec<f32>,
15    pub done: bool,
16}
17
18/// Prioritized replay buffer
19pub struct PrioritizedReplayBuffer {
20    capacity: usize,
21    buffer: VecDeque<Experience>,
22    priorities: Vec<f32>,
23    position: usize,
24    alpha: f64,
25    beta: f64,
26}
27
28impl PrioritizedReplayBuffer {
29    /// Create new prioritized replay buffer
30    pub fn new(capacity: usize, alpha: f64, beta: f64) -> Self {
31        Self {
32            capacity,
33            buffer: VecDeque::with_capacity(capacity),
34            priorities: vec![1.0; capacity],
35            position: 0,
36            alpha,
37            beta,
38        }
39    }
40
41    /// Add experience to buffer
42    pub fn add(&mut self, experience: Experience) {
43        let max_priority = self.priorities.iter()
44            .max_by(|a, b| a.partial_cmp(b).unwrap())
45            .copied()
46            .unwrap_or(1.0);
47
48        if self.buffer.len() < self.capacity {
49            self.buffer.push_back(experience);
50        } else {
51            self.buffer[self.position] = experience;
52        }
53
54        self.priorities[self.position] = max_priority;
55        self.position = (self.position + 1) % self.capacity;
56    }
57
58    /// Sample batch from buffer
59    pub fn sample(&self, batch_size: usize) -> Option<SampledBatch> {
60        if self.buffer.len() < batch_size {
61            return None;
62        }
63
64        let mut rng = rand::rng();
65
66        // Calculate sampling probabilities
67        let priorities: Vec<f32> = self.priorities[..self.buffer.len()]
68            .iter()
69            .map(|&p| p.powf(self.alpha as f32))
70            .collect();
71
72        let sum: f32 = priorities.iter().sum();
73        let probs: Vec<f32> = priorities.iter().map(|&p| p / sum).collect();
74
75        // Sample indices
76        let mut indices = Vec::with_capacity(batch_size);
77        let mut experiences = Vec::with_capacity(batch_size);
78
79        for _ in 0..batch_size {
80            let r: f32 = rng.random();
81            let mut cumsum = 0.0;
82            let mut idx = 0;
83
84            for (i, &prob) in probs.iter().enumerate() {
85                cumsum += prob;
86                if r <= cumsum {
87                    idx = i;
88                    break;
89                }
90            }
91
92            indices.push(idx);
93            experiences.push(self.buffer[idx].clone());
94        }
95
96        // Calculate importance sampling weights
97        let total = self.buffer.len() as f32;
98        let weights: Vec<f32> = indices.iter()
99            .map(|&idx| {
100                let prob = probs[idx];
101                (total * prob).powf(-self.beta as f32)
102            })
103            .collect();
104
105        let max_weight = weights.iter()
106            .max_by(|a, b| a.partial_cmp(b).unwrap())
107            .copied()
108            .unwrap_or(1.0);
109
110        let normalized_weights: Vec<f32> = weights.iter()
111            .map(|&w| w / max_weight)
112            .collect();
113
114        Some(SampledBatch {
115            experiences,
116            indices,
117            weights: normalized_weights,
118        })
119    }
120
121    /// Update priorities based on TD errors
122    pub fn update_priorities(&mut self, indices: &[usize], td_errors: &[f32]) {
123        for (&idx, &error) in indices.iter().zip(td_errors.iter()) {
124            self.priorities[idx] = error.abs() + 1e-6;
125        }
126    }
127
128    /// Get buffer length
129    pub fn len(&self) -> usize {
130        self.buffer.len()
131    }
132
133    /// Check if buffer is empty
134    pub fn is_empty(&self) -> bool {
135        self.buffer.is_empty()
136    }
137}
138
139/// Sampled batch from replay buffer
140pub struct SampledBatch {
141    pub experiences: Vec<Experience>,
142    pub indices: Vec<usize>,
143    pub weights: Vec<f32>,
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[test]
151    fn test_replay_buffer() {
152        let mut buffer = PrioritizedReplayBuffer::new(100, 0.6, 0.4);
153
154        let exp = Experience {
155            state: vec![0.0; 300],
156            action: (0, vec![0.0; 6]),
157            reward: 1.0,
158            next_state: vec![0.0; 300],
159            done: false,
160        };
161
162        buffer.add(exp);
163        assert_eq!(buffer.len(), 1);
164    }
165}