use core::ops::SubAssign;
use super::{Error, Weight};
use crate::Distribution;
use alloc::vec::Vec;
use rand::distr::uniform::{SampleBorrow, SampleUniform};
use rand::{Rng, RngExt};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(
feature = "serde",
serde(bound(serialize = "W: Serialize, W::Sampler: Serialize"))
)]
#[cfg_attr(
feature = "serde",
serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>"))
)]
#[derive(Clone, Default, Debug, PartialEq)]
pub struct WeightedTreeIndex<
W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight,
> {
subtotals: Vec<W>,
}
impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
WeightedTreeIndex<W>
{
pub fn new<I>(weights: I) -> Result<Self, Error>
where
I: IntoIterator,
I::Item: SampleBorrow<W>,
{
let mut subtotals: Vec<W> = weights.into_iter().map(|x| x.borrow().clone()).collect();
for weight in subtotals.iter() {
if !(*weight >= W::ZERO) {
return Err(Error::InvalidWeight);
}
}
let n = subtotals.len();
for i in (1..n).rev() {
let w = subtotals[i].clone();
let parent = (i - 1) / 2;
subtotals[parent]
.checked_add_assign(&w)
.map_err(|()| Error::Overflow)?;
}
Ok(Self { subtotals })
}
pub fn is_empty(&self) -> bool {
self.subtotals.is_empty()
}
pub fn len(&self) -> usize {
self.subtotals.len()
}
pub fn is_valid(&self) -> bool {
if let Some(weight) = self.subtotals.first() {
*weight > W::ZERO
} else {
false
}
}
pub fn get(&self, index: usize) -> W {
let left_index = 2 * index + 1;
let right_index = 2 * index + 2;
let mut w = self.subtotals[index].clone();
w -= self.subtotal(left_index);
w -= self.subtotal(right_index);
w
}
pub fn pop(&mut self) -> Option<W> {
self.subtotals.pop().inspect(|weight| {
let mut index = self.len();
while index != 0 {
index = (index - 1) / 2;
self.subtotals[index] -= weight.clone();
}
})
}
pub fn push(&mut self, weight: W) -> Result<(), Error> {
if !(weight >= W::ZERO) {
return Err(Error::InvalidWeight);
}
if let Some(total) = self.subtotals.first() {
let mut total = total.clone();
if total.checked_add_assign(&weight).is_err() {
return Err(Error::Overflow);
}
}
let mut index = self.len();
self.subtotals.push(weight.clone());
while index != 0 {
index = (index - 1) / 2;
self.subtotals[index].checked_add_assign(&weight).unwrap();
}
Ok(())
}
pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), Error> {
if !(weight >= W::ZERO) {
return Err(Error::InvalidWeight);
}
let old_weight = self.get(index);
if weight > old_weight {
let mut difference = weight;
difference -= old_weight;
if let Some(total) = self.subtotals.first() {
let mut total = total.clone();
if total.checked_add_assign(&difference).is_err() {
return Err(Error::Overflow);
}
}
self.subtotals[index]
.checked_add_assign(&difference)
.unwrap();
while index != 0 {
index = (index - 1) / 2;
self.subtotals[index]
.checked_add_assign(&difference)
.unwrap();
}
} else if weight < old_weight {
let mut difference = old_weight;
difference -= weight;
self.subtotals[index] -= difference.clone();
while index != 0 {
index = (index - 1) / 2;
self.subtotals[index] -= difference.clone();
}
}
Ok(())
}
fn subtotal(&self, index: usize) -> W {
if index < self.subtotals.len() {
self.subtotals[index].clone()
} else {
W::ZERO
}
}
}
impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
WeightedTreeIndex<W>
{
pub fn try_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<usize, Error> {
let total_weight = self.subtotals.first().cloned().unwrap_or(W::ZERO);
if total_weight == W::ZERO {
return Err(Error::InsufficientNonZero);
}
let mut target_weight = rng.random_range(W::ZERO..total_weight);
let mut index = 0;
loop {
let left_index = 2 * index + 1;
let left_subtotal = self.subtotal(left_index);
if target_weight < left_subtotal {
index = left_index;
continue;
}
target_weight -= left_subtotal;
let right_index = 2 * index + 2;
let right_subtotal = self.subtotal(right_index);
if target_weight < right_subtotal {
index = right_index;
continue;
}
target_weight -= right_subtotal;
break;
}
assert!(target_weight >= W::ZERO);
assert!(target_weight < self.get(index));
Ok(index)
}
}
impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight> Distribution<usize>
for WeightedTreeIndex<W>
{
#[track_caller]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
self.try_sample(rng).unwrap()
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_no_item_error() {
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
#[allow(clippy::needless_borrows_for_generic_args)]
let tree = WeightedTreeIndex::<f64>::new(&[]).unwrap();
assert_eq!(
tree.try_sample(&mut rng).unwrap_err(),
Error::InsufficientNonZero
);
}
#[test]
fn test_overflow_error() {
assert_eq!(WeightedTreeIndex::new([i32::MAX, 2]), Err(Error::Overflow));
let mut tree = WeightedTreeIndex::new([i32::MAX - 2, 1]).unwrap();
assert_eq!(tree.push(3), Err(Error::Overflow));
assert_eq!(tree.update(1, 4), Err(Error::Overflow));
tree.update(1, 2).unwrap();
}
#[test]
fn test_all_weights_zero_error() {
let tree = WeightedTreeIndex::<f64>::new([0.0, 0.0]).unwrap();
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
assert_eq!(
tree.try_sample(&mut rng).unwrap_err(),
Error::InsufficientNonZero
);
}
#[test]
fn test_invalid_weight_error() {
assert_eq!(
WeightedTreeIndex::<i32>::new([1, -1]).unwrap_err(),
Error::InvalidWeight
);
#[allow(clippy::needless_borrows_for_generic_args)]
let mut tree = WeightedTreeIndex::<i32>::new(&[]).unwrap();
assert_eq!(tree.push(-1).unwrap_err(), Error::InvalidWeight);
tree.push(1).unwrap();
assert_eq!(tree.update(0, -1).unwrap_err(), Error::InvalidWeight);
}
#[test]
fn test_tree_modifications() {
let mut tree = WeightedTreeIndex::new([9, 1, 2]).unwrap();
tree.push(3).unwrap();
tree.push(5).unwrap();
tree.update(0, 0).unwrap();
assert_eq!(tree.pop(), Some(5));
let expected = WeightedTreeIndex::new([0, 1, 2, 3]).unwrap();
assert_eq!(tree, expected);
}
#[test]
#[allow(clippy::needless_range_loop)]
fn test_sample_counts_match_probabilities() {
let start = 1;
let end = 3;
let samples = 20;
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
let weights: Vec<f64> = (0..end).map(|_| rng.random()).collect();
let mut tree = WeightedTreeIndex::new(weights).unwrap();
let mut total_weight = 0.0;
let mut weights = alloc::vec![0.0; end];
for i in 0..end {
tree.update(i, i as f64).unwrap();
weights[i] = i as f64;
total_weight += i as f64;
}
for i in 0..start {
tree.update(i, 0.0).unwrap();
weights[i] = 0.0;
total_weight -= i as f64;
}
let mut counts = alloc::vec![0_usize; end];
for _ in 0..samples {
let i = tree.sample(&mut rng);
counts[i] += 1;
}
for i in 0..start {
assert_eq!(counts[i], 0);
}
for i in start..end {
let diff = counts[i] as f64 / samples as f64 - weights[i] / total_weight;
assert!(diff.abs() < 0.05);
}
}
}