use crate::buffer::replay::Transition;
use crate::error::{RlError, RlResult};
use crate::handle::RlHandle;
#[derive(Debug, Clone)]
struct SumTree {
n: usize, tree: Vec<f64>, min_tree: Vec<f64>, }
impl SumTree {
fn new(n: usize) -> Self {
Self {
n,
tree: vec![0.0; 2 * n],
min_tree: vec![f64::MAX; 2 * n],
}
}
fn update(&mut self, i: usize, priority: f64) {
let pos = i + self.n; self.tree[pos] = priority;
self.min_tree[pos] = priority;
let mut p = pos >> 1;
while p >= 1 {
self.tree[p] = self.tree[2 * p] + self.tree[2 * p + 1];
self.min_tree[p] = self.min_tree[2 * p].min(self.min_tree[2 * p + 1]);
p >>= 1;
}
}
#[inline]
fn total(&self) -> f64 {
self.tree[1]
}
#[inline]
fn min_priority(&self) -> f64 {
self.min_tree[1]
}
fn find(&self, value: f64) -> usize {
let mut node = 1_usize;
let mut v = value;
while node < self.n {
let left = 2 * node;
if self.tree[left] >= v {
node = left;
} else {
v -= self.tree[left];
node = left + 1;
}
}
node - self.n }
fn priority_at(&self, i: usize) -> f64 {
self.tree[i + self.n]
}
}
#[derive(Debug, Clone)]
pub struct PrioritySample {
pub transition: Transition,
pub index: usize,
pub weight: f32,
}
#[derive(Debug, Clone)]
pub struct PrioritizedReplayBuffer {
capacity: usize,
obs_dim: usize,
act_dim: usize,
obs: Vec<f32>,
actions: Vec<f32>,
rewards: Vec<f32>,
next_obs: Vec<f32>,
dones: Vec<f32>,
tree: SumTree,
alpha: f64,
beta: f64,
max_priority: f64,
head: usize,
size: usize,
}
impl PrioritizedReplayBuffer {
pub fn new(
capacity: usize,
obs_dim: usize,
act_dim: usize,
alpha: f64,
beta_start: f64,
) -> Self {
assert!(capacity > 0, "capacity must be > 0");
let cap2 = capacity.next_power_of_two();
Self {
capacity,
obs_dim,
act_dim,
obs: vec![0.0; capacity * obs_dim],
actions: vec![0.0; capacity * act_dim],
rewards: vec![0.0; capacity],
next_obs: vec![0.0; capacity * obs_dim],
dones: vec![0.0; capacity],
tree: SumTree::new(cap2),
alpha,
beta: beta_start,
max_priority: 1.0,
head: 0,
size: 0,
}
}
pub fn push(
&mut self,
obs: impl AsRef<[f32]>,
action: impl AsRef<[f32]>,
reward: f32,
next_obs: impl AsRef<[f32]>,
done: bool,
) {
let obs = obs.as_ref();
let action = action.as_ref();
let next_obs = next_obs.as_ref();
let i = self.head;
self.obs[i * self.obs_dim..(i + 1) * self.obs_dim].copy_from_slice(obs);
self.actions[i * self.act_dim..(i + 1) * self.act_dim].copy_from_slice(action);
self.rewards[i] = reward;
self.next_obs[i * self.obs_dim..(i + 1) * self.obs_dim].copy_from_slice(next_obs);
self.dones[i] = if done { 1.0 } else { 0.0 };
self.tree.update(i, self.max_priority.powf(self.alpha));
self.head = (self.head + 1) % self.capacity;
if self.size < self.capacity {
self.size += 1;
}
}
pub fn update_priority(&mut self, index: usize, priority: f64) {
let p = priority.max(1e-6);
if p > self.max_priority {
self.max_priority = p;
}
self.tree.update(index, p.powf(self.alpha));
}
pub fn set_beta(&mut self, beta: f64) {
self.beta = beta.clamp(0.0, 1.0);
}
pub fn anneal_beta(&mut self, step: f64) {
self.beta = (self.beta + step).min(1.0);
}
#[must_use]
#[inline]
pub fn len(&self) -> usize {
self.size
}
#[must_use]
#[inline]
pub fn is_empty(&self) -> bool {
self.size == 0
}
pub fn sample(
&self,
batch_size: usize,
handle: &mut RlHandle,
) -> RlResult<Vec<PrioritySample>> {
if self.size < batch_size {
return Err(RlError::InsufficientTransitions {
have: self.size,
need: batch_size,
});
}
let total = self.tree.total();
if total <= 0.0 {
return Err(RlError::ZeroPrioritySum);
}
let rng = handle.rng_mut();
let segment = total / batch_size as f64;
let min_p = self.tree.min_priority() / total;
let max_w = (1.0 / (self.size as f64 * min_p)).powf(self.beta) as f32;
let mut out = Vec::with_capacity(batch_size);
for k in 0..batch_size {
let lo = k as f64 * segment;
let hi = lo + segment;
let v = lo + rng.next_f32() as f64 * (hi - lo);
let idx = self.tree.find(v.min(total - 1e-9)).min(self.size - 1);
let p = self.tree.priority_at(idx) / total;
let w = ((1.0 / (self.size as f64 * p)).powf(self.beta) as f32 / max_w).min(1.0);
let obs = self.obs[idx * self.obs_dim..(idx + 1) * self.obs_dim].to_vec();
let action = self.actions[idx * self.act_dim..(idx + 1) * self.act_dim].to_vec();
let reward = self.rewards[idx];
let next_obs = self.next_obs[idx * self.obs_dim..(idx + 1) * self.obs_dim].to_vec();
let done = self.dones[idx] > 0.5;
out.push(PrioritySample {
transition: Transition {
obs,
action,
reward,
next_obs,
done,
},
index: idx,
weight: w,
});
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_per(cap: usize) -> PrioritizedReplayBuffer {
PrioritizedReplayBuffer::new(cap, 2, 1, 0.6, 0.4)
}
fn fill_per(buf: &mut PrioritizedReplayBuffer, n: usize) {
for i in 0..n {
buf.push(
[i as f32, i as f32 + 1.0],
[0.5_f32],
i as f32 * 0.1,
[i as f32 + 1.0, i as f32 + 2.0],
false,
);
}
}
#[test]
fn sum_tree_basic() {
let mut t = SumTree::new(4);
t.update(0, 1.0);
t.update(1, 2.0);
t.update(2, 3.0);
t.update(3, 4.0);
assert!((t.total() - 10.0).abs() < 1e-9, "total={}", t.total());
}
#[test]
fn sum_tree_find() {
let mut t = SumTree::new(4);
t.update(0, 1.0);
t.update(1, 2.0);
t.update(2, 3.0);
t.update(3, 4.0);
let idx = t.find(2.5);
assert_eq!(idx, 1, "find(2.5) should return idx=1, got {idx}");
}
#[test]
fn per_push_and_len() {
let mut buf = make_per(32);
fill_per(&mut buf, 20);
assert_eq!(buf.len(), 20);
}
#[test]
fn per_sample_size() {
let mut buf = make_per(64);
fill_per(&mut buf, 64);
let mut handle = RlHandle::default_handle();
let batch = buf.sample(16, &mut handle).unwrap();
assert_eq!(batch.len(), 16);
}
#[test]
fn per_weights_in_range() {
let mut buf = make_per(64);
fill_per(&mut buf, 64);
let mut handle = RlHandle::default_handle();
let batch = buf.sample(32, &mut handle).unwrap();
for s in &batch {
assert!(s.weight > 0.0 && s.weight <= 1.0, "weight={}", s.weight);
}
}
#[test]
fn per_update_priority() {
let mut buf = make_per(16);
fill_per(&mut buf, 16);
buf.update_priority(0, 100.0);
let mut handle = RlHandle::default_handle();
let mut counts = [0_usize; 16];
for _ in 0..200 {
let batch = buf.sample(1, &mut handle).unwrap();
counts[batch[0].index] += 1;
}
assert!(
counts[0] > 200 / 16,
"high-priority index should be over-sampled"
);
}
#[test]
fn per_insufficient_error() {
let buf = make_per(16);
let mut handle = RlHandle::default_handle();
assert!(buf.sample(5, &mut handle).is_err());
}
#[test]
fn per_anneal_beta() {
let mut buf = make_per(16);
buf.set_beta(0.4);
buf.anneal_beta(0.3);
assert!((buf.beta - 0.7).abs() < 1e-9);
buf.anneal_beta(1.0);
assert!((buf.beta - 1.0).abs() < 1e-9);
}
#[test]
fn sum_tree_min_priority() {
let mut t = SumTree::new(4);
t.update(0, 5.0);
t.update(1, 2.0);
t.update(2, 8.0);
t.update(3, 3.0);
assert!((t.min_priority() - 2.0).abs() < 1e-9);
}
}