1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
use itertools::{Either, Itertools};
use rand::Rng;

/// It's often helpful to have weighted probabilities.
/// This struct serves as a sort of weighted bag; you can give it entries
/// with various weights, and then randomly sample them.
///
/// This is the way Minecraft loot tables work, if this sounds familiar.
///
/// The algorithm used is [Vose's Alias Method](https://www.keithschwarz.com/darts-dice-coins/)
/// (scroll to the bottom), which to be honest I absolutely do not understand.
/// But it has O(n) creation and O(1) selection, so sounds good to me.
///
/// You can't edit the probabilities after you've created it due to the algorithm.
#[derive(Debug, Clone)]
pub struct WeightedPicker<T> {
    prob: Vec<f64>,
    alias: Vec<usize>,
    items: Vec<T>,
}

impl<T> WeightedPicker<T> {
    /**
    Initialize a WeightedPicker from the given
    items and weights.

    Panics if you pass it an empty Vec.

    ```
    # use cogs_gamedev::chance::WeightedPicker;

    let picker = WeightedPicker::new(vec![
        ("common", 10.0),
        ("uncommon", 5.0),
        ("rare", 2.0),
        ("legendary", 1.0),
        ("mythic", 0.1),
    ]);

    let mut rng = rand::thread_rng();
    for _ in 0..10 {
        println!("- {}", picker.get(&mut rng));
    }

    /*
        A sample output:
        - legendary
        - rare
        - uncommon
        - common
        - common
        - rare
        - uncommon
        - common
        - common
        - uncommon
    */
    ```

    */
    pub fn new(entries: Vec<(T, f64)>) -> Self {
        assert_ne!(entries.len(), 0, "Cannot use an empty vec!");

        let total_weight: f64 = entries.iter().map(|(_, weight)| *weight).sum();
        let len = entries.len();
        let average = (len as f64).recip();

        let (items, weights): (Vec<_>, Vec<_>) = entries.into_iter().unzip();

        let (mut small, mut large): (Vec<_>, Vec<_>) = weights
            .iter()
            .enumerate()
            .map(|(idx, weight)| {
                let prob = weight / total_weight * len as f64;
                (idx, prob)
            })
            .partition_map(|(idx, prob)| {
                // true goes to small, false to large
                if prob < average {
                    Either::Left(idx)
                } else {
                    Either::Right(idx)
                }
            });

        let mut alias = vec![0; len];
        let mut prob = vec![0.0; len];

        while !small.is_empty() && !large.is_empty() {
            // what do you mean this is great rust code
            let less = small.pop().unwrap();
            let more = large.pop().unwrap();

            prob[less] *= len as f64;
            alias[less] = more;

            let prev_more = prob[more];
            let prev_less = prob[less];
            prob[more] = prev_more + prev_less - average;

            if prob[more] >= average {
                large.push(more)
            } else {
                small.push(more);
            }
        }
        while let Some(last) = small.pop() {
            prob[last] = 1.0;
        }
        while let Some(last) = large.pop() {
            prob[last] = 1.0;
        }

        debug_assert_eq!(prob.len(), len);
        debug_assert_eq!(alias.len(), len);
        debug_assert_eq!(items.len(), len);

        Self { alias, items, prob }
    }

    /// Get an item from the list.
    pub fn get<R: Rng + ?Sized>(&self, rng: &mut R) -> &T {
        &self.items[self.get_idx(rng)]
    }

    /// Get an index into the internal list.
    ///
    /// You can use this function to save some space by passing a vec
    /// where `T` is `()`, if you want `usize` outputs, I guess.
    pub fn get_idx<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
        let column = rng.gen_range(0..self.prob.len());
        let coin_toss = rng.gen::<f64>() < self.prob[column];
        if coin_toss {
            column
        } else {
            self.alias[column]
        }
    }

    /// The same as creating a WeightedPicker and then calling `get`,
    /// but you don't need to actually make the WeightedPicker.
    pub fn pick<R: Rng + ?Sized>(items: Vec<(T, f64)>, rng: &mut R) -> T {
        let mut wp = WeightedPicker::new(items);
        let idx = wp.get_idx(rng);
        // this would be unsound to use after removal,
        // but fortunately we don't need to use it again
        wp.items.remove(idx)
    }
}

// doctests don't println so let's replicate that test
#[test]
fn pick() {
    let picker = WeightedPicker::new(vec![
        ("common", 10.0),
        ("uncommon", 5.0),
        ("rare", 2.0),
        ("legendary", 1.0),
        ("mythic", 0.1),
    ]);

    let mut rng = rand::thread_rng();
    for _ in 0..10 {
        println!("- {}", picker.get(&mut rng));
    }
}