cogs_gamedev/
chance.rs

1//! Random and probabilistic things helpful for games.
2
3use itertools::{Either, Itertools};
4use rand::Rng;
5
6/// It's often helpful to have weighted probabilities.
7/// This struct serves as a sort of weighted bag; you can give it entries
8/// with various weights, and then randomly sample them.
9///
10/// This is the way Minecraft loot tables work, if this sounds familiar.
11///
12/// The algorithm used is [Vose's Alias Method](https://www.keithschwarz.com/darts-dice-coins/)
13/// (scroll to the bottom), which to be honest I absolutely do not understand.
14/// But it has O(n) creation and O(1) selection, so sounds good to me.
15///
16/// You can't edit the probabilities after you've created it due to the algorithm.
17#[derive(Debug, Clone)]
18pub struct WeightedPicker<T> {
19    prob: Vec<f64>,
20    alias: Vec<usize>,
21    items: Vec<T>,
22}
23
24impl<T> WeightedPicker<T> {
25    /**
26    Initialize a WeightedPicker from the given
27    items and weights.
28
29    Panics if you pass it an empty Vec.
30
31    ```
32    # use cogs_gamedev::chance::WeightedPicker;
33
34    let picker = WeightedPicker::new(vec![
35        ("common", 10.0),
36        ("uncommon", 5.0),
37        ("rare", 2.0),
38        ("legendary", 1.0),
39        ("mythic", 0.1),
40    ]);
41
42    let mut rng = rand::thread_rng();
43    for _ in 0..10 {
44        println!("- {}", picker.get(&mut rng));
45    }
46
47    /*
48        A sample output:
49        - legendary
50        - rare
51        - uncommon
52        - common
53        - common
54        - rare
55        - uncommon
56        - common
57        - common
58        - uncommon
59    */
60    ```
61
62    */
63    pub fn new(entries: Vec<(T, f64)>) -> Self {
64        assert_ne!(entries.len(), 0, "Cannot use an empty vec!");
65
66        let total_weight: f64 = entries.iter().map(|(_, weight)| *weight).sum();
67        let len = entries.len();
68        let average = (len as f64).recip();
69
70        let (items, weights): (Vec<_>, Vec<_>) = entries.into_iter().unzip();
71
72        let (mut small, mut large): (Vec<_>, Vec<_>) = weights
73            .iter()
74            .enumerate()
75            .map(|(idx, weight)| {
76                let prob = weight / total_weight * len as f64;
77                (idx, prob)
78            })
79            .partition_map(|(idx, prob)| {
80                // true goes to small, false to large
81                if prob < average {
82                    Either::Left(idx)
83                } else {
84                    Either::Right(idx)
85                }
86            });
87
88        let mut alias = vec![0; len];
89        let mut prob = vec![0.0; len];
90
91        while !small.is_empty() && !large.is_empty() {
92            // what do you mean this is great rust code
93            let less = small.pop().unwrap();
94            let more = large.pop().unwrap();
95
96            prob[less] *= len as f64;
97            alias[less] = more;
98
99            let prev_more = prob[more];
100            let prev_less = prob[less];
101            prob[more] = prev_more + prev_less - average;
102
103            if prob[more] >= average {
104                large.push(more)
105            } else {
106                small.push(more);
107            }
108        }
109        while let Some(last) = small.pop() {
110            prob[last] = 1.0;
111        }
112        while let Some(last) = large.pop() {
113            prob[last] = 1.0;
114        }
115
116        debug_assert_eq!(prob.len(), len);
117        debug_assert_eq!(alias.len(), len);
118        debug_assert_eq!(items.len(), len);
119
120        Self { alias, items, prob }
121    }
122
123    /// Get an item from the list.
124    pub fn get<R: Rng + ?Sized>(&self, rng: &mut R) -> &T {
125        &self.items[self.get_idx(rng)]
126    }
127
128    /// Get an index into the internal list.
129    /// This is like [`WeightedPicker::get`], but returns the index of the
130    /// selected value instead of the value.
131    ///
132    /// You can use this function to save some space by passing a vec
133    /// where `T` is `()`, if you want `usize` outputs, I guess.
134    pub fn get_idx<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
135        let column = rng.gen_range(0..self.prob.len());
136        let coin_toss = rng.gen::<f64>() < self.prob[column];
137        if coin_toss {
138            column
139        } else {
140            self.alias[column]
141        }
142    }
143
144    /// Manually index into the picker's array.
145    pub fn get_by_idx(&self, idx: usize) -> Option<&T> {
146        self.items.get(idx)
147    }
148
149    /// Manually index into the picker's array.
150    /// You can use this to mutate entries once they've been created.
151    ///
152    /// Note there is no way to mutate probabilities after creation,
153    /// nor any way to add or remove possible values.
154    pub fn get_mut_by_idx(&mut self, idx: usize) -> Option<&mut T> {
155        self.items.get_mut(idx)
156    }
157
158    /// The same as creating a WeightedPicker and then calling `get`,
159    /// but you don't need to actually make the WeightedPicker.
160    pub fn pick<R: Rng + ?Sized>(items: Vec<(T, f64)>, rng: &mut R) -> T {
161        let mut wp = WeightedPicker::new(items);
162        let idx = wp.get_idx(rng);
163        // this would be unsound to use after removal,
164        // but fortunately we don't need to use it again
165        // not sure why i can't move out of it.
166        wp.items.remove(idx)
167    }
168}
169
170// doctests don't println so let's replicate that test
171#[test]
172fn pick() {
173    let picker = WeightedPicker::new(vec![
174        ("common", 10.0),
175        ("uncommon", 5.0),
176        ("rare", 2.0),
177        ("legendary", 1.0),
178        ("mythic", 0.1),
179    ]);
180
181    let mut rng = rand::thread_rng();
182    for _ in 0..10 {
183        println!("- {}", picker.get(&mut rng));
184    }
185}