random_pick/
lib.rs

1/*!
2# Random Pick
3Pick an element from a slice randomly by given weights.
4
5## Examples
6
7```rust
8enum Prize {
9    Legendary,
10    Rare,
11    Enchanted,
12    Common,
13}
14
15let prize_list = [Prize::Legendary, Prize::Rare, Prize::Enchanted, Prize::Common]; // available prizes
16
17let slice = &prize_list;
18let weights = [1, 5, 15, 30]; // a scale of chance of picking each kind of prize
19
20let n = 1000000;
21let mut counter = [0usize; 4];
22
23for _ in 0..n {
24    let picked_item = random_pick::pick_from_slice(slice, &weights).unwrap();
25
26    match picked_item {
27        Prize::Legendary=>{
28            counter[0] += 1;
29           }
30        Prize::Rare=>{
31            counter[1] += 1;
32        }
33        Prize::Enchanted=>{
34            counter[2] += 1;
35        }
36        Prize::Common=>{
37            counter[3] += 1;
38        }
39    }
40}
41
42println!("{}", counter[0]); // Should be close to 20000
43println!("{}", counter[1]); // Should be close to 100000
44println!("{}", counter[2]); // Should be close to 300000
45println!("{}", counter[3]); // Should be close to 600000
46```
47
48The length of the slice is usually an integral multiple (larger than zero) of that of weights.
49
50If you have multiple slices, you don't need to use extra space to concat them, just use the `pick_from_multiple_slices` function, instead of `pick_from_slice`.
51
52Besides picking a single element from a slice or slices, you can also use `pick_multiple_from_slice` and `pick_multiple_from_multiple_slices` functions. Their overhead is lower than that of non-multiple-pick functions with extra loops.
53*/
54
55use rand::{rng, Rng};
56
57const MAX_NUMBER: usize = usize::MAX;
58
59/// Pick an element from a slice randomly by given weights.
60pub fn pick_from_slice<'a, T>(slice: &'a [T], weights: &'a [usize]) -> Option<&'a T> {
61    let slice_len = slice.len();
62
63    let index = gen_usize_with_weights(slice_len, weights)?;
64
65    Some(&slice[index])
66}
67
68/// Pick an element from multiple slices randomly by given weights.
69pub fn pick_from_multiple_slices<'a, T>(slices: &[&'a [T]], weights: &'a [usize]) -> Option<&'a T> {
70    let len: usize = slices.iter().map(|slice| slice.len()).sum();
71
72    let mut index = gen_usize_with_weights(len, weights)?;
73
74    for slice in slices {
75        let len = slice.len();
76
77        if index < len {
78            return Some(&slice[index]);
79        } else {
80            index -= len;
81        }
82    }
83
84    None
85}
86
87/// Pick multiple elements from a slice randomly by given weights.
88pub fn pick_multiple_from_slice<'a, T>(
89    slice: &'a [T],
90    weights: &'a [usize],
91    count: usize,
92) -> Vec<&'a T> {
93    let slice_len = slice.len();
94
95    gen_multiple_usize_with_weights(slice_len, weights, count)
96        .iter()
97        .map(|&index| &slice[index])
98        .collect()
99}
100
101/// Pick multiple elements from multiple slices randomly by given weights.
102pub fn pick_multiple_from_multiple_slices<'a, T>(
103    slices: &[&'a [T]],
104    weights: &'a [usize],
105    count: usize,
106) -> Vec<&'a T> {
107    let len: usize = slices.iter().map(|slice| slice.len()).sum();
108
109    gen_multiple_usize_with_weights(len, weights, count)
110        .iter()
111        .map(|index| {
112            let mut index = *index;
113
114            let mut s = slices[0];
115
116            for slice in slices {
117                let len = slice.len();
118
119                if index < len {
120                    s = slice;
121                    break;
122                } else {
123                    index -= len;
124                }
125            }
126
127            &s[index]
128        })
129        .collect()
130}
131
132/// Get a usize value by given weights.
133pub fn gen_usize_with_weights(high: usize, weights: &[usize]) -> Option<usize> {
134    let weights_len = weights.len();
135
136    if weights_len == 0 || high == 0 {
137        return None;
138    } else if weights_len == 1 {
139        if weights[0] == 0 {
140            return None;
141        }
142
143        return Some(rng().random_range(0..high));
144    } else {
145        let mut weights_sum = 0f64;
146        let mut max_weight = 0;
147
148        for w in weights.iter().copied() {
149            weights_sum += w as f64;
150            if w > max_weight {
151                max_weight = w;
152            }
153        }
154
155        if max_weight == 0 {
156            return None;
157        }
158
159        let mut rng = rng();
160
161        let index_scale = (high as f64) / (weights_len as f64);
162
163        let weights_scale = (MAX_NUMBER as f64) / weights_sum;
164
165        let rnd = rng.random_range(0..=MAX_NUMBER) as f64;
166
167        let mut temp = 0f64;
168
169        for (i, w) in weights.iter().copied().enumerate() {
170            temp += (w as f64) * weights_scale;
171            if temp > rnd {
172                let index = ((i as f64) * index_scale) as usize;
173
174                return Some(rng.random_range(index..((((i + 1) as f64) * index_scale) as usize)));
175            }
176        }
177    }
178
179    None
180}
181
182/// Get multiple usize values by given weights.
183pub fn gen_multiple_usize_with_weights(high: usize, weights: &[usize], count: usize) -> Vec<usize> {
184    let mut result: Vec<usize> = Vec::with_capacity(count);
185
186    let weights_len = weights.len();
187
188    if weights_len > 0 && high > 0 {
189        if weights_len == 1 {
190            if weights[0] != 0 {
191                let mut rng = rng();
192
193                for _ in 0..count {
194                    result.push(rng.random_range(0..high));
195                }
196            }
197        } else {
198            let mut weights_sum = 0f64;
199            let mut max_weight = 0;
200
201            for w in weights.iter().copied() {
202                weights_sum += w as f64;
203                if w > max_weight {
204                    max_weight = w;
205                }
206            }
207
208            if max_weight > 0 {
209                let index_scale = (high as f64) / (weights_len as f64);
210
211                let weights_scale = (MAX_NUMBER as f64) / weights_sum;
212
213                let mut rng = rng();
214
215                for _ in 0..count {
216                    let rnd = rng.random_range(0..=MAX_NUMBER) as f64;
217
218                    let mut temp = 0f64;
219
220                    for (i, w) in weights.iter().copied().enumerate() {
221                        temp += (w as f64) * weights_scale;
222                        if temp > rnd {
223                            let index = ((i as f64) * index_scale) as usize;
224
225                            result.push(
226                                rng.random_range(
227                                    index..((((i + 1) as f64) * index_scale) as usize),
228                                ),
229                            );
230                            break;
231                        }
232                    }
233                }
234            }
235        }
236    }
237
238    result
239}