use crate::DType;
use crate::stats::error::{StatsError, StatsResult};
use crate::stats::{DiscreteDistribution, Distribution};
use numr::algorithm::special::SpecialFunctions;
use numr::error::Result;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
#[derive(Debug, Clone, Copy)]
pub struct NegativeBinomial {
r: f64,
p: f64,
}
impl NegativeBinomial {
pub fn new(r: u64, p: f64) -> StatsResult<Self> {
Self::new_real(r as f64, p)
}
pub fn new_real(r: f64, p: f64) -> StatsResult<Self> {
if r <= 0.0 {
return Err(StatsError::InvalidParameter {
name: "r".to_string(),
value: r,
reason: "number of successes must be positive".to_string(),
});
}
if p <= 0.0 || p > 1.0 {
return Err(StatsError::InvalidParameter {
name: "p".to_string(),
value: p,
reason: "probability must be in (0, 1]".to_string(),
});
}
Ok(Self { r, p })
}
pub fn r(&self) -> f64 {
self.r
}
pub fn p(&self) -> f64 {
self.p
}
}
impl Distribution for NegativeBinomial {
fn mean(&self) -> f64 {
self.r * (1.0 - self.p) / self.p
}
fn var(&self) -> f64 {
self.r * (1.0 - self.p) / (self.p * self.p)
}
fn entropy(&self) -> f64 {
let mut h = 0.0;
let mut k = 0_u64;
let mut total_prob = 0.0;
while total_prob < 0.9999 && k < 10000 {
let p_k = self.pmf(k);
if p_k > 1e-300 {
h -= p_k * p_k.ln();
}
total_prob += p_k;
k += 1;
}
h
}
fn median(&self) -> f64 {
self.ppf(0.5).unwrap_or(self.mean().round() as u64) as f64
}
fn mode(&self) -> f64 {
if self.r <= 1.0 {
0.0
} else {
((self.r - 1.0) * (1.0 - self.p) / self.p).floor()
}
}
fn skewness(&self) -> f64 {
(2.0 - self.p) / (self.r * (1.0 - self.p)).sqrt()
}
fn kurtosis(&self) -> f64 {
(6.0 / self.r) + (self.p * self.p) / (self.r * (1.0 - self.p))
}
}
impl DiscreteDistribution for NegativeBinomial {
fn pmf(&self, k: u64) -> f64 {
self.log_pmf(k).exp()
}
fn log_pmf(&self, k: u64) -> f64 {
use super::super::continuous::special::lgamma;
let r = self.r;
let p = self.p;
let k_f = k as f64;
lgamma(k_f + r) - lgamma(k_f + 1.0) - lgamma(r) + r * p.ln() + k_f * (1.0 - p).ln()
}
fn cdf(&self, k: u64) -> f64 {
use super::super::continuous::special::betainc;
betainc(self.r, k as f64 + 1.0, self.p)
}
fn sf(&self, k: u64) -> f64 {
1.0 - self.cdf(k)
}
fn ppf(&self, prob: f64) -> StatsResult<u64> {
if !(0.0..=1.0).contains(&prob) {
return Err(StatsError::InvalidParameter {
name: "p".to_string(),
value: prob,
reason: "probability must be in [0, 1]".to_string(),
});
}
if prob == 0.0 {
return Ok(0);
}
if prob == 1.0 {
return Ok(u64::MAX);
}
let mean = self.mean();
let std = self.var().sqrt();
let mut low = 0_u64;
let mut high = (mean + 10.0 * std).max(100.0) as u64;
while self.cdf(high) < prob {
high *= 2;
if high > 1_000_000_000 {
return Ok(u64::MAX);
}
}
while low < high {
let mid = low + (high - low) / 2;
if self.cdf(mid) < prob {
low = mid + 1;
} else {
high = mid;
}
}
Ok(low)
}
fn pmf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let log_pmf = self.log_pmf_tensor(k, client)?;
client.exp(&log_pmf)
}
fn log_pmf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let ln_p = self.p.ln();
let ln_q = (1.0 - self.p).ln();
let lgamma_r = self.r.ln();
let k_floor = client.floor(k)?;
let k_plus_r = client.add_scalar(&k_floor, self.r)?;
let k_plus_1 = client.add_scalar(&k_floor, 1.0)?;
let lgamma_k_plus_r = client.lgamma(&k_plus_r)?;
let lgamma_k_plus_1 = client.lgamma(&k_plus_1)?;
let log_binom_coeff = client.sub(&lgamma_k_plus_r, &lgamma_k_plus_1)?;
let log_binom_coeff = client.sub_scalar(&log_binom_coeff, lgamma_r)?;
let r_times_ln_p = self.r * ln_p;
let k_times_ln_q = client.mul_scalar(&k_floor, ln_q)?;
let result = client.add_scalar(&log_binom_coeff, r_times_ln_p)?;
client.add(&result, &k_times_ln_q)
}
fn cdf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let k_floor = client.floor(k)?;
let k_plus_1 = client.add_scalar(&k_floor, 1.0)?;
let shape = k_floor.shape();
let r_tensor = client.fill(shape, self.r, k_floor.dtype())?;
let p_tensor = client.fill(shape, self.p, k_floor.dtype())?;
client.betainc(&r_tensor, &k_plus_1, &p_tensor)
}
fn sf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let cdf = self.cdf_tensor(k, client)?;
client.sub_scalar(&client.mul_scalar(&cdf, -1.0)?, -1.0)
}
fn ppf_tensor<R: Runtime<DType = DType>, C>(
&self,
p: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let mean = self.r * (1.0 - self.p) / self.p;
let var = self.r * (1.0 - self.p) / (self.p * self.p);
let std = var.sqrt();
let two_p_minus_1 = client.sub_scalar(&client.mul_scalar(p, 2.0)?, 1.0)?;
let erfinv_val = client.erfinv(&two_p_minus_1)?;
let z = client.mul_scalar(&erfinv_val, std::f64::consts::SQRT_2)?;
let scaled = client.mul_scalar(&z, std)?;
let result = client.add_scalar(&scaled, mean)?;
client.clamp(&result, 0.0, f64::INFINITY)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_negative_binomial_creation() {
assert!(NegativeBinomial::new(5, 0.5).is_ok());
assert!(NegativeBinomial::new(0, 0.5).is_err());
assert!(NegativeBinomial::new(5, 0.0).is_err());
assert!(NegativeBinomial::new(5, 1.5).is_err());
}
#[test]
fn test_negative_binomial_pmf() {
let nb = NegativeBinomial::new(3, 0.5).unwrap();
assert!((nb.pmf(0) - 0.125).abs() < 1e-10);
assert!((nb.pmf(1) - 0.1875).abs() < 1e-10);
let sum: f64 = (0..100).map(|k| nb.pmf(k)).sum();
assert!((sum - 1.0).abs() < 1e-6);
}
#[test]
fn test_negative_binomial_cdf() {
let nb = NegativeBinomial::new(3, 0.5).unwrap();
assert!((nb.cdf(0) - nb.pmf(0)).abs() < 1e-10);
assert!(nb.cdf(0) <= nb.cdf(1));
assert!(nb.cdf(1) <= nb.cdf(2));
assert!(nb.cdf(100) > 0.999);
}
#[test]
fn test_negative_binomial_mean() {
let nb = NegativeBinomial::new(5, 0.5).unwrap();
assert!((nb.mean() - 5.0).abs() < 1e-10);
}
#[test]
fn test_negative_binomial_variance() {
let nb = NegativeBinomial::new(5, 0.5).unwrap();
assert!((nb.var() - 10.0).abs() < 1e-10);
}
#[test]
fn test_negative_binomial_ppf() {
let nb = NegativeBinomial::new(5, 0.5).unwrap();
assert_eq!(nb.ppf(0.0).unwrap(), 0);
for k in [0, 1, 5, 10] {
let p = nb.cdf(k);
let recovered = nb.ppf(p).unwrap();
assert!(recovered == k || recovered == k + 1);
}
}
#[test]
fn test_geometric_special_case() {
let nb = NegativeBinomial::new(1, 0.3).unwrap();
let expected_mean = 0.7 / 0.3;
assert!((nb.mean() - expected_mean).abs() < 1e-10);
}
}