use rand_distr::WeightedAliasIndex;
#[derive(Clone, Debug)]
pub struct Distribution<T> {
elements: Vec<T>,
dist: WeightedAliasIndex<f64>,
}
impl<T> Distribution<T> {
pub fn new(weighted_elements: Vec<(T, f64)>) -> Option<Self> {
let (elements, weights) = weighted_elements.into_iter().unzip();
match WeightedAliasIndex::new(weights) {
Ok(dist) => Some(Distribution { elements, dist }),
Err(err) => {
log::error!(
"Distribution::new: Error creating weighted probability distribution: {:?}",
err
);
None
}
}
}
pub fn flat(elements: Vec<T>) -> Option<Self> {
Distribution::new(std::iter::zip(elements, std::iter::repeat(1.0)).collect())
}
pub fn singleton(element: T) -> Self {
Distribution::new(vec![(element, 1.0)]).unwrap()
}
pub fn sample_using<R: rand::Rng>(&self, rng: &mut R) -> &T {
let index = self.weighted_index(rng);
&self.elements[index]
}
pub fn sample_using_mut<R: rand::Rng>(&mut self, rng: &mut R) -> &mut T {
let index = self.weighted_index(rng);
&mut self.elements[index]
}
pub fn sample(&self) -> &T {
self.sample_using(&mut rand::thread_rng())
}
pub fn sample_mut(&mut self) -> &mut T {
self.sample_using_mut(&mut rand::thread_rng())
}
fn weighted_index<R: rand::Rng>(&self, rng: &mut R) -> usize {
<WeightedAliasIndex<f64> as rand_distr::Distribution<usize>>::sample(&self.dist, rng)
}
}