use crate::DType;
use crate::stats::discrete::log_binom;
use crate::stats::error::{StatsError, StatsResult};
use crate::stats::{DiscreteDistribution, Distribution};
use numr::algorithm::special::SpecialFunctions;
use numr::error::Result;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
#[derive(Debug, Clone, Copy)]
pub struct Hypergeometric {
pop_size: u64,
num_success: u64,
num_draws: u64,
}
impl Hypergeometric {
pub fn new(pop_size: u64, num_success: u64, num_draws: u64) -> StatsResult<Self> {
if pop_size == 0 {
return Err(StatsError::InvalidParameter {
name: "pop_size".to_string(),
value: pop_size as f64,
reason: "population size must be positive".to_string(),
});
}
if num_success > pop_size {
return Err(StatsError::InvalidParameter {
name: "num_success".to_string(),
value: num_success as f64,
reason: "number of successes cannot exceed population size".to_string(),
});
}
if num_draws > pop_size {
return Err(StatsError::InvalidParameter {
name: "num_draws".to_string(),
value: num_draws as f64,
reason: "number of draws cannot exceed population size".to_string(),
});
}
Ok(Self {
pop_size,
num_success,
num_draws,
})
}
pub fn pop_size(&self) -> u64 {
self.pop_size
}
pub fn num_success(&self) -> u64 {
self.num_success
}
pub fn num_draws(&self) -> u64 {
self.num_draws
}
pub fn min_val(&self) -> u64 {
let n = self.num_draws;
let n_minus_k = self.pop_size - self.num_success;
n.saturating_sub(n_minus_k)
}
pub fn max_val(&self) -> u64 {
self.num_draws.min(self.num_success)
}
}
impl Distribution for Hypergeometric {
fn mean(&self) -> f64 {
let n = self.num_draws as f64;
let k = self.num_success as f64;
let big_n = self.pop_size as f64;
n * k / big_n
}
fn var(&self) -> f64 {
let n = self.num_draws as f64;
let k = self.num_success as f64;
let big_n = self.pop_size as f64;
if big_n <= 1.0 {
return 0.0;
}
n * k * (big_n - k) * (big_n - n) / (big_n * big_n * (big_n - 1.0))
}
fn entropy(&self) -> f64 {
let min_k = self.min_val();
let max_k = self.max_val();
let mut h = 0.0;
for k in min_k..=max_k {
let p = self.pmf(k);
if p > 1e-300 {
h -= p * p.ln();
}
}
h
}
fn median(&self) -> f64 {
self.ppf(0.5).unwrap_or(self.mean().round() as u64) as f64
}
fn mode(&self) -> f64 {
let n = self.num_draws as f64;
let k = self.num_success as f64;
let big_n = self.pop_size as f64;
(((n + 1.0) * (k + 1.0)) / (big_n + 2.0)).floor()
}
fn skewness(&self) -> f64 {
let n = self.num_draws as f64;
let k = self.num_success as f64;
let big_n = self.pop_size as f64;
if big_n <= 2.0 {
return f64::NAN;
}
let _p = k / big_n;
let numerator = (big_n - 2.0 * k) * (big_n - 1.0).sqrt() * (big_n - 2.0 * n);
let denominator = (n * k * (big_n - k) * (big_n - n)).sqrt() * (big_n - 2.0);
numerator / denominator
}
fn kurtosis(&self) -> f64 {
let n = self.num_draws as f64;
let k = self.num_success as f64;
let big_n = self.pop_size as f64;
if big_n <= 3.0 {
return f64::NAN;
}
let a = (big_n - 1.0)
* big_n
* big_n
* (big_n * (big_n + 1.0) - 6.0 * k * (big_n - k) - 6.0 * n * (big_n - n))
+ 6.0 * n * k * (big_n - k) * (big_n - n) * (5.0 * big_n - 6.0);
let b = n * k * (big_n - k) * (big_n - n) * (big_n - 2.0) * (big_n - 3.0);
if b == 0.0 {
return f64::NAN;
}
a / b
}
}
impl DiscreteDistribution for Hypergeometric {
fn pmf(&self, k: u64) -> f64 {
let log_p = self.log_pmf(k);
if log_p.is_finite() { log_p.exp() } else { 0.0 }
}
fn log_pmf(&self, k: u64) -> f64 {
let min_k = self.min_val();
let max_k = self.max_val();
if k < min_k || k > max_k {
return f64::NEG_INFINITY;
}
let big_n = self.pop_size;
let big_k = self.num_success;
let n = self.num_draws;
log_binom(big_k, k) + log_binom(big_n - big_k, n - k) - log_binom(big_n, n)
}
fn cdf(&self, k: u64) -> f64 {
let max_k = self.max_val();
if k >= max_k {
return 1.0;
}
let min_k = self.min_val();
if k < min_k {
return 0.0;
}
let mut sum = 0.0;
for i in min_k..=k {
sum += self.pmf(i);
}
sum.min(1.0)
}
fn sf(&self, k: u64) -> f64 {
1.0 - self.cdf(k)
}
fn ppf(&self, prob: f64) -> StatsResult<u64> {
if !(0.0..=1.0).contains(&prob) {
return Err(StatsError::InvalidParameter {
name: "p".to_string(),
value: prob,
reason: "probability must be in [0, 1]".to_string(),
});
}
let min_k = self.min_val();
let max_k = self.max_val();
if prob == 0.0 {
return Ok(min_k);
}
if prob >= 1.0 {
return Ok(max_k);
}
let mut cumulative = 0.0;
for k in min_k..=max_k {
cumulative += self.pmf(k);
if cumulative >= prob {
return Ok(k);
}
}
Ok(max_k)
}
fn pmf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let log_pmf = self.log_pmf_tensor(k, client)?;
client.exp(&log_pmf)
}
fn log_pmf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let n_f = self.num_draws as f64;
let k_f = self.num_success as f64;
let big_n_f = self.pop_size as f64;
let k_floor = client.floor(k)?;
let k_plus_1 = client.add_scalar(&k_floor, 1.0)?;
let neg_k = client.mul_scalar(&k_floor, -1.0)?;
let k_f_minus_k = client.add_scalar(&neg_k, k_f)?;
let k_f_minus_k_plus_1 = client.add_scalar(&k_f_minus_k, 1.0)?;
let lgamma_k_plus_1 = client.lgamma(&k_plus_1)?;
let lgamma_k_f_minus_k_plus_1 = client.lgamma(&k_f_minus_k_plus_1)?;
let shape = k_floor.shape();
let lgamma_k_f_plus_1_tensor = client.fill(shape, k_f.ln(), k_floor.dtype())?;
let log_c_k_k = client.sub(&lgamma_k_f_plus_1_tensor, &lgamma_k_plus_1)?;
let log_c_k_k = client.sub(&log_c_k_k, &lgamma_k_f_minus_k_plus_1)?;
let n_minus_k = client.add_scalar(&neg_k, n_f)?;
let n_minus_k_plus_1 = client.add_scalar(&n_minus_k, 1.0)?;
let big_n_f_minus_k_f = big_n_f - k_f;
let big_n_f_minus_k_f_minus_n = big_n_f_minus_k_f - n_f;
let denom_const = client.add_scalar(&neg_k, big_n_f_minus_k_f_minus_n)?;
let denom_plus_1 = client.add_scalar(&denom_const, 1.0)?;
let lgamma_n_minus_k_plus_1 = client.lgamma(&n_minus_k_plus_1)?;
let lgamma_denom_plus_1 = client.lgamma(&denom_plus_1)?;
let lgamma_nk_f_plus_1_tensor =
client.fill(shape, big_n_f_minus_k_f.ln(), k_floor.dtype())?;
let log_c_nk_nk = client.sub(&lgamma_nk_f_plus_1_tensor, &lgamma_n_minus_k_plus_1)?;
let log_c_nk_nk = client.sub(&log_c_nk_nk, &lgamma_denom_plus_1)?;
let n_plus_1 = client.fill(shape, n_f + 1.0, k_floor.dtype())?;
let lgamma_n_plus_1 = client.lgamma(&n_plus_1)?;
let lgamma_big_n_minus_n_plus_1_tensor =
client.fill(shape, (big_n_f - n_f + 1.0).ln(), k_floor.dtype())?;
let lgamma_big_n_plus_1_tensor = client.fill(shape, big_n_f.ln(), k_floor.dtype())?;
let log_c_n_n = client.sub(&lgamma_big_n_plus_1_tensor, &lgamma_n_plus_1)?;
let log_c_n_n = client.sub(&log_c_n_n, &lgamma_big_n_minus_n_plus_1_tensor)?;
let result = client.add(&log_c_k_k, &log_c_nk_nk)?;
client.sub(&result, &log_c_n_n)
}
fn cdf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let k_floor = client.floor(k)?;
let mean = self.mean();
let var = self.var();
let std = var.sqrt();
let centered = client.sub_scalar(&k_floor, mean)?;
let z = client.div_scalar(¢ered, std)?;
let z_scaled = client.mul_scalar(&z, -std::f64::consts::FRAC_1_SQRT_2)?;
let erfc_val = client.erfc(&z_scaled)?;
client.mul_scalar(&erfc_val, 0.5)
}
fn sf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let cdf = self.cdf_tensor(k, client)?;
client.sub_scalar(&client.mul_scalar(&cdf, -1.0)?, -1.0)
}
fn ppf_tensor<R: Runtime<DType = DType>, C>(
&self,
p: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let mean = self.mean();
let var = self.var();
let std = var.sqrt();
let two_p_minus_1 = client.sub_scalar(&client.mul_scalar(p, 2.0)?, 1.0)?;
let erfinv_val = client.erfinv(&two_p_minus_1)?;
let z = client.mul_scalar(&erfinv_val, std::f64::consts::SQRT_2)?;
let scaled = client.mul_scalar(&z, std)?;
let result = client.add_scalar(&scaled, mean)?;
let min_k = self.min_val() as f64;
let max_k = self.max_val() as f64;
client.clamp(&result, min_k, max_k)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hypergeometric_creation() {
assert!(Hypergeometric::new(20, 7, 12).is_ok());
assert!(Hypergeometric::new(0, 0, 0).is_err());
assert!(Hypergeometric::new(10, 15, 5).is_err()); assert!(Hypergeometric::new(10, 5, 15).is_err()); }
#[test]
fn test_hypergeometric_bounds() {
let h = Hypergeometric::new(20, 7, 12).unwrap();
assert_eq!(h.min_val(), 0);
assert_eq!(h.max_val(), 7);
}
#[test]
fn test_hypergeometric_pmf() {
let h = Hypergeometric::new(20, 7, 12).unwrap();
let sum: f64 = (h.min_val()..=h.max_val()).map(|k| h.pmf(k)).sum();
assert!((sum - 1.0).abs() < 1e-10);
assert!((h.pmf(8) - 0.0).abs() < 1e-10);
}
#[test]
fn test_hypergeometric_cdf() {
let h = Hypergeometric::new(20, 7, 12).unwrap();
assert!((h.cdf(h.max_val()) - 1.0).abs() < 1e-10);
let mut prev = 0.0;
for k in h.min_val()..=h.max_val() {
let curr = h.cdf(k);
assert!(curr >= prev);
prev = curr;
}
}
#[test]
fn test_hypergeometric_mean() {
let h = Hypergeometric::new(20, 7, 12).unwrap();
assert!((h.mean() - 4.2).abs() < 1e-10);
}
#[test]
fn test_hypergeometric_variance() {
let h = Hypergeometric::new(20, 7, 12).unwrap();
let expected = 12.0 * 7.0 * 13.0 * 8.0 / (400.0 * 19.0);
assert!((h.var() - expected).abs() < 1e-10);
}
#[test]
fn test_hypergeometric_ppf() {
let h = Hypergeometric::new(20, 7, 12).unwrap();
assert_eq!(h.ppf(0.0).unwrap(), h.min_val());
assert_eq!(h.ppf(1.0).unwrap(), h.max_val());
for k in h.min_val()..=h.max_val() {
let p = h.cdf(k);
let recovered = h.ppf(p).unwrap();
assert!(recovered == k || recovered == k + 1);
}
}
#[test]
fn test_hypergeometric_mode() {
let h = Hypergeometric::new(20, 7, 12).unwrap();
let mode = h.mode() as u64;
let pmf_mode = h.pmf(mode);
for k in h.min_val()..=h.max_val() {
assert!(h.pmf(k) <= pmf_mode + 1e-10);
}
}
#[test]
fn test_hypergeometric_extreme() {
let h = Hypergeometric::new(10, 10, 5).unwrap();
assert!((h.pmf(5) - 1.0).abs() < 1e-10);
let h = Hypergeometric::new(10, 0, 5).unwrap();
assert!((h.pmf(0) - 1.0).abs() < 1e-10);
}
}