probability/distribution/
bernoulli.rs

1use alloc::{vec, vec::Vec};
2#[allow(unused_imports)]
3use special::Primitive;
4
5use distribution;
6use source::Source;
7
8/// A Bernoulli distribution.
9#[derive(Clone, Copy, Debug)]
10pub struct Bernoulli {
11    p: f64,
12    q: f64,
13    pq: f64,
14}
15
16impl Bernoulli {
17    /// Create a Bernoulli distribution with success probability `p`.
18    ///
19    /// It should hold that `p > 0` and `p < 1`.
20    #[inline]
21    pub fn new(p: f64) -> Self {
22        should!(p > 0.0 && p < 1.0);
23        Bernoulli {
24            p,
25            q: 1.0 - p,
26            pq: p * (1.0 - p),
27        }
28    }
29
30    /// Create a Bernoulli distribution with failure probability `q`.
31    ///
32    /// It should hold that `q > 0` and `q < 1`. This constructor is preferable
33    /// when `q` is very small.
34    #[inline]
35    pub fn with_failure(q: f64) -> Self {
36        should!(q > 0.0 && q < 1.0);
37        Bernoulli {
38            p: 1.0 - q,
39            q,
40            pq: (1.0 - q) * q,
41        }
42    }
43
44    /// Return the success probability.
45    #[inline(always)]
46    pub fn p(&self) -> f64 {
47        self.p
48    }
49
50    /// Return the failure probability.
51    #[inline(always)]
52    pub fn q(&self) -> f64 {
53        self.q
54    }
55}
56
57impl distribution::Discrete for Bernoulli {
58    #[inline]
59    fn mass(&self, x: u8) -> f64 {
60        if x == 0 {
61            self.q
62        } else if x == 1 {
63            self.p
64        } else {
65            0.0
66        }
67    }
68}
69
70impl distribution::Distribution for Bernoulli {
71    type Value = u8;
72
73    #[inline]
74    fn distribution(&self, x: f64) -> f64 {
75        if x < 0.0 {
76            0.0
77        } else if x < 1.0 {
78            self.q
79        } else {
80            1.0
81        }
82    }
83}
84
85impl distribution::Entropy for Bernoulli {
86    fn entropy(&self) -> f64 {
87        -self.q * self.q.ln() - self.p * self.p.ln()
88    }
89}
90
91impl distribution::Inverse for Bernoulli {
92    #[inline]
93    fn inverse(&self, p: f64) -> u8 {
94        should!((0.0..=1.0).contains(&p));
95        if p <= self.q {
96            0
97        } else {
98            1
99        }
100    }
101}
102
103impl distribution::Kurtosis for Bernoulli {
104    #[inline]
105    fn kurtosis(&self) -> f64 {
106        (1.0 - 6.0 * self.pq) / (self.pq)
107    }
108}
109
110impl distribution::Mean for Bernoulli {
111    #[inline]
112    fn mean(&self) -> f64 {
113        self.p
114    }
115}
116
117impl distribution::Median for Bernoulli {
118    fn median(&self) -> f64 {
119        use core::cmp::Ordering::*;
120        match self.p.partial_cmp(&self.q) {
121            Some(Less) => 0.0,
122            Some(Equal) => 0.5,
123            Some(Greater) => 1.0,
124            None => unreachable!(),
125        }
126    }
127}
128
129impl distribution::Modes for Bernoulli {
130    fn modes(&self) -> Vec<u8> {
131        use core::cmp::Ordering::*;
132        match self.p.partial_cmp(&self.q) {
133            Some(Less) => vec![0],
134            Some(Equal) => vec![0, 1],
135            Some(Greater) => vec![1],
136            None => unreachable!(),
137        }
138    }
139}
140
141impl distribution::Sample for Bernoulli {
142    #[inline]
143    fn sample<S>(&self, source: &mut S) -> u8
144    where
145        S: Source,
146    {
147        if source.read::<f64>() < self.q {
148            0
149        } else {
150            1
151        }
152    }
153}
154
155impl distribution::Skewness for Bernoulli {
156    #[inline]
157    fn skewness(&self) -> f64 {
158        (1.0 - 2.0 * self.p) / self.pq.sqrt()
159    }
160}
161
162impl distribution::Variance for Bernoulli {
163    #[inline]
164    fn variance(&self) -> f64 {
165        self.pq
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use alloc::{vec, vec::Vec};
172    use assert;
173    use prelude::*;
174
175    macro_rules! new(
176        (failure $q:expr) => (Bernoulli::with_failure($q));
177        ($p:expr) => (Bernoulli::new($p));
178    );
179
180    #[test]
181    fn distribution() {
182        let d = new!(0.25);
183        let x = vec![-0.1, 0.0, 0.1, 0.25, 0.5, 1.0, 1.1];
184        let p = vec![0.0, 0.75, 0.75, 0.75, 0.75, 1.0, 1.0];
185        assert_eq!(
186            &x.iter().map(|&x| d.distribution(x)).collect::<Vec<_>>(),
187            &p
188        );
189    }
190
191    #[test]
192    fn entropy() {
193        let d = vec![new!(0.25), new!(0.5), new!(0.75)];
194        assert::close(
195            &d.iter().map(|d| d.entropy()).collect::<Vec<_>>(),
196            &vec![0.5623351446188083, 0.6931471805599453, 0.5623351446188083],
197            1e-16,
198        );
199    }
200
201    #[test]
202    fn inverse() {
203        let d = new!(0.25);
204        let p = vec![0.0, 0.25, 0.5, 0.75, 0.75000000001, 1.0];
205        let x = vec![0, 0, 0, 0, 1, 1];
206        assert_eq!(&p.iter().map(|&p| d.inverse(p)).collect::<Vec<_>>(), &x);
207    }
208
209    #[test]
210    fn kurtosis() {
211        assert_eq!(new!(0.5).kurtosis(), -2.0);
212    }
213
214    #[test]
215    fn mass() {
216        let d = new!(0.25);
217        assert_eq!(
218            &(0..3).map(|x| d.mass(x)).collect::<Vec<_>>(),
219            &[0.75, 0.25, 0.0]
220        );
221    }
222
223    #[test]
224    fn mean() {
225        assert_eq!(new!(0.5).mean(), 0.5);
226    }
227
228    #[test]
229    fn median() {
230        assert_eq!(new!(0.25).median(), 0.0);
231        assert_eq!(new!(0.5).median(), 0.5);
232        assert_eq!(new!(0.75).median(), 1.0);
233    }
234
235    #[test]
236    fn modes() {
237        assert_eq!(new!(0.25).modes(), vec![0]);
238        assert_eq!(new!(0.5).modes(), vec![0, 1]);
239        assert_eq!(new!(0.75).modes(), vec![1]);
240    }
241
242    #[test]
243    fn sample() {
244        assert!(
245            Independent(&new!(0.25), &mut source::default(42))
246                .take(100)
247                .fold(0, |a, b| a + b)
248                <= 100
249        );
250    }
251
252    #[test]
253    fn skewness() {
254        assert_eq!(new!(0.5).skewness(), 0.0);
255    }
256
257    #[test]
258    fn variance() {
259        assert_eq!(new!(0.25).variance(), 0.1875);
260    }
261}