use rand::Rng;
use serde::{Deserialize, Serialize};
#[derive(Clone, Serialize, Deserialize)]
pub struct Transition {
pub state: Vec<f64>,
pub action: usize,
pub reward: f64,
pub next_state: Vec<f64>,
pub done: bool,
}
pub struct PrioritizedBatch {
pub transitions: Vec<Transition>,
pub indices: Vec<usize>,
pub weights: Vec<f64>,
}
struct SumTree {
data: Vec<f64>,
capacity: usize,
}
impl SumTree {
fn new(capacity: usize) -> Self {
let cap = capacity.next_power_of_two();
Self { data: vec![0.0; 2 * cap], capacity: cap }
}
fn total(&self) -> f64 { self.data[1] }
fn update(&mut self, leaf: usize, priority: f64) {
let mut idx = self.capacity + leaf;
let diff = priority - self.data[idx];
self.data[idx] = priority;
idx >>= 1;
while idx >= 1 {
self.data[idx] += diff;
if idx == 1 { break; }
idx >>= 1;
}
}
fn find(&self, mut value: f64) -> usize {
let mut idx = 1;
while idx < self.capacity {
let left = idx * 2;
let right = left + 1;
if value <= self.data[left] {
idx = left;
} else {
value -= self.data[left];
idx = right;
}
}
idx - self.capacity
}
fn get(&self, leaf: usize) -> f64 {
self.data[self.capacity + leaf]
}
}
pub struct PrioritizedReplayBuffer {
transitions: Vec<Option<Transition>>,
tree: SumTree,
capacity: usize,
write_head: usize,
len: usize,
max_priority: f64,
alpha: f64,
beta: f64,
beta_increment: f64,
epsilon: f64,
}
impl PrioritizedReplayBuffer {
pub fn new(capacity: usize) -> Self {
let cap = capacity.next_power_of_two();
Self {
transitions: vec![None; cap],
tree: SumTree::new(cap),
capacity: cap,
write_head: 0,
len: 0,
max_priority: 1.0,
alpha: 0.6,
beta: 0.4,
beta_increment: 0.001,
epsilon: 1e-6,
}
}
pub fn push(&mut self, t: Transition) {
let idx = self.write_head;
self.transitions[idx] = Some(t);
self.tree.update(idx, self.max_priority.powf(self.alpha));
self.write_head = (self.write_head + 1) % self.capacity;
self.len = (self.len + 1).min(self.capacity);
}
pub fn sample(&mut self, n: usize) -> PrioritizedBatch {
let n = n.min(self.len);
let mut transitions = Vec::with_capacity(n);
let mut indices = Vec::with_capacity(n);
let mut weights = Vec::with_capacity(n);
let total = self.tree.total();
let segment = total / n as f64;
let min_prob = self.tree.get(self.min_priority_leaf()) / total;
let max_weight = (self.len as f64 * min_prob).powf(-self.beta);
let mut rng = rand::thread_rng();
for i in 0..n {
let lo = segment * i as f64;
let hi = segment * (i + 1) as f64;
let value = rng.gen_range(lo..hi.min(total - 1e-10));
let leaf = self.tree.find(value);
let leaf = leaf.min(self.capacity - 1);
if self.transitions[leaf].is_none() {
transitions.push(Transition {
state: vec![],
action: 0,
reward: 0.0,
next_state: vec![],
done: true,
});
indices.push(leaf);
weights.push(1.0);
continue;
}
let priority = self.tree.get(leaf).max(self.epsilon);
let prob = priority / total;
let w = ((self.len as f64 * prob).powf(-self.beta) / max_weight).min(1.0);
transitions.push(self.transitions[leaf].clone().unwrap());
indices.push(leaf);
weights.push(w);
}
self.beta = (self.beta + self.beta_increment).min(1.0);
PrioritizedBatch { transitions, indices, weights }
}
pub fn update_priorities(&mut self, indices: &[usize], td_errors: &[f64]) {
for (&idx, &err) in indices.iter().zip(td_errors.iter()) {
let priority = (err.abs() + self.epsilon).powf(self.alpha);
self.tree.update(idx, priority);
if priority > self.max_priority {
self.max_priority = priority;
}
}
}
pub fn len(&self) -> usize { self.len }
pub fn is_empty(&self) -> bool { self.len == 0 }
fn min_priority_leaf(&self) -> usize {
let mut min_leaf = 0;
let mut min_val = f64::MAX;
for i in 0..self.len {
let v = self.tree.get(i);
if v < min_val && v > 0.0 {
min_val = v;
min_leaf = i;
}
}
min_leaf
}
}