use rand::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, PartialEq, Clone, Default)]
pub struct WalkerTable {
aliases: Vec<usize>,
probs: Vec<f32>,
}
impl WalkerTable {
pub fn new(aliases: Vec<usize>, probs: Vec<f32>) -> WalkerTable {
WalkerTable {
aliases: aliases,
probs: probs,
}
}
pub fn next(&self) -> usize {
let mut rng = rand::thread_rng();
self.next_rng(&mut rng)
}
pub fn next_rng(&self, rng: &mut impl Rng) -> usize {
let i = rng.gen_range(0..self.probs.len());
let r = rng.gen::<f32>();
if r < self.probs[i] {
return self.aliases[i];
}
i
}
}
#[cfg(test)]
mod table_test {
use crate::builder::*;
use rand;
const N: usize = 100_000;
fn count<T: PartialEq>(slice: &[T], target: T) -> f32 {
slice
.iter()
.fold(0.0, |acc, cur| if *cur == target { acc + 1.0 } else { acc })
}
#[test]
fn unweighted_random_sampling() {
let index_weights = [0; 4];
let builder = WalkerTableBuilder::new(&index_weights);
let wa_table = builder.build();
let mut rng = rand::thread_rng();
let idxs = (0..N)
.map(|_| wa_table.next_rng(&mut rng))
.collect::<Vec<usize>>();
let i_0 = count(&idxs, 0);
let i_1 = count(&idxs, 1);
let i_2 = count(&idxs, 2);
let i_3 = count(&idxs, 3);
const EXPT: f32 = N as f32 * 0.25;
assert!(
(EXPT * 0.95 < i_0 && i_0 < EXPT * 1.05)
&& (EXPT * 0.95 < i_1 && i_1 < EXPT * 1.05)
&& (EXPT * 0.95 < i_2 && i_2 < EXPT * 1.05)
&& (EXPT * 0.95 < i_3 && i_3 < EXPT * 1.05)
)
}
#[test]
fn weighted_random_sampling() {
let index_weights = [2, 1, 7, 0];
let builder = WalkerTableBuilder::new(&index_weights);
let wa_table = builder.build();
let idxs = (0..N).map(|_| wa_table.next()).collect::<Vec<usize>>();
let i_0 = count(&idxs, 0);
let i_1 = count(&idxs, 1);
let i_2 = count(&idxs, 2);
let i_3 = count(&idxs, 3);
const EXPT: [f32; 4] = [N as f32 * 0.2, N as f32 * 0.1, N as f32 * 0.7, 0.0];
assert!(
(EXPT[0] * 0.95 < i_0 && i_0 < EXPT[0] * 1.05)
&& (EXPT[1] * 0.95 < i_1 && i_1 < EXPT[1] * 1.05)
&& (EXPT[2] * 0.95 < i_2 && i_2 < EXPT[2] * 1.05)
&& (EXPT[3] == i_3)
)
}
}