qfall_math/utils/sample/
binomial.rs

1// Copyright © 2025 Niklas Siemer
2//
3// This file is part of qFALL-math.
4//
5// qFALL-math is free software: you can redistribute it and/or modify it under
6// the terms of the Mozilla Public License Version 2.0 as published by the
7// Mozilla Foundation. See <https://mozilla.org/en-US/MPL/2.0/>.
8
9//! This module includes core functionality to sample according to the
10//! binomial distribution.
11
12use crate::{error::MathError, integer::Z, rational::Q};
13use rand::rngs::ThreadRng;
14use rand_distr::{Binomial, Distribution};
15
16/// Enables sampling a [`Z`] according to the binomial distribution `Bin(n, p)`.
17///
18/// Attributes:
19/// - `distr`: defines the binomial distribution with parameters `n` and `p` to sample from
20/// - `rng`: defines the [`ThreadRng`] that's used to sample from `distr`
21///
22/// # Examples
23/// ```
24/// use qfall_math::utils::sample::binomial::BinomialSampler;
25/// let n = 2;
26/// let p = 0.5;
27///
28/// let mut bin_sampler = BinomialSampler::init(n, p).unwrap();
29///
30/// let sample = bin_sampler.sample();
31///
32/// assert!(0 <= sample);
33/// assert!(sample <= n);
34/// ```
35pub struct BinomialSampler {
36    distr: Binomial,
37    rng: ThreadRng,
38}
39
40impl BinomialSampler {
41    /// Initializes the [`BinomialSampler`] with
42    /// - `distr` as the binomial distribution with `n` tries and success probability `p` for each try, and
43    /// - `rng` as a fresh [`ThreadRng`].
44    ///
45    /// Parameters:
46    /// - `n`: specifies the number of tries
47    /// - `p`: specifies the success probability
48    ///
49    /// Returns a [`BinomialSampler`] or a [`MathError`] if `n < 0`,
50    /// `p ∉ (0,1)`, or `n` does not fit into an [`i64`].
51    ///
52    /// # Examples
53    /// ```
54    /// use qfall_math::utils::sample::binomial::BinomialSampler;
55    /// let n = 2;
56    /// let p = 0.5;
57    ///
58    /// let mut bin_sampler = BinomialSampler::init(n, p).unwrap();
59    /// ```
60    ///
61    /// # Errors and Failures
62    /// - Returns a [`MathError`] of type [`InvalidIntegerInput`](MathError::InvalidIntegerInput)
63    ///   if `n < 0`.
64    /// - Returns a [`MathError`] of type [`InvalidInterval`](MathError::InvalidInterval)
65    ///   if `p ∉ (0,1)`.
66    /// - Returns a [`MathError`] of type [`ConversionError`](MathError::ConversionError)
67    ///   if `n` does not fit into an [`i64`].
68    pub fn init(n: impl Into<Z>, p: impl Into<Q>) -> Result<Self, MathError> {
69        let n = n.into();
70        let p = p.into();
71
72        if p <= Q::ZERO || p >= Q::ONE {
73            return Err(MathError::InvalidInterval(format!(
74                "p (the probability of success for binomial sampling) must be chosen between 0 and 1. \
75                Currently it is {p}. \
76                Hence, the interval to sample from is invalid and contains only exactly one number."
77            )));
78        }
79        if n < Z::ZERO {
80            return Err(MathError::InvalidIntegerInput(format!(
81                "n (the number of trials for binomial sampling) must be no smaller than 0. Currently it is {n}."
82            )));
83        }
84
85        let n = i64::try_from(n)? as u64;
86        let p = f64::from(&p);
87
88        let distr = Binomial::new(n, p).unwrap();
89        let rng = rand::rng();
90
91        Ok(Self { distr, rng })
92    }
93
94    /// Samples a [`Z`] according to the binomial distribution `Bin(n, p)`.
95    ///
96    /// # Examples
97    /// ```
98    /// use qfall_math::utils::sample::binomial::BinomialSampler;
99    /// let n = 2;
100    /// let p = 0.5;
101    ///
102    /// let mut bin_sampler = BinomialSampler::init(n, p).unwrap();
103    ///
104    /// let sample = bin_sampler.sample();
105    ///
106    /// assert!(0 <= sample);
107    /// assert!(sample <= n);
108    /// ```
109    pub fn sample(&mut self) -> Z {
110        Z::from(self.distr.sample(&mut self.rng))
111    }
112}
113
114#[cfg(test)]
115mod test_binomial_sampler {
116    use super::{BinomialSampler, Q, Z};
117
118    /// Checks whether the range is kept,
119    /// i.e. if any result is at least 0 and smaller than or equal to `n`.
120    #[test]
121    fn keeps_range() {
122        let n = 16;
123        let p = 0.5;
124        let mut bin_sampler = BinomialSampler::init(n, p).unwrap();
125
126        for _ in 0..16 {
127            let sample = bin_sampler.sample();
128            // sample >= 0 check is not required as sample is a u64
129            assert!(sample <= n);
130        }
131    }
132
133    /// Roughly checks that the collected samples are distributed according to the binomial distribution.
134    #[test]
135    fn distribution() {
136        let n = 2;
137        let p = 0.5;
138        let mut bin_sampler = BinomialSampler::init(n, p).unwrap();
139
140        let mut counts = [0; 3];
141        // count sampled instances
142        for _ in 0..1000 {
143            let sample = u64::try_from(bin_sampler.sample()).unwrap() as usize;
144            counts[sample] += 1;
145        }
146
147        let expl_text = String::from("This test can fail with probability close to 0. 
148        It fails if the sampled occurrences do not look like a typical binomial random distribution. 
149        If this happens, rerun the tests several times and check whether this issue comes up again.");
150
151        // Check that the sampled occurrences roughly look
152        // like a binomial distribution
153        assert!(counts[0] > 100, "{expl_text}");
154        assert!(counts[0] < 400, "{expl_text}");
155        assert!(counts[1] > 300, "{expl_text}");
156        assert!(counts[1] < 700, "{expl_text}");
157        assert!(counts[2] > 100, "{expl_text}");
158        assert!(counts[2] < 400, "{expl_text}");
159    }
160
161    /// Checks whether invalid choices for n result in an error.
162    #[test]
163    fn invalid_n() {
164        let p = 0.5;
165
166        assert!(BinomialSampler::init(&Z::MINUS_ONE, p).is_err());
167        assert!(BinomialSampler::init(&Z::from(i64::MIN), p).is_err());
168    }
169
170    /// Checks whether invalid choices for p result in an error.
171    #[test]
172    fn invalid_p() {
173        let n = 2;
174
175        assert!(BinomialSampler::init(n, &Q::MINUS_ONE).is_err());
176        assert!(BinomialSampler::init(n, &Q::ZERO).is_err());
177        assert!(BinomialSampler::init(n, &Q::ONE).is_err());
178        assert!(BinomialSampler::init(n, &Q::from(5)).is_err());
179    }
180}