wicker 0.2.0

Weighted probability picker for Rust
Documentation
#![doc = include_str!("../README.md")]

pub mod iter;
use iter::*;

use itertools::{Either, Itertools};
use rand::Rng;

#[derive(Debug, Clone)]
pub struct WeightedPicker<T> {
    prob: Vec<f64>,
    alias: Vec<usize>,
    items: Vec<(T, f64)>,
}

impl<T> WeightedPicker<T> {
    /**
    Initialize a WeightedPicker from the given
    items and weights.

    Panics if you pass it an empty Vec.
    */
    #[must_use]
    pub fn new(entries: Vec<(T, f64)>) -> Self {
        assert_ne!(entries.len(), 0, "Cannot use an empty vec!");

        let total_weight: f64 = entries.iter().map(|(_, weight)| *weight).sum();
        let len = entries.len();
        let average = (len as f64).recip();

        let weights = entries.iter().map(|(_, w)| *w).collect_vec();
        let items = entries;

        let (mut small, mut large): (Vec<_>, Vec<_>) = weights
            .iter()
            .enumerate()
            .map(|(idx, weight)| {
                let prob = weight / total_weight * len as f64;
                (idx, prob)
            })
            .partition_map(|(idx, prob)| {
                // true goes to small, false to large
                if prob < average {
                    Either::Left(idx)
                } else {
                    Either::Right(idx)
                }
            });

        let mut alias = vec![0; len];
        let mut prob = vec![0.0; len];

        while !small.is_empty() && !large.is_empty() {
            // what do you mean, this is great rust code
            let less = small.pop().unwrap();
            let more = large.pop().unwrap();

            prob[less] *= len as f64;
            alias[less] = more;

            let prev_more = prob[more];
            let prev_less = prob[less];
            prob[more] = prev_more + prev_less - average;

            if prob[more] >= average {
                large.push(more)
            } else {
                small.push(more);
            }
        }
        while let Some(last) = small.pop() {
            prob[last] = 1.0;
        }
        while let Some(last) = large.pop() {
            prob[last] = 1.0;
        }

        debug_assert_eq!(prob.len(), len);
        debug_assert_eq!(alias.len(), len);
        debug_assert_eq!(items.len(), len);

        Self { alias, items, prob }
    }

    /// Randomly pick an item from the list.
    #[must_use]
    pub fn get<R: Rng + ?Sized>(&self, rng: &mut R) -> &T {
        &self.items[self.get_idx(rng)].0
    }

    /// Randomly pick an *index* from the list.
    /// This is like [`WeightedPicker::get`], but returns the index of the
    /// selected value instead of the value.
    ///
    /// You can use this function to save some space by passing a vec
    /// where `T` is `()`, if you want `usize` outputs, I guess.
    #[must_use]
    pub fn get_idx<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
        let column = rng.gen_range(0..self.prob.len());
        let coin_toss = rng.gen::<f64>() < self.prob[column];
        if coin_toss {
            column
        } else {
            self.alias[column]
        }
    }

    /// Get the number of entries in the picker.
    #[must_use]
    #[allow(clippy::len_without_is_empty)] // it can never be empty
    pub fn len(&self) -> usize {
        self.items.len()
    }

    /// Manually index into the picker's array.
    pub fn get_by_idx(&self, idx: usize) -> Option<&T> {
        self.items.get(idx).map(|(it, _)| it)
    }

    /// Manually index into the picker's array.
    /// You can use this to mutate entries once they've been created.
    ///
    /// Note there is no way to mutate *probabilities* after creation,
    /// nor any way to add or remove possible values.
    #[must_use]
    pub fn get_mut_by_idx(&mut self, idx: usize) -> Option<&mut T> {
        self.items.get_mut(idx).map(|(it, _)| it)
    }

    /// Convenience method for creating a WeightedPicker and then calling `get`,
    /// so you don't need to actually make the WeightedPicker.
    #[must_use]
    pub fn pick<R: Rng + ?Sized>(items: Vec<(T, f64)>, rng: &mut R) -> T {
        let mut wp = WeightedPicker::new(items);
        let idx = wp.get_idx(rng);
        // this would be unsound to use after removal,
        // but fortunately we don't need to use it again
        // not sure why i can't move out of it.
        wp.items.swap_remove(idx).0
    }

    /**
    Iterate through the elements in order of probability, from least to most.

    This requires sorting internally, so is `O(n log(n))`.
    ```rust
    # use wicker::{WeightedPicker, iter::Iter};
    let picker = WeightedPicker::new(vec![
    ("hello", 0.3),
    ("wicker", 10.0),
    ("this", 3.0),
    ("world", 1.0),
    ("is", 5.0),
    ]);
    let v = picker.iter().copied().collect::<Vec<_>>();
    assert_eq!(v, &["hello", "world", "this", "is", "wicker"]);
    ```
    */
    pub fn iter(&self) -> Iter<'_, T> {
        Iter::new(self)
    }
}

// doctests don't println so let's replicate that test
#[test]
fn pick() {
    let picker = WeightedPicker::new(vec![
        ("common", 10.0),
        ("uncommon", 5.0),
        ("rare", 2.0),
        ("legendary", 1.0),
        ("mythic", 0.1),
    ]);

    let mut rng = rand::thread_rng();
    for _ in 0..10 {
        println!("- {}", picker.get(&mut rng));
    }
}