rstat/univariate/
bernoulli.rs1use crate::{
2 fitting::{Likelihood, Score, MLE},
3 statistics::{FisherInformation, Modes, Quantiles, ShannonEntropy, UnivariateMoments},
4 univariate::binomial::Binomial,
5 Convolution,
6 DiscreteDistribution,
7 Distribution,
8 Probability,
9};
10use ndarray::Array2;
11use spaces::discrete::Binary;
12use std::fmt;
13
14params! {
15 #[derive(Copy)]
16 Params {
17 p: Probability<>
18 }
19}
20
21pub struct Grad {
22 pub p: f64,
23}
24
25#[derive(Debug, Clone, Copy)]
26pub struct Bernoulli {
27 pub(crate) params: Params,
28
29 q: Probability,
30 variance: f64,
31}
32
33impl Bernoulli {
34 pub fn new(p: f64) -> Result<Bernoulli, failure::Error> {
35 Params::new(p).map(|ps| Bernoulli {
36 q: !ps.p,
37 params: ps,
38 variance: p * (1.0 - p),
39 })
40 }
41
42 pub fn new_unchecked(p: f64) -> Bernoulli { Params::new_unchecked(p).into() }
43}
44
45impl From<Params> for Bernoulli {
46 fn from(params: Params) -> Bernoulli {
47 Bernoulli {
48 q: !params.p,
49 variance: params.p.0 * (1.0 - params.p.0),
50
51 params,
52 }
53 }
54}
55
56impl Distribution for Bernoulli {
57 type Support = Binary;
58 type Params = Params;
59
60 fn support(&self) -> Binary { Binary }
61
62 fn params(&self) -> Params { self.params }
63
64 fn cdf(&self, k: &bool) -> Probability {
65 if *k {
66 Probability::one()
67 } else {
68 Probability::zero()
69 }
70 }
71
72 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> bool {
73 rng.gen_bool(self.params.p.unwrap())
74 }
75}
76
77impl DiscreteDistribution for Bernoulli {
78 fn pmf(&self, k: &bool) -> Probability {
79 match k {
80 true => self.params.p,
81 false => self.q,
82 }
83 }
84}
85
86impl UnivariateMoments for Bernoulli {
87 fn mean(&self) -> f64 { self.params.p.into() }
88
89 fn variance(&self) -> f64 { self.variance }
90
91 fn skewness(&self) -> f64 { (1.0 - 2.0 * self.params.p.unwrap()) / self.variance.sqrt() }
92
93 fn kurtosis(&self) -> f64 { 1.0 / self.variance - 6.0 }
94
95 fn excess_kurtosis(&self) -> f64 { 1.0 / self.variance - 9.0 }
96}
97
98impl Quantiles for Bernoulli {
99 fn quantile(&self, _: Probability) -> f64 { unimplemented!() }
100
101 fn median(&self) -> f64 {
102 match self.params.p.unwrap() {
103 p if (p - 0.5).abs() < 1e-7 => 0.5,
104 p if (p < 0.5) => 0.0,
105 _ => 1.0,
106 }
107 }
108}
109
110impl Modes for Bernoulli {
111 fn modes(&self) -> Vec<bool> {
112 use std::cmp::Ordering::*;
113
114 match self.params.p.partial_cmp(&self.q) {
115 Some(Less) => vec![false],
116 Some(Equal) => vec![false, true],
117 Some(Greater) => vec![false],
118 None => unreachable!(),
119 }
120 }
121}
122
123impl ShannonEntropy for Bernoulli {
124 fn shannon_entropy(&self) -> f64 {
125 let p = self.params.p.unwrap();
126 let q = self.q.unwrap();
127
128 if q.abs() < 1e-7 || (q - 1.0).abs() < 1e-7 {
129 0.0
130 } else {
131 -q * q.ln() - p * p.ln()
132 }
133 }
134}
135
136impl FisherInformation for Bernoulli {
137 fn fisher_information(&self) -> Array2<f64> { Array2::from_elem((1, 1), 1.0 / self.variance) }
138}
139
140impl Likelihood for Bernoulli {
141 fn log_likelihood(&self, samples: &[bool]) -> f64 {
142 samples.into_iter().map(|x| self.log_pmf(x)).sum()
143 }
144}
145
146impl Score for Bernoulli {
147 type Grad = Grad;
148
149 fn score(&self, samples: &[bool]) -> Grad {
150 Grad {
151 p: samples.into_iter().map(|x| 1.0 / self.pmf(x)).sum(),
152 }
153 }
154}
155
156impl MLE for Bernoulli {
157 fn fit_mle(xs: &[bool]) -> Result<Self, failure::Error> {
158 let n = xs.len() as f64;
159 let p = xs.iter().fold(0, |acc, &x| acc + x as u64) as f64 / n;
160
161 Bernoulli::new(p)
162 }
163}
164
165impl Convolution<Bernoulli> for Bernoulli {
166 type Output = Binomial;
167
168 fn convolve(self, rv: Bernoulli) -> Result<Binomial, failure::Error> {
169 let p1 = self.params.p;
170 let p2 = rv.params.p;
171
172 assert_constraint!(p1 == p2)?;
173
174 Ok(Binomial::new_unchecked(2, self.params.p))
175 }
176}
177
178impl fmt::Display for Bernoulli {
179 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "Ber({})", self.params.p) }
180}