compute/distributions/
bernoulli.rs

1#![allow(clippy::float_cmp)]
2use crate::distributions::*;
3
4/// Implements the [Bernoulli distribution](https://en.wikipedia.org/wiki/Bernoulli_distribution).
5#[derive(Debug, Clone, Copy)]
6pub struct Bernoulli {
7    /// Probability `p` of the Bernoulli distribution
8    p: f64,
9}
10
11impl Bernoulli {
12    /// Create a new Bernoulli distribution with probability `p`.
13    ///
14    /// # Errors
15    /// Panics if p is not in [0, 1].
16    pub fn new(p: f64) -> Self {
17        if !(0. ..=1.).contains(&p) {
18            panic!("`p` must be in [0, 1].");
19        }
20        Bernoulli { p }
21    }
22    pub fn set_p(&mut self, p: f64) -> &mut Self {
23        if !(0. ..=1.).contains(&p) {
24            panic!("`p` must be in [0, 1].");
25        }
26        self.p = p;
27        self
28    }
29}
30
31impl Default for Bernoulli {
32    fn default() -> Self {
33        Self::new(0.5)
34    }
35}
36
37impl Distribution for Bernoulli {
38    type Output = f64;
39    /// Samples from the given Bernoulli distribution.
40    fn sample(&self) -> f64 {
41        if self.p == 1. {
42            return 1.;
43        } else if self.p == 0. {
44            return 0.;
45        }
46
47        if self.p > alea::f64() {
48            1.
49        } else {
50            0.
51        }
52    }
53}
54
55impl Distribution1D for Bernoulli {
56    fn update(&mut self, params: &[f64]) {
57        self.set_p(params[0]);
58    }
59}
60
61impl Discrete for Bernoulli {
62    /// Calculates the [probability mass
63    /// function](https://en.wikipedia.org/wiki/Probability_mass_function) for the given  Bernoulli
64    /// distribution at `x`.
65    ///
66    fn pmf(&self, k: i64) -> f64 {
67        if k == 0 {
68            1. - self.p
69        } else if k == 1 {
70            self.p
71        } else {
72            0.
73        }
74    }
75}
76
77impl Mean for Bernoulli {
78    type MeanType = f64;
79    /// Calculates the mean of the Bernoulli distribution, which is `p`.
80    fn mean(&self) -> f64 {
81        self.p
82    }
83}
84
85impl Variance for Bernoulli {
86    type VarianceType = f64;
87    /// Calculates the variance, given by `p*q = p(1-p)`.
88    fn var(&self) -> f64 {
89        self.p * (1. - self.p)
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::statistics::{mean, var};
97    use approx_eq::assert_approx_eq;
98
99    #[test]
100    fn test_bernoulli() {
101        let data = Bernoulli::new(0.75).sample_n(1e6 as usize);
102        for i in &data {
103            assert!(*i == 0. || *i == 1.);
104        }
105        assert_approx_eq!(0.75, mean(&data), 1e-2);
106        assert_approx_eq!(0.75 * 0.25, var(&data), 1e-2);
107        assert!(Bernoulli::default().pmf(2) == 0.);
108        assert!(Bernoulli::default().pmf(0) == 0.5);
109    }
110}