use std::fmt::Debug;
use fastrand::Rng;
#[derive(Debug, Clone)]
pub struct VoseAlias {
pub original_probabilities: Vec<f64>,
alias: Vec<usize>,
prob: Vec<f64>,
rng: Rng,
}
impl PartialEq for VoseAlias {
#[coverage(off)]
fn eq(&self, other: &Self) -> bool {
self.alias.eq(&other.alias) && self.prob.eq(&other.prob)
}
}
impl VoseAlias {
#[coverage(off)]
pub fn new(mut probabilities: Vec<f64>) -> VoseAlias {
let original_probabilities = probabilities.clone();
assert!(!probabilities.is_empty());
let sum = probabilities.iter().fold(
0.0,
#[coverage(off)]
|sum, p| sum + p,
);
#[allow(clippy::float_cmp)]
if sum != 1.0 {
for p in &mut probabilities {
*p /= sum;
}
}
let sum = probabilities.iter().fold(
0.0,
#[coverage(off)]
|sum, p| sum + p,
);
assert!((sum - 1.0).abs() < 0.1);
let size = probabilities.len();
let mut small = Vec::with_capacity(size);
let mut large = Vec::with_capacity(size);
let mut alias: Vec<usize> = vec![0; size];
let mut prob: Vec<f64> = vec![0.0; size];
for (i, p) in probabilities.iter_mut().enumerate() {
*p *= size as f64;
if *p < 1.0 {
small.push(i);
} else {
large.push(i);
}
}
loop {
match (small.pop(), large.pop()) {
(Some(l), Some(g)) => {
let p_l = probabilities[l];
prob[l] = p_l; alias[l] = g;
let p_g = probabilities[g];
let p_g = (p_g + p_l) - 1.0;
probabilities[g] = p_g; if p_g < 1.0 {
small.push(g); } else {
large.push(g); }
}
(Some(l), None) => {
prob[l] = 1.0;
}
(None, Some(g)) => {
prob[g] = 1.0;
}
(None, None) => break,
}
}
VoseAlias {
original_probabilities,
alias,
prob,
rng: Rng::default(),
}
}
#[coverage(off)]
pub fn sample(&self) -> usize {
let i = self.rng.usize(..self.prob.len());
if self.rng.f64() <= unsafe { *self.prob.get_unchecked(i) } {
i
} else {
unsafe { *self.alias.get_unchecked(i) }
}
}
}
#[cfg(test)]
mod tests {
use super::VoseAlias;
#[test]
#[coverage(off)]
fn test_probabilities_1() {
let alias = VoseAlias::new(vec![0.1, 0.4, 0.2, 0.3]);
let mut choices = vec![0, 0, 0, 0];
for _ in 0..100_000 {
let i = alias.sample();
choices[i] += 1;
}
println!("{:?}", choices);
}
#[test]
#[coverage(off)]
fn test_probabilities_2() {
let alias = VoseAlias::new(vec![0.1, 0.4, 0.2, 0.3]);
let mut choices = vec![0, 0, 0, 0];
for _ in 0..100_000 {
let i = alias.sample();
choices[i] += 1;
}
println!("{:?}", choices);
}
}