content_extractor_rl/
replay_buffer.rs1use rand::RngExt;
6use std::collections::VecDeque;
7
8#[derive(Clone)]
10pub struct Experience {
11 pub state: Vec<f32>,
12 pub action: (usize, Vec<f32>),
13 pub reward: f32,
14 pub next_state: Vec<f32>,
15 pub done: bool,
16}
17
18pub struct PrioritizedReplayBuffer {
20 capacity: usize,
21 buffer: VecDeque<Experience>,
22 priorities: Vec<f32>,
23 position: usize,
24 alpha: f64,
25 beta: f64,
26}
27
28impl PrioritizedReplayBuffer {
29 pub fn new(capacity: usize, alpha: f64, beta: f64) -> Self {
31 Self {
32 capacity,
33 buffer: VecDeque::with_capacity(capacity),
34 priorities: vec![1.0; capacity],
35 position: 0,
36 alpha,
37 beta,
38 }
39 }
40
41 pub fn add(&mut self, experience: Experience) {
43 let max_priority = self.priorities.iter()
44 .max_by(|a, b| a.partial_cmp(b).unwrap())
45 .copied()
46 .unwrap_or(1.0);
47
48 if self.buffer.len() < self.capacity {
49 self.buffer.push_back(experience);
50 } else {
51 self.buffer[self.position] = experience;
52 }
53
54 self.priorities[self.position] = max_priority;
55 self.position = (self.position + 1) % self.capacity;
56 }
57
58 pub fn sample(&self, batch_size: usize) -> Option<SampledBatch> {
60 if self.buffer.len() < batch_size {
61 return None;
62 }
63
64 let mut rng = rand::rng();
65
66 let priorities: Vec<f32> = self.priorities[..self.buffer.len()]
68 .iter()
69 .map(|&p| p.powf(self.alpha as f32))
70 .collect();
71
72 let sum: f32 = priorities.iter().sum();
73 let probs: Vec<f32> = priorities.iter().map(|&p| p / sum).collect();
74
75 let mut indices = Vec::with_capacity(batch_size);
77 let mut experiences = Vec::with_capacity(batch_size);
78
79 for _ in 0..batch_size {
80 let r: f32 = rng.random();
81 let mut cumsum = 0.0;
82 let mut idx = 0;
83
84 for (i, &prob) in probs.iter().enumerate() {
85 cumsum += prob;
86 if r <= cumsum {
87 idx = i;
88 break;
89 }
90 }
91
92 indices.push(idx);
93 experiences.push(self.buffer[idx].clone());
94 }
95
96 let total = self.buffer.len() as f32;
98 let weights: Vec<f32> = indices.iter()
99 .map(|&idx| {
100 let prob = probs[idx];
101 (total * prob).powf(-self.beta as f32)
102 })
103 .collect();
104
105 let max_weight = weights.iter()
106 .max_by(|a, b| a.partial_cmp(b).unwrap())
107 .copied()
108 .unwrap_or(1.0);
109
110 let normalized_weights: Vec<f32> = weights.iter()
111 .map(|&w| w / max_weight)
112 .collect();
113
114 Some(SampledBatch {
115 experiences,
116 indices,
117 weights: normalized_weights,
118 })
119 }
120
121 pub fn update_priorities(&mut self, indices: &[usize], td_errors: &[f32]) {
123 for (&idx, &error) in indices.iter().zip(td_errors.iter()) {
124 self.priorities[idx] = error.abs() + 1e-6;
125 }
126 }
127
128 pub fn len(&self) -> usize {
130 self.buffer.len()
131 }
132
133 pub fn is_empty(&self) -> bool {
135 self.buffer.is_empty()
136 }
137}
138
139pub struct SampledBatch {
141 pub experiences: Vec<Experience>,
142 pub indices: Vec<usize>,
143 pub weights: Vec<f32>,
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn test_replay_buffer() {
152 let mut buffer = PrioritizedReplayBuffer::new(100, 0.6, 0.4);
153
154 let exp = Experience {
155 state: vec![0.0; 300],
156 action: (0, vec![0.0; 6]),
157 reward: 1.0,
158 next_state: vec![0.0; 300],
159 done: false,
160 };
161
162 buffer.add(exp);
163 assert_eq!(buffer.len(), 1);
164 }
165}