fast_bernoulli/lib.rs
1#![doc = include_str!("../README.md")]
2#![deny(missing_debug_implementations, missing_docs)]
3#![forbid(unsafe_code)]
4
5// [What follows is another outstanding comment from Jim Blandy explaining why
6// this technique works.]
7//
8// This comment should just read, "Generate skip counts with a geometric
9// distribution", and leave everyone to go look that up and see why it's the
10// right thing to do, if they don't know already.
11//
12// BUT IF YOU'RE CURIOUS, COMMENTS ARE FREE...
13//
14// Instead of generating a fresh random number for every trial, we can
15// randomly generate a count of how many times we should return false before
16// the next time we return true. We call this a "skip count". Once we've
17// returned true, we generate a fresh skip count, and begin counting down
18// again.
19//
20// Here's an awesome fact: by exercising a little care in the way we generate
21// skip counts, we can produce results indistinguishable from those we would
22// get "rolling the dice" afresh for every trial.
23//
24// In short, skip counts in Bernoulli trials of probability `P` obey a geometric
25// distribution. If a random variable `X` is uniformly distributed from
26// `[0..1)`, then `floor(log(X) / log(1-P))` has the appropriate geometric
27// distribution for the skip counts.
28//
29// Why that formula?
30//
31// Suppose we're to return `true` with some probability `P`, say, `0.3`. Spread
32// all possible futures along a line segment of length `1`. In portion `P` of
33// those cases, we'll return true on the next call to `trial`; the skip count is
34// 0. For the remaining portion `1-P` of cases, the skip count is `1` or more.
35//
36// ```
37// skip: 0 1 or more
38// |------------------^-----------------------------------------|
39// portion: 0.3 0.7
40// P 1-P
41// ```
42//
43// But the "1 or more" section of the line is subdivided the same way: *within
44// that section*, in portion `P` the second call to `trial()` returns `true`, and
45// in portion `1-P` it returns `false` a second time; the skip count is two or
46// more. So we return `true` on the second call in proportion `0.7 * 0.3`, and
47// skip at least the first two in proportion `0.7 * 0.7`.
48//
49// ```
50// skip: 0 1 2 or more
51// |------------------^------------^----------------------------|
52// portion: 0.3 0.7 * 0.3 0.7 * 0.7
53// P (1-P)*P (1-P)^2
54// ```
55//
56// We can continue to subdivide:
57//
58// ```
59// skip >= 0: |------------------------------------------------- (1-P)^0 --|
60// skip >= 1: | ------------------------------- (1-P)^1 --|
61// skip >= 2: | ------------------ (1-P)^2 --|
62// skip >= 3: | ^ ---------- (1-P)^3 --|
63// skip >= 4: | . --- (1-P)^4 --|
64// .
65// ^X, see below
66// ```
67//
68// In other words, the likelihood of the next `n` calls to `trial` returning
69// `false` is `(1-P)^n`. The longer a run we require, the more the likelihood
70// drops. Further calls may return `false` too, but this is the probability
71// we'll skip at least `n`.
72//
73// This is interesting, because we can pick a point along this line segment and
74// see which skip count's range it falls within; the point `X` above, for
75// example, is within the ">= 2" range, but not within the ">= 3" range, so it
76// designates a skip count of `2`. So if we pick points on the line at random
77// and use the skip counts they fall under, that will be indistinguishable from
78// generating a fresh random number between `0` and `1` for each trial and
79// comparing it to `P`.
80//
81// So to find the skip count for a point `X`, we must ask: To what whole power
82// must we raise `1-P` such that we include `X`, but the next power would
83// exclude it? This is exactly `floor(log(X) / log(1-P))`.
84//
85// Our algorithm is then, simply: When constructed, compute an initial skip
86// count. Return `false` from `trial` that many times, and then compute a new
87// skip count.
88//
89// For a call to `multi_trial(n)`, if the skip count is greater than `n`, return
90// `false` and subtract `n` from the skip count. If the skip count is less than
91// `n`, return true and compute a new skip count. Since each trial is
92// independent, it doesn't matter by how much `n` overshoots the skip count; we
93// can actually compute a new skip count at *any* time without affecting the
94// distribution. This is really beautiful.
95
96use rand::Rng;
97
98/// Fast Bernoulli sampling: each event has equal probability of being sampled.
99///
100/// See the [crate-level documentation][crate] for more general
101/// information.
102///
103/// # Example
104///
105/// ```
106/// use fast_bernoulli::FastBernoulli;
107/// use rand::Rng;
108///
109/// // Get the thread-local random number generator.
110/// let mut rng = rand::thread_rng();
111///
112/// // Create a `FastBernoulli` instance that samples events with probability 1/20.
113/// let mut bernoulli = FastBernoulli::new(0.05, &mut rng);
114///
115/// // Each time your event occurs, perform a Bernoulli trail to determine whether
116/// // you should sample the event or not.
117/// let on_my_event = || {
118/// if bernoulli.trial(&mut rng) {
119/// // Record the sample...
120/// }
121/// };
122/// ```
123#[derive(Debug, Clone, Copy)]
124pub struct FastBernoulli {
125 probability: f64,
126 skip_count: u32,
127}
128
129impl FastBernoulli {
130 /// Construct a new `FastBernoulli` instance that samples events with the
131 /// given probability.
132 ///
133 /// # Panics
134 ///
135 /// The probability must be within the range `0.0 <= probability <= 1.0` and
136 /// this method will panic if that is not the case.
137 ///
138 /// # Example
139 ///
140 /// ```
141 /// use rand::Rng;
142 /// use fast_bernoulli::FastBernoulli;
143 ///
144 /// let mut rng = rand::thread_rng();
145 /// let sample_one_in_a_hundred = FastBernoulli::new(0.01, &mut rng);
146 /// ```
147 pub fn new<R>(probability: f64, rng: &mut R) -> Self
148 where
149 R: Rng + ?Sized,
150 {
151 assert!(
152 0.0 <= probability && probability <= 1.0,
153 "`probability` must be in the range `0.0 <= probability <= 1.0`"
154 );
155
156 let mut bernoulli = FastBernoulli {
157 probability,
158 skip_count: 0,
159 };
160 bernoulli.reset_skip_count(rng);
161 bernoulli
162 }
163
164 fn reset_skip_count<R>(&mut self, rng: &mut R)
165 where
166 R: Rng + ?Sized,
167 {
168 if self.probability == 0.0 {
169 // Edge case: we will never sample any event.
170 self.skip_count = u32::MAX;
171 } else if self.probability == 1.0 {
172 // Edge case: we will sample every event.
173 self.skip_count = 0;
174 } else {
175 // Common case: we need to choose a new skip count using the
176 // formula `floor(log(x) / log(1 - P))`, as explained in the
177 // comment at the top of this file.
178 let x: f64 = rng.gen_range(0.0..1.0);
179 let skip_count = (x.ln() / (1.0 - self.probability).ln()).floor();
180 debug_assert!(skip_count >= 0.0);
181 self.skip_count = if skip_count <= (u32::MAX as f64) {
182 skip_count as u32
183 } else {
184 // Clamp the skip count to `u32::MAX`. This can skew
185 // sampling when we are sampling with a very low
186 // probability, but it is better than any super-robust
187 // alternative we have, such as representing skip counts
188 // with big nums.
189 u32::MAX
190 };
191 }
192 }
193
194 /// Perform a Bernoulli trial: returns `true` with the configured
195 /// probability.
196 ///
197 /// Call this each time an event occurs to determine whether to sample the
198 /// event.
199 ///
200 /// The lower the configured probability, the less overhead calling this
201 /// function has.
202 ///
203 /// # Example
204 ///
205 /// ```
206 /// use rand::Rng;
207 /// use fast_bernoulli::FastBernoulli;
208 ///
209 /// let mut rng = rand::thread_rng();
210 /// let mut bernoulli = FastBernoulli::new(0.1, &mut rng);
211 ///
212 /// // Each time an event occurs, call `trial`...
213 /// if bernoulli.trial(&mut rng) {
214 /// // ...and if it returns true, record a sample of this event.
215 /// }
216 /// ```
217 pub fn trial<R>(&mut self, rng: &mut R) -> bool
218 where
219 R: Rng + ?Sized,
220 {
221 if self.skip_count > 0 {
222 self.skip_count -= 1;
223 return false;
224 }
225
226 self.reset_skip_count(rng);
227 self.probability != 0.0
228 }
229
230 /// Perform `n` Bernoulli trials at once.
231 ///
232 /// This is semantically equivalent to calling the `trial()` method `n`
233 /// times and returning `true` if any of those calls returned `true`, but
234 /// runs in `O(1)` time instead of `O(n)` time.
235 ///
236 /// What is this good for? In some applications, some events are "bigger"
237 /// than others. For example, large memory allocations are more significant
238 /// than small memory allocations. Perhaps we'd like to imagine that we're
239 /// drawing allocations from a stream of bytes, and performing a separate
240 /// Bernoulli trial on every byte from the stream. We can accomplish this by
241 /// calling `multi_trial(s)` for the number of bytes `s`, and sampling the
242 /// event if that call returns true.
243 ///
244 /// Of course, this style of sampling needs to be paired with analysis and
245 /// presentation that makes the "size" of the event apparent, lest trials
246 /// with large values for `n` appear to be indistinguishable from those with
247 /// small values for `n`, despite being potentially much more likely to be
248 /// sampled.
249 ///
250 /// # Example
251 ///
252 /// ```
253 /// use rand::Rng;
254 /// use fast_bernoulli::FastBernoulli;
255 ///
256 /// let mut rng = rand::thread_rng();
257 /// let mut byte_sampler = FastBernoulli::new(0.05, &mut rng);
258 ///
259 /// // When we observe a `malloc` of ten bytes event...
260 /// if byte_sampler.multi_trial(10, &mut rng) {
261 /// // ... if `multi_trial` returns `true` then we sample it.
262 /// record_malloc_sample(10);
263 /// }
264 ///
265 /// // And when we observe a `malloc` of 1024 bytes event...
266 /// if byte_sampler.multi_trial(1024, &mut rng) {
267 /// // ... if `multi_trial` returns `true` then we sample this larger
268 /// // allocation.
269 /// record_malloc_sample(1024);
270 /// }
271 /// # fn record_malloc_sample(_: u32) {}
272 /// ```
273 pub fn multi_trial<R>(&mut self, n: u32, rng: &mut R) -> bool
274 where
275 R: Rng + ?Sized,
276 {
277 if n < self.skip_count {
278 self.skip_count -= n;
279 return false;
280 }
281
282 self.reset_skip_count(rng);
283 self.probability != 0.0
284 }
285
286 /// Get the probability with which events are sampled.
287 ///
288 /// This is a number between `0.0` and `1.0`.
289 ///
290 /// This is the same value that was passed to `FastBernoulli::new` when
291 /// constructing this instance.
292 #[inline]
293 pub fn probability(&self) -> f64 {
294 self.probability
295 }
296
297 /// How many events will be skipped until the next event is sampled?
298 ///
299 /// When `self.probability() == 0.0` this method's return value is
300 /// inaccurate, and logically should be infinity.
301 ///
302 /// # Example
303 ///
304 /// ```
305 /// use rand::Rng;
306 /// use fast_bernoulli::FastBernoulli;
307 ///
308 /// let mut rng = rand::thread_rng();
309 /// let mut bernoulli = FastBernoulli::new(0.1, &mut rng);
310 ///
311 /// // Get the number of upcoming events that will not be sampled.
312 /// let skip_count = bernoulli.skip_count();
313 ///
314 /// // That many events will not be sampled.
315 /// for _ in 0..skip_count {
316 /// assert!(!bernoulli.trial(&mut rng));
317 /// }
318 ///
319 /// // The next event will be sampled.
320 /// assert!(bernoulli.trial(&mut rng));
321 /// ```
322 #[inline]
323 pub fn skip_count(&self) -> u32 {
324 self.skip_count
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 #[test]
333 fn expected_number_of_samples() {
334 let mut rng = rand::thread_rng();
335
336 let probability = 0.01;
337 let events = 10_000;
338 let expected = (events as f64) * probability;
339 let error_tolerance = expected * 0.25;
340
341 let mut bernoulli = FastBernoulli::new(probability, &mut rng);
342
343 let mut num_sampled = 0;
344 for _ in 0..events {
345 if bernoulli.trial(&mut rng) {
346 num_sampled += 1;
347 }
348 }
349
350 let min = (expected - error_tolerance) as u32;
351 let max = (expected + error_tolerance) as u32;
352 assert!(
353 min <= num_sampled && num_sampled <= max,
354 "expected ~{} samples, found {} (acceptable range is {} to {})",
355 expected,
356 num_sampled,
357 min,
358 max,
359 );
360 }
361}