hyperopt/kernel/discrete/
binomial.rs

1use std::{fmt::Debug, iter::Sum};
2
3use fastrand::Rng;
4use num_traits::{Float, FromPrimitive, One, Zero};
5
6use crate::{
7    iter::{range_inclusive, range_step_from},
8    kernel::Kernel,
9    traits::{
10        loopback::SelfSub,
11        shortcuts::{Additive, Multiplicative},
12    },
13    Density,
14    Sample,
15};
16
17/// Discrete kernel function based on the [binomial distribution][1].
18///
19/// The probability mass function is normalized by dividing on the standard deviation.
20///
21/// Note that [`Binomial::density`] is a `O(at)` operation, so it's pretty slow.
22///
23/// [1]: https://en.wikipedia.org/wiki/Binomial_distribution
24#[derive(Copy, Clone, Debug, Eq, PartialEq)]
25pub struct Binomial<P, D> {
26    /// Number of independent experiments (distribution parameter).
27    pub n: P,
28
29    /// Experiment success rate (distribution parameter).
30    pub p: D,
31}
32
33impl<P, D> Binomial<P, D> {
34    /// Probability mass function.
35    fn pmf(&self, at: P) -> D
36    where
37        P: Copy + Into<D> + PartialOrd + Zero + One + SelfSub,
38        D: Float + Sum,
39    {
40        if self.p == D::one() {
41            // The only possible outcome is `at == n`:
42            if at == self.n { D::one() } else { D::zero() }
43        } else if self.p == D::zero() {
44            // The only possible outcome is `at == 0`:
45            if at == P::zero() { D::one() } else { D::zero() }
46        } else if at <= self.n {
47            // lg(n choose k) = Σ ln(n + 1 - i) - ln(i)
48            let log_binomial: D = range_inclusive(P::one(), at)
49                .map(|i| (self.n - at + i).into().ln() - i.into().ln())
50                .sum();
51            let log_pmf = log_binomial
52                + at.into() * self.p.ln()
53                + (self.n - at).into() * (D::one() - self.p).ln();
54            log_pmf.exp()
55        } else {
56            // It is impossible to have more successes than experiments, hence the zero.
57            D::zero()
58        }
59    }
60
61    /// Standard deviation: √(p * (1 - p) / n).
62    fn std(&self) -> D
63    where
64        P: Copy + Into<D>,
65        D: Float,
66    {
67        (self.n.into() * self.p * (D::one() - self.p)).sqrt()
68    }
69
70    fn inverse_cdf(&self, cdf: D) -> P
71    where
72        P: Copy + Into<D> + One + Zero + PartialOrd + SelfSub,
73        D: Copy + Float + Sum,
74    {
75        range_step_from(P::zero(), P::one())
76            .scan(D::zero(), |acc, at| {
77                *acc = *acc + self.pmf(at);
78                Some((at, *acc))
79            })
80            .find(|(_, acc)| *acc >= cdf)
81            .expect("there should be a next sample")
82            .0
83    }
84}
85
86impl<P, D> Density for Binomial<P, D>
87where
88    P: Copy + Into<D> + Zero + PartialOrd + One + SelfSub,
89    D: Float + Sum,
90{
91    type Param = P;
92    type Output = D;
93
94    fn density(&self, at: Self::Param) -> Self::Output {
95        self.pmf(at) / self.std()
96    }
97}
98
99impl<P, D> Sample for Binomial<P, D>
100where
101    P: Copy + Into<D> + One + Zero + PartialOrd + SelfSub,
102    D: Float + FromPrimitive + Sum,
103{
104    type Param = P;
105
106    fn sample(&self, rng: &mut Rng) -> Self::Param {
107        self.inverse_cdf(D::from_f64(rng.f64()).unwrap())
108    }
109}
110
111impl<P, D> Kernel for Binomial<P, D>
112where
113    Self: Density<Param = P, Output = D> + Sample<Param = P>,
114    P: Copy + Ord + Additive + Multiplicative + Into<D> + One,
115    D: Multiplicative,
116{
117    type Param = P;
118
119    fn new(location: P, std: P) -> Self {
120        // Solving these for `p` and `n`:
121        // Bandwidth: σ = √(p(1-p)/n)
122        // Location: l = pn
123
124        // Getting:
125        // σ² = pn(1-p) = l(1-p)
126        // 1-p = σ²/l
127        // p = 1-(σ²/l)
128        // n = l/p = l/(1-(σ²/l)) = l/((l-σ²)/l) = l²/(l-σ²)
129
130        // Restrict bandwidth to avoid infinite `n` and/or negative `p`:
131        let sigma_squared = (std * std).min(location - P::one());
132
133        #[allow(clippy::suspicious_operation_groupings)]
134        let n = (location * location / (location - sigma_squared)).max(P::one());
135        Self {
136            n,
137            p: Into::<D>::into(location) / Into::<D>::into(n),
138        }
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use approx::assert_abs_diff_eq;
145
146    use super::*;
147
148    #[test]
149    fn pmf_ok() {
150        assert_abs_diff_eq!(Binomial { n: 5, p: 0.5 }.pmf(2), 0.3125);
151        assert_abs_diff_eq!(
152            Binomial { n: 20, p: 0.5 }.pmf(10),
153            0.176_197,
154            epsilon = 0.000_001
155        );
156        assert_abs_diff_eq!(
157            Binomial { n: 20, p: 0.5 }.pmf(5),
158            0.014_786,
159            epsilon = 0.000_001
160        );
161        assert_abs_diff_eq!(Binomial { n: 20_u32, p: 0.5 }.pmf(21_u32), 0.0);
162    }
163
164    #[test]
165    fn pmf_corner_cases() {
166        assert_abs_diff_eq!(Binomial { n: 1, p: 0.0 }.pmf(0), 1.0);
167        assert_abs_diff_eq!(Binomial { n: 1, p: 0.0 }.pmf(1), 0.0);
168        assert_abs_diff_eq!(Binomial { n: 1, p: 1.0 }.pmf(0), 0.0);
169        assert_abs_diff_eq!(Binomial { n: 1, p: 1.0 }.pmf(1), 1.0);
170    }
171
172    #[test]
173    fn inverse_cdf_ok() {
174        assert_eq!(Binomial { n: 20, p: 0.5 }.inverse_cdf(0.588), 10);
175        assert_eq!(Binomial { n: 20, p: 0.5 }.inverse_cdf(0.020_694), 5);
176        assert_eq!(Binomial { n: 1, p: 0.0 }.inverse_cdf(1.0), 0);
177    }
178
179    #[test]
180    fn std_ok() {
181        assert_abs_diff_eq!(Binomial { n: 20, p: 0.5 }.std(), 2.23607, epsilon = 0.00001);
182    }
183
184    #[test]
185    fn new_ok() {
186        let kernel = Binomial::<_, f64>::new(5, 2);
187        assert_eq!(kernel.n, 25);
188        assert_abs_diff_eq!(kernel.p, 0.2);
189    }
190
191    #[test]
192    fn new_bandwidth_overflow() {
193        let kernel = Binomial::<_, f64>::new(2, 100);
194        assert_eq!(kernel.n, 4);
195        assert_abs_diff_eq!(kernel.p, 0.5);
196    }
197}