use super::csprng::SeedableRng;
#[derive(Debug)]
pub struct AliasTable {
prob: Vec<f64>,
alias: Vec<usize>,
n: usize,
}
impl AliasTable {
pub fn new(weights: &[f64]) -> Option<Self> {
let n = weights.len();
if n == 0 {
return None;
}
let total: f64 = weights.iter().sum();
if total <= 0.0 {
return None;
}
let scale = n as f64 / total;
let mut scaled: Vec<f64> = weights.iter().map(|w| w * scale).collect();
let mut prob = vec![0.0f64; n];
let mut alias = vec![0usize; n];
let mut small: Vec<usize> = Vec::new();
let mut large: Vec<usize> = Vec::new();
for (i, &s) in scaled.iter().enumerate() {
if s < 1.0 {
small.push(i);
} else {
large.push(i);
}
}
while let (Some(s), Some(&l)) = (small.pop(), large.last()) {
prob[s] = scaled[s];
alias[s] = l;
scaled[l] -= 1.0 - scaled[s];
if scaled[l] < 1.0 {
large.pop();
small.push(l);
}
}
for &l in &large {
prob[l] = 1.0;
}
for &s in &small {
prob[s] = 1.0;
}
Some(Self { prob, alias, n })
}
pub fn sample(&self, rng: &mut SeedableRng) -> usize {
let i = rng.gen_range(self.n as u64) as usize;
let u = rng.gen_f64();
if u < self.prob[i] { i } else { self.alias[i] }
}
pub fn sample_without_replacement(&self, rng: &mut SeedableRng, count: usize) -> Vec<usize> {
let count = count.min(self.n);
if count == self.n {
return (0..self.n).collect();
}
if count <= self.n / 2 {
let mut selected = std::collections::HashSet::with_capacity(count);
let mut result = Vec::with_capacity(count);
let max_attempts = count * 20; let mut attempts = 0;
while result.len() < count && attempts < max_attempts {
let idx = self.sample(rng);
if selected.insert(idx) {
result.push(idx);
}
attempts += 1;
}
result
} else {
let mut indices: Vec<usize> = (0..self.n).collect();
for i in (1..self.n).rev() {
let j = rng.gen_range((i + 1) as u64) as usize;
indices.swap(i, j);
}
indices.truncate(count);
indices
}
}
pub fn sample_with_replacement(&self, rng: &mut SeedableRng, count: usize) -> Vec<usize> {
(0..count).map(|_| self.sample(rng)).collect()
}
pub fn len(&self) -> usize {
self.n
}
pub fn is_empty(&self) -> bool {
self.n == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn uniform_distribution() {
let weights = vec![1.0, 1.0, 1.0, 1.0];
let table = AliasTable::new(&weights).unwrap();
assert_eq!(table.len(), 4);
let mut rng = SeedableRng::from_seed_str("uniform-test");
let mut counts = [0u32; 4];
let n = 10_000;
for _ in 0..n {
counts[table.sample(&mut rng)] += 1;
}
for (i, &c) in counts.iter().enumerate() {
assert!(
c > 2000 && c < 3000,
"item {i} sampled {c} times, expected ~2500"
);
}
}
#[test]
fn skewed_distribution() {
let weights = vec![9.0, 1.0];
let table = AliasTable::new(&weights).unwrap();
let mut rng = SeedableRng::from_seed_str("skewed-test");
let mut counts = [0u32; 2];
let n = 10_000;
for _ in 0..n {
counts[table.sample(&mut rng)] += 1;
}
assert!(counts[0] > 8000, "item 0: {}", counts[0]);
assert!(counts[1] > 500, "item 1: {}", counts[1]);
}
#[test]
fn zero_weight_never_selected() {
let weights = vec![0.0, 1.0, 0.0, 1.0];
let table = AliasTable::new(&weights).unwrap();
let mut rng = SeedableRng::from_seed_str("zero-test");
for _ in 0..1000 {
let idx = table.sample(&mut rng);
assert!(idx == 1 || idx == 3, "got zero-weight index {idx}");
}
}
#[test]
fn all_zero_returns_none() {
assert!(AliasTable::new(&[0.0, 0.0]).is_none());
assert!(AliasTable::new(&[]).is_none());
}
#[test]
fn without_replacement() {
let weights = vec![1.0, 1.0, 1.0, 1.0, 1.0];
let table = AliasTable::new(&weights).unwrap();
let mut rng = SeedableRng::from_seed_str("no-replace");
let selected = table.sample_without_replacement(&mut rng, 3);
assert_eq!(selected.len(), 3);
let set: std::collections::HashSet<usize> = selected.iter().copied().collect();
assert_eq!(set.len(), 3);
}
#[test]
fn without_replacement_all() {
let weights = vec![1.0, 1.0, 1.0];
let table = AliasTable::new(&weights).unwrap();
let mut rng = SeedableRng::from_seed_str("all");
let selected = table.sample_without_replacement(&mut rng, 10); assert_eq!(selected.len(), 3);
}
#[test]
fn with_replacement_allows_duplicates() {
let weights = vec![1.0]; let table = AliasTable::new(&weights).unwrap();
let mut rng = SeedableRng::from_seed_str("replace");
let selected = table.sample_with_replacement(&mut rng, 5);
assert_eq!(selected.len(), 5);
assert!(selected.iter().all(|&i| i == 0));
}
#[test]
fn deterministic_with_seed() {
let weights = vec![1.0, 2.0, 3.0, 4.0];
let table = AliasTable::new(&weights).unwrap();
let mut rng1 = SeedableRng::from_seed_str("deterministic");
let mut rng2 = SeedableRng::from_seed_str("deterministic");
let seq1: Vec<usize> = (0..20).map(|_| table.sample(&mut rng1)).collect();
let seq2: Vec<usize> = (0..20).map(|_| table.sample(&mut rng2)).collect();
assert_eq!(seq1, seq2);
}
}