random_picker/
picker.rs

1use crate::*;
2use rand::{rngs::OsRng, RngCore};
3use std::hash::Hash;
4
5/// Generator of groups of random items of type `T` with different probabilities.
6/// According to the configuration, items in each group can be either
7/// repetitive or non-repetitive.
8pub struct Picker<T: Clone + Eq + Hash, R: RngCore> {
9    rng: R,
10
11    table: Vec<(T, f64)>,
12    grid: Vec<f64>,
13    grid_width: f64,
14    repetitive: bool,
15
16    table_picked: Vec<bool>,    // used in `pick_indexes()`, size: table.len()
17    picked_indexes: Vec<usize>, // read it after calling `pick_indexes()`
18}
19
20impl<T: Clone + Eq + Hash> Picker<T, OsRng> {
21    /// Builds the `Picker` with given configuration, using the OS random source.
22    pub fn build(conf: Config<T>) -> Result<Self, Error> {
23        Picker::build_with_rng(conf, OsRng)
24    }
25}
26
27impl<T: Clone + Eq + Hash, R: RngCore> Picker<T, R> {
28    /// Builds the `Picker` with given configuration and the given random source.
29    pub fn build_with_rng(conf: Config<T>, rng: R) -> Result<Self, Error> {
30        let table_len = conf.table.len();
31        let mut picker = Self {
32            rng,
33            table: Vec::with_capacity(table_len),
34            grid: Vec::with_capacity(table_len),
35            grid_width: 0.,
36            repetitive: conf.repetitive,
37            table_picked: Vec::with_capacity(table_len),
38            picked_indexes: Vec::with_capacity(table_len),
39        };
40        picker.configure(conf)?;
41        Ok(picker)
42    }
43
44    /// Applies new configuration.
45    pub fn configure(&mut self, conf: Config<T>) -> Result<(), Error> {
46        self.table = conf.vec_table()?;
47        let table_len = self.table.len();
48
49        self.grid.clear();
50        self.grid.reserve(table_len);
51        let mut cur = 0.;
52        for (_, val) in &self.table {
53            cur += val;
54            self.grid.push(cur);
55        }
56        self.grid_width = *self.grid.last().unwrap();
57
58        self.repetitive = conf.repetitive;
59
60        self.table_picked.resize(table_len, false);
61        self.picked_indexes.reserve(table_len);
62
63        Ok(())
64    }
65
66    /// Returns the size of the weight table that contains all possible choices (p > 0).
67    ///
68    /// ```
69    /// use random_picker::Picker;
70    /// let mut conf: random_picker::Config<String> = "
71    ///     a = 0; b = 1; c = 1.1
72    /// ".parse().unwrap();
73    /// let picker = Picker::build(conf.clone()).unwrap();
74    /// assert_eq!(picker.table_len(), 2);
75    /// conf.append_str("b = 0; c = 0");
76    /// assert!(Picker::build(conf).is_err());
77    /// ```
78    #[inline(always)]
79    pub fn table_len(&self) -> usize {
80        self.table.len()
81    }
82
83    /// Picks `amount` of items and returns the group of items.
84    /// `amount` must not exceed `table_len()`.
85    #[inline(always)]
86    pub fn pick(&mut self, amount: usize) -> Result<Vec<T>, Error> {
87        self.pick_indexes(amount)?;
88        Ok(self
89            .picked_indexes
90            .iter()
91            .map(|&i| self.item_key(i))
92            .collect())
93    }
94
95    /// Picks `dest.len()` of items and writes them into `dest` (avoids allocation).
96    /// Length of `dest` must not exceed `table_len()`.
97    #[inline]
98    pub fn write_to(&mut self, dest: &mut [T]) -> Result<(), Error> {
99        self.pick_indexes(dest.len())?;
100        for (i, k) in dest.iter_mut().enumerate() {
101            *k = self.item_key(self.picked_indexes[i]);
102        }
103        Ok(())
104    }
105
106    /// Evaluates probabilities of existences of table items in each group
107    /// of length `amount`, by generating groups of items for `test_times`.
108    ///
109    /// ```
110    /// use random_picker::*;
111    /// let mut conf: Config<String> = "
112    ///     a=856; b=139; c=297; d=378; e=1304;
113    ///     f=289; g=199; h=528; i=627; j=  13;
114    ///     k= 42; l=339; m=249; n=707; o= 797;
115    ///     p=199; q= 12; r=677; s=607; t=1045;
116    ///     u=249; v= 92; w=149; x= 17; y= 199; z=8;
117    /// ".parse().unwrap();
118    /// assert_eq!(conf.repetitive, false);
119    /// assert_eq!(conf.table.len(), 26);
120    /// let table_probs = conf.calc_probabilities(3).unwrap();
121    ///
122    /// let mut picker = Picker::build(conf.clone()).unwrap();
123    /// let table_freqs = picker.test_freqs(3, 1_000_000).unwrap();
124    /// for (k, v) in table_freqs.iter() {
125    ///     assert!((*v - *table_probs.get(k).unwrap()).abs() < 0.005);
126    /// }
127    ///
128    /// conf.append_str("repetitive = true");
129    /// assert_eq!(conf.repetitive, true);
130    /// let table_probs = conf.calc_probabilities(3).unwrap();;
131    ///
132    /// let mut picker = Picker::build_with_rng(conf, rand::thread_rng()).unwrap();
133    /// let table_freqs = picker.test_freqs(3, 1_000_000).unwrap();
134    /// for (k, v) in table_freqs.iter() {
135    ///     assert!((*v - *table_probs.get(k).unwrap()).abs() < 0.005);
136    /// }
137    /// ```
138    pub fn test_freqs(&mut self, amount: usize, test_times: usize) -> Result<Table<T>, Error> {
139        if test_times == 0 {
140            return Ok(self.table.iter().map(|(k, _)| (k.clone(), 0.)).collect());
141        }
142
143        let mut tbl_freq = vec![0_usize; self.table_len()];
144        if !self.repetitive {
145            for _ in 0..test_times {
146                self.pick_indexes(amount)?;
147                for &idx in &self.picked_indexes {
148                    tbl_freq[idx] += 1;
149                }
150            }
151        } else {
152            let mut tbl_picked = vec![false; self.table_len()];
153            for _ in 0..test_times {
154                tbl_picked.fill(false);
155                self.pick_indexes(amount)?;
156                for &idx in &self.picked_indexes {
157                    if !tbl_picked[idx] {
158                        tbl_freq[idx] += 1;
159                        tbl_picked[idx] = true;
160                    }
161                }
162            }
163        }
164
165        let test_times = test_times as f64;
166        let table = tbl_freq
167            .iter()
168            .enumerate()
169            .map(|(i, &v)| (self.item_key(i), v as f64 / test_times))
170            .collect();
171        Ok(table)
172    }
173
174    /// Picks `amount` of indexes and replaces values in `self.picked_indexes`.
175    #[inline]
176    fn pick_indexes(&mut self, amount: usize) -> Result<(), Error> {
177        if !self.repetitive && amount > self.table_len() {
178            return Err(Error::InvalidAmount);
179        }
180        self.picked_indexes.clear();
181
182        self.table_picked.fill(false);
183        while self.picked_indexes.len() < amount {
184            let i = self.pick_index()?;
185            if !self.repetitive {
186                if self.table_picked[i] {
187                    continue;
188                }
189                self.table_picked[i] = true;
190            }
191            self.picked_indexes.push(i);
192        }
193        Ok(())
194    }
195
196    #[inline(always)]
197    fn pick_index(&mut self) -> Result<usize, Error> {
198        let mut bytes = [0u8; 4];
199        self.rng
200            .try_fill_bytes(&mut bytes)
201            .map_err(Error::RandError)?;
202
203        let val = (u32::from_ne_bytes(bytes) as f64) / (u32::MAX as f64) * self.grid_width;
204        for (i, &v) in self.grid.iter().enumerate() {
205            if val <= v {
206                return Ok(i);
207            };
208        }
209
210        Ok(self.table_len() - 1) // almost impossible
211    }
212
213    #[inline(always)]
214    fn item_key(&self, i: usize) -> T {
215        self.table[i].0.clone()
216    }
217}