rand_distr 0.4.1

Sampling from random number distributions
Documentation
// Copyright 2019 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! This module contains an implementation of alias method for sampling random
//! indices with probabilities proportional to a collection of weights.

use super::WeightedError;
use crate::{uniform::SampleUniform, Distribution, Uniform};
use core::fmt;
use core::iter::Sum;
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
use rand::Rng;
use alloc::{boxed::Box, vec, vec::Vec};

/// A distribution using weighted sampling to pick a discretely selected item.
///
/// Sampling a [`WeightedAliasIndex<W>`] distribution returns the index of a randomly
/// selected element from the vector used to create the [`WeightedAliasIndex<W>`].
/// The chance of a given element being picked is proportional to the value of
/// the element. The weights can have any type `W` for which a implementation of
/// [`AliasableWeight`] exists.
///
/// # Performance
///
/// Given that `n` is the number of items in the vector used to create an
/// [`WeightedAliasIndex<W>`], it will require `O(n)` amount of memory.
/// More specifically it takes up some constant amount of memory plus
/// the vector used to create it and a [`Vec<u32>`] with capacity `n`.
///
/// Time complexity for the creation of a [`WeightedAliasIndex<W>`] is `O(n)`.
/// Sampling is `O(1)`, it makes a call to [`Uniform<u32>::sample`] and a call
/// to [`Uniform<W>::sample`].
///
/// # Example
///
/// ```
/// use rand_distr::WeightedAliasIndex;
/// use rand::prelude::*;
///
/// let choices = vec!['a', 'b', 'c'];
/// let weights = vec![2, 1, 1];
/// let dist = WeightedAliasIndex::new(weights).unwrap();
/// let mut rng = thread_rng();
/// for _ in 0..100 {
///     // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
///     println!("{}", choices[dist.sample(&mut rng)]);
/// }
///
/// let items = [('a', 0), ('b', 3), ('c', 7)];
/// let dist2 = WeightedAliasIndex::new(items.iter().map(|item| item.1).collect()).unwrap();
/// for _ in 0..100 {
///     // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
///     println!("{}", items[dist2.sample(&mut rng)].0);
/// }
/// ```
///
/// [`WeightedAliasIndex<W>`]: WeightedAliasIndex
/// [`Vec<u32>`]: Vec
/// [`Uniform<u32>::sample`]: Distribution::sample
/// [`Uniform<W>::sample`]: Distribution::sample
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub struct WeightedAliasIndex<W: AliasableWeight> {
    aliases: Box<[u32]>,
    no_alias_odds: Box<[W]>,
    uniform_index: Uniform<u32>,
    uniform_within_weight_sum: Uniform<W>,
}

