use crate::stats::continuous::special;
use crate::stats::distribution::Distribution;
use crate::stats::error::{StatsError, StatsResult};
#[derive(Debug, Clone)]
pub struct Multinomial {
n: u64,
p: Vec<f64>,
k: usize,
}
impl Multinomial {
pub fn new(n: u64, p: Vec<f64>) -> StatsResult<Self> {
if p.is_empty() {
return Err(StatsError::InvalidParameter {
name: "p".to_string(),
value: 0.0,
reason: "probability vector cannot be empty".to_string(),
});
}
for (i, &prob) in p.iter().enumerate() {
if prob < 0.0 {
return Err(StatsError::InvalidParameter {
name: format!("p[{}]", i),
value: prob,
reason: "probability must be non-negative".to_string(),
});
}
}
let sum: f64 = p.iter().sum();
if (sum - 1.0).abs() > 1e-9 {
return Err(StatsError::InvalidParameter {
name: "p".to_string(),
value: sum,
reason: "probabilities must sum to 1.0".to_string(),
});
}
let k = p.len();
Ok(Self { n, p, k })
}
pub fn n(&self) -> u64 {
self.n
}
pub fn p(&self) -> &[f64] {
&self.p
}
pub fn k(&self) -> usize {
self.k
}
pub fn pmf(&self, x: &[u64]) -> f64 {
assert_eq!(x.len(), self.k, "x must have length k");
let sum: u64 = x.iter().sum();
if sum != self.n {
return 0.0;
}
self.log_pmf(x).exp()
}
pub fn log_pmf(&self, x: &[u64]) -> f64 {
assert_eq!(x.len(), self.k, "x must have length k");
let sum: u64 = x.iter().sum();
if sum != self.n {
return f64::NEG_INFINITY;
}
let mut log_result = special::lgamma((self.n + 1) as f64);
for &xi in x.iter() {
log_result -= special::lgamma((xi + 1) as f64);
}
for (xi, pi) in x.iter().zip(self.p.iter()) {
if *xi > 0 {
log_result += (*xi as f64) * pi.ln();
}
}
log_result
}
pub fn mean_vec(&self) -> Vec<f64> {
let n_f = self.n as f64;
self.p.iter().map(|&pi| n_f * pi).collect()
}
pub fn cov_matrix(&self) -> Vec<Vec<f64>> {
let n_f = self.n as f64;
let mut cov = vec![vec![0.0; self.k]; self.k];
for (i, row) in cov.iter_mut().enumerate().take(self.k) {
for (j, cell) in row.iter_mut().enumerate().take(self.k) {
if i == j {
*cell = n_f * self.p[i] * (1.0 - self.p[i]);
} else {
*cell = -n_f * self.p[i] * self.p[j];
}
}
}
cov
}
fn var_first(&self) -> f64 {
self.n as f64 * self.p[0] * (1.0 - self.p[0])
}
}
impl Distribution for Multinomial {
fn mean(&self) -> f64 {
self.n as f64 * self.p[0]
}
fn var(&self) -> f64 {
self.var_first()
}
fn entropy(&self) -> f64 {
let n_f = self.n as f64;
let mut h = special::lgamma(n_f + 1.0);
for &pi in self.p.iter() {
if pi > 0.0 {
let n_pi = n_f * pi;
h -= n_pi * (n_pi.ln() - 1.0);
h -= n_pi * special::digamma(n_pi + 1.0);
}
}
for &pi in self.p.iter() {
if pi > 0.0 {
h += -pi * pi.ln();
}
}
h
}
fn median(&self) -> f64 {
(self.n as f64 * self.p[0]).floor()
}
fn mode(&self) -> f64 {
((self.n as f64 + 1.0) * self.p[0]).floor()
}
fn skewness(&self) -> f64 {
let var = self.var_first();
if var == 0.0 {
return 0.0;
}
let numerator = 1.0 - 2.0 * self.p[0];
numerator / var.sqrt()
}
fn kurtosis(&self) -> f64 {
let var = self.var_first();
if var == 0.0 {
return 0.0;
}
let numerator = 1.0 - 6.0 * self.p[0] * (1.0 - self.p[0]);
numerator / var
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multinomial_creation() {
let m = Multinomial::new(10, vec![0.3, 0.3, 0.4]).unwrap();
assert_eq!(m.n(), 10);
assert_eq!(m.k(), 3);
assert_eq!(m.p(), &[0.3, 0.3, 0.4]);
assert!(Multinomial::new(10, vec![0.3, 0.3, 0.3]).is_err());
assert!(Multinomial::new(10, vec![-0.1, 0.6, 0.5]).is_err());
assert!(Multinomial::new(10, vec![]).is_err());
}
#[test]
fn test_multinomial_mean_vec() {
let m = Multinomial::new(10, vec![0.2, 0.3, 0.5]).unwrap();
let mean = m.mean_vec();
assert!((mean[0] - 2.0).abs() < 1e-10); assert!((mean[1] - 3.0).abs() < 1e-10); assert!((mean[2] - 5.0).abs() < 1e-10); }
#[test]
fn test_multinomial_covariance() {
let m = Multinomial::new(10, vec![0.5, 0.5]).unwrap();
let cov = m.cov_matrix();
assert!((cov[0][0] - 2.5).abs() < 1e-10);
assert!((cov[1][1] - 2.5).abs() < 1e-10);
assert!((cov[0][1] - (-2.5)).abs() < 1e-10);
assert!((cov[1][0] - (-2.5)).abs() < 1e-10);
}
#[test]
fn test_multinomial_pmf() {
let m = Multinomial::new(3, vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]).unwrap();
let pmf = m.pmf(&[1, 1, 1]);
let expected = 6.0 / 27.0;
assert!((pmf - expected).abs() < 1e-10);
let pmf = m.pmf(&[3, 0, 0]);
let expected = 1.0 / 27.0;
assert!((pmf - expected).abs() < 1e-10);
let pmf = m.pmf(&[2, 1, 0]);
let expected = 3.0 / 27.0;
assert!((pmf - expected).abs() < 1e-10);
}
#[test]
fn test_multinomial_pmf_invalid() {
let m = Multinomial::new(3, vec![0.5, 0.5]).unwrap();
assert!((m.pmf(&[1, 1]) - 0.0).abs() < 1e-10);
}
#[test]
fn test_multinomial_pmf_sums_to_one() {
let m = Multinomial::new(2, vec![0.4, 0.6]).unwrap();
let outcomes = [vec![0, 2], vec![1, 1], vec![2, 0]];
let total: f64 = outcomes.iter().map(|x| m.pmf(x)).sum();
assert!((total - 1.0).abs() < 1e-10);
}
#[test]
fn test_multinomial_log_pmf() {
let m = Multinomial::new(3, vec![0.5, 0.5]).unwrap();
let log_pmf = m.log_pmf(&[1, 2]);
let pmf = m.pmf(&[1, 2]);
assert!((log_pmf.exp() - pmf).abs() < 1e-10);
}
#[test]
fn test_multinomial_moments() {
let m = Multinomial::new(100, vec![0.3, 0.7]).unwrap();
assert!((m.mean() - 30.0).abs() < 1e-10);
assert!((m.var() - 21.0).abs() < 1e-10);
let expected_skew = 0.4 / 21.0_f64.sqrt();
assert!((m.skewness() - expected_skew).abs() < 1e-10);
}
#[test]
fn test_multinomial_edge_case_zero_probability() {
let m = Multinomial::new(5, vec![1.0, 0.0]).unwrap();
assert!((m.pmf(&[5, 0]) - 1.0).abs() < 1e-10);
assert!((m.pmf(&[4, 1]) - 0.0).abs() < 1e-10);
}
#[test]
fn test_multinomial_uniform() {
let m = Multinomial::new(4, vec![0.25, 0.25, 0.25, 0.25]).unwrap();
let pmf = m.pmf(&[1, 1, 1, 1]);
let expected = 24.0 / 256.0;
assert!((pmf - expected).abs() < 1e-10);
}
}