fast_bernoulli/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(missing_debug_implementations, missing_docs)]
3#![forbid(unsafe_code)]
4
5// [What follows is another outstanding comment from Jim Blandy explaining why
6// this technique works.]
7//
8// This comment should just read, "Generate skip counts with a geometric
9// distribution", and leave everyone to go look that up and see why it's the
10// right thing to do, if they don't know already.
11//
12// BUT IF YOU'RE CURIOUS, COMMENTS ARE FREE...
13//
14// Instead of generating a fresh random number for every trial, we can
15// randomly generate a count of how many times we should return false before
16// the next time we return true. We call this a "skip count". Once we've
17// returned true, we generate a fresh skip count, and begin counting down
18// again.
19//
20// Here's an awesome fact: by exercising a little care in the way we generate
21// skip counts, we can produce results indistinguishable from those we would
22// get "rolling the dice" afresh for every trial.
23//
24// In short, skip counts in Bernoulli trials of probability `P` obey a geometric
25// distribution. If a random variable `X` is uniformly distributed from
26// `[0..1)`, then `floor(log(X) / log(1-P))` has the appropriate geometric
27// distribution for the skip counts.
28//
29// Why that formula?
30//
31// Suppose we're to return `true` with some probability `P`, say, `0.3`. Spread
32// all possible futures along a line segment of length `1`. In portion `P` of
33// those cases, we'll return true on the next call to `trial`; the skip count is
34// 0. For the remaining portion `1-P` of cases, the skip count is `1` or more.
35//
36// ```
37//    skip:             0                         1 or more
38//             |------------------^-----------------------------------------|
39// portion:            0.3                            0.7
40//                      P                             1-P
41// ```
42//
43// But the "1 or more" section of the line is subdivided the same way: *within
44// that section*, in portion `P` the second call to `trial()` returns `true`, and
45// in portion `1-P` it returns `false` a second time; the skip count is two or
46// more. So we return `true` on the second call in proportion `0.7 * 0.3`, and
47// skip at least the first two in proportion `0.7 * 0.7`.
48//
49// ```
50//    skip:             0                1              2 or more
51//             |------------------^------------^----------------------------|
52// portion:            0.3           0.7 * 0.3          0.7 * 0.7
53//                      P             (1-P)*P            (1-P)^2
54// ```
55//
56// We can continue to subdivide:
57//
58// ```
59// skip >= 0:  |------------------------------------------------- (1-P)^0 --|
60// skip >= 1:  |                  ------------------------------- (1-P)^1 --|
61// skip >= 2:  |                               ------------------ (1-P)^2 --|
62// skip >= 3:  |                                 ^     ---------- (1-P)^3 --|
63// skip >= 4:  |                                 .            --- (1-P)^4 --|
64//                                               .
65//                                               ^X, see below
66// ```
67//
68// In other words, the likelihood of the next `n` calls to `trial` returning
69// `false` is `(1-P)^n`. The longer a run we require, the more the likelihood
70// drops. Further calls may return `false` too, but this is the probability
71// we'll skip at least `n`.
72//
73// This is interesting, because we can pick a point along this line segment and
74// see which skip count's range it falls within; the point `X` above, for
75// example, is within the ">= 2" range, but not within the ">= 3" range, so it
76// designates a skip count of `2`. So if we pick points on the line at random
77// and use the skip counts they fall under, that will be indistinguishable from
78// generating a fresh random number between `0` and `1` for each trial and
79// comparing it to `P`.
80//
81// So to find the skip count for a point `X`, we must ask: To what whole power
82// must we raise `1-P` such that we include `X`, but the next power would
83// exclude it? This is exactly `floor(log(X) / log(1-P))`.
84//
85// Our algorithm is then, simply: When constructed, compute an initial skip
86// count. Return `false` from `trial` that many times, and then compute a new
87// skip count.
88//
89// For a call to `multi_trial(n)`, if the skip count is greater than `n`, return
90// `false` and subtract `n` from the skip count. If the skip count is less than
91// `n`, return true and compute a new skip count. Since each trial is
92// independent, it doesn't matter by how much `n` overshoots the skip count; we
93// can actually compute a new skip count at *any* time without affecting the
94// distribution. This is really beautiful.
95
96use rand::Rng;
97
98/// Fast Bernoulli sampling: each event has equal probability of being sampled.
99///
100/// See the [crate-level documentation][crate] for more general
101/// information.
102///
103/// # Example
104///
105/// ```
106/// use fast_bernoulli::FastBernoulli;
107/// use rand::Rng;
108///
109/// // Get the thread-local random number generator.
110/// let mut rng = rand::thread_rng();
111///
112/// // Create a `FastBernoulli` instance that samples events with probability 1/20.
113/// let mut bernoulli = FastBernoulli::new(0.05, &mut rng);
114///
115/// // Each time your event occurs, perform a Bernoulli trail to determine whether
116/// // you should sample the event or not.
117/// let on_my_event = || {
118///     if bernoulli.trial(&mut rng) {
119///         // Record the sample...
120///     }
121/// };
122/// ```
123#[derive(Debug, Clone, Copy)]
124pub struct FastBernoulli {
125    probability: f64,
126    skip_count: u32,
127}
128
129impl FastBernoulli {
130    /// Construct a new `FastBernoulli` instance that samples events with the
131    /// given probability.
132    ///
133    /// # Panics
134    ///
135    /// The probability must be within the range `0.0 <= probability <= 1.0` and
136    /// this method will panic if that is not the case.
137    ///
138    /// # Example
139    ///
140    /// ```
141    /// use rand::Rng;
142    /// use fast_bernoulli::FastBernoulli;
143    ///
144    /// let mut rng = rand::thread_rng();
145    /// let sample_one_in_a_hundred = FastBernoulli::new(0.01, &mut rng);
146    /// ```
147    pub fn new<R>(probability: f64, rng: &mut R) -> Self
148    where
149        R: Rng + ?Sized,
150    {
151        assert!(
152            0.0 <= probability && probability <= 1.0,
153            "`probability` must be in the range `0.0 <= probability <= 1.0`"
154        );
155
156        let mut bernoulli = FastBernoulli {
157            probability,
158            skip_count: 0,
159        };
160        bernoulli.reset_skip_count(rng);
161        bernoulli
162    }
163
164    fn reset_skip_count<R>(&mut self, rng: &mut R)
165    where
166        R: Rng + ?Sized,
167    {
168        if self.probability == 0.0 {
169            // Edge case: we will never sample any event.
170            self.skip_count = u32::MAX;
171        } else if self.probability == 1.0 {
172            // Edge case: we will sample every event.
173            self.skip_count = 0;
174        } else {
175            // Common case: we need to choose a new skip count using the
176            // formula `floor(log(x) / log(1 - P))`, as explained in the
177            // comment at the top of this file.
178            let x: f64 = rng.gen_range(0.0..1.0);
179            let skip_count = (x.ln() / (1.0 - self.probability).ln()).floor();
180            debug_assert!(skip_count >= 0.0);
181            self.skip_count = if skip_count <= (u32::MAX as f64) {
182                skip_count as u32
183            } else {
184                // Clamp the skip count to `u32::MAX`. This can skew
185                // sampling when we are sampling with a very low
186                // probability, but it is better than any super-robust
187                // alternative we have, such as representing skip counts
188                // with big nums.
189                u32::MAX
190            };
191        }
192    }
193
194    /// Perform a Bernoulli trial: returns `true` with the configured
195    /// probability.
196    ///
197    /// Call this each time an event occurs to determine whether to sample the
198    /// event.
199    ///
200    /// The lower the configured probability, the less overhead calling this
201    /// function has.
202    ///
203    /// # Example
204    ///
205    /// ```
206    /// use rand::Rng;
207    /// use fast_bernoulli::FastBernoulli;
208    ///
209    /// let mut rng = rand::thread_rng();
210    /// let mut bernoulli = FastBernoulli::new(0.1, &mut rng);
211    ///
212    /// // Each time an event occurs, call `trial`...
213    /// if bernoulli.trial(&mut rng) {
214    ///     // ...and if it returns true, record a sample of this event.
215    /// }
216    /// ```
217    pub fn trial<R>(&mut self, rng: &mut R) -> bool
218    where
219        R: Rng + ?Sized,
220    {
221        if self.skip_count > 0 {
222            self.skip_count -= 1;
223            return false;
224        }
225
226        self.reset_skip_count(rng);
227        self.probability != 0.0
228    }
229
230    /// Perform `n` Bernoulli trials at once.
231    ///
232    /// This is semantically equivalent to calling the `trial()` method `n`
233    /// times and returning `true` if any of those calls returned `true`, but
234    /// runs in `O(1)` time instead of `O(n)` time.
235    ///
236    /// What is this good for? In some applications, some events are "bigger"
237    /// than others. For example, large memory allocations are more significant
238    /// than small memory allocations. Perhaps we'd like to imagine that we're
239    /// drawing allocations from a stream of bytes, and performing a separate
240    /// Bernoulli trial on every byte from the stream. We can accomplish this by
241    /// calling `multi_trial(s)` for the number of bytes `s`, and sampling the
242    /// event if that call returns true.
243    ///
244    /// Of course, this style of sampling needs to be paired with analysis and
245    /// presentation that makes the "size" of the event apparent, lest trials
246    /// with large values for `n` appear to be indistinguishable from those with
247    /// small values for `n`, despite being potentially much more likely to be
248    /// sampled.
249    ///
250    /// # Example
251    ///
252    /// ```
253    /// use rand::Rng;
254    /// use fast_bernoulli::FastBernoulli;
255    ///
256    /// let mut rng = rand::thread_rng();
257    /// let mut byte_sampler = FastBernoulli::new(0.05, &mut rng);
258    ///
259    /// // When we observe a `malloc` of ten bytes event...
260    /// if byte_sampler.multi_trial(10, &mut rng) {
261    ///     // ... if `multi_trial` returns `true` then we sample it.
262    ///     record_malloc_sample(10);
263    /// }
264    ///
265    /// // And when we observe a `malloc` of 1024 bytes event...
266    /// if byte_sampler.multi_trial(1024, &mut rng) {
267    ///     // ... if `multi_trial` returns `true` then we sample this larger
268    ///     // allocation.
269    ///     record_malloc_sample(1024);
270    /// }
271    /// # fn record_malloc_sample(_: u32) {}
272    /// ```
273    pub fn multi_trial<R>(&mut self, n: u32, rng: &mut R) -> bool
274    where
275        R: Rng + ?Sized,
276    {
277        if n < self.skip_count {
278            self.skip_count -= n;
279            return false;
280        }
281
282        self.reset_skip_count(rng);
283        self.probability != 0.0
284    }
285
286    /// Get the probability with which events are sampled.
287    ///
288    /// This is a number between `0.0` and `1.0`.
289    ///
290    /// This is the same value that was passed to `FastBernoulli::new` when
291    /// constructing this instance.
292    #[inline]
293    pub fn probability(&self) -> f64 {
294        self.probability
295    }
296
297    /// How many events will be skipped until the next event is sampled?
298    ///
299    /// When `self.probability() == 0.0` this method's return value is
300    /// inaccurate, and logically should be infinity.
301    ///
302    /// # Example
303    ///
304    /// ```
305    /// use rand::Rng;
306    /// use fast_bernoulli::FastBernoulli;
307    ///
308    /// let mut rng = rand::thread_rng();
309    /// let mut bernoulli = FastBernoulli::new(0.1, &mut rng);
310    ///
311    /// // Get the number of upcoming events that will not be sampled.
312    /// let skip_count = bernoulli.skip_count();
313    ///
314    /// // That many events will not be sampled.
315    /// for _ in 0..skip_count {
316    ///     assert!(!bernoulli.trial(&mut rng));
317    /// }
318    ///
319    /// // The next event will be sampled.
320    /// assert!(bernoulli.trial(&mut rng));
321    /// ```
322    #[inline]
323    pub fn skip_count(&self) -> u32 {
324        self.skip_count
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[test]
333    fn expected_number_of_samples() {
334        let mut rng = rand::thread_rng();
335
336        let probability = 0.01;
337        let events = 10_000;
338        let expected = (events as f64) * probability;
339        let error_tolerance = expected * 0.25;
340
341        let mut bernoulli = FastBernoulli::new(probability, &mut rng);
342
343        let mut num_sampled = 0;
344        for _ in 0..events {
345            if bernoulli.trial(&mut rng) {
346                num_sampled += 1;
347            }
348        }
349
350        let min = (expected - error_tolerance) as u32;
351        let max = (expected + error_tolerance) as u32;
352        assert!(
353            min <= num_sampled && num_sampled <= max,
354            "expected ~{} samples, found {} (acceptable range is {} to {})",
355            expected,
356            num_sampled,
357            min,
358            max,
359        );
360    }
361}