proportionate_selector/lib.rs
1//! Proportionate selection from discrete distribution.
2//!
3//! `proportionate_selector` allows sampling from empirical discrete distribution,
4//! at rumtime. Each sample is generated independently, and has no coupling to previously
5//! generated or future samples. This allows for quick, and reliable sample generation from
6//! some known discrete distribution.
7//!
8//! ## Use cases
9//!
10//! * Multivariant a/b tests
11//! * Simple lootbox generation in games
12//! * Use in evolutionary algorithms
13//! * Help content promotion
14//! * Coupon code generation
15//! * and more...
16//!
17//! ## Example
18//!
19//! Suppose we want to build _very simple_ lootbox reward collectables, based on
20//! some rarity associated with the reward collectables. And we want to be able to
21//! modify _rarity_ of such collectables (thousands of possible items) are runtime.
22//!
23//! For example,
24//!
25//! | Reward Item | Rarity | Probability of Occurance (1/Rarity) |
26//! | ----------- | :----: | :---------------------------------: |
27//! | Reward A | 50 | (1/50) = 0.02 |
28//! | Reward B | 10 | (1/10) = 0.10 |
29//! | Reward C | 10 | (1/10) = 0.10 |
30//! | Reward D | 2 | (1/2) = 0.5 |
31//! | No Reward | 3.5714 | (1/3.5714) = 0.28 |
32//!
33//! Note: `proportionate_selector` requires that sum of probabilities equals to 1.
34//! For some reason, you are using different ranking methoddologies, you can
35//! normalize probabilities prior to using `proportionate_selector`. In most cases,
36//! you should be doing this anyways.
37//!
38//! ```rust
39//! use proportionate_selector::*;
40//!
41//! #[derive(PartialEq, Debug)]
42//! pub struct LootboxItem {
43//! pub id: i32,
44//! /// Likelihood of recieve item from lootbox.
45//! /// Rarity represents inverse lilihood of recieveing
46//! /// this item.
47//! ///
48//! /// e.g. rairity of 1, means lootbox item will be more
49//! /// frequently generated as opposed to rarity of 100.
50//! pub rarity: f64,
51//! }
52//!
53//! impl Probability for LootboxItem {
54//! fn prob(&self) -> f64 {
55//! // rarity is modeled as 1 out of X occurance, so
56//! // rarity of 20 has probability of 1/20.
57//! 1.0 / self.rarity
58//! }
59//! }
60//!
61//! let endOfLevel1Box = vec![
62//! LootboxItem {id: 0, rarity: 50.0}, // 2%
63//! LootboxItem {id: 1, rarity: 10.0}, // 10%
64//! LootboxItem {id: 2, rarity: 10.0}, // 10%
65//! LootboxItem {id: 3, rarity: 2.0}, // 50%
66//! LootboxItem {id: 4, rarity: 3.5714}, // 28%
67//! ];
68//!
69//! // create discrete distribution for sampling
70//! let epdf = DiscreteDistribution::new(&endOfLevel1Box, SamplingMethod::Linear).unwrap();
71//! let s = epdf.sample();
72//!
73//! println!("{:?}", epdf.sample());
74//! ```
75//!
76//! ## Benchmarks (+/- 5%)
77//!
78//! | Sampling | Time | Number of Items |
79//! | ------------ | :----: | --------------- |
80//! | Linear | 30 ns | 100 |
81//! | Linear | 6 us | 10,000 |
82//! | Linear | 486 us | 1,000,000 |
83//! | Cdf | 31 ns | 100 |
84//! | Cdf | 41 ns | 10,000 |
85//! | Cdf | 62 ns | 1,000,000 |
86//! | Stochastic | 315 ns | 100 |
87//! | Stochastic | 30 us | 10,000 |
88//! | Stochastic | 40 us | 1,000,000 |
89//!
90//! Beanchmark ran on:
91//! ```text
92//! Model Name: Mac mini
93//! Model Identifier: Macmini9,1
94//! Chip: Apple M1
95//! Total Number of Cores: 8 (4 performance and 4 efficiency)
96//! Memory: 16 GB
97//! ```
98//!
99pub mod errors;
100pub mod util;
101
102use anyhow::{bail, Result};
103use errors::ProportionalSelectionErr::*;
104use rand::distributions::{Distribution, Uniform};
105use rand::rngs::SmallRng;
106use rand::Rng;
107use rand::SeedableRng;
108
109use util::*;
110
111/// Sampling method to use when, sampling from discrete distribution.
112#[derive(Debug, Clone, Copy)]
113pub enum SamplingMethod {
114 /// Performs linear scan on probabilities.
115 ///
116 /// Worst case is O(n).
117 Linear,
118 /// Performs selection by creating cumulative distribution function (CDF),
119 /// and performing selection.
120 ///
121 /// Worst case is O(ln n).
122 CumulativeDistributionFunction,
123 /// Performs selection using stochastic acceptance. Average case is O(1), but
124 /// may require at most N call to random number generator function.
125 ///
126 /// Worst case is O(n).
127 /// Average case is O(1).
128 ///
129 /// Reference: <https://arxiv.org/abs/1109.3627>
130 StochasticAcceptance,
131}
132
133pub trait Probability {
134 /// Returns non-negative probability of occurance.
135 ///
136 /// Probability must be within range of 0 to 1.
137 fn prob(&self) -> f64;
138}
139
140enum DistributionStore<'a, T: Probability> {
141 Frequency {
142 freq: Vec<f64>,
143 total: f64,
144 items: &'a Vec<T>,
145 },
146 Cdf {
147 cdf: Vec<f64>,
148 items: &'a Vec<T>,
149 },
150 MaxFrequency {
151 max_freq: f64,
152 items: &'a Vec<T>,
153 },
154}
155
156/// Represents empirical discrete distribution.
157pub struct DiscreteDistribution<'a, T: Probability> {
158 /// Stores distribution attributes for quick sample generation.
159 store: DistributionStore<'a, T>,
160}
161
162impl<'_a, T: Probability> DiscreteDistribution<'_a, T> {
163 /// Returns DiscreteDistribution based on the selection method.
164 ///
165 /// ```rust
166 /// use proportionate_selector::*;
167 ///
168 /// #[derive(PartialEq)]
169 /// pub struct MultiVariantMarketingWebsiteItem {
170 /// pub id: i32,
171 /// /// Likelihood of recieve marketing website version.
172 /// /// Rarity represents inverse lilihood of recieveing
173 /// /// this item.
174 /// ///
175 /// /// e.g. rairity of 1, means website item item will be more
176 /// /// frequently generated as opposed to rarity of 100.
177 /// pub rarity: f64,
178 /// }
179 ///
180 /// impl Probability for MultiVariantMarketingWebsiteItem {
181 /// fn prob(&self) -> f64 {
182 /// // rarity is modeled as 1 out of X occurance, so
183 /// // rarity of 20 has probability of 1/20.
184 /// 1.0 / self.rarity
185 /// }
186 /// }
187 ///
188 /// let summer2020Launch = vec![
189 /// MultiVariantMarketingWebsiteItem {id: 0, rarity: 5.0}, // 20%
190 /// MultiVariantMarketingWebsiteItem {id: 1, rarity: 5.0}, // 20%
191 /// MultiVariantMarketingWebsiteItem {id: 2, rarity: 10.0}, // 10%
192 /// MultiVariantMarketingWebsiteItem {id: 3, rarity: 2.5}, // 40%
193 /// MultiVariantMarketingWebsiteItem {id: 4, rarity: 10.0}, // 40%
194 /// ];
195 ///
196 /// // create distribution for sampling
197 /// let epdf = DiscreteDistribution::new(&summer2020Launch, SamplingMethod::Linear);
198 /// assert!(epdf.is_ok());
199 /// ```
200 pub fn new(items: &'_a Vec<T>, method: SamplingMethod) -> Result<Self> {
201 // Must have definable discrete distribution.
202 let n = items.len();
203 if n <= 1 {
204 bail!(InsufficientProbabilitiesProvided { actual: n });
205 }
206
207 // Apply some tolerance to probability bounds.
208 const EPSILON: f64 = 0.001;
209 let total_p: f64 = items.iter().map(|i| i.prob()).sum();
210
211 // Impossible to perform sampling per occurance probabilitity, if
212 // occurance of all possible scenario is greater than 1 or less than 1.
213 if !(1.0 - EPSILON..=1.0 + EPSILON).contains(&total_p) {
214 bail!(SumOfAllProbabilitiesDoesNotEqualToOne { actual: total_p });
215 }
216
217 match method {
218 SamplingMethod::Linear => {
219 let freq = items
220 .iter()
221 .map(|item| item.prob() * 100.0)
222 .collect::<Vec<_>>();
223
224 let total = freq.iter().sum();
225
226 Ok(Self {
227 store: DistributionStore::Frequency { freq, total, items },
228 })
229 }
230
231 SamplingMethod::CumulativeDistributionFunction => {
232 let cdf = items
233 .iter()
234 .enumerate()
235 .scan(0.0, |acc, item| {
236 *acc += item.1.prob();
237 Some(*acc)
238 })
239 .collect::<Vec<_>>();
240
241 Ok(Self {
242 store: DistributionStore::Cdf { cdf, items },
243 })
244 }
245
246 SamplingMethod::StochasticAcceptance => {
247 let max_freq = items
248 .iter()
249 .map(|i| i.prob())
250 .max_by(|lhs, rhs| lhs.total_cmp(rhs))
251 .unwrap_or(0.0);
252
253 Ok(Self {
254 store: DistributionStore::MaxFrequency { max_freq, items },
255 })
256 }
257 }
258 }
259
260 /// Returns a sample based on discrete distribution.
261 ///
262 /// As invocation of sample() reaches large number (e.g. +infinity), the
263 /// difference between population (defined discrete distribution), and
264 /// distribution from generated sample diminishes.
265 ///
266 /// ```rust
267 /// use proportionate_selector::*;
268 ///
269 /// #[derive(PartialEq)]
270 /// pub struct LootboxItem {
271 /// pub id: i32,
272 /// /// Likelihood of recieve item from lootbox.
273 /// /// Rarity represents inverse lilihood of recieveing
274 /// /// this item.
275 /// ///
276 /// /// e.g. rairity of 1, means lootbox item will be more
277 /// /// frequently generated as opposed to rarity of 100.
278 /// pub rarity: f64,
279 /// }
280 ///
281 /// impl Probability for LootboxItem {
282 /// fn prob(&self) -> f64 {
283 /// // rarity is modeled as 1 out of X occurance, so
284 /// // rarity of 20 has probability of 1/20.
285 /// 1.0 / self.rarity
286 /// }
287 /// }
288 ///
289 /// let endOfLevel1Box = vec![
290 /// LootboxItem {id: 0, rarity: 5.0}, // 20%
291 /// LootboxItem {id: 1, rarity: 5.0}, // 20%
292 /// LootboxItem {id: 2, rarity: 10.0}, // 10%
293 /// LootboxItem {id: 3, rarity: 2.5}, // 40%
294 /// LootboxItem {id: 4, rarity: 10.0}, // 40%
295 /// ];
296 ///
297 /// // create discrete distribution for sampling
298 /// let epdf = DiscreteDistribution::new(&endOfLevel1Box, SamplingMethod::Linear).unwrap();
299 /// let s = epdf.sample();
300 ///
301 /// assert!(s.is_some());
302 /// ```
303 ///
304 pub fn sample(&'_a self) -> Option<&T> {
305 match &self.store {
306 DistributionStore::Cdf { cdf, items } => sample_cdf(cdf, items),
307 DistributionStore::Frequency { freq, total, items } => {
308 sample_linear(freq, total, items)
309 }
310 DistributionStore::MaxFrequency { max_freq, items } => {
311 sample_stochastic(max_freq, items)
312 }
313 }
314 }
315}
316
317fn sample_linear<'a, T: Probability>(freq: &[f64], total: &f64, items: &'a [T]) -> Option<&'a T> {
318 let mut rng = rand::thread_rng();
319 let total_n = convert(*total + 1.0);
320 let terminal = f64::from(Uniform::from(0..total_n).sample(&mut rng));
321 let mut acc: f64 = 0.0;
322
323 for (i, f) in freq.iter().enumerate() {
324 acc += *f;
325 if acc > terminal {
326 return items.get(i);
327 }
328 }
329
330 None
331}
332
333fn sample_cdf<'a, T: Probability>(cdf: &[f64], items: &'a [T]) -> Option<&'a T> {
334 let mut rng = rand::thread_rng();
335 let random = rng.gen();
336 items.get(bisect_left(cdf, &random))
337}
338
339fn sample_stochastic<'a, T: Probability>(max_freq: &f64, items: &'a [T]) -> Option<&'a T> {
340 let n = items.len();
341 let mut small_rng = SmallRng::from_entropy();
342 loop {
343 let i = small_rng.gen_range(0..n);
344 match items.get(i) {
345 None => return None,
346 Some(x) => {
347 let rand: f64 = small_rng.gen();
348 if rand < (x.prob() / max_freq) {
349 return Some(x);
350 }
351 continue;
352 }
353 }
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use kolmogorov_smirnov::test_f64;
360 use std::collections::HashMap;
361
362 use crate::*;
363 use SamplingMethod::*;
364
365 /// Some probability of english alphabet on occurance.
366 /// Refernce: https://www3.nd.edu/~busiforc/handouts/cryptography/letterfrequencies.html
367 const FIXTURE_ALPHABETS_PROBS: [(char, f64); 26] = [
368 ('E', 0.111607),
369 ('M', 0.030129),
370 ('A', 0.084966),
371 ('H', 0.030034),
372 ('R', 0.075809),
373 ('G', 0.024705),
374 ('I', 0.075448),
375 ('B', 0.020720),
376 ('O', 0.071635),
377 ('F', 0.018121),
378 ('T', 0.069509),
379 ('Y', 0.017779),
380 ('N', 0.066544),
381 ('W', 0.012899),
382 ('S', 0.057351),
383 ('K', 0.011016),
384 ('L', 0.054893),
385 ('V', 0.010074),
386 ('C', 0.045388),
387 ('X', 0.002902),
388 ('U', 0.036308),
389 ('Z', 0.002722),
390 ('D', 0.033844),
391 ('J', 0.001965),
392 ('P', 0.031671),
393 ('Q', 0.001962),
394 ];
395
396 impl Probability for (char, f64) {
397 fn prob(&self) -> f64 {
398 self.1
399 }
400 }
401
402 /// Basic Monte Carlo Simulation
403 fn monte_carlo(store: DiscreteDistribution<(char, f64)>, n: usize) -> HashMap<char, f64> {
404 std::iter::repeat_with(|| store.sample())
405 .take(n)
406 .filter_map(std::convert::identity)
407 .fold(HashMap::new(), |mut counter, p| {
408 *counter.entry(p.0).or_insert(0.0) += 1.0;
409 counter
410 })
411 }
412
413 /// Asserts if generated distribution from sample() matches that of
414 /// provided distribution using Kolmogorov-Smirnov test.
415 ///
416 /// Reference: https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test
417 fn matches_expected_distribution(n: usize, method: SamplingMethod) {
418 // Setup
419 let mut abc_probs = FIXTURE_ALPHABETS_PROBS.to_vec();
420 abc_probs.sort_by(|a, b| a.0.cmp(&b.0));
421 let store = DiscreteDistribution::new(&abc_probs, method.clone()).unwrap();
422
423 // Act
424 let mut obs: Vec<(char, f64)> =
425 monte_carlo(store, n).iter().map(|r| (*r.0, *r.1)).collect();
426
427 obs.sort_by(|lhs, rhs| lhs.0.cmp(&rhs.0));
428 let expected_pdf: Vec<f64> = abc_probs.iter().map(|e| e.1).collect();
429 let obs_pdf: Vec<f64> = obs.iter().map(|o| o.1 / (n as f64)).collect();
430
431 // Assert
432 let ks_result = test_f64(&expected_pdf, &obs_pdf, 0.99);
433 assert!(
434 !ks_result.is_rejected,
435 "Generated samples do not belong to expected distribution per KS test, for n={}, sampling={:#?} confidence={}",
436 n,
437 method,
438 ks_result.confidence
439 )
440 }
441
442 #[test]
443 fn linear_sampling() {
444 matches_expected_distribution(1000, Linear);
445 matches_expected_distribution(100_00, Linear);
446 matches_expected_distribution(100_00_00, Linear);
447 }
448
449 #[test]
450 fn cfd_sampling() {
451 matches_expected_distribution(1000, CumulativeDistributionFunction);
452 matches_expected_distribution(100_00, CumulativeDistributionFunction);
453 matches_expected_distribution(100_00_00, CumulativeDistributionFunction);
454 }
455
456 #[test]
457 fn stochastic_sampling() {
458 matches_expected_distribution(1000, StochasticAcceptance);
459 matches_expected_distribution(100_00, StochasticAcceptance);
460 matches_expected_distribution(100_00_00, StochasticAcceptance);
461 }
462}