1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
use rand_distr::WeightedAliasIndex;
/// A weighted probability distribution over a set of discrete elements, such as moves.
///
/// A distribution consists of a set of elements with associated weights. A weight indicates how
/// likely that element is compared to the other elements in the distribution.
///
/// # Examples
///
/// The following distribution defines a weighted coin where heads is three times as likely to
/// occur as tails. The value `coin` can be expected to be `"heads"` 75% of the time and `"tails"`
/// 25% of the time.
///
/// ```
/// use t4t::Distribution;
///
/// let dist = Distribution::new(vec![("heads", 3.0), ("tails", 1.0)]).unwrap();
/// let coin = dist.sample();
/// ```
///
/// In the following distribution, the value `'A'` is 2.5 times as likely as `'B'` and 5 times as
/// likely as `'C'`, so `abc` can be expected to be `'A'` 62.5% (5/8) of the time, `'B'` 25% of the
/// time, and `'C'` 12.5% of the time.
///
/// ```
/// use t4t::Distribution;
///
/// let dist = Distribution::new(vec![('A', 2.5), ('B', 1.0), ('C', 0.5)]).unwrap();
/// let abc = dist.sample();
/// ```
#[derive(Clone, Debug)]
pub struct Distribution<T> {
elements: Vec<T>,
dist: WeightedAliasIndex<f64>,
}
impl<T> Distribution<T> {
/// Create a new weighted distribution given an association list of elements and their weights.
///
/// # Errors
///
/// Logs an error and returns `None` if:
/// - The vector is empty.
/// - The vector is longer than u32::MAX.
/// - For any weight `w`: `w < 0.0` or `w > max`
/// where `max = f64::MAX / weighted_elements.len()`.
/// - The sum of the weights is zero.
pub fn new(weighted_elements: Vec<(T, f64)>) -> Option<Self> {
let (elements, weights) = weighted_elements.into_iter().unzip();
match WeightedAliasIndex::new(weights) {
Ok(dist) => Some(Distribution { elements, dist }),
Err(err) => {
log::error!(
"Distribution::new: Error creating weighted probability distribution: {:?}",
err
);
None
}
}
}
/// Create a new flat distribution over the given elements.
///
/// # Errors
///
/// Logs an error and returns `None` if:
/// - The vector is empty.
/// - The vector is longer than u32::MAX.
pub fn flat(elements: Vec<T>) -> Option<Self> {
Distribution::new(std::iter::zip(elements, std::iter::repeat(1.0)).collect())
}
/// Create a trivial distribution consisting of a single element.
pub fn singleton(element: T) -> Self {
Distribution::new(vec![(element, 1.0)]).unwrap()
}
/// Sample a random value from the distribution using `rng` as the source of randomness.
pub fn sample_using<R: rand::Rng>(&self, rng: &mut R) -> &T {
let index = self.weighted_index(rng);
&self.elements[index]
}
/// Sample a random value from the distribution using `rng` as the source of randomness,
/// returning a mutable reference to the sampled element.
pub fn sample_using_mut<R: rand::Rng>(&mut self, rng: &mut R) -> &mut T {
let index = self.weighted_index(rng);
&mut self.elements[index]
}
/// Sample a random value from the distribution using `rand::thread_rng()` as the source of
/// randomness.
pub fn sample(&self) -> &T {
self.sample_using(&mut rand::thread_rng())
}
/// Sample a random value from the distribution using `rand::thread_rng()` as the source of
/// randomness, returning a mutable reference to the sampled element.
pub fn sample_mut(&mut self) -> &mut T {
self.sample_using_mut(&mut rand::thread_rng())
}
/// Get an index into the element list according to the probability distribution.
fn weighted_index<R: rand::Rng>(&self, rng: &mut R) -> usize {
<WeightedAliasIndex<f64> as rand_distr::Distribution<usize>>::sample(&self.dist, rng)
}
}