compute/distributions/
bernoulli.rs1#![allow(clippy::float_cmp)]
2use crate::distributions::*;
3
4#[derive(Debug, Clone, Copy)]
6pub struct Bernoulli {
7 p: f64,
9}
10
11impl Bernoulli {
12 pub fn new(p: f64) -> Self {
17 if !(0. ..=1.).contains(&p) {
18 panic!("`p` must be in [0, 1].");
19 }
20 Bernoulli { p }
21 }
22 pub fn set_p(&mut self, p: f64) -> &mut Self {
23 if !(0. ..=1.).contains(&p) {
24 panic!("`p` must be in [0, 1].");
25 }
26 self.p = p;
27 self
28 }
29}
30
31impl Default for Bernoulli {
32 fn default() -> Self {
33 Self::new(0.5)
34 }
35}
36
37impl Distribution for Bernoulli {
38 type Output = f64;
39 fn sample(&self) -> f64 {
41 if self.p == 1. {
42 return 1.;
43 } else if self.p == 0. {
44 return 0.;
45 }
46
47 if self.p > alea::f64() {
48 1.
49 } else {
50 0.
51 }
52 }
53}
54
55impl Distribution1D for Bernoulli {
56 fn update(&mut self, params: &[f64]) {
57 self.set_p(params[0]);
58 }
59}
60
61impl Discrete for Bernoulli {
62 fn pmf(&self, k: i64) -> f64 {
67 if k == 0 {
68 1. - self.p
69 } else if k == 1 {
70 self.p
71 } else {
72 0.
73 }
74 }
75}
76
77impl Mean for Bernoulli {
78 type MeanType = f64;
79 fn mean(&self) -> f64 {
81 self.p
82 }
83}
84
85impl Variance for Bernoulli {
86 type VarianceType = f64;
87 fn var(&self) -> f64 {
89 self.p * (1. - self.p)
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use crate::statistics::{mean, var};
97 use approx_eq::assert_approx_eq;
98
99 #[test]
100 fn test_bernoulli() {
101 let data = Bernoulli::new(0.75).sample_n(1e6 as usize);
102 for i in &data {
103 assert!(*i == 0. || *i == 1.);
104 }
105 assert_approx_eq!(0.75, mean(&data), 1e-2);
106 assert_approx_eq!(0.75 * 0.25, var(&data), 1e-2);
107 assert!(Bernoulli::default().pmf(2) == 0.);
108 assert!(Bernoulli::default().pmf(0) == 0.5);
109 }
110}