use crate::error::{StatsError, StatsResult};
use crate::utils::special_functions::ln_gamma;
use num_traits::ToPrimitive;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct BinomialConfig<T>
where
T: ToPrimitive,
{
pub n: u64,
pub p: T,
}
impl<T> BinomialConfig<T>
where
T: ToPrimitive,
{
pub fn new(n: u64, p: T) -> StatsResult<Self> {
let p_64 = p.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "BinomialConfig::new: Failed to convert p to f64".to_string(),
})?;
if n == 0 {
return Err(StatsError::InvalidInput {
message: "BinomialConfig::new: n must be positive".to_string(),
});
}
if !((0.0..=1.0).contains(&p_64)) {
return Err(StatsError::InvalidInput {
message: "BinomialConfig::new: p must be between 0 and 1".to_string(),
});
}
Ok(Self { n, p })
}
}
#[inline]
pub fn pmf<T>(k: u64, n: u64, p: T) -> StatsResult<f64>
where
T: ToPrimitive,
{
let p_64 = p.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "binomial_distribution::pmf: Failed to convert p to f64".to_string(),
})?;
if n == 0 {
return Err(StatsError::InvalidInput {
message: "binomial_distribution::pmf: n must be positive".to_string(),
});
}
if !((0.0..=1.0).contains(&p_64)) {
return Err(StatsError::InvalidInput {
message: "binomial_distribution::pmf: p must be between 0 and 1".to_string(),
});
}
let combinations = combination(n, k)?;
if p_64 == 0.0 {
return Ok(if k == 0 { combinations } else { 0.0 });
}
if p_64 == 1.0 {
return Ok(if k == n { combinations } else { 0.0 });
}
let k_f64 = k as f64;
let n_minus_k_f64 = (n - k) as f64;
let log_prob = k_f64 * p_64.ln() + n_minus_k_f64 * (1.0 - p_64).ln();
let prob = log_prob.exp();
Ok(combinations * prob)
}
#[inline]
pub fn cdf(k: u64, n: u64, p: f64) -> StatsResult<f64> {
if n == 0 {
return Err(StatsError::InvalidInput {
message: "binomial_distribution::cdf: n must be positive".to_string(),
});
}
if !((0.0..=1.0).contains(&p)) {
return Err(StatsError::InvalidInput {
message: "binomial_distribution::cdf: p must be between 0 and 1".to_string(),
});
}
if k > n {
return Err(StatsError::InvalidInput {
message: "binomial_distribution::cdf: k must be less than or equal to n".to_string(),
});
}
if p == 0.0 {
return Ok(1.0); }
if p == 1.0 {
return Ok(if k >= n { 1.0 } else { 0.0 });
}
let q = 1.0 - p;
let mut pmf_i = q.powi(n as i32);
if pmf_i == 0.0 && n > 0 {
let log_pmf_0 = (n as f64) * q.ln();
pmf_i = log_pmf_0.exp();
}
let mut cdf_sum = pmf_i;
let ratio = p / q;
for i in 0..k {
pmf_i *= ((n - i) as f64 / (i + 1) as f64) * ratio;
cdf_sum += pmf_i;
}
Ok(cdf_sum.clamp(0.0, 1.0))
}
#[inline]
fn combination(n: u64, k: u64) -> StatsResult<f64> {
if k > n {
return Err(StatsError::InvalidInput {
message: "binomial_distribution::combination: k must be less than or equal to n"
.to_string(),
});
}
if k > n / 2 {
return combination(n, n - k);
}
Ok((1..=k).fold(1.0_f64, |acc, i| acc * (n - i + 1) as f64 / i as f64))
}
#[derive(Debug, Clone, Copy)]
pub struct Binomial {
pub n: u64,
pub p: f64,
}
impl Binomial {
pub fn new(n: u64, p: f64) -> StatsResult<Self> {
if n == 0 {
return Err(StatsError::InvalidInput {
message: "Binomial::new: n must be at least 1".to_string(),
});
}
if !(0.0..=1.0).contains(&p) {
return Err(StatsError::InvalidInput {
message: "Binomial::new: p must be in [0, 1]".to_string(),
});
}
Ok(Self { n, p })
}
pub fn fit(data: &[f64]) -> StatsResult<Self> {
if data.is_empty() {
return Err(StatsError::InvalidInput {
message: "Binomial::fit: data must not be empty".to_string(),
});
}
let n = data
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max)
.round() as u64;
let mean = data.iter().sum::<f64>() / data.len() as f64;
let p = if n == 0 { 0.5 } else { mean / n as f64 };
Self::new(n.max(1), p.clamp(0.0, 1.0))
}
}
impl crate::distributions::traits::DiscreteDistribution for Binomial {
fn name(&self) -> &str {
"Binomial"
}
fn num_params(&self) -> usize {
2
}
fn pmf(&self, k: u64) -> StatsResult<f64> {
pmf(k, self.n, self.p)
}
fn logpmf(&self, k: u64) -> StatsResult<f64> {
let n = self.n;
if k > n {
return Ok(f64::NEG_INFINITY);
}
let log_binom =
ln_gamma((n + 1) as f64) - ln_gamma((k + 1) as f64) - ln_gamma((n - k + 1) as f64);
let log_p = match (self.p, k) {
(0.0, 0) => 0.0,
(0.0, _) => return Ok(f64::NEG_INFINITY),
(_, _) => k as f64 * self.p.ln(),
};
let log_q = match (self.p, n - k) {
(1.0, 0) => 0.0,
(1.0, _) => return Ok(f64::NEG_INFINITY),
(_, nk) => nk as f64 * (1.0 - self.p).ln(),
};
Ok(log_binom + log_p + log_q)
}
fn cdf(&self, k: u64) -> StatsResult<f64> {
cdf(k, self.n, self.p)
}
fn mean(&self) -> f64 {
self.n as f64 * self.p
}
fn variance(&self) -> f64 {
self.n as f64 * self.p * (1.0 - self.p)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_binomial_pmf() {
let n = 10;
let p = 0.5;
let k = 5;
let result = pmf(k, n, p).unwrap();
assert!(
!result.is_nan(),
"PMF returned NaN for k={}, n={}, p={}",
k,
n,
p
);
}
#[test]
fn test_binomial_cdf() {
let n = 10;
let p = 0.5;
let k = 5;
let result = cdf(k, n, p).unwrap();
assert!(
!result.is_nan(),
"CDF returned NaN for k={}, n={}, p={}",
k,
n,
p
);
}
#[test]
fn test_binomial_pmf_large_values_n() {
let n = 2_200_000_000u64;
let k = 5u64;
let p = 0.5;
let result = pmf(k, n, p);
match result {
Ok(val) => {
assert!(
!val.is_infinite(),
"PMF should not be infinite for large values"
);
}
Err(_) => {
}
}
}
#[test]
fn test_binomial_pmf_large_values_k() {
let n = 2u64;
let k = 2_200_000_000_000u64;
let p = 0.5;
let result = pmf(k, n, p);
match result {
Ok(val) => {
assert!(
!val.is_infinite(),
"PMF should not be infinite for large values"
);
}
Err(_) => {
}
}
}
#[test]
fn test_binomial_config_new_valid() {
let config = BinomialConfig::new(10, 0.5);
assert!(config.is_ok());
let config = config.unwrap();
assert_eq!(config.n, 10);
}
#[test]
fn test_binomial_config_new_n_zero() {
let result = BinomialConfig::new(0, 0.5);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_binomial_config_new_p_out_of_range_negative() {
let result = BinomialConfig::new(10, -0.1);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_binomial_config_new_p_out_of_range_above_one() {
let result = BinomialConfig::new(10, 1.1);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_binomial_config_new_p_zero() {
let config = BinomialConfig::new(10, 0.0);
assert!(config.is_ok());
}
#[test]
fn test_binomial_config_new_p_one() {
let config = BinomialConfig::new(10, 1.0);
assert!(config.is_ok());
}
#[test]
fn test_binomial_pmf_p_zero_k_zero() {
let result = pmf(0, 10, 0.0).unwrap();
assert_eq!(result, 1.0);
}
#[test]
fn test_binomial_pmf_p_zero_k_greater_than_zero() {
let result = pmf(5, 10, 0.0).unwrap();
assert_eq!(result, 0.0);
}
#[test]
fn test_binomial_pmf_p_one_k_equals_n() {
let result = pmf(10, 10, 1.0).unwrap();
assert_eq!(result, 1.0);
}
#[test]
fn test_binomial_pmf_p_one_k_less_than_n() {
let result = pmf(5, 10, 1.0).unwrap();
assert_eq!(result, 0.0);
}
#[test]
fn test_binomial_pmf_n_zero() {
let result = pmf(0, 0, 0.5);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_binomial_pmf_p_out_of_range() {
let result = pmf(5, 10, 1.5);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_binomial_cdf_k_greater_than_n() {
let result = cdf(15, 10, 0.5);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_binomial_combination_symmetry() {
let n = 10u64;
let k = 8u64;
let result1 = combination(n, k).unwrap();
let result2 = combination(n, n - k).unwrap();
assert_eq!(result1, result2);
assert_eq!(result1, 45.0);
}
#[test]
fn test_binomial_combination_k_greater_than_n() {
let result = combination(10, 15);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_binomial_combination_k_equals_n() {
let result = combination(10, 10).unwrap();
assert_eq!(result, 1.0);
}
#[test]
fn test_binomial_combination_k_zero() {
let result = combination(10, 0).unwrap();
assert_eq!(result, 1.0);
}
#[test]
fn test_binomial_config_new_n_one() {
let config = BinomialConfig::new(1, 0.5);
assert!(config.is_ok());
let config = config.unwrap();
assert_eq!(config.n, 1);
}
#[test]
fn test_binomial_pmf_k_greater_than_n() {
let result = pmf(15, 10, 0.5);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_binomial_cdf_n_zero() {
let result = cdf(5, 0, 0.5);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_binomial_cdf_p_out_of_range() {
let result = cdf(5, 10, 1.5);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_binomial_combination_k_exactly_n_over_2() {
let n = 10u64;
let k = 5u64; let result = combination(n, k).unwrap();
assert_eq!(result, 252.0);
}
#[test]
fn test_binomial_combination_k_just_over_n_over_2() {
let n = 10u64;
let k = 6u64; let result1 = combination(n, k).unwrap();
let result2 = combination(n, n - k).unwrap();
assert_eq!(result1, result2);
assert_eq!(result1, 210.0);
}
}