use thiserror::Error;
use crate::error::SearchError;
use crate::search::{search_bounded_zero, search_monotone, SEARCH_BOUND};
use crate::special::beta_inc;
use crate::special::gamma_log;
use crate::traits::{Discrete, DiscreteCdf, Mean, Variance};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct NegativeBinomial {
r: u64,
pr: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Error)]
pub enum NegativeBinomialError {
#[error("success probability {0} outside (0..1]")]
PrOutOfRange(f64),
#[error("`r` must be positive")]
RNotPositive,
#[error("probability {0} outside [0..1]")]
PNotInRange(f64),
#[error("probability {0} outside [0..1]")]
QNotInRange(f64),
#[error("p ({p}) and q ({q}) are not complementary: |p + q - 1| > 3ε")]
PQSumNotOne { p: f64, q: f64 },
#[error(transparent)]
Search(#[from] SearchError),
}
impl NegativeBinomial {
#[inline]
pub fn new(r: u64, pr: f64) -> Self {
Self::try_new(r, pr).unwrap()
}
#[inline]
pub fn try_new(r: u64, pr: f64) -> Result<Self, NegativeBinomialError> {
if r == 0 {
return Err(NegativeBinomialError::RNotPositive);
}
if !(pr > 0.0 && pr <= 1.0 && pr.is_finite()) {
return Err(NegativeBinomialError::PrOutOfRange(pr));
}
Ok(Self { r, pr })
}
#[inline]
pub const fn r(&self) -> u64 {
self.r
}
#[inline]
pub const fn pr(&self) -> f64 {
self.pr
}
#[inline]
pub fn search_r(p: f64, q: f64, pr: f64, s: u64) -> Result<f64, NegativeBinomialError> {
check_pq(p, q)?;
if !(pr > 0.0 && pr <= 1.0) {
return Err(NegativeBinomialError::PrOutOfRange(pr));
}
let sf = s as f64;
let f = |r: f64| {
let (cum, ccum) = beta_inc(r, sf + 1.0, pr, 1.0 - pr);
if p <= q {
cum - p
} else {
ccum - q
}
};
Ok(search_monotone(
0.0,
SEARCH_BOUND,
5.0,
0.0,
SEARCH_BOUND,
f,
)?)
}
#[inline]
pub fn search_pr(p: f64, q: f64, r: u64, s: u64) -> Result<f64, NegativeBinomialError> {
check_pq(p, q)?;
let rf = r as f64;
let sf = s as f64;
if p <= q {
let f = |pr: f64| {
let (cum, _ccum) = beta_inc(rf, sf + 1.0, pr, 1.0 - pr);
cum - p
};
Ok(search_bounded_zero(0.0, 1.0, f)?)
} else {
let f = |ompr: f64| {
let (_cum, ccum) = beta_inc(rf, sf + 1.0, 1.0 - ompr, ompr);
ccum - q
};
let ompr = search_bounded_zero(0.0, 1.0, f)?;
Ok(1.0 - ompr)
}
}
}
#[inline]
fn check_p(p: f64) -> Result<(), NegativeBinomialError> {
if !(0.0..=1.0).contains(&p) || !p.is_finite() {
Err(NegativeBinomialError::PNotInRange(p))
} else {
Ok(())
}
}
#[inline]
fn check_q(q: f64) -> Result<(), NegativeBinomialError> {
if !(0.0..=1.0).contains(&q) || !q.is_finite() {
Err(NegativeBinomialError::QNotInRange(q))
} else {
Ok(())
}
}
#[inline]
fn check_pq(p: f64, q: f64) -> Result<(), NegativeBinomialError> {
check_p(p)?;
check_q(q)?;
if (p + q - 1.0).abs() > 3.0 * f64::EPSILON {
return Err(NegativeBinomialError::PQSumNotOne { p, q });
}
Ok(())
}
impl DiscreteCdf for NegativeBinomial {
type Error = NegativeBinomialError;
#[inline]
fn cdf(&self, s: u64) -> f64 {
let (cum, _) = beta_inc(self.r as f64, s as f64 + 1.0, self.pr, 1.0 - self.pr);
cum
}
#[inline]
fn ccdf(&self, s: u64) -> f64 {
let (_, ccum) = beta_inc(self.r as f64, s as f64 + 1.0, self.pr, 1.0 - self.pr);
ccum
}
#[inline]
fn inverse_cdf(&self, p: f64) -> Result<u64, NegativeBinomialError> {
check_p(p)?;
if p == 0.0 {
return Ok(0);
}
if p == 1.0 {
return Ok(u64::MAX);
}
let pr = self.pr;
let r = self.r as f64;
let mean = r * (1.0 - pr) / pr;
let sd = (mean / pr).sqrt();
let mut hi = (mean + 10.0 * sd + 10.0).ceil() as u64;
while self.cdf(hi) < p && hi < u64::MAX / 2 {
hi *= 2;
}
if self.cdf(hi) < p {
return Ok(u64::MAX);
}
let mut lo = 0u64;
while lo < hi {
let mid = lo + (hi - lo) / 2;
if self.cdf(mid) < p {
lo = mid + 1;
} else {
hi = mid;
}
}
Ok(lo)
}
}
impl NegativeBinomial {
#[inline]
pub fn inverse_ccdf(&self, q: f64) -> Result<f64, NegativeBinomialError> {
check_q(q)?;
let rf = self.r as f64;
let pr = self.pr;
let p = 1.0 - q;
let f = |s: f64| {
let (cum, ccum) = beta_inc(rf, s + 1.0, pr, 1.0 - pr);
if p <= q {
cum - p
} else {
ccum - q
}
};
Ok(search_monotone(
0.0,
SEARCH_BOUND,
5.0,
0.0,
SEARCH_BOUND,
f,
)?)
}
}
impl Discrete for NegativeBinomial {
#[inline]
fn pmf(&self, s: u64) -> f64 {
self.ln_pmf(s).exp()
}
#[inline]
fn ln_pmf(&self, s: u64) -> f64 {
let rf = self.r as f64;
let sf = s as f64;
let log_c = gamma_log(sf + rf) - gamma_log(sf + 1.0) - gamma_log(rf);
log_c + rf * self.pr.ln() + sf * (1.0 - self.pr).ln()
}
}
impl Mean for NegativeBinomial {
#[inline]
fn mean(&self) -> f64 {
self.r as f64 * (1.0 - self.pr) / self.pr
}
}
impl Variance for NegativeBinomial {
#[inline]
fn variance(&self) -> f64 {
let m = self.mean();
m / self.pr
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rejects_invalid_parameters() {
assert!(matches!(
NegativeBinomial::try_new(0, 0.5),
Err(NegativeBinomialError::RNotPositive)
));
assert!(matches!(
NegativeBinomial::try_new(1, 0.0),
Err(NegativeBinomialError::PrOutOfRange(0.0))
));
}
#[test]
fn inverse_zero_and_moments() {
let d = NegativeBinomial::new(5, 0.4);
assert_eq!(d.inverse_cdf(0.0).unwrap(), 0);
assert!(d.ln_pmf(3).is_finite());
assert!(d.mean().is_finite());
assert!(d.variance().is_finite());
}
#[test]
fn search_helpers_reject_invalid_inputs() {
assert!(matches!(
NegativeBinomial::search_r(-0.1, 1.1, 0.5, 3),
Err(NegativeBinomialError::PNotInRange(-0.1))
));
assert!(matches!(
NegativeBinomial::search_r(0.5, 0.5, 0.0, 3),
Err(NegativeBinomialError::PrOutOfRange(0.0))
));
}
}