probability/distribution/
bernoulli.rs1use alloc::{vec, vec::Vec};
2#[allow(unused_imports)]
3use special::Primitive;
4
5use distribution;
6use source::Source;
7
8#[derive(Clone, Copy, Debug)]
10pub struct Bernoulli {
11 p: f64,
12 q: f64,
13 pq: f64,
14}
15
16impl Bernoulli {
17 #[inline]
21 pub fn new(p: f64) -> Self {
22 should!(p > 0.0 && p < 1.0);
23 Bernoulli {
24 p,
25 q: 1.0 - p,
26 pq: p * (1.0 - p),
27 }
28 }
29
30 #[inline]
35 pub fn with_failure(q: f64) -> Self {
36 should!(q > 0.0 && q < 1.0);
37 Bernoulli {
38 p: 1.0 - q,
39 q,
40 pq: (1.0 - q) * q,
41 }
42 }
43
44 #[inline(always)]
46 pub fn p(&self) -> f64 {
47 self.p
48 }
49
50 #[inline(always)]
52 pub fn q(&self) -> f64 {
53 self.q
54 }
55}
56
57impl distribution::Discrete for Bernoulli {
58 #[inline]
59 fn mass(&self, x: u8) -> f64 {
60 if x == 0 {
61 self.q
62 } else if x == 1 {
63 self.p
64 } else {
65 0.0
66 }
67 }
68}
69
70impl distribution::Distribution for Bernoulli {
71 type Value = u8;
72
73 #[inline]
74 fn distribution(&self, x: f64) -> f64 {
75 if x < 0.0 {
76 0.0
77 } else if x < 1.0 {
78 self.q
79 } else {
80 1.0
81 }
82 }
83}
84
85impl distribution::Entropy for Bernoulli {
86 fn entropy(&self) -> f64 {
87 -self.q * self.q.ln() - self.p * self.p.ln()
88 }
89}
90
91impl distribution::Inverse for Bernoulli {
92 #[inline]
93 fn inverse(&self, p: f64) -> u8 {
94 should!((0.0..=1.0).contains(&p));
95 if p <= self.q {
96 0
97 } else {
98 1
99 }
100 }
101}
102
103impl distribution::Kurtosis for Bernoulli {
104 #[inline]
105 fn kurtosis(&self) -> f64 {
106 (1.0 - 6.0 * self.pq) / (self.pq)
107 }
108}
109
110impl distribution::Mean for Bernoulli {
111 #[inline]
112 fn mean(&self) -> f64 {
113 self.p
114 }
115}
116
117impl distribution::Median for Bernoulli {
118 fn median(&self) -> f64 {
119 use core::cmp::Ordering::*;
120 match self.p.partial_cmp(&self.q) {
121 Some(Less) => 0.0,
122 Some(Equal) => 0.5,
123 Some(Greater) => 1.0,
124 None => unreachable!(),
125 }
126 }
127}
128
129impl distribution::Modes for Bernoulli {
130 fn modes(&self) -> Vec<u8> {
131 use core::cmp::Ordering::*;
132 match self.p.partial_cmp(&self.q) {
133 Some(Less) => vec![0],
134 Some(Equal) => vec![0, 1],
135 Some(Greater) => vec![1],
136 None => unreachable!(),
137 }
138 }
139}
140
141impl distribution::Sample for Bernoulli {
142 #[inline]
143 fn sample<S>(&self, source: &mut S) -> u8
144 where
145 S: Source,
146 {
147 if source.read::<f64>() < self.q {
148 0
149 } else {
150 1
151 }
152 }
153}
154
155impl distribution::Skewness for Bernoulli {
156 #[inline]
157 fn skewness(&self) -> f64 {
158 (1.0 - 2.0 * self.p) / self.pq.sqrt()
159 }
160}
161
162impl distribution::Variance for Bernoulli {
163 #[inline]
164 fn variance(&self) -> f64 {
165 self.pq
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use alloc::{vec, vec::Vec};
172 use assert;
173 use prelude::*;
174
175 macro_rules! new(
176 (failure $q:expr) => (Bernoulli::with_failure($q));
177 ($p:expr) => (Bernoulli::new($p));
178 );
179
180 #[test]
181 fn distribution() {
182 let d = new!(0.25);
183 let x = vec![-0.1, 0.0, 0.1, 0.25, 0.5, 1.0, 1.1];
184 let p = vec![0.0, 0.75, 0.75, 0.75, 0.75, 1.0, 1.0];
185 assert_eq!(
186 &x.iter().map(|&x| d.distribution(x)).collect::<Vec<_>>(),
187 &p
188 );
189 }
190
191 #[test]
192 fn entropy() {
193 let d = vec![new!(0.25), new!(0.5), new!(0.75)];
194 assert::close(
195 &d.iter().map(|d| d.entropy()).collect::<Vec<_>>(),
196 &vec![0.5623351446188083, 0.6931471805599453, 0.5623351446188083],
197 1e-16,
198 );
199 }
200
201 #[test]
202 fn inverse() {
203 let d = new!(0.25);
204 let p = vec![0.0, 0.25, 0.5, 0.75, 0.75000000001, 1.0];
205 let x = vec![0, 0, 0, 0, 1, 1];
206 assert_eq!(&p.iter().map(|&p| d.inverse(p)).collect::<Vec<_>>(), &x);
207 }
208
209 #[test]
210 fn kurtosis() {
211 assert_eq!(new!(0.5).kurtosis(), -2.0);
212 }
213
214 #[test]
215 fn mass() {
216 let d = new!(0.25);
217 assert_eq!(
218 &(0..3).map(|x| d.mass(x)).collect::<Vec<_>>(),
219 &[0.75, 0.25, 0.0]
220 );
221 }
222
223 #[test]
224 fn mean() {
225 assert_eq!(new!(0.5).mean(), 0.5);
226 }
227
228 #[test]
229 fn median() {
230 assert_eq!(new!(0.25).median(), 0.0);
231 assert_eq!(new!(0.5).median(), 0.5);
232 assert_eq!(new!(0.75).median(), 1.0);
233 }
234
235 #[test]
236 fn modes() {
237 assert_eq!(new!(0.25).modes(), vec![0]);
238 assert_eq!(new!(0.5).modes(), vec![0, 1]);
239 assert_eq!(new!(0.75).modes(), vec![1]);
240 }
241
242 #[test]
243 fn sample() {
244 assert!(
245 Independent(&new!(0.25), &mut source::default(42))
246 .take(100)
247 .fold(0, |a, b| a + b)
248 <= 100
249 );
250 }
251
252 #[test]
253 fn skewness() {
254 assert_eq!(new!(0.5).skewness(), 0.0);
255 }
256
257 #[test]
258 fn variance() {
259 assert_eq!(new!(0.25).variance(), 0.1875);
260 }
261}