use super::Error;
use crate::{Distribution, Uniform, uniform::SampleUniform};
use alloc::{boxed::Box, vec, vec::Vec};
use core::fmt;
use core::iter::Sum;
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
use rand::{Rng, RngExt};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(
feature = "serde",
serde(bound(serialize = "W: Serialize, W::Sampler: Serialize"))
)]
#[cfg_attr(
feature = "serde",
serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>"))
)]
pub struct WeightedAliasIndex<W: AliasableWeight> {
aliases: Box<[u32]>,
no_alias_odds: Box<[W]>,
uniform_index: Uniform<u32>,
uniform_within_weight_sum: Uniform<W>,
weight_sum: W,
}
impl<W: AliasableWeight> WeightedAliasIndex<W> {
pub fn new(weights: Vec<W>) -> Result<Self, Error> {
let n = weights.len();
if n == 0 || n > u32::MAX as usize {
return Err(Error::InvalidInput);
}
let n = n as u32;
let max_weight_size = W::try_from_u32_lossy(n)
.map(|n| W::MAX / n)
.unwrap_or(W::ZERO);
if !weights
.iter()
.all(|&w| W::ZERO <= w && w <= max_weight_size)
{
return Err(Error::InvalidWeight);
}
let weight_sum = AliasableWeight::sum(weights.as_slice());
let weight_sum = if weight_sum > W::MAX {
W::MAX
} else {
weight_sum
};
if weight_sum == W::ZERO {
return Err(Error::InsufficientNonZero);
}
let n_converted = W::try_from_u32_lossy(n).unwrap();
let mut no_alias_odds = weights.into_boxed_slice();
for odds in no_alias_odds.iter_mut() {
*odds *= n_converted;
*odds = if *odds > W::MAX { W::MAX } else { *odds };
}
struct Aliases {
aliases: Box<[u32]>,
smalls_head: u32,
bigs_head: u32,
}
impl Aliases {
fn new(size: u32) -> Self {
Aliases {
aliases: vec![0; size as usize].into_boxed_slice(),
smalls_head: u32::MAX,
bigs_head: u32::MAX,
}
}
fn push_small(&mut self, idx: u32) {
self.aliases[idx as usize] = self.smalls_head;
self.smalls_head = idx;
}
fn push_big(&mut self, idx: u32) {
self.aliases[idx as usize] = self.bigs_head;
self.bigs_head = idx;
}
fn pop_small(&mut self) -> u32 {
let popped = self.smalls_head;
self.smalls_head = self.aliases[popped as usize];
popped
}
fn pop_big(&mut self) -> u32 {
let popped = self.bigs_head;
self.bigs_head = self.aliases[popped as usize];
popped
}
fn smalls_is_empty(&self) -> bool {
self.smalls_head == u32::MAX
}
fn bigs_is_empty(&self) -> bool {
self.bigs_head == u32::MAX
}
fn set_alias(&mut self, idx: u32, alias: u32) {
self.aliases[idx as usize] = alias;
}
}
let mut aliases = Aliases::new(n);
for (index, &odds) in no_alias_odds.iter().enumerate() {
if odds < weight_sum {
aliases.push_small(index as u32);
} else {
aliases.push_big(index as u32);
}
}
while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() {
let s = aliases.pop_small();
let b = aliases.pop_big();
aliases.set_alias(s, b);
no_alias_odds[b as usize] =
no_alias_odds[b as usize] - weight_sum + no_alias_odds[s as usize];
if no_alias_odds[b as usize] < weight_sum {
aliases.push_small(b);
} else {
aliases.push_big(b);
}
}
while !aliases.smalls_is_empty() {
no_alias_odds[aliases.pop_small() as usize] = weight_sum;
}
while !aliases.bigs_is_empty() {
no_alias_odds[aliases.pop_big() as usize] = weight_sum;
}
let uniform_index = Uniform::new(0, n).unwrap();
let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum).unwrap();
Ok(Self {
aliases: aliases.aliases,
no_alias_odds,
uniform_index,
uniform_within_weight_sum,
weight_sum,
})
}
pub fn weights(&self) -> Vec<W> {
let n = self.aliases.len();
let n_converted = W::try_from_u32_lossy(n as u32).unwrap();
let mut alias_contributions = vec![W::ZERO; n];
for j in 0..n {
if self.no_alias_odds[j] < self.weight_sum {
let contribution = self.weight_sum - self.no_alias_odds[j];
let alias_index = self.aliases[j] as usize;
alias_contributions[alias_index] += contribution;
}
}
self.no_alias_odds
.iter()
.zip(&alias_contributions)
.map(|(&no_alias_odd, &alias_contribution)| {
(no_alias_odd + alias_contribution) / n_converted
})
.collect()
}
}
impl<W: AliasableWeight> Distribution<usize> for WeightedAliasIndex<W> {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let candidate = rng.sample(self.uniform_index);
if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] {
candidate as usize
} else {
self.aliases[candidate as usize] as usize
}
}
}
impl<W: AliasableWeight> fmt::Debug for WeightedAliasIndex<W>
where
W: fmt::Debug,
Uniform<W>: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("WeightedAliasIndex")
.field("aliases", &self.aliases)
.field("no_alias_odds", &self.no_alias_odds)
.field("uniform_index", &self.uniform_index)
.field("uniform_within_weight_sum", &self.uniform_within_weight_sum)
.finish()
}
}
impl<W: AliasableWeight> Clone for WeightedAliasIndex<W>
where
Uniform<W>: Clone,
{
fn clone(&self) -> Self {
Self {
aliases: self.aliases.clone(),
no_alias_odds: self.no_alias_odds.clone(),
uniform_index: self.uniform_index,
uniform_within_weight_sum: self.uniform_within_weight_sum.clone(),
weight_sum: self.weight_sum,
}
}
}
pub trait AliasableWeight:
Sized
+ Copy
+ SampleUniform
+ PartialOrd
+ Add<Output = Self>
+ AddAssign
+ Sub<Output = Self>
+ SubAssign
+ Mul<Output = Self>
+ MulAssign
+ Div<Output = Self>
+ DivAssign
+ Sum
{
const MAX: Self;
const ZERO: Self;
fn try_from_u32_lossy(n: u32) -> Option<Self>;
fn sum(values: &[Self]) -> Self {
values.iter().copied().sum()
}
}
macro_rules! impl_weight_for_float {
($T: ident) => {
impl AliasableWeight for $T {
const MAX: Self = $T::MAX;
const ZERO: Self = 0.0;
fn try_from_u32_lossy(n: u32) -> Option<Self> {
Some(n as $T)
}
fn sum(values: &[Self]) -> Self {
pairwise_sum(values)
}
}
};
}
fn pairwise_sum<T: AliasableWeight>(values: &[T]) -> T {
if values.len() <= 32 {
values.iter().copied().sum()
} else {
let mid = values.len() / 2;
let (a, b) = values.split_at(mid);
pairwise_sum(a) + pairwise_sum(b)
}
}
macro_rules! impl_weight_for_int {
($T: ident) => {
impl AliasableWeight for $T {
const MAX: Self = $T::MAX;
const ZERO: Self = 0;
fn try_from_u32_lossy(n: u32) -> Option<Self> {
let n_converted = n as Self;
if n_converted >= Self::ZERO && n_converted as u32 == n {
Some(n_converted)
} else {
None
}
}
}
};
}
impl_weight_for_float!(f64);
impl_weight_for_float!(f32);
impl_weight_for_int!(usize);
impl_weight_for_int!(u128);
impl_weight_for_int!(u64);
impl_weight_for_int!(u32);
impl_weight_for_int!(u16);
impl_weight_for_int!(u8);
impl_weight_for_int!(i128);
impl_weight_for_int!(i64);
impl_weight_for_int!(i32);
impl_weight_for_int!(i16);
impl_weight_for_int!(i8);
#[cfg(test)]
mod test {
use super::*;
#[test]
#[cfg_attr(miri, ignore)] fn test_weighted_index_f32() {
test_weighted_index(f32::into);
assert_eq!(
WeightedAliasIndex::new(vec![f32::INFINITY]).unwrap_err(),
Error::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(),
Error::InsufficientNonZero
);
assert_eq!(
WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(),
Error::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![f32::NEG_INFINITY]).unwrap_err(),
Error::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![f32::NAN]).unwrap_err(),
Error::InvalidWeight
);
}
#[test]
#[cfg_attr(miri, ignore)] fn test_weighted_index_u128() {
test_weighted_index(|x: u128| x as f64);
}
#[test]
#[cfg_attr(miri, ignore)] fn test_weighted_index_i128() {
test_weighted_index(|x: i128| x as f64);
assert_eq!(
WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(),
Error::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![i128::MIN]).unwrap_err(),
Error::InvalidWeight
);
}
#[test]
#[cfg_attr(miri, ignore)] fn test_weighted_index_u8() {
test_weighted_index(u8::into);
}
#[test]
#[cfg_attr(miri, ignore)] fn test_weighted_index_i8() {
test_weighted_index(i8::into);
assert_eq!(
WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(),
Error::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![i8::MIN]).unwrap_err(),
Error::InvalidWeight
);
}
fn test_weighted_index<W: AliasableWeight, F: Fn(W) -> f64>(w_to_f64: F)
where
WeightedAliasIndex<W>: fmt::Debug,
{
const NUM_WEIGHTS: u32 = 10;
const ZERO_WEIGHT_INDEX: u32 = 3;
const NUM_SAMPLES: u32 = 15000;
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
let weights = {
let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize);
let random_weight_distribution = Uniform::new_inclusive(
W::ZERO,
W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(),
)
.unwrap();
for _ in 0..NUM_WEIGHTS {
weights.push(rng.sample(&random_weight_distribution));
}
weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO;
weights
};
let weight_sum = weights.iter().copied().sum::<W>();
let expected_counts = weights
.iter()
.map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64)
.collect::<Vec<f64>>();
let weight_distribution = WeightedAliasIndex::new(weights).unwrap();
let mut counts = vec![0; NUM_WEIGHTS as usize];
for _ in 0..NUM_SAMPLES {
counts[rng.sample(&weight_distribution)] += 1;
}
assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0);
for (count, expected_count) in counts.into_iter().zip(expected_counts) {
let difference = (count as f64 - expected_count).abs();
let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1;
assert!(difference <= max_allowed_difference);
}
assert_eq!(
WeightedAliasIndex::<W>::new(vec![]).unwrap_err(),
Error::InvalidInput
);
assert_eq!(
WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(),
Error::InsufficientNonZero
);
assert_eq!(
WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
Error::InvalidWeight
);
}
#[test]
fn test_weights_reconstruction() {
{
let weights_i32 = vec![10, 2, 8, 0, 30, 5];
let dist_i32 = WeightedAliasIndex::new(weights_i32.clone()).unwrap();
assert_eq!(weights_i32, dist_i32.weights());
}
{
let weights_u64 = vec![1, 1, 1, 1, 1];
let dist_u64 = WeightedAliasIndex::new(weights_u64.clone()).unwrap();
assert_eq!(weights_u64, dist_u64.weights());
}
{
const EPSILON: f64 = 1e-9;
let weights_f64 = vec![0.5, 0.2, 0.3, 0.0, 1.5, 0.88];
let dist_f64 = WeightedAliasIndex::new(weights_f64.clone()).unwrap();
let reconstructed_f64 = dist_f64.weights();
assert_eq!(weights_f64.len(), reconstructed_f64.len());
for (original, reconstructed) in weights_f64.iter().zip(reconstructed_f64.iter()) {
assert!(
f64::abs(original - reconstructed) < EPSILON,
"Weight reconstruction failed: original {}, reconstructed {}",
original,
reconstructed
);
}
}
{
let weights_single = vec![42_u32];
let dist_single = WeightedAliasIndex::new(weights_single.clone()).unwrap();
assert_eq!(weights_single, dist_single.weights());
}
}
#[test]
fn value_stability() {
fn test_samples<W: AliasableWeight>(
weights: Vec<W>,
buf: &mut [usize],
expected: &[usize],
) {
assert_eq!(buf.len(), expected.len());
let distr = WeightedAliasIndex::new(weights).unwrap();
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
for r in buf.iter_mut() {
*r = rng.sample(&distr);
}
assert_eq!(buf, expected);
}
let mut buf = [0; 10];
test_samples(
vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1],
&mut buf,
&[6, 5, 7, 5, 8, 7, 6, 2, 3, 7],
);
test_samples(
vec![0.7f32, 0.1, 0.1, 0.1],
&mut buf,
&[2, 0, 0, 0, 0, 0, 0, 0, 1, 3],
);
test_samples(
vec![1.0f64, 0.999, 0.998, 0.997],
&mut buf,
&[2, 1, 2, 3, 2, 1, 3, 2, 1, 1],
);
}
}