rstat/univariate/
bernoulli.rs

1use crate::{
2    fitting::{Likelihood, Score, MLE},
3    statistics::{FisherInformation, Modes, Quantiles, ShannonEntropy, UnivariateMoments},
4    univariate::binomial::Binomial,
5    Convolution,
6    DiscreteDistribution,
7    Distribution,
8    Probability,
9};
10use ndarray::Array2;
11use spaces::discrete::Binary;
12use std::fmt;
13
14params! {
15    #[derive(Copy)]
16    Params {
17        p: Probability<>
18    }
19}
20
21pub struct Grad {
22    pub p: f64,
23}
24
25#[derive(Debug, Clone, Copy)]
26pub struct Bernoulli {
27    pub(crate) params: Params,
28
29    q: Probability,
30    variance: f64,
31}
32
33impl Bernoulli {
34    pub fn new(p: f64) -> Result<Bernoulli, failure::Error> {
35        Params::new(p).map(|ps| Bernoulli {
36            q: !ps.p,
37            params: ps,
38            variance: p * (1.0 - p),
39        })
40    }
41
42    pub fn new_unchecked(p: f64) -> Bernoulli { Params::new_unchecked(p).into() }
43}
44
45impl From<Params> for Bernoulli {
46    fn from(params: Params) -> Bernoulli {
47        Bernoulli {
48            q: !params.p,
49            variance: params.p.0 * (1.0 - params.p.0),
50
51            params,
52        }
53    }
54}
55
56impl Distribution for Bernoulli {
57    type Support = Binary;
58    type Params = Params;
59
60    fn support(&self) -> Binary { Binary }
61
62    fn params(&self) -> Params { self.params }
63
64    fn cdf(&self, k: &bool) -> Probability {
65        if *k {
66            Probability::one()
67        } else {
68            Probability::zero()
69        }
70    }
71
72    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> bool {
73        rng.gen_bool(self.params.p.unwrap())
74    }
75}
76
77impl DiscreteDistribution for Bernoulli {
78    fn pmf(&self, k: &bool) -> Probability {
79        match k {
80            true => self.params.p,
81            false => self.q,
82        }
83    }
84}
85
86impl UnivariateMoments for Bernoulli {
87    fn mean(&self) -> f64 { self.params.p.into() }
88
89    fn variance(&self) -> f64 { self.variance }
90
91    fn skewness(&self) -> f64 { (1.0 - 2.0 * self.params.p.unwrap()) / self.variance.sqrt() }
92
93    fn kurtosis(&self) -> f64 { 1.0 / self.variance - 6.0 }
94
95    fn excess_kurtosis(&self) -> f64 { 1.0 / self.variance - 9.0 }
96}
97
98impl Quantiles for Bernoulli {
99    fn quantile(&self, _: Probability) -> f64 { unimplemented!() }
100
101    fn median(&self) -> f64 {
102        match self.params.p.unwrap() {
103            p if (p - 0.5).abs() < 1e-7 => 0.5,
104            p if (p < 0.5) => 0.0,
105            _ => 1.0,
106        }
107    }
108}
109
110impl Modes for Bernoulli {
111    fn modes(&self) -> Vec<bool> {
112        use std::cmp::Ordering::*;
113
114        match self.params.p.partial_cmp(&self.q) {
115            Some(Less) => vec![false],
116            Some(Equal) => vec![false, true],
117            Some(Greater) => vec![false],
118            None => unreachable!(),
119        }
120    }
121}
122
123impl ShannonEntropy for Bernoulli {
124    fn shannon_entropy(&self) -> f64 {
125        let p = self.params.p.unwrap();
126        let q = self.q.unwrap();
127
128        if q.abs() < 1e-7 || (q - 1.0).abs() < 1e-7 {
129            0.0
130        } else {
131            -q * q.ln() - p * p.ln()
132        }
133    }
134}
135
136impl FisherInformation for Bernoulli {
137    fn fisher_information(&self) -> Array2<f64> { Array2::from_elem((1, 1), 1.0 / self.variance) }
138}
139
140impl Likelihood for Bernoulli {
141    fn log_likelihood(&self, samples: &[bool]) -> f64 {
142        samples.into_iter().map(|x| self.log_pmf(x)).sum()
143    }
144}
145
146impl Score for Bernoulli {
147    type Grad = Grad;
148
149    fn score(&self, samples: &[bool]) -> Grad {
150        Grad {
151            p: samples.into_iter().map(|x| 1.0 / self.pmf(x)).sum(),
152        }
153    }
154}
155
156impl MLE for Bernoulli {
157    fn fit_mle(xs: &[bool]) -> Result<Self, failure::Error> {
158        let n = xs.len() as f64;
159        let p = xs.iter().fold(0, |acc, &x| acc + x as u64) as f64 / n;
160
161        Bernoulli::new(p)
162    }
163}
164
165impl Convolution<Bernoulli> for Bernoulli {
166    type Output = Binomial;
167
168    fn convolve(self, rv: Bernoulli) -> Result<Binomial, failure::Error> {
169        let p1 = self.params.p;
170        let p2 = rv.params.p;
171
172        assert_constraint!(p1 == p2)?;
173
174        Ok(Binomial::new_unchecked(2, self.params.p))
175    }
176}
177
178impl fmt::Display for Bernoulli {
179    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "Ber({})", self.params.p) }
180}