use thiserror::Error;
use crate::error::SearchError;
use crate::search::{search_bounded_zero, search_monotone, SEARCH_BOUND};
use crate::special::beta_inc;
use crate::special::gamma_log;
use crate::traits::{Discrete, DiscreteCdf, Mean, Variance};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Binomial {
n: u64,
pr: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Error)]
pub enum BinomialError {
#[error("success probability {0} outside [0..1]")]
PrOutOfRange(f64),
#[error("number of successes {s} exceeds the number of trials {n}")]
SuccessesExceedTrials { s: u64, n: u64 },
#[error("probability {0} outside [0..1]")]
PNotInRange(f64),
#[error("probability {0} outside [0..1]")]
QNotInRange(f64),
#[error("p ({p}) and q ({q}) are not complementary: |p + q - 1| > 3ε")]
PQSumNotOne { p: f64, q: f64 },
#[error(transparent)]
Search(#[from] SearchError),
}
impl Binomial {
#[inline]
pub fn new(n: u64, pr: f64) -> Self {
Self::try_new(n, pr).unwrap()
}
#[inline]
pub fn try_new(n: u64, pr: f64) -> Result<Self, BinomialError> {
if !(0.0..=1.0).contains(&pr) || !pr.is_finite() {
return Err(BinomialError::PrOutOfRange(pr));
}
Ok(Self { n, pr })
}
#[inline]
pub const fn n(&self) -> u64 {
self.n
}
#[inline]
pub const fn pr(&self) -> f64 {
self.pr
}
#[inline]
pub fn search_trials(p: f64, q: f64, pr: f64, s: u64) -> Result<f64, BinomialError> {
check_pq(p, q)?;
if !(0.0..=1.0).contains(&pr) || !pr.is_finite() {
return Err(BinomialError::PrOutOfRange(pr));
}
let sf = s as f64;
let f = |n: f64| {
if sf >= n {
return if p <= q { 1.0 - p } else { -q };
}
let (sf_bin, cdf_bin) = beta_inc(sf + 1.0, n - sf, pr, 1.0 - pr);
if p <= q {
cdf_bin - p
} else {
sf_bin - q
}
};
Ok(search_monotone(
0.0,
SEARCH_BOUND,
5.0,
0.0,
SEARCH_BOUND,
f,
)?)
}
#[inline]
pub fn search_pr(p: f64, q: f64, n: u64, s: u64) -> Result<f64, BinomialError> {
check_pq(p, q)?;
if s > n {
return Err(BinomialError::SuccessesExceedTrials { s, n });
}
let nf = n as f64;
let sf = s as f64;
if p <= q {
let f = |pr: f64| {
let (_sf_bin, cdf_bin) = beta_inc(sf + 1.0, nf - sf, pr, 1.0 - pr);
cdf_bin - p
};
Ok(search_bounded_zero(0.0, 1.0, f)?)
} else {
let f = |ompr: f64| {
let (sf_bin, _cdf_bin) = beta_inc(sf + 1.0, nf - sf, 1.0 - ompr, ompr);
sf_bin - q
};
let ompr = search_bounded_zero(0.0, 1.0, f)?;
Ok(1.0 - ompr)
}
}
}
#[inline]
fn check_p(p: f64) -> Result<(), BinomialError> {
if !(0.0..=1.0).contains(&p) || !p.is_finite() {
Err(BinomialError::PNotInRange(p))
} else {
Ok(())
}
}
#[inline]
fn check_q(q: f64) -> Result<(), BinomialError> {
if !(0.0..=1.0).contains(&q) || !q.is_finite() {
Err(BinomialError::QNotInRange(q))
} else {
Ok(())
}
}
#[inline]
fn check_pq(p: f64, q: f64) -> Result<(), BinomialError> {
check_p(p)?;
check_q(q)?;
if (p + q - 1.0).abs() > 3.0 * f64::EPSILON {
return Err(BinomialError::PQSumNotOne { p, q });
}
Ok(())
}
fn cumbin(s: u64, n: u64, pr: f64) -> (f64, f64) {
if s >= n {
return (1.0, 0.0);
}
let sf = s as f64;
let nf = n as f64;
let (p, q) = beta_inc(sf + 1.0, nf - sf, pr, 1.0 - pr);
(q, p)
}
impl DiscreteCdf for Binomial {
type Error = BinomialError;
#[inline]
fn cdf(&self, s: u64) -> f64 {
cumbin(s, self.n, self.pr).0
}
#[inline]
fn ccdf(&self, s: u64) -> f64 {
cumbin(s, self.n, self.pr).1
}
#[inline]
fn inverse_cdf(&self, p: f64) -> Result<u64, BinomialError> {
check_p(p)?;
if p == 0.0 {
return Ok(0);
}
if p == 1.0 {
return Ok(self.n);
}
let mut lo = 0u64;
let mut hi = self.n;
while lo < hi {
let mid = lo + (hi - lo) / 2;
if self.cdf(mid) < p {
lo = mid + 1;
} else {
hi = mid;
}
}
Ok(lo)
}
}
impl Binomial {
#[inline]
pub fn inverse_ccdf(&self, q: f64) -> Result<f64, BinomialError> {
check_q(q)?;
let nf = self.n as f64;
let pr = self.pr;
let p = 1.0 - q;
let f = |s: f64| {
let (cum, ccum) = if s >= nf {
(1.0, 0.0)
} else {
let (cb_ccum, cb_cum) = beta_inc(s + 1.0, nf - s, pr, 1.0 - pr);
(cb_cum, cb_ccum)
};
if p <= q {
cum - p
} else {
ccum - q
}
};
Ok(search_monotone(0.0, nf, 5.0, 0.0, nf, f)?)
}
}
impl Discrete for Binomial {
#[inline]
fn pmf(&self, s: u64) -> f64 {
if s > self.n {
return 0.0;
}
self.ln_pmf(s).exp()
}
#[inline]
fn ln_pmf(&self, s: u64) -> f64 {
if s > self.n {
return f64::NEG_INFINITY;
}
let n = self.n as f64;
let sf = s as f64;
let pr = self.pr;
let log_c = gamma_log(n + 1.0) - gamma_log(sf + 1.0) - gamma_log(n - sf + 1.0);
let log_pr = if pr == 0.0 {
if s == 0 {
0.0
} else {
f64::NEG_INFINITY
}
} else {
sf * pr.ln()
};
let log_q = if pr == 1.0 {
if s == self.n {
0.0
} else {
f64::NEG_INFINITY
}
} else {
(n - sf) * (1.0 - pr).ln()
};
log_c + log_pr + log_q
}
}
impl Mean for Binomial {
#[inline]
fn mean(&self) -> f64 {
self.n as f64 * self.pr
}
}
impl Variance for Binomial {
#[inline]
fn variance(&self) -> f64 {
self.n as f64 * self.pr * (1.0 - self.pr)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rejects_invalid_inputs() {
assert!(matches!(
Binomial::try_new(10, -0.1),
Err(BinomialError::PrOutOfRange(-0.1))
));
assert!(matches!(
Binomial::search_trials(-0.1, 1.1, 0.5, 3),
Err(BinomialError::PNotInRange(-0.1))
));
assert!(matches!(
Binomial::search_trials(0.5, 0.5, f64::NAN, 3),
Err(BinomialError::PrOutOfRange(x)) if x.is_nan()
));
assert!(matches!(
Binomial::search_pr(0.5, 0.5, 3, 4),
Err(BinomialError::SuccessesExceedTrials { s: 4, n: 3 })
));
}
#[cfg(not(miri))]
#[test]
fn edge_and_moment_cases() {
let b = Binomial::new(10, 0.3);
assert_eq!(b.inverse_cdf(0.0).unwrap(), 0);
assert_eq!(b.pmf(11), 0.0);
assert_eq!(b.ln_pmf(11), f64::NEG_INFINITY);
assert_eq!(Binomial::new(10, 0.0).ln_pmf(0), 0.0);
assert_eq!(Binomial::new(10, 1.0).ln_pmf(10), 0.0);
assert_eq!(b.mean(), 3.0);
assert!((b.variance() - 2.1).abs() < 1e-15);
}
}