impl<W: AliasableWeight> WeightedAliasIndex<W> {
    /// Creates a new [`WeightedAliasIndex`].
    ///
    /// Returns an error if:
    /// - The vector is empty.
    /// - The vector is longer than `u32::MAX`.
    /// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX /
    ///   weights.len()`.
    /// - The sum of weights is zero.
    pub fn new(weights: Vec<W>) -> Result<Self, WeightedError> {
        let n = weights.len();
        if n == 0 {
            return Err(WeightedError::NoItem);
        } else if n > ::core::u32::MAX as usize {
            return Err(WeightedError::TooMany);
        }
        let n = n as u32;

        let max_weight_size = W::try_from_u32_lossy(n)
            .map(|n| W::MAX / n)
            .unwrap_or(W::ZERO);
        if !weights
            .iter()
            .all(|&w| W::ZERO <= w && w <= max_weight_size)
        {
            return Err(WeightedError::InvalidWeight);
        }

        // The sum of weights will represent 100% of no alias odds.
        let weight_sum = AliasableWeight::sum(weights.as_slice());
        // Prevent floating point overflow due to rounding errors.
        let weight_sum = if weight_sum > W::MAX {
            W::MAX
        } else {
            weight_sum
        };
        if weight_sum == W::ZERO {
            return Err(WeightedError::AllWeightsZero);
        }

        // `weight_sum` would have been zero if `try_from_lossy` causes an error here.
        let n_converted = W::try_from_u32_lossy(n).unwrap();

        let mut no_alias_odds = weights.into_boxed_slice();
        for odds in no_alias_odds.iter_mut() {
            *odds *= n_converted;
            // Prevent floating point overflow due to rounding errors.
            *odds = if *odds > W::MAX { W::MAX } else { *odds };
        }

        /// This struct is designed to contain three data structures at once,
        /// sharing the same memory. More precisely it contains two linked lists
        /// and an alias map, which will be the output of this method. To keep
        /// the three data structures from getting in each other's way, it must
        /// be ensured that a single index is only ever in one of them at the
        /// same time.
        struct Aliases {
            aliases: Box<[u32]>,
            smalls_head: u32,
            bigs_head: u32,
        }

        impl Aliases {
            fn new(size: u32) -> Self {
                Aliases {
                    aliases: vec![0; size as usize].into_boxed_slice(),
                    smalls_head: ::core::u32::MAX,
                    bigs_head: ::core::u32::MAX,
                }
            }

            fn push_small(&mut self, idx: u32) {
                self.aliases[idx as usize] = self.smalls_head;
                self.smalls_head = idx;
            }

            fn push_big(&mut self, idx: u32) {
                self.aliases[idx as usize] = self.bigs_head;
                self.bigs_head = idx;
            }

            fn pop_small(&mut self) -> u32 {
                let popped = self.smalls_head;
                self.smalls_head = self.aliases[popped as usize];
                popped
            }

            fn pop_big(&mut self) -> u32 {
                let popped = self.bigs_head;
                self.bigs_head = self.aliases[popped as usize];
                popped
            }

            fn smalls_is_empty(&self) -> bool {
                self.smalls_head == ::core::u32::MAX
            }

            fn bigs_is_empty(&self) -> bool {
                self.bigs_head == ::core::u32::MAX
            }

            fn set_alias(&mut self, idx: u32, alias: u32) {
                self.aliases[idx as usize] = alias;
            }
        }

        let mut aliases = Aliases::new(n);

        // Split indices into those with small weights and those with big weights.
        for (index, &odds) in no_alias_odds.iter().enumerate() {
            if odds < weight_sum {
                aliases.push_small(index as u32);
            } else {
                aliases.push_big(index as u32);
            }
        }

        // Build the alias map by finding an alias with big weight for each index with
        // small weight.
        while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() {
            let s = aliases.pop_small();
            let b = aliases.pop_big();

            aliases.set_alias(s, b);
            no_alias_odds[b as usize] =
                no_alias_odds[b as usize] - weight_sum + no_alias_odds[s as usize];

            if no_alias_odds[b as usize] < weight_sum {
                aliases.push_small(b);
            } else {
                aliases.push_big(b);
            }
        }

        // The remaining indices should have no alias odds of about 100%. This is due to
        // numeric accuracy. Otherwise they would be exactly 100%.
        while !aliases.smalls_is_empty() {
            no_alias_odds[aliases.pop_small() as usize] = weight_sum;
        }
        while !aliases.bigs_is_empty() {
            no_alias_odds[aliases.pop_big() as usize] = weight_sum;
        }

        // Prepare distributions for sampling. Creating them beforehand improves
        // sampling performance.
        let uniform_index = Uniform::new(0, n);
        let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum);

        Ok(Self {
            aliases: aliases.aliases,
            no_alias_odds,
            uniform_index,
            uniform_within_weight_sum,
        })
    }
}

impl<W: AliasableWeight> Distribution<usize> for WeightedAliasIndex<W> {
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
        let candidate = rng.sample(self.uniform_index);
        if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] {
            candidate as usize
        } else {
            self.aliases[candidate as usize] as usize
        }
    }
}

impl<W: AliasableWeight> fmt::Debug for WeightedAliasIndex<W>
where
    W: fmt::Debug,
    Uniform<W>: fmt::Debug,
{
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        f.debug_struct("WeightedAliasIndex")
            .field("aliases", &self.aliases)
            .field("no_alias_odds", &self.no_alias_odds)
            .field("uniform_index", &self.uniform_index)
            .field("uniform_within_weight_sum", &self.uniform_within_weight_sum)
            .finish()
    }
}

