random_picker/picker.rs
1use crate::*;
2use rand::{rngs::OsRng, RngCore};
3use std::hash::Hash;
4
5/// Generator of groups of random items of type `T` with different probabilities.
6/// According to the configuration, items in each group can be either
7/// repetitive or non-repetitive.
8pub struct Picker<T: Clone + Eq + Hash, R: RngCore> {
9 rng: R,
10
11 table: Vec<(T, f64)>,
12 grid: Vec<f64>,
13 grid_width: f64,
14 repetitive: bool,
15
16 table_picked: Vec<bool>, // used in `pick_indexes()`, size: table.len()
17 picked_indexes: Vec<usize>, // read it after calling `pick_indexes()`
18}
19
20impl<T: Clone + Eq + Hash> Picker<T, OsRng> {
21 /// Builds the `Picker` with given configuration, using the OS random source.
22 pub fn build(conf: Config<T>) -> Result<Self, Error> {
23 Picker::build_with_rng(conf, OsRng)
24 }
25}
26
27impl<T: Clone + Eq + Hash, R: RngCore> Picker<T, R> {
28 /// Builds the `Picker` with given configuration and the given random source.
29 pub fn build_with_rng(conf: Config<T>, rng: R) -> Result<Self, Error> {
30 let table_len = conf.table.len();
31 let mut picker = Self {
32 rng,
33 table: Vec::with_capacity(table_len),
34 grid: Vec::with_capacity(table_len),
35 grid_width: 0.,
36 repetitive: conf.repetitive,
37 table_picked: Vec::with_capacity(table_len),
38 picked_indexes: Vec::with_capacity(table_len),
39 };
40 picker.configure(conf)?;
41 Ok(picker)
42 }
43
44 /// Applies new configuration.
45 pub fn configure(&mut self, conf: Config<T>) -> Result<(), Error> {
46 self.table = conf.vec_table()?;
47 let table_len = self.table.len();
48
49 self.grid.clear();
50 self.grid.reserve(table_len);
51 let mut cur = 0.;
52 for (_, val) in &self.table {
53 cur += val;
54 self.grid.push(cur);
55 }
56 self.grid_width = *self.grid.last().unwrap();
57
58 self.repetitive = conf.repetitive;
59
60 self.table_picked.resize(table_len, false);
61 self.picked_indexes.reserve(table_len);
62
63 Ok(())
64 }
65
66 /// Returns the size of the weight table that contains all possible choices (p > 0).
67 ///
68 /// ```
69 /// use random_picker::Picker;
70 /// let mut conf: random_picker::Config<String> = "
71 /// a = 0; b = 1; c = 1.1
72 /// ".parse().unwrap();
73 /// let picker = Picker::build(conf.clone()).unwrap();
74 /// assert_eq!(picker.table_len(), 2);
75 /// conf.append_str("b = 0; c = 0");
76 /// assert!(Picker::build(conf).is_err());
77 /// ```
78 #[inline(always)]
79 pub fn table_len(&self) -> usize {
80 self.table.len()
81 }
82
83 /// Picks `amount` of items and returns the group of items.
84 /// `amount` must not exceed `table_len()`.
85 #[inline(always)]
86 pub fn pick(&mut self, amount: usize) -> Result<Vec<T>, Error> {
87 self.pick_indexes(amount)?;
88 Ok(self
89 .picked_indexes
90 .iter()
91 .map(|&i| self.item_key(i))
92 .collect())
93 }
94
95 /// Picks `dest.len()` of items and writes them into `dest` (avoids allocation).
96 /// Length of `dest` must not exceed `table_len()`.
97 #[inline]
98 pub fn write_to(&mut self, dest: &mut [T]) -> Result<(), Error> {
99 self.pick_indexes(dest.len())?;
100 for (i, k) in dest.iter_mut().enumerate() {
101 *k = self.item_key(self.picked_indexes[i]);
102 }
103 Ok(())
104 }
105
106 /// Evaluates probabilities of existences of table items in each group
107 /// of length `amount`, by generating groups of items for `test_times`.
108 ///
109 /// ```
110 /// use random_picker::*;
111 /// let mut conf: Config<String> = "
112 /// a=856; b=139; c=297; d=378; e=1304;
113 /// f=289; g=199; h=528; i=627; j= 13;
114 /// k= 42; l=339; m=249; n=707; o= 797;
115 /// p=199; q= 12; r=677; s=607; t=1045;
116 /// u=249; v= 92; w=149; x= 17; y= 199; z=8;
117 /// ".parse().unwrap();
118 /// assert_eq!(conf.repetitive, false);
119 /// assert_eq!(conf.table.len(), 26);
120 /// let table_probs = conf.calc_probabilities(3).unwrap();
121 ///
122 /// let mut picker = Picker::build(conf.clone()).unwrap();
123 /// let table_freqs = picker.test_freqs(3, 1_000_000).unwrap();
124 /// for (k, v) in table_freqs.iter() {
125 /// assert!((*v - *table_probs.get(k).unwrap()).abs() < 0.005);
126 /// }
127 ///
128 /// conf.append_str("repetitive = true");
129 /// assert_eq!(conf.repetitive, true);
130 /// let table_probs = conf.calc_probabilities(3).unwrap();;
131 ///
132 /// let mut picker = Picker::build_with_rng(conf, rand::thread_rng()).unwrap();
133 /// let table_freqs = picker.test_freqs(3, 1_000_000).unwrap();
134 /// for (k, v) in table_freqs.iter() {
135 /// assert!((*v - *table_probs.get(k).unwrap()).abs() < 0.005);
136 /// }
137 /// ```
138 pub fn test_freqs(&mut self, amount: usize, test_times: usize) -> Result<Table<T>, Error> {
139 if test_times == 0 {
140 return Ok(self.table.iter().map(|(k, _)| (k.clone(), 0.)).collect());
141 }
142
143 let mut tbl_freq = vec![0_usize; self.table_len()];
144 if !self.repetitive {
145 for _ in 0..test_times {
146 self.pick_indexes(amount)?;
147 for &idx in &self.picked_indexes {
148 tbl_freq[idx] += 1;
149 }
150 }
151 } else {
152 let mut tbl_picked = vec![false; self.table_len()];
153 for _ in 0..test_times {
154 tbl_picked.fill(false);
155 self.pick_indexes(amount)?;
156 for &idx in &self.picked_indexes {
157 if !tbl_picked[idx] {
158 tbl_freq[idx] += 1;
159 tbl_picked[idx] = true;
160 }
161 }
162 }
163 }
164
165 let test_times = test_times as f64;
166 let table = tbl_freq
167 .iter()
168 .enumerate()
169 .map(|(i, &v)| (self.item_key(i), v as f64 / test_times))
170 .collect();
171 Ok(table)
172 }
173
174 /// Picks `amount` of indexes and replaces values in `self.picked_indexes`.
175 #[inline]
176 fn pick_indexes(&mut self, amount: usize) -> Result<(), Error> {
177 if !self.repetitive && amount > self.table_len() {
178 return Err(Error::InvalidAmount);
179 }
180 self.picked_indexes.clear();
181
182 self.table_picked.fill(false);
183 while self.picked_indexes.len() < amount {
184 let i = self.pick_index()?;
185 if !self.repetitive {
186 if self.table_picked[i] {
187 continue;
188 }
189 self.table_picked[i] = true;
190 }
191 self.picked_indexes.push(i);
192 }
193 Ok(())
194 }
195
196 #[inline(always)]
197 fn pick_index(&mut self) -> Result<usize, Error> {
198 let mut bytes = [0u8; 4];
199 self.rng
200 .try_fill_bytes(&mut bytes)
201 .map_err(Error::RandError)?;
202
203 let val = (u32::from_ne_bytes(bytes) as f64) / (u32::MAX as f64) * self.grid_width;
204 for (i, &v) in self.grid.iter().enumerate() {
205 if val <= v {
206 return Ok(i);
207 };
208 }
209
210 Ok(self.table_len() - 1) // almost impossible
211 }
212
213 #[inline(always)]
214 fn item_key(&self, i: usize) -> T {
215 self.table[i].0.clone()
216 }
217}