use super::Distribution;
use num::Complex;
use RustQuant_error::RustQuantError;
pub struct Bernoulli {
p: f64,
}
impl Default for Bernoulli {
fn default() -> Self {
Self::new(0.5)
}
}
impl Bernoulli {
#[must_use]
pub fn new(probability: f64) -> Bernoulli {
assert!((0.0..=1.0).contains(&probability));
Bernoulli { p: probability }
}
}
impl Distribution for Bernoulli {
fn cf(&self, t: f64) -> Complex<f64> {
assert!((0.0..=1.0).contains(&self.p));
let i: Complex<f64> = Complex::i();
1.0 - self.p + self.p * (i * t).exp()
}
fn pdf(&self, x: f64) -> f64 {
self.pmf(x)
}
fn pmf(&self, k: f64) -> f64 {
assert!((0.0..=1.0).contains(&self.p));
assert!(k == 0.0 || k == 1.0);
(self.p).powi(k as i32) * (1.0 - self.p).powi(1 - k as i32)
}
fn cdf(&self, k: f64) -> f64 {
assert!((0.0..=1.0).contains(&self.p));
if (k as i32) < 0 {
0.0
} else if (0..1).contains(&(k as i32)) {
1.0 - self.p
} else {
1.0
}
}
fn inv_cdf(&self, p: f64) -> f64 {
assert!((0.0..=1.0).contains(&p));
if p < 1.0 - self.p {
0.0
} else {
1.0
}
}
fn mean(&self) -> f64 {
self.p
}
fn median(&self) -> f64 {
if self.p < 0.5 {
0.0
} else {
1.0
}
}
fn mode(&self) -> f64 {
if self.p <= 0.5 {
0.0
} else {
1.0
}
}
fn variance(&self) -> f64 {
self.p * (1.0 - self.p)
}
fn skewness(&self) -> f64 {
let p = self.p;
((1.0 - p) - p) / (p * (1.0 - p)).sqrt()
}
fn kurtosis(&self) -> f64 {
let p = self.p;
(1.0 - 6.0 * p * (1.0 - p)) / (p * (1.0 - p))
}
fn entropy(&self) -> f64 {
(self.p - 1.0) * (1.0 - self.p).ln() - self.p * (self.p).ln()
}
fn mgf(&self, t: f64) -> f64 {
1.0 - self.p + self.p * f64::exp(t)
}
fn sample(&self, n: usize) -> Result<Vec<f64>, RustQuantError> {
use rand::thread_rng;
use rand_distr::{Bernoulli, Distribution};
assert!(n > 0);
let mut rng = thread_rng();
let dist = Bernoulli::new(self.p)?;
let mut variates: Vec<f64> = Vec::with_capacity(n);
for _ in 0..variates.capacity() {
variates.push(usize::from(dist.sample(&mut rng)) as f64);
}
Ok(variates)
}
}
#[cfg(test)]
mod tests_bernoulli {
use super::*;
use RustQuant_error::RustQuantError;
use RustQuant_utils::{assert_approx_equal, RUSTQUANT_EPSILON as EPS};
#[test]
fn test_bernoulli_functions() {
let dist = Bernoulli::new(1.0);
let cf = dist.cf(1.0);
assert_approx_equal!(cf.re, 0.540_302_305_868_139_8, EPS);
assert_approx_equal!(cf.im, 0.841_470_984_807_896_5, EPS);
let bernoulli = Bernoulli::new(0.5);
let pmf = dist.pmf(1.0);
assert_approx_equal!(pmf, 1.0, EPS);
let pmf_zero = bernoulli.pmf(0.0);
let pmf_one = bernoulli.pmf(1.0);
assert_approx_equal!(pmf_zero, 0.5, EPS);
assert_approx_equal!(pmf_one, 0.5, EPS);
let cdf = dist.cdf(1.0);
assert_approx_equal!(cdf, 1.0, EPS);
let cdf_neg = bernoulli.cdf(-1.0);
let cdf_zero = bernoulli.cdf(0.0);
let cdf_half = bernoulli.cdf(0.5);
let cdf_one = bernoulli.cdf(1.0);
let cdf_two = bernoulli.cdf(2.0);
assert_approx_equal!(cdf_neg, 0.0, EPS);
assert_approx_equal!(cdf_zero, 0.5, EPS);
assert_approx_equal!(cdf_half, 0.5, EPS);
assert_approx_equal!(cdf_one, 1.0, EPS);
assert_approx_equal!(cdf_two, 1.0, EPS);
let mgf = bernoulli.mgf(1.0);
assert_approx_equal!(mgf, 1.0 - 0.5 + 0.5 * 1_f64.exp(), EPS);
let cf = bernoulli.cf(1.0);
assert_eq!(
cf,
Complex::new(1.0 - 0.5 + 0.5 * 1_f64.cos(), 0.5 * 1_f64.sin())
);
}
#[test]
fn test_bernoulli_moments() {
let bernoulli = Bernoulli::new(0.5);
assert_approx_equal!(bernoulli.mean(), 0.5, EPS);
assert_approx_equal!(bernoulli.variance(), 0.25, EPS);
assert_approx_equal!(bernoulli.skewness(), 0.0, EPS);
assert_approx_equal!(bernoulli.kurtosis(), -2.0, EPS);
}
#[test]
fn test_bernoulli_entropy() {
let bernoulli = Bernoulli::new(0.5);
assert_approx_equal!(
bernoulli.entropy(),
-(0.5f64.ln() * 0.5 + (1.0 - 0.5_f64).ln() * (1.0 - 0.5)),
EPS
);
}
#[test]
fn test_default() {
let bernoulli = Bernoulli::default();
assert_approx_equal!(bernoulli.p, 0.5, EPS);
}
#[test]
#[should_panic(expected = "assertion failed: (0.0..=1.0).contains(&probability)")]
fn test_new_invalid_probability_low() {
let _ = Bernoulli::new(-0.5);
}
#[test]
#[should_panic(expected = "assertion failed: (0.0..=1.0).contains(&probability)")]
fn test_new_invalid_probability_high() {
let _ = Bernoulli::new(1.5);
}
#[test]
#[should_panic(expected = "assertion failed: k == 0.0 || k == 1.0")]
fn test_pmf_invalid_input() {
let bernoulli = Bernoulli::new(0.5);
bernoulli.pmf(2.0);
}
#[test]
fn test_cdf_negative_input() {
let bernoulli = Bernoulli::new(0.5);
let cdf_neg = bernoulli.cdf(-1.0);
assert_approx_equal!(cdf_neg, 0.0, EPS);
}
#[test]
fn test_cdf_positive_input() {
let bernoulli = Bernoulli::new(0.5);
let cdf_one = bernoulli.cdf(1.0);
let cdf_two = bernoulli.cdf(2.0);
assert_approx_equal!(cdf_one, 1.0, EPS);
assert_approx_equal!(cdf_two, 1.0, EPS);
}
#[test]
fn test_inv_cdf() {
let bernoulli = Bernoulli::new(0.5);
let inv_cdf_one = bernoulli.inv_cdf(0.5);
let inv_cdf_two = bernoulli.inv_cdf(0.3);
assert_approx_equal!(inv_cdf_one, 1.0, EPS);
assert_approx_equal!(inv_cdf_two, 0.0, EPS);
}
#[test]
fn test_median() {
let bernoulli = Bernoulli::new(0.5);
let median = bernoulli.median();
assert_approx_equal!(median, 1.0, EPS);
}
#[test]
fn test_mode() {
let bernoulli = Bernoulli::new(0.5);
let mode = bernoulli.mode();
assert_approx_equal!(mode, 0.0, EPS);
}
#[test]
#[should_panic(expected = "assertion failed: n > 0")]
fn test_sample_zero_size() {
let bernoulli = Bernoulli::new(0.5);
let _ = bernoulli.sample(0);
}
#[test]
fn test_sample_positive_size() -> Result<(), RustQuantError> {
let bernoulli = Bernoulli::new(0.5);
let sample = bernoulli.sample(100)?;
assert_eq!(sample.len(), 100);
for &value in &sample {
assert!(value == 0.0 || (value - 1.0).abs() < EPS);
}
Ok(())
}
}