impl<W: AliasableWeight> Clone for WeightedAliasIndex<W>
where Uniform<W>: Clone
{
    fn clone(&self) -> Self {
        Self {
            aliases: self.aliases.clone(),
            no_alias_odds: self.no_alias_odds.clone(),
            uniform_index: self.uniform_index,
            uniform_within_weight_sum: self.uniform_within_weight_sum.clone(),
        }
    }
}

/// Trait that must be implemented for weights, that are used with
/// [`WeightedAliasIndex`]. Currently no guarantees on the correctness of
/// [`WeightedAliasIndex`] are given for custom implementations of this trait.
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub trait AliasableWeight:
    Sized
    + Copy
    + SampleUniform
    + PartialOrd
    + Add<Output = Self>
    + AddAssign
    + Sub<Output = Self>
    + SubAssign
    + Mul<Output = Self>
    + MulAssign
    + Div<Output = Self>
    + DivAssign
    + Sum
{
    /// Maximum number representable by `Self`.
    const MAX: Self;

    /// Element of `Self` equivalent to 0.
    const ZERO: Self;

    /// Produce an instance of `Self` from a `u32` value, or return `None` if
    /// out of range. Loss of precision (where `Self` is a floating point type)
    /// is acceptable.
    fn try_from_u32_lossy(n: u32) -> Option<Self>;

    /// Sums all values in slice `values`.
    fn sum(values: &[Self]) -> Self {
        values.iter().copied().sum()
    }
}

macro_rules! impl_weight_for_float {
    ($T: ident) => {
        impl AliasableWeight for $T {
            const MAX: Self = ::core::$T::MAX;
            const ZERO: Self = 0.0;

            fn try_from_u32_lossy(n: u32) -> Option<Self> {
                Some(n as $T)
            }

            fn sum(values: &[Self]) -> Self {
                pairwise_sum(values)
            }
        }
    };
}

/// In comparison to naive accumulation, the pairwise sum algorithm reduces
/// rounding errors when there are many floating point values.
fn pairwise_sum<T: AliasableWeight>(values: &[T]) -> T {
    if values.len() <= 32 {
        values.iter().copied().sum()
    } else {
        let mid = values.len() / 2;
        let (a, b) = values.split_at(mid);
        pairwise_sum(a) + pairwise_sum(b)
    }
}

macro_rules! impl_weight_for_int {
    ($T: ident) => {
        impl AliasableWeight for $T {
            const MAX: Self = ::core::$T::MAX;
            const ZERO: Self = 0;

            fn try_from_u32_lossy(n: u32) -> Option<Self> {
                let n_converted = n as Self;
                if n_converted >= Self::ZERO && n_converted as u32 == n {
                    Some(n_converted)
                } else {
                    None
                }
            }
        }
    };
}

