use crate::error::{StatsError, StatsResult};
use crate::sampling::SampleableDistribution;
use scirs2_core::ndarray::{Array1, ArrayBase, Data, Ix1};
use scirs2_core::random::prelude::*;
use scirs2_core::validation::{check_probabilities, check_probabilities_sum_to_one};
use scirs2_core::Rng;
use std::fmt::Debug;
#[allow(dead_code)]
fn factorial(n: u64) -> f64 {
if n <= 1 {
return 1.0;
}
let mut result = 1.0;
for i in 2..=n {
result *= i as f64;
}
result
}
#[allow(dead_code)]
fn multinomial_coef(n: u64, xs: &[u64]) -> f64 {
let mut denominator = 1.0;
for &x in xs {
denominator *= factorial(x);
}
factorial(n) / denominator
}
#[derive(Debug, Clone)]
pub struct Multinomial {
pub n: u64,
pub p: Array1<f64>,
}
impl Multinomial {
pub fn new<D>(n: u64, p: ArrayBase<D, Ix1>) -> StatsResult<Self>
where
D: Data<Elem = f64>,
{
let p_owned = p.to_owned();
check_probabilities(&p_owned, "Probabilities").map_err(StatsError::from)?;
check_probabilities_sum_to_one(&p_owned, "Probabilities", None)
.map_err(StatsError::from)?;
Ok(Multinomial {
n,
p: p_owned,
})
}
pub fn pmf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
where
D: Data<Elem = f64>,
{
let x_vec = x.to_owned();
if x_vec.len() != self.p.len() {
return 0.0;
}
let mut x_u64 = Vec::with_capacity(x_vec.len());
let mut sum = 0;
for &val in x_vec.iter() {
if val < 0.0 || (val - val.floor()).abs() > 1e-10 {
return 0.0;
}
let val_u64 = val as u64;
x_u64.push(val_u64);
sum += val_u64;
}
if sum != self.n {
return 0.0;
}
let coef = multinomial_coef(self.n, &x_u64);
let mut product = 1.0;
for (i, &count) in x_u64.iter().enumerate() {
product *= self.p[i].powf(count as f64);
}
coef * product
}
pub fn logpmf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
where
D: Data<Elem = f64>,
{
let x_vec = x.to_owned();
if x_vec.len() != self.p.len() {
return f64::NEG_INFINITY;
}
let mut x_u64 = Vec::with_capacity(x_vec.len());
let mut sum = 0;
for &val in x_vec.iter() {
if val < 0.0 || (val - val.floor()).abs() > 1e-10 {
return f64::NEG_INFINITY;
}
let val_u64 = val as u64;
x_u64.push(val_u64);
sum += val_u64;
}
if sum != self.n {
return f64::NEG_INFINITY;
}
let log_coef = factorial(self.n).ln();
let mut log_denom = 0.0;
for &count in &x_u64 {
log_denom += factorial(count).ln();
}
let mut log_prob_sum = 0.0;
for (i, &count) in x_u64.iter().enumerate() {
if count > 0 {
log_prob_sum += (count as f64) * self.p[i].ln();
}
}
log_coef - log_denom + log_prob_sum
}
pub fn rvs(&self, size: usize) -> StatsResult<Vec<Array1<f64>>> {
let mut rng = thread_rng();
let mut samples = Vec::with_capacity(size);
let k = self.p.len();
for _ in 0..size {
let mut counts = vec![0u64; k];
for _ in 0..self.n {
let u: f64 = rng.random();
let mut cumulative = 0.0;
let mut category = 0;
for (i, &prob) in self.p.iter().enumerate() {
cumulative += prob;
if u <= cumulative {
category = i;
break;
}
}
counts[category] += 1;
}
let sample = Array1::from_iter(counts.iter().map(|&x| x as f64));
samples.push(sample);
}
Ok(samples)
}
pub fn rvs_single(&self) -> StatsResult<Array1<f64>> {
let samples = self.rvs(1)?;
Ok(samples[0].clone())
}
pub fn mean(&self) -> Array1<f64> {
let n_f64 = self.n as f64;
self.p.mapv(|p_i| n_f64 * p_i)
}
pub fn cov(&self) -> scirs2_core::ndarray::Array2<f64> {
let k = self.p.len();
let n_f64 = self.n as f64;
let mut cov = scirs2_core::ndarray::Array2::zeros((k, k));
for i in 0..k {
for j in 0..k {
if i == j {
cov[[i, j]] = n_f64 * self.p[i] * (1.0 - self.p[i]);
} else {
cov[[i, j]] = -n_f64 * self.p[i] * self.p[j];
}
}
}
cov
}
}
#[allow(dead_code)]
pub fn multinomial<D>(n: u64, p: ArrayBase<D, Ix1>) -> StatsResult<Multinomial>
where
D: Data<Elem = f64>,
{
Multinomial::new(n, p)
}
impl SampleableDistribution<Array1<f64>> for Multinomial {
fn rvs(&self, size: usize) -> StatsResult<Vec<Array1<f64>>> {
self.rvs(size)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_multinomial_creation() {
let n = 10;
let p = array![0.2, 0.3, 0.5];
let multinomial = Multinomial::new(n, p.clone()).expect("Operation failed");
assert_eq!(multinomial.n, n);
assert_eq!(multinomial.p, p);
let p_invalid_sum = array![0.2, 0.3, 0.6]; assert!(Multinomial::new(n, p_invalid_sum).is_err());
let p_negative = array![0.2, -0.1, 0.9];
assert!(Multinomial::new(n, p_negative).is_err());
}
#[test]
fn test_multinomial_pmf() {
let n = 5;
let p = array![0.5, 0.5];
let multinomial = Multinomial::new(n, p).expect("Operation failed");
let x1 = array![2.0, 3.0];
let pmf1 = multinomial.pmf(&x1);
let expected_pmf1 = 0.3125;
assert_relative_eq!(pmf1, expected_pmf1, epsilon = 1e-10);
let x2 = array![5.0, 0.0];
let pmf2 = multinomial.pmf(&x2);
let expected_pmf2 = 0.03125;
assert_relative_eq!(pmf2, expected_pmf2, epsilon = 1e-10);
let x_invalid = array![2.0, 2.0]; let pmf_invalid = multinomial.pmf(&x_invalid);
assert_eq!(pmf_invalid, 0.0);
let x_non_int = array![2.5, 2.5];
let pmf_non_int = multinomial.pmf(&x_non_int);
assert_eq!(pmf_non_int, 0.0);
let x_wrong_dim = array![2.0, 3.0, 0.0];
let pmf_wrong_dim = multinomial.pmf(&x_wrong_dim);
assert_eq!(pmf_wrong_dim, 0.0);
}
#[test]
fn test_multinomial_logpmf() {
let n = 5;
let p = array![0.5, 0.5];
let multinomial = Multinomial::new(n, p).expect("Operation failed");
let x1 = array![2.0, 3.0];
let logpmf1 = multinomial.logpmf(&x1);
let pmf1 = multinomial.pmf(&x1);
assert_relative_eq!(logpmf1.exp(), pmf1, epsilon = 1e-10);
let x_invalid = array![2.0, 2.0]; let logpmf_invalid = multinomial.logpmf(&x_invalid);
assert_eq!(logpmf_invalid, f64::NEG_INFINITY);
}
#[test]
fn test_multinomial_mean() {
let n = 10;
let p = array![0.2, 0.3, 0.5];
let multinomial = Multinomial::new(n, p).expect("Operation failed");
let mean = multinomial.mean();
let expected_mean = array![2.0, 3.0, 5.0];
for i in 0..3 {
assert_relative_eq!(mean[i], expected_mean[i], epsilon = 1e-10);
}
}
#[test]
fn test_multinomial_cov() {
let n = 10;
let p = array![0.2, 0.3, 0.5];
let multinomial = Multinomial::new(n, p).expect("Operation failed");
let cov = multinomial.cov();
assert_relative_eq!(cov[[0, 0]], 10.0 * 0.2 * 0.8, epsilon = 1e-10); assert_relative_eq!(cov[[1, 1]], 10.0 * 0.3 * 0.7, epsilon = 1e-10); assert_relative_eq!(cov[[2, 2]], 10.0 * 0.5 * 0.5, epsilon = 1e-10);
assert_relative_eq!(cov[[0, 1]], -10.0 * 0.2 * 0.3, epsilon = 1e-10); assert_relative_eq!(cov[[0, 2]], -10.0 * 0.2 * 0.5, epsilon = 1e-10); assert_relative_eq!(cov[[1, 2]], -10.0 * 0.3 * 0.5, epsilon = 1e-10);
assert_relative_eq!(cov[[1, 0]], cov[[0, 1]], epsilon = 1e-10);
assert_relative_eq!(cov[[2, 0]], cov[[0, 2]], epsilon = 1e-10);
assert_relative_eq!(cov[[2, 1]], cov[[1, 2]], epsilon = 1e-10);
}
#[test]
fn test_multinomial_rvs() {
let n = 100;
let p = array![0.2, 0.3, 0.5];
let multinomial = Multinomial::new(n, p.clone()).expect("Operation failed");
let num_samples = 100;
let samples = multinomial.rvs(num_samples).expect("Operation failed");
assert_eq!(samples.len(), num_samples);
for sample in &samples {
assert_eq!(sample.len(), 3);
let sum: f64 = sample.sum();
assert_eq!(sum, n as f64);
}
let mut sample_sum = array![0.0, 0.0, 0.0];
for sample in &samples {
sample_sum += sample;
}
let sample_mean = sample_sum / num_samples as f64;
let expected_mean = array![20.0, 30.0, 50.0];
for i in 0..3 {
assert!((sample_mean[i] - expected_mean[i]).abs() < 5.0);
}
}
#[test]
fn test_multinomial_rvs_single() {
let n = 10;
let p = array![0.2, 0.3, 0.5];
let multinomial = Multinomial::new(n, p).expect("Operation failed");
let sample = multinomial.rvs_single().expect("Operation failed");
assert_eq!(sample.len(), 3);
let sum: f64 = sample.sum();
assert_eq!(sum, n as f64);
}
#[test]
fn test_multinomial_coef() {
let coef1 = multinomial_coef(5, &[2, 3]);
assert_eq!(coef1, 10.0);
let coef2 = multinomial_coef(8, &[3, 2, 3]);
assert_eq!(coef2, 560.0);
}
}