proportionate_selector/
lib.rs

1//! Proportionate selection from discrete distribution.
2//!
3//! `proportionate_selector` allows sampling from empirical discrete distribution,
4//! at rumtime. Each sample is generated independently, and has no coupling to previously
5//! generated or future samples. This allows for quick, and reliable sample generation from
6//! some known discrete distribution.
7//!
8//! ## Use cases
9//!
10//! * Multivariant a/b tests
11//! * Simple lootbox generation in games
12//! * Use in evolutionary algorithms
13//! * Help content promotion
14//! * Coupon code generation
15//! * and more...
16//!
17//! ## Example
18//!
19//! Suppose we want to build _very simple_ lootbox reward collectables, based on
20//! some rarity associated with the reward collectables. And we want to be able to
21//! modify _rarity_ of such collectables (thousands of possible items) are runtime.
22//!
23//! For example,
24//!
25//! | Reward Item | Rarity | Probability of Occurance (1/Rarity) |
26//! | ----------- | :----: | :---------------------------------: |
27//! | Reward A    |   50   |     (1/50) = 0.02                   |
28//! | Reward B    |   10   |     (1/10) = 0.10                   |
29//! | Reward C    |   10   |     (1/10) = 0.10                   |
30//! | Reward D    |   2    |     (1/2) = 0.5                     |
31//! | No Reward   | 3.5714 |     (1/3.5714) = 0.28               |
32//!
33//! Note: `proportionate_selector` requires that sum of probabilities equals to 1.
34//! For some reason, you are using different ranking methoddologies, you can
35//! normalize probabilities prior to using `proportionate_selector`. In most cases,
36//! you should be doing this anyways.
37//!
38//! ```rust
39//! use proportionate_selector::*;
40//!
41//! #[derive(PartialEq, Debug)]
42//! pub struct LootboxItem {
43//!     pub id: i32,
44//!     /// Likelihood of recieve item from lootbox.
45//!     /// Rarity represents inverse lilihood of recieveing
46//!     /// this item.
47//!     ///
48//!     /// e.g. rairity of 1, means lootbox item will be more
49//!     /// frequently generated as opposed to rarity of 100.
50//!     pub rarity: f64,
51//! }
52//!
53//! impl Probability for LootboxItem {
54//!     fn prob(&self) -> f64 {
55//!         // rarity is modeled as 1 out of X occurance, so
56//!         // rarity of 20 has probability of 1/20.
57//!         1.0 / self.rarity
58//!     }
59//! }
60//!
61//! let endOfLevel1Box = vec![
62//!     LootboxItem {id: 0, rarity: 50.0},   // 2%
63//!     LootboxItem {id: 1, rarity: 10.0},   // 10%
64//!     LootboxItem {id: 2, rarity: 10.0},   // 10%
65//!     LootboxItem {id: 3, rarity: 2.0},    // 50%
66//!     LootboxItem {id: 4, rarity: 3.5714}, // 28%
67//! ];
68//!
69//! // create discrete distribution for sampling
70//! let epdf = DiscreteDistribution::new(&endOfLevel1Box, SamplingMethod::Linear).unwrap();
71//! let s = epdf.sample();
72//!
73//! println!("{:?}", epdf.sample());
74//! ```
75//!
76//! ## Benchmarks (+/- 5%)
77//!
78//! | Sampling     |  Time  | Number of Items |
79//! | ------------ | :----: | --------------- |
80//! | Linear       | 30 ns  | 100             |
81//! | Linear       | 6 us   | 10,000          |
82//! | Linear       | 486 us | 1,000,000       |
83//! | Cdf          | 31 ns  | 100             |
84//! | Cdf          | 41 ns  | 10,000          |
85//! | Cdf          | 62 ns  | 1,000,000       |
86//! | Stochastic   | 315 ns | 100             |
87//! | Stochastic   | 30 us  | 10,000          |
88//! | Stochastic   | 40 us  | 1,000,000       |
89//!
90//! Beanchmark ran on:
91//! ```text
92//!   Model Name: Mac mini
93//!   Model Identifier: Macmini9,1
94//!   Chip: Apple M1
95//!   Total Number of Cores: 8 (4 performance and 4 efficiency)
96//!   Memory: 16 GB
97//! ```
98//!
99pub mod errors;
100pub mod util;
101
102use anyhow::{bail, Result};
103use errors::ProportionalSelectionErr::*;
104use rand::distributions::{Distribution, Uniform};
105use rand::rngs::SmallRng;
106use rand::Rng;
107use rand::SeedableRng;
108
109use util::*;
110
111/// Sampling method to use when, sampling from discrete distribution.
112#[derive(Debug, Clone, Copy)]
113pub enum SamplingMethod {
114    /// Performs linear scan on probabilities.
115    ///
116    /// Worst case is O(n).
117    Linear,
118    /// Performs selection by creating cumulative distribution function (CDF),
119    /// and performing selection.
120    ///
121    /// Worst case is O(ln n).
122    CumulativeDistributionFunction,
123    /// Performs selection using stochastic acceptance. Average case is O(1), but
124    /// may require at most N call to random number generator function.
125    ///
126    /// Worst case is O(n).
127    /// Average case is O(1).
128    ///
129    /// Reference: <https://arxiv.org/abs/1109.3627>
130    StochasticAcceptance,
131}
132
133pub trait Probability {
134    /// Returns non-negative probability of occurance.
135    ///
136    /// Probability must be within range of 0 to 1.
137    fn prob(&self) -> f64;
138}
139
140enum DistributionStore<'a, T: Probability> {
141    Frequency {
142        freq: Vec<f64>,
143        total: f64,
144        items: &'a Vec<T>,
145    },
146    Cdf {
147        cdf: Vec<f64>,
148        items: &'a Vec<T>,
149    },
150    MaxFrequency {
151        max_freq: f64,
152        items: &'a Vec<T>,
153    },
154}
155
156/// Represents empirical discrete distribution.
157pub struct DiscreteDistribution<'a, T: Probability> {
158    /// Stores distribution attributes for quick sample generation.
159    store: DistributionStore<'a, T>,
160}
161
162impl<'_a, T: Probability> DiscreteDistribution<'_a, T> {
163    /// Returns DiscreteDistribution based on the selection method.
164    ///
165    /// ```rust
166    /// use proportionate_selector::*;
167    ///
168    /// #[derive(PartialEq)]
169    /// pub struct MultiVariantMarketingWebsiteItem {
170    ///     pub id: i32,
171    ///     /// Likelihood of recieve marketing website version.
172    ///     /// Rarity represents inverse lilihood of recieveing
173    ///     /// this item.
174    ///     ///
175    ///     /// e.g. rairity of 1, means website item item will be more
176    ///     /// frequently generated as opposed to rarity of 100.
177    ///     pub rarity: f64,
178    /// }
179    ///
180    /// impl Probability for MultiVariantMarketingWebsiteItem {
181    ///     fn prob(&self) -> f64 {
182    ///         // rarity is modeled as 1 out of X occurance, so
183    ///         // rarity of 20 has probability of 1/20.
184    ///         1.0 / self.rarity
185    ///     }
186    /// }
187    ///
188    /// let summer2020Launch = vec![
189    ///     MultiVariantMarketingWebsiteItem {id: 0, rarity: 5.0}, // 20%
190    ///     MultiVariantMarketingWebsiteItem {id: 1, rarity: 5.0}, // 20%
191    ///     MultiVariantMarketingWebsiteItem {id: 2, rarity: 10.0}, // 10%
192    ///     MultiVariantMarketingWebsiteItem {id: 3, rarity: 2.5}, // 40%
193    ///     MultiVariantMarketingWebsiteItem {id: 4, rarity: 10.0}, // 40%
194    /// ];
195    ///
196    /// // create distribution for sampling
197    /// let epdf = DiscreteDistribution::new(&summer2020Launch, SamplingMethod::Linear);
198    /// assert!(epdf.is_ok());
199    /// ```
200    pub fn new(items: &'_a Vec<T>, method: SamplingMethod) -> Result<Self> {
201        // Must have definable discrete distribution.
202        let n = items.len();
203        if n <= 1 {
204            bail!(InsufficientProbabilitiesProvided { actual: n });
205        }
206
207        // Apply some tolerance to probability bounds.
208        const EPSILON: f64 = 0.001;
209        let total_p: f64 = items.iter().map(|i| i.prob()).sum();
210
211        // Impossible to perform sampling per occurance probabilitity, if
212        // occurance of all possible scenario is greater than 1 or less than 1.
213        if !(1.0 - EPSILON..=1.0 + EPSILON).contains(&total_p) {
214            bail!(SumOfAllProbabilitiesDoesNotEqualToOne { actual: total_p });
215        }
216
217        match method {
218            SamplingMethod::Linear => {
219                let freq = items
220                    .iter()
221                    .map(|item| item.prob() * 100.0)
222                    .collect::<Vec<_>>();
223
224                let total = freq.iter().sum();
225
226                Ok(Self {
227                    store: DistributionStore::Frequency { freq, total, items },
228                })
229            }
230
231            SamplingMethod::CumulativeDistributionFunction => {
232                let cdf = items
233                    .iter()
234                    .enumerate()
235                    .scan(0.0, |acc, item| {
236                        *acc += item.1.prob();
237                        Some(*acc)
238                    })
239                    .collect::<Vec<_>>();
240
241                Ok(Self {
242                    store: DistributionStore::Cdf { cdf, items },
243                })
244            }
245
246            SamplingMethod::StochasticAcceptance => {
247                let max_freq = items
248                    .iter()
249                    .map(|i| i.prob())
250                    .max_by(|lhs, rhs| lhs.total_cmp(rhs))
251                    .unwrap_or(0.0);
252
253                Ok(Self {
254                    store: DistributionStore::MaxFrequency { max_freq, items },
255                })
256            }
257        }
258    }
259
260    /// Returns a sample based on discrete distribution.
261    ///
262    /// As invocation of sample() reaches large number (e.g. +infinity), the
263    /// difference between population (defined discrete distribution), and
264    /// distribution from generated sample diminishes.
265    ///
266    /// ```rust
267    /// use proportionate_selector::*;
268    ///
269    /// #[derive(PartialEq)]
270    /// pub struct LootboxItem {
271    ///     pub id: i32,
272    ///     /// Likelihood of recieve item from lootbox.
273    ///     /// Rarity represents inverse lilihood of recieveing
274    ///     /// this item.
275    ///     ///
276    ///     /// e.g. rairity of 1, means lootbox item will be more
277    ///     /// frequently generated as opposed to rarity of 100.
278    ///     pub rarity: f64,
279    /// }
280    ///
281    /// impl Probability for LootboxItem {
282    ///     fn prob(&self) -> f64 {
283    ///         // rarity is modeled as 1 out of X occurance, so
284    ///         // rarity of 20 has probability of 1/20.
285    ///         1.0 / self.rarity
286    ///     }
287    /// }
288    ///
289    /// let endOfLevel1Box = vec![
290    ///     LootboxItem {id: 0, rarity: 5.0}, // 20%
291    ///     LootboxItem {id: 1, rarity: 5.0}, // 20%
292    ///     LootboxItem {id: 2, rarity: 10.0}, // 10%
293    ///     LootboxItem {id: 3, rarity: 2.5}, // 40%
294    ///     LootboxItem {id: 4, rarity: 10.0}, // 40%
295    /// ];
296    ///
297    /// // create discrete distribution for sampling
298    /// let epdf = DiscreteDistribution::new(&endOfLevel1Box, SamplingMethod::Linear).unwrap();
299    /// let s = epdf.sample();
300    ///
301    /// assert!(s.is_some());
302    /// ```
303    ///
304    pub fn sample(&'_a self) -> Option<&T> {
305        match &self.store {
306            DistributionStore::Cdf { cdf, items } => sample_cdf(cdf, items),
307            DistributionStore::Frequency { freq, total, items } => {
308                sample_linear(freq, total, items)
309            }
310            DistributionStore::MaxFrequency { max_freq, items } => {
311                sample_stochastic(max_freq, items)
312            }
313        }
314    }
315}
316
317fn sample_linear<'a, T: Probability>(freq: &[f64], total: &f64, items: &'a [T]) -> Option<&'a T> {
318    let mut rng = rand::thread_rng();
319    let total_n = convert(*total + 1.0);
320    let terminal = f64::from(Uniform::from(0..total_n).sample(&mut rng));
321    let mut acc: f64 = 0.0;
322
323    for (i, f) in freq.iter().enumerate() {
324        acc += *f;
325        if acc > terminal {
326            return items.get(i);
327        }
328    }
329
330    None
331}
332
333fn sample_cdf<'a, T: Probability>(cdf: &[f64], items: &'a [T]) -> Option<&'a T> {
334    let mut rng = rand::thread_rng();
335    let random = rng.gen();
336    items.get(bisect_left(cdf, &random))
337}
338
339fn sample_stochastic<'a, T: Probability>(max_freq: &f64, items: &'a [T]) -> Option<&'a T> {
340    let n = items.len();
341    let mut small_rng = SmallRng::from_entropy();
342    loop {
343        let i = small_rng.gen_range(0..n);
344        match items.get(i) {
345            None => return None,
346            Some(x) => {
347                let rand: f64 = small_rng.gen();
348                if rand < (x.prob() / max_freq) {
349                    return Some(x);
350                }
351                continue;
352            }
353        }
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use kolmogorov_smirnov::test_f64;
360    use std::collections::HashMap;
361
362    use crate::*;
363    use SamplingMethod::*;
364
365    /// Some probability of english alphabet on occurance.
366    /// Refernce: https://www3.nd.edu/~busiforc/handouts/cryptography/letterfrequencies.html
367    const FIXTURE_ALPHABETS_PROBS: [(char, f64); 26] = [
368        ('E', 0.111607),
369        ('M', 0.030129),
370        ('A', 0.084966),
371        ('H', 0.030034),
372        ('R', 0.075809),
373        ('G', 0.024705),
374        ('I', 0.075448),
375        ('B', 0.020720),
376        ('O', 0.071635),
377        ('F', 0.018121),
378        ('T', 0.069509),
379        ('Y', 0.017779),
380        ('N', 0.066544),
381        ('W', 0.012899),
382        ('S', 0.057351),
383        ('K', 0.011016),
384        ('L', 0.054893),
385        ('V', 0.010074),
386        ('C', 0.045388),
387        ('X', 0.002902),
388        ('U', 0.036308),
389        ('Z', 0.002722),
390        ('D', 0.033844),
391        ('J', 0.001965),
392        ('P', 0.031671),
393        ('Q', 0.001962),
394    ];
395
396    impl Probability for (char, f64) {
397        fn prob(&self) -> f64 {
398            self.1
399        }
400    }
401
402    /// Basic Monte Carlo Simulation
403    fn monte_carlo(store: DiscreteDistribution<(char, f64)>, n: usize) -> HashMap<char, f64> {
404        std::iter::repeat_with(|| store.sample())
405            .take(n)
406            .filter_map(std::convert::identity)
407            .fold(HashMap::new(), |mut counter, p| {
408                *counter.entry(p.0).or_insert(0.0) += 1.0;
409                counter
410            })
411    }
412
413    /// Asserts if generated distribution from sample() matches that of
414    /// provided distribution using Kolmogorov-Smirnov test.
415    ///
416    /// Reference: https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test
417    fn matches_expected_distribution(n: usize, method: SamplingMethod) {
418        // Setup
419        let mut abc_probs = FIXTURE_ALPHABETS_PROBS.to_vec();
420        abc_probs.sort_by(|a, b| a.0.cmp(&b.0));
421        let store = DiscreteDistribution::new(&abc_probs, method.clone()).unwrap();
422
423        // Act
424        let mut obs: Vec<(char, f64)> =
425            monte_carlo(store, n).iter().map(|r| (*r.0, *r.1)).collect();
426
427        obs.sort_by(|lhs, rhs| lhs.0.cmp(&rhs.0));
428        let expected_pdf: Vec<f64> = abc_probs.iter().map(|e| e.1).collect();
429        let obs_pdf: Vec<f64> = obs.iter().map(|o| o.1 / (n as f64)).collect();
430
431        // Assert
432        let ks_result = test_f64(&expected_pdf, &obs_pdf, 0.99);
433        assert!(
434            !ks_result.is_rejected,
435            "Generated samples do not belong to expected distribution per KS test, for n={}, sampling={:#?} confidence={}",
436            n,
437            method,
438            ks_result.confidence
439        )
440    }
441
442    #[test]
443    fn linear_sampling() {
444        matches_expected_distribution(1000, Linear);
445        matches_expected_distribution(100_00, Linear);
446        matches_expected_distribution(100_00_00, Linear);
447    }
448
449    #[test]
450    fn cfd_sampling() {
451        matches_expected_distribution(1000, CumulativeDistributionFunction);
452        matches_expected_distribution(100_00, CumulativeDistributionFunction);
453        matches_expected_distribution(100_00_00, CumulativeDistributionFunction);
454    }
455
456    #[test]
457    fn stochastic_sampling() {
458        matches_expected_distribution(1000, StochasticAcceptance);
459        matches_expected_distribution(100_00, StochasticAcceptance);
460        matches_expected_distribution(100_00_00, StochasticAcceptance);
461    }
462}