impl_weight_for_float!(f64);
impl_weight_for_float!(f32);
impl_weight_for_int!(usize);
#[cfg(not(target_os = "emscripten"))]
#[cfg_attr(doc_cfg, doc(cfg(not(target_os = "emscripten"))))]
impl_weight_for_int!(u128);
impl_weight_for_int!(u64);
impl_weight_for_int!(u32);
impl_weight_for_int!(u16);
impl_weight_for_int!(u8);
impl_weight_for_int!(isize);
#[cfg(not(target_os = "emscripten"))]
#[cfg_attr(doc_cfg, doc(cfg(not(target_os = "emscripten"))))]
impl_weight_for_int!(i128);
impl_weight_for_int!(i64);
impl_weight_for_int!(i32);
impl_weight_for_int!(i16);
impl_weight_for_int!(i8);

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    #[cfg_attr(miri, ignore)] // Miri is too slow
    fn test_weighted_index_f32() {
        test_weighted_index(f32::into);

        // Floating point special cases
        assert_eq!(
            WeightedAliasIndex::new(vec![::core::f32::INFINITY]).unwrap_err(),
            WeightedError::InvalidWeight
        );
        assert_eq!(
            WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(),
            WeightedError::AllWeightsZero
        );
        assert_eq!(
            WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(),
            WeightedError::InvalidWeight
        );
        assert_eq!(
            WeightedAliasIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(),
            WeightedError::InvalidWeight
        );
        assert_eq!(
            WeightedAliasIndex::new(vec![::core::f32::NAN]).unwrap_err(),
            WeightedError::InvalidWeight
        );
    }

    #[cfg(not(target_os = "emscripten"))]
    #[test]
    #[cfg_attr(miri, ignore)] // Miri is too slow
    fn test_weighted_index_u128() {
        test_weighted_index(|x: u128| x as f64);
    }

    #[cfg(not(target_os = "emscripten"))]
    #[test]
    #[cfg_attr(miri, ignore)] // Miri is too slow
    fn test_weighted_index_i128() {
        test_weighted_index(|x: i128| x as f64);

        // Signed integer special cases
        assert_eq!(
            WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(),
            WeightedError::InvalidWeight
        );
        assert_eq!(
            WeightedAliasIndex::new(vec![::core::i128::MIN]).unwrap_err(),
            WeightedError::InvalidWeight
        );
    }

    #[test]
    #[cfg_attr(miri, ignore)] // Miri is too slow
    fn test_weighted_index_u8() {
        test_weighted_index(u8::into);
    }

    #[test]
    #[cfg_attr(miri, ignore)] // Miri is too slow
    fn test_weighted_index_i8() {
        test_weighted_index(i8::into);

        // Signed integer special cases
        assert_eq!(
            WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(),
            WeightedError::InvalidWeight
        );
        assert_eq!(
            WeightedAliasIndex::new(vec![::core::i8::MIN]).unwrap_err(),
            WeightedError::InvalidWeight
        );
    }

    fn test_weighted_index<W: AliasableWeight, F: Fn(W) -> f64>(w_to_f64: F)
    where WeightedAliasIndex<W>: fmt::Debug {
        const NUM_WEIGHTS: u32 = 10;
        const ZERO_WEIGHT_INDEX: u32 = 3;
        const NUM_SAMPLES: u32 = 15000;
        let mut rng = crate::test::rng(0x9c9fa0b0580a7031);

        let weights = {
            let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize);
            let random_weight_distribution = Uniform::new_inclusive(
                W::ZERO,
                W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(),
            );
            for _ in 0..NUM_WEIGHTS {
                weights.push(rng.sample(&random_weight_distribution));
            }
            weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO;
            weights
        };
        let weight_sum = weights.iter().copied().sum::<W>();
        let expected_counts = weights
            .iter()
            .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64)
            .collect::<Vec<f64>>();
        let weight_distribution = WeightedAliasIndex::new(weights).unwrap();

        let mut counts = vec![0; NUM_WEIGHTS as usize];
        for _ in 0..NUM_SAMPLES {
            counts[rng.sample(&weight_distribution)] += 1;
        }

        assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0);
        for (count, expected_count) in counts.into_iter().zip(expected_counts) {
            let difference = (count as f64 - expected_count).abs();
            let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1;
            assert!(difference <= max_allowed_difference);
        }

        assert_eq!(
            WeightedAliasIndex::<W>::new(vec![]).unwrap_err(),
            WeightedError::NoItem
        );
        assert_eq!(
            WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(),
            WeightedError::AllWeightsZero
        );
        assert_eq!(
            WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
            WeightedError::InvalidWeight
        );
    }

    #[test]
    fn value_stability() {
        fn test_samples<W: AliasableWeight>(weights: Vec<W>, buf: &mut [usize], expected: &[usize]) {
            assert_eq!(buf.len(), expected.len());
            let distr = WeightedAliasIndex::new(weights).unwrap();
            let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
            for r in buf.iter_mut() {
                *r = rng.sample(&distr);
            }
            assert_eq!(buf, expected);
        }

        let mut buf = [0; 10];
        test_samples(vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[
            6, 5, 7, 5, 8, 7, 6, 2, 3, 7,
        ]);
        test_samples(vec![0.7f32, 0.1, 0.1, 0.1], &mut buf, &[
            2, 0, 0, 0, 0, 0, 0, 0, 1, 3,
        ]);
        test_samples(vec![1.0f64, 0.999, 0.998, 0.997], &mut buf, &[
            2, 1, 2, 3, 2, 1, 3, 2, 1, 1,
        ]);
    }
}