use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::Array1;
use scirs2_core::numeric::{Float, FloatConst, NumCast};
use scirs2_core::random::prelude::*;
use std::cmp;
#[derive(Debug, Clone)]
pub struct Hypergeometric<F: Float> {
n_population: usize,
n_success: usize,
n_draws: usize,
loc: F,
}
impl<F: Float + NumCast + FloatConst + std::fmt::Display> Hypergeometric<F> {
pub fn new(n_population: usize, n_success: usize, ndraws: usize, loc: F) -> StatsResult<Self> {
if n_population == 0 {
return Err(StatsError::InvalidArgument(
"Population size must be positive".to_string(),
));
}
if n_success > n_population {
return Err(StatsError::InvalidArgument(
"Number of success states cannot exceed population size".to_string(),
));
}
if ndraws > n_population {
return Err(StatsError::InvalidArgument(
"Number of draws cannot exceed population size".to_string(),
));
}
Ok(Hypergeometric {
n_population,
n_success,
n_draws: ndraws,
loc,
})
}
pub fn pmf(&self, x: F) -> F {
let adjusted_x = x - self.loc;
let k_f = adjusted_x.to_f64().unwrap_or(f64::NAN);
if k_f.fract() != 0.0 || k_f.is_nan() {
return F::zero();
}
let k = k_f as i64;
if k < 0 {
return F::zero();
}
let k = k as usize;
let max_possible = cmp::min(self.n_draws, self.n_success);
let min_possible = self
.n_draws
.saturating_sub(self.n_population - self.n_success);
if k < min_possible || k > max_possible {
return F::zero();
}
let ln_pmf = ln_binomial(self.n_success, k)
+ ln_binomial(self.n_population - self.n_success, self.n_draws - k)
- ln_binomial(self.n_population, self.n_draws);
F::from(ln_pmf.exp()).unwrap_or(F::zero())
}
pub fn cdf(&self, x: F) -> F {
let adjusted_x = x - self.loc;
let k_f = adjusted_x.to_f64().unwrap_or(f64::NAN);
if k_f.is_nan() {
return F::zero();
}
let k_floor = k_f.floor() as i64;
if k_floor < 0 {
return F::zero();
}
let min_possible = self
.n_draws
.saturating_sub(self.n_population - self.n_success);
let max_k = cmp::min(k_floor as usize, cmp::min(self.n_draws, self.n_success));
let mut cdf_value = F::zero();
for k in min_possible..=max_k {
cdf_value =
cdf_value + self.pmf(F::from(k).expect("Failed to convert to float") + self.loc);
}
if cdf_value > F::one() {
F::one()
} else {
cdf_value
}
}
pub fn rvs(&self, size: usize) -> StatsResult<Array1<F>> {
let mut rng = thread_rng();
let mut samples = Array1::zeros(size);
for i in 0..size {
let mut successes = 0;
let mut population_remaining = self.n_population;
let mut success_remaining = self.n_success;
for _ in 0..self.n_draws {
if population_remaining == 0 {
break;
}
let p_success = F::from(success_remaining).expect("Failed to convert to float")
/ F::from(population_remaining).expect("Failed to convert to float");
if rng.random_range(0.0..1.0) < p_success.to_f64().expect("Operation failed") {
successes += 1;
success_remaining -= 1;
} else {
}
population_remaining -= 1;
}
samples[i] = F::from(successes).expect("Failed to convert to float") + self.loc;
}
Ok(samples)
}
pub fn mean(&self) -> F {
let mean_val = (self.n_draws as f64) * (self.n_success as f64) / (self.n_population as f64);
F::from(mean_val).expect("Failed to convert to float") + self.loc
}
pub fn var(&self) -> F {
if self.n_population <= 1 {
return F::zero();
}
let n_draws = self.n_draws as f64;
let k = self.n_success as f64;
let n = self.n_population as f64;
let p = k / n;
let variance = n_draws * p * (1.0 - p) * (n - n_draws) / (n - 1.0);
F::from(variance).expect("Failed to convert to float")
}
pub fn std(&self) -> F {
self.var().sqrt()
}
}
#[allow(dead_code)]
fn ln_binomial(n: usize, k: usize) -> f64 {
if k > n {
return f64::NEG_INFINITY;
}
if k == 0 || k == n {
return 0.0;
}
let k = k.min(n - k);
let ln_n_fact = (1..=n).map(|i| (i as f64).ln()).sum::<f64>();
let ln_k_fact = (1..=k).map(|i| (i as f64).ln()).sum::<f64>();
let ln_n_minus_k_fact = (1..=(n - k)).map(|i| (i as f64).ln()).sum::<f64>();
ln_n_fact - ln_k_fact - ln_n_minus_k_fact
}
#[allow(dead_code)]
pub fn hypergeom<F>(
n_population: usize,
n_success: usize,
n_draws: usize,
loc: F,
) -> StatsResult<Hypergeometric<F>>
where
F: Float + NumCast + FloatConst + std::fmt::Display,
{
Hypergeometric::new(n_population, n_success, n_draws, loc)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_hypergeometric_creation() {
let hyper = Hypergeometric::new(20, 7, 12, 0.0).expect("Operation failed");
assert_eq!(hyper.n_population, 20);
assert_eq!(hyper.n_success, 7);
assert_eq!(hyper.n_draws, 12);
assert_eq!(hyper.loc, 0.0);
assert!(Hypergeometric::<f64>::new(0, 5, 10, 0.0).is_err()); assert!(Hypergeometric::<f64>::new(20, 21, 10, 0.0).is_err()); assert!(Hypergeometric::<f64>::new(20, 7, 21, 0.0).is_err()); }
#[test]
fn test_hypergeometric_pmf() {
let hyper = Hypergeometric::new(20, 7, 12, 0.0).expect("Operation failed");
assert_relative_eq!(hyper.pmf(0.0), 0.0001031991744066048, epsilon = 1e-10);
assert_relative_eq!(hyper.pmf(3.0), 0.1986584107327147, epsilon = 1e-6);
assert_relative_eq!(hyper.pmf(4.0), 0.3575851393188869, epsilon = 1e-6);
assert_relative_eq!(hyper.pmf(7.0), 0.0102167182662539, epsilon = 1e-6);
assert_eq!(hyper.pmf(-1.0), 0.0);
assert_eq!(hyper.pmf(8.0), 0.0);
assert_eq!(hyper.pmf(0.5), 0.0);
let shifted_hyper = Hypergeometric::new(20, 7, 12, 2.0).expect("Operation failed");
assert_relative_eq!(
shifted_hyper.pmf(2.0),
0.0001031991744066048,
epsilon = 1e-10
); assert_relative_eq!(shifted_hyper.pmf(5.0), 0.1986584107327147, epsilon = 1e-6);
}
#[test]
fn test_hypergeometric_cdf() {
let hyper = Hypergeometric::new(20, 7, 12, 0.0).expect("Operation failed");
assert_relative_eq!(hyper.cdf(0.0), 0.0001031991744066048, epsilon = 1e-10);
assert_relative_eq!(hyper.cdf(3.0), 0.2507739938080501, epsilon = 1e-6);
assert_relative_eq!(hyper.cdf(4.0), 0.608359133126937, epsilon = 1e-6);
assert_eq!(hyper.cdf(7.0), 1.0);
assert_eq!(hyper.cdf(-1.0), 0.0);
assert_eq!(hyper.cdf(20.0), 1.0);
let shifted_hyper = Hypergeometric::new(20, 7, 12, 2.0).expect("Operation failed");
assert_relative_eq!(
shifted_hyper.cdf(2.0),
0.0001031991744066048,
epsilon = 1e-10
); assert_relative_eq!(shifted_hyper.cdf(5.0), 0.2507739938080501, epsilon = 1e-6);
}
#[test]
fn test_hypergeometric_stats() {
let hyper = Hypergeometric::new(20, 7, 12, 0.0).expect("Operation failed");
assert_relative_eq!(hyper.mean(), 4.2, epsilon = 1e-10);
assert_relative_eq!(hyper.var(), 1.1494736842105262, epsilon = 1e-10);
assert_relative_eq!(hyper.std(), 1.0721351053904196, epsilon = 1e-10);
let shifted_hyper = Hypergeometric::new(20, 7, 12, 3.0).expect("Operation failed");
assert_relative_eq!(shifted_hyper.mean(), 7.2, epsilon = 1e-10); assert_relative_eq!(shifted_hyper.var(), 1.1494736842105262, epsilon = 1e-10);
}
#[test]
fn test_hypergeometric_rvs() {
let hyper = Hypergeometric::<f64>::new(100, 40, 20, 0.0).expect("Operation failed");
let samples = hyper.rvs(1000).expect("Operation failed");
assert_eq!(samples.len(), 1000);
for sample in samples.iter() {
assert!(sample.fract() == 0.0); assert!(*sample >= 0.0);
assert!(*sample <= 20.0); }
let mean = samples.sum() / samples.len() as f64;
assert!((mean - 8.0).abs() < 0.5); }
#[test]
fn test_hypergeometric_edge_cases() {
let hyper_no_success = Hypergeometric::new(20, 0, 10, 0.0).expect("Operation failed");
assert_eq!(hyper_no_success.pmf(0.0), 1.0);
assert_eq!(hyper_no_success.pmf(1.0), 0.0);
assert_eq!(hyper_no_success.mean(), 0.0);
assert_eq!(hyper_no_success.var(), 0.0);
let hyper_no_draws = Hypergeometric::new(20, 10, 0, 0.0).expect("Operation failed");
assert_eq!(hyper_no_draws.pmf(0.0), 1.0);
assert_eq!(hyper_no_draws.pmf(1.0), 0.0);
assert_eq!(hyper_no_draws.mean(), 0.0);
assert_eq!(hyper_no_draws.var(), 0.0);
let hyper_all_success = Hypergeometric::new(20, 20, 10, 0.0).expect("Operation failed");
assert_eq!(hyper_all_success.pmf(10.0), 1.0);
assert_eq!(hyper_all_success.pmf(9.0), 0.0);
assert_eq!(hyper_all_success.mean(), 10.0);
assert_eq!(hyper_all_success.var(), 0.0);
}
#[test]
fn test_ln_binomial() {
assert_relative_eq!(ln_binomial(5, 2).exp(), 10.0, epsilon = 1e-10);
assert_relative_eq!(ln_binomial(10, 5).exp(), 252.0, epsilon = 1e-10);
assert_eq!(ln_binomial(5, 0).exp(), 1.0);
assert_eq!(ln_binomial(5, 5).exp(), 1.0);
assert_eq!(ln_binomial(0, 0).exp(), 1.0);
assert!(ln_binomial(5, 6) < 0.0); }
}