1use crate::buffer::replay::Transition;
34use crate::error::{RlError, RlResult};
35use crate::handle::RlHandle;
36
37#[derive(Debug, Clone)]
41struct SumTree {
42 n: usize, tree: Vec<f64>, min_tree: Vec<f64>, }
46
47impl SumTree {
48 fn new(n: usize) -> Self {
49 Self {
50 n,
51 tree: vec![0.0; 2 * n],
52 min_tree: vec![f64::MAX; 2 * n],
53 }
54 }
55
56 fn update(&mut self, i: usize, priority: f64) {
58 let pos = i + self.n; self.tree[pos] = priority;
60 self.min_tree[pos] = priority;
61 let mut p = pos >> 1;
62 while p >= 1 {
63 self.tree[p] = self.tree[2 * p] + self.tree[2 * p + 1];
64 self.min_tree[p] = self.min_tree[2 * p].min(self.min_tree[2 * p + 1]);
65 p >>= 1;
66 }
67 }
68
69 #[inline]
71 fn total(&self) -> f64 {
72 self.tree[1]
73 }
74
75 #[inline]
77 fn min_priority(&self) -> f64 {
78 self.min_tree[1]
79 }
80
81 fn find(&self, value: f64) -> usize {
83 let mut node = 1_usize;
84 let mut v = value;
85 while node < self.n {
86 let left = 2 * node;
87 if self.tree[left] >= v {
88 node = left;
89 } else {
90 v -= self.tree[left];
91 node = left + 1;
92 }
93 }
94 node - self.n }
96
97 fn priority_at(&self, i: usize) -> f64 {
99 self.tree[i + self.n]
100 }
101}
102
103#[derive(Debug, Clone)]
107pub struct PrioritySample {
108 pub transition: Transition,
110 pub index: usize,
112 pub weight: f32,
114}
115
116#[derive(Debug, Clone)]
119pub struct PrioritizedReplayBuffer {
120 capacity: usize,
121 obs_dim: usize,
122 act_dim: usize,
123 obs: Vec<f32>,
125 actions: Vec<f32>,
126 rewards: Vec<f32>,
127 next_obs: Vec<f32>,
128 dones: Vec<f32>,
129 tree: SumTree,
131 alpha: f64,
133 beta: f64,
135 max_priority: f64,
137 head: usize,
139 size: usize,
141}
142
143impl PrioritizedReplayBuffer {
144 pub fn new(
151 capacity: usize,
152 obs_dim: usize,
153 act_dim: usize,
154 alpha: f64,
155 beta_start: f64,
156 ) -> Self {
157 assert!(capacity > 0, "capacity must be > 0");
158 let cap2 = capacity.next_power_of_two();
160 Self {
161 capacity,
162 obs_dim,
163 act_dim,
164 obs: vec![0.0; capacity * obs_dim],
165 actions: vec![0.0; capacity * act_dim],
166 rewards: vec![0.0; capacity],
167 next_obs: vec![0.0; capacity * obs_dim],
168 dones: vec![0.0; capacity],
169 tree: SumTree::new(cap2),
170 alpha,
171 beta: beta_start,
172 max_priority: 1.0,
173 head: 0,
174 size: 0,
175 }
176 }
177
178 pub fn push(
181 &mut self,
182 obs: impl AsRef<[f32]>,
183 action: impl AsRef<[f32]>,
184 reward: f32,
185 next_obs: impl AsRef<[f32]>,
186 done: bool,
187 ) {
188 let obs = obs.as_ref();
189 let action = action.as_ref();
190 let next_obs = next_obs.as_ref();
191
192 let i = self.head;
193 self.obs[i * self.obs_dim..(i + 1) * self.obs_dim].copy_from_slice(obs);
194 self.actions[i * self.act_dim..(i + 1) * self.act_dim].copy_from_slice(action);
195 self.rewards[i] = reward;
196 self.next_obs[i * self.obs_dim..(i + 1) * self.obs_dim].copy_from_slice(next_obs);
197 self.dones[i] = if done { 1.0 } else { 0.0 };
198 self.tree.update(i, self.max_priority.powf(self.alpha));
200 self.head = (self.head + 1) % self.capacity;
201 if self.size < self.capacity {
202 self.size += 1;
203 }
204 }
205
206 pub fn update_priority(&mut self, index: usize, priority: f64) {
213 let p = priority.max(1e-6);
214 if p > self.max_priority {
215 self.max_priority = p;
216 }
217 self.tree.update(index, p.powf(self.alpha));
218 }
219
220 pub fn set_beta(&mut self, beta: f64) {
222 self.beta = beta.clamp(0.0, 1.0);
223 }
224
225 pub fn anneal_beta(&mut self, step: f64) {
227 self.beta = (self.beta + step).min(1.0);
228 }
229
230 #[must_use]
232 #[inline]
233 pub fn len(&self) -> usize {
234 self.size
235 }
236
237 #[must_use]
239 #[inline]
240 pub fn is_empty(&self) -> bool {
241 self.size == 0
242 }
243
244 pub fn sample(
255 &self,
256 batch_size: usize,
257 handle: &mut RlHandle,
258 ) -> RlResult<Vec<PrioritySample>> {
259 if self.size < batch_size {
260 return Err(RlError::InsufficientTransitions {
261 have: self.size,
262 need: batch_size,
263 });
264 }
265 let total = self.tree.total();
266 if total <= 0.0 {
267 return Err(RlError::ZeroPrioritySum);
268 }
269 let rng = handle.rng_mut();
270 let segment = total / batch_size as f64;
271 let min_p = self.tree.min_priority() / total;
272 let max_w = (1.0 / (self.size as f64 * min_p)).powf(self.beta) as f32;
273
274 let mut out = Vec::with_capacity(batch_size);
275 for k in 0..batch_size {
276 let lo = k as f64 * segment;
277 let hi = lo + segment;
278 let v = lo + rng.next_f32() as f64 * (hi - lo);
279 let idx = self.tree.find(v.min(total - 1e-9)).min(self.size - 1);
280
281 let p = self.tree.priority_at(idx) / total;
282 let w = ((1.0 / (self.size as f64 * p)).powf(self.beta) as f32 / max_w).min(1.0);
283
284 let obs = self.obs[idx * self.obs_dim..(idx + 1) * self.obs_dim].to_vec();
285 let action = self.actions[idx * self.act_dim..(idx + 1) * self.act_dim].to_vec();
286 let reward = self.rewards[idx];
287 let next_obs = self.next_obs[idx * self.obs_dim..(idx + 1) * self.obs_dim].to_vec();
288 let done = self.dones[idx] > 0.5;
289 out.push(PrioritySample {
290 transition: Transition {
291 obs,
292 action,
293 reward,
294 next_obs,
295 done,
296 },
297 index: idx,
298 weight: w,
299 });
300 }
301 Ok(out)
302 }
303}
304
305#[cfg(test)]
308mod tests {
309 use super::*;
310
311 fn make_per(cap: usize) -> PrioritizedReplayBuffer {
312 PrioritizedReplayBuffer::new(cap, 2, 1, 0.6, 0.4)
313 }
314
315 fn fill_per(buf: &mut PrioritizedReplayBuffer, n: usize) {
316 for i in 0..n {
317 buf.push(
318 [i as f32, i as f32 + 1.0],
319 [0.5_f32],
320 i as f32 * 0.1,
321 [i as f32 + 1.0, i as f32 + 2.0],
322 false,
323 );
324 }
325 }
326
327 #[test]
328 fn sum_tree_basic() {
329 let mut t = SumTree::new(4);
330 t.update(0, 1.0);
331 t.update(1, 2.0);
332 t.update(2, 3.0);
333 t.update(3, 4.0);
334 assert!((t.total() - 10.0).abs() < 1e-9, "total={}", t.total());
335 }
336
337 #[test]
338 fn sum_tree_find() {
339 let mut t = SumTree::new(4);
340 t.update(0, 1.0);
341 t.update(1, 2.0);
342 t.update(2, 3.0);
343 t.update(3, 4.0);
344 let idx = t.find(2.5);
346 assert_eq!(idx, 1, "find(2.5) should return idx=1, got {idx}");
347 }
348
349 #[test]
350 fn per_push_and_len() {
351 let mut buf = make_per(32);
352 fill_per(&mut buf, 20);
353 assert_eq!(buf.len(), 20);
354 }
355
356 #[test]
357 fn per_sample_size() {
358 let mut buf = make_per(64);
359 fill_per(&mut buf, 64);
360 let mut handle = RlHandle::default_handle();
361 let batch = buf.sample(16, &mut handle).unwrap();
362 assert_eq!(batch.len(), 16);
363 }
364
365 #[test]
366 fn per_weights_in_range() {
367 let mut buf = make_per(64);
368 fill_per(&mut buf, 64);
369 let mut handle = RlHandle::default_handle();
370 let batch = buf.sample(32, &mut handle).unwrap();
371 for s in &batch {
372 assert!(s.weight > 0.0 && s.weight <= 1.0, "weight={}", s.weight);
373 }
374 }
375
376 #[test]
377 fn per_update_priority() {
378 let mut buf = make_per(16);
379 fill_per(&mut buf, 16);
380 buf.update_priority(0, 100.0);
382 let mut handle = RlHandle::default_handle();
384 let mut counts = [0_usize; 16];
385 for _ in 0..200 {
386 let batch = buf.sample(1, &mut handle).unwrap();
387 counts[batch[0].index] += 1;
388 }
389 assert!(
391 counts[0] > 200 / 16,
392 "high-priority index should be over-sampled"
393 );
394 }
395
396 #[test]
397 fn per_insufficient_error() {
398 let buf = make_per(16);
399 let mut handle = RlHandle::default_handle();
400 assert!(buf.sample(5, &mut handle).is_err());
401 }
402
403 #[test]
404 fn per_anneal_beta() {
405 let mut buf = make_per(16);
406 buf.set_beta(0.4);
407 buf.anneal_beta(0.3);
408 assert!((buf.beta - 0.7).abs() < 1e-9);
409 buf.anneal_beta(1.0);
410 assert!((buf.beta - 1.0).abs() < 1e-9);
411 }
412
413 #[test]
414 fn sum_tree_min_priority() {
415 let mut t = SumTree::new(4);
416 t.update(0, 5.0);
417 t.update(1, 2.0);
418 t.update(2, 8.0);
419 t.update(3, 3.0);
420 assert!((t.min_priority() - 2.0).abs() < 1e-9);
421 }
422}