hyperopt/kernel/discrete/
binomial.rs1use std::{fmt::Debug, iter::Sum};
2
3use fastrand::Rng;
4use num_traits::{Float, FromPrimitive, One, Zero};
5
6use crate::{
7 iter::{range_inclusive, range_step_from},
8 kernel::Kernel,
9 traits::{
10 loopback::SelfSub,
11 shortcuts::{Additive, Multiplicative},
12 },
13 Density,
14 Sample,
15};
16
17#[derive(Copy, Clone, Debug, Eq, PartialEq)]
25pub struct Binomial<P, D> {
26 pub n: P,
28
29 pub p: D,
31}
32
33impl<P, D> Binomial<P, D> {
34 fn pmf(&self, at: P) -> D
36 where
37 P: Copy + Into<D> + PartialOrd + Zero + One + SelfSub,
38 D: Float + Sum,
39 {
40 if self.p == D::one() {
41 if at == self.n { D::one() } else { D::zero() }
43 } else if self.p == D::zero() {
44 if at == P::zero() { D::one() } else { D::zero() }
46 } else if at <= self.n {
47 let log_binomial: D = range_inclusive(P::one(), at)
49 .map(|i| (self.n - at + i).into().ln() - i.into().ln())
50 .sum();
51 let log_pmf = log_binomial
52 + at.into() * self.p.ln()
53 + (self.n - at).into() * (D::one() - self.p).ln();
54 log_pmf.exp()
55 } else {
56 D::zero()
58 }
59 }
60
61 fn std(&self) -> D
63 where
64 P: Copy + Into<D>,
65 D: Float,
66 {
67 (self.n.into() * self.p * (D::one() - self.p)).sqrt()
68 }
69
70 fn inverse_cdf(&self, cdf: D) -> P
71 where
72 P: Copy + Into<D> + One + Zero + PartialOrd + SelfSub,
73 D: Copy + Float + Sum,
74 {
75 range_step_from(P::zero(), P::one())
76 .scan(D::zero(), |acc, at| {
77 *acc = *acc + self.pmf(at);
78 Some((at, *acc))
79 })
80 .find(|(_, acc)| *acc >= cdf)
81 .expect("there should be a next sample")
82 .0
83 }
84}
85
86impl<P, D> Density for Binomial<P, D>
87where
88 P: Copy + Into<D> + Zero + PartialOrd + One + SelfSub,
89 D: Float + Sum,
90{
91 type Param = P;
92 type Output = D;
93
94 fn density(&self, at: Self::Param) -> Self::Output {
95 self.pmf(at) / self.std()
96 }
97}
98
99impl<P, D> Sample for Binomial<P, D>
100where
101 P: Copy + Into<D> + One + Zero + PartialOrd + SelfSub,
102 D: Float + FromPrimitive + Sum,
103{
104 type Param = P;
105
106 fn sample(&self, rng: &mut Rng) -> Self::Param {
107 self.inverse_cdf(D::from_f64(rng.f64()).unwrap())
108 }
109}
110
111impl<P, D> Kernel for Binomial<P, D>
112where
113 Self: Density<Param = P, Output = D> + Sample<Param = P>,
114 P: Copy + Ord + Additive + Multiplicative + Into<D> + One,
115 D: Multiplicative,
116{
117 type Param = P;
118
119 fn new(location: P, std: P) -> Self {
120 let sigma_squared = (std * std).min(location - P::one());
132
133 #[allow(clippy::suspicious_operation_groupings)]
134 let n = (location * location / (location - sigma_squared)).max(P::one());
135 Self {
136 n,
137 p: Into::<D>::into(location) / Into::<D>::into(n),
138 }
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use approx::assert_abs_diff_eq;
145
146 use super::*;
147
148 #[test]
149 fn pmf_ok() {
150 assert_abs_diff_eq!(Binomial { n: 5, p: 0.5 }.pmf(2), 0.3125);
151 assert_abs_diff_eq!(
152 Binomial { n: 20, p: 0.5 }.pmf(10),
153 0.176_197,
154 epsilon = 0.000_001
155 );
156 assert_abs_diff_eq!(
157 Binomial { n: 20, p: 0.5 }.pmf(5),
158 0.014_786,
159 epsilon = 0.000_001
160 );
161 assert_abs_diff_eq!(Binomial { n: 20_u32, p: 0.5 }.pmf(21_u32), 0.0);
162 }
163
164 #[test]
165 fn pmf_corner_cases() {
166 assert_abs_diff_eq!(Binomial { n: 1, p: 0.0 }.pmf(0), 1.0);
167 assert_abs_diff_eq!(Binomial { n: 1, p: 0.0 }.pmf(1), 0.0);
168 assert_abs_diff_eq!(Binomial { n: 1, p: 1.0 }.pmf(0), 0.0);
169 assert_abs_diff_eq!(Binomial { n: 1, p: 1.0 }.pmf(1), 1.0);
170 }
171
172 #[test]
173 fn inverse_cdf_ok() {
174 assert_eq!(Binomial { n: 20, p: 0.5 }.inverse_cdf(0.588), 10);
175 assert_eq!(Binomial { n: 20, p: 0.5 }.inverse_cdf(0.020_694), 5);
176 assert_eq!(Binomial { n: 1, p: 0.0 }.inverse_cdf(1.0), 0);
177 }
178
179 #[test]
180 fn std_ok() {
181 assert_abs_diff_eq!(Binomial { n: 20, p: 0.5 }.std(), 2.23607, epsilon = 0.00001);
182 }
183
184 #[test]
185 fn new_ok() {
186 let kernel = Binomial::<_, f64>::new(5, 2);
187 assert_eq!(kernel.n, 25);
188 assert_abs_diff_eq!(kernel.p, 0.2);
189 }
190
191 #[test]
192 fn new_bandwidth_overflow() {
193 let kernel = Binomial::<_, f64>::new(2, 100);
194 assert_eq!(kernel.n, 4);
195 assert_abs_diff_eq!(kernel.p, 0.5);
196 }
197}