use crate::FloatScalar;
use crate::special::{lgamma, betainc};
use super::{DiscreteDistribution, StatsError};
#[derive(Debug, Clone, Copy)]
pub struct Binomial<T> {
n: u64,
p: T,
}
impl<T: FloatScalar> Binomial<T> {
pub fn new(n: u64, p: T) -> Result<Self, StatsError> {
if p < T::zero() || p > T::one() {
return Err(StatsError::InvalidParameter);
}
Ok(Self { n, p })
}
}
impl<T: FloatScalar> Binomial<T> {
pub fn sample(&self, rng: &mut super::Rng) -> u64 {
let np = T::from(self.n).unwrap() * self.p;
let nq = T::from(self.n).unwrap() * (T::one() - self.p);
let five = T::from(5.0).unwrap();
if self.n <= 20 || np < five || nq < five {
let mut count = 0u64;
for _ in 0..self.n {
if rng.next_float::<T>() < self.p {
count += 1;
}
}
count
} else {
let mean = np;
let std = (np * (T::one() - self.p)).sqrt();
let x = mean + std * rng.next_normal::<T>();
let k = x.round().to_u64().unwrap_or(0);
k.min(self.n)
}
}
pub fn sample_array<const K: usize>(&self, rng: &mut super::Rng) -> [u64; K] {
let mut out = [0u64; K];
for v in out.iter_mut() {
*v = self.sample(rng);
}
out
}
}
impl<T: FloatScalar> DiscreteDistribution<T> for Binomial<T> {
fn pmf(&self, k: u64) -> T {
if k > self.n {
return T::zero();
}
self.ln_pmf(k).exp()
}
fn ln_pmf(&self, k: u64) -> T {
if k > self.n {
return T::neg_infinity();
}
let one = T::one();
let nf = T::from(self.n).unwrap();
let kf = T::from(k).unwrap();
lgamma(nf + one) - lgamma(kf + one) - lgamma(nf - kf + one)
+ kf * self.p.ln()
+ (nf - kf) * (one - self.p).ln()
}
fn cdf(&self, k: u64) -> T {
if k >= self.n {
return T::one();
}
let one = T::one();
let a = T::from(self.n - k).unwrap();
let b = T::from(k + 1).unwrap();
betainc(a, b, one - self.p).unwrap_or(T::nan())
}
fn quantile(&self, p: T) -> u64 {
if p <= T::zero() {
return 0;
}
if p >= T::one() {
return self.n;
}
let mean = T::from(self.n).unwrap() * self.p;
let std = (mean * (T::one() - self.p)).sqrt();
let z = super::normal_quantile_standard(p);
let k0 = (mean + std * z)
.max(T::zero())
.min(T::from(self.n).unwrap())
.floor()
.to_u64()
.unwrap_or(0);
super::discrete_quantile_search(|k| self.cdf(k), p, k0)
}
fn mean(&self) -> T {
T::from(self.n).unwrap() * self.p
}
fn variance(&self) -> T {
T::from(self.n).unwrap() * self.p * (T::one() - self.p)
}
}