use crate::factorial::Factorial;
pub fn combinations(n: u8, r: u8, factorial: &impl Factorial) -> u128 {
assert!(n >= r, "n ({n}) < r ({r})");
factorial.get(n) / factorial.get(r) / factorial.get(n - r)
}
pub fn binomial(n: u8, r: u8, p: f64, factorial: &impl Factorial) -> f64 {
assert!(n >= r, "n ({n}) < r ({r})");
assert!(p <= 1.0, "p ({p}) > 1.0");
combinations(n, r, factorial) as f64 * p.powi(r as i32) * (1.0 - p).powi((n - r) as i32)
}
pub fn trinomial(n: u8, r_1: u8, r_2: u8, p_1: f64, p_2: f64, factorial: &impl Factorial) -> f64 {
assert!(r_1 + r_2 <= n, "r_1 ({r_1}) + r_2 ({r_2}) > n ({n})");
assert!(p_1 + p_2 <= 1.0, "p_1 ({p_1}) + p_2 ({p_2}) + > 1.0");
let p_3 = 1.0 - p_1 - p_2;
let r_3 = n - r_1 - r_2;
(factorial.get(n) / factorial.get(r_1) / factorial.get(r_2) / factorial.get(r_3)) as f64 * p_1.powi(r_1 as i32) * p_2.powi(r_2 as i32) * p_3.powi(r_3 as i32)
}
pub fn quadranomial(n: u8, r_1: u8, r_2: u8, r_3: u8, p_1: f64, p_2: f64, p_3: f64, factorial: &impl Factorial) -> f64 {
assert!(r_1 + r_2 + r_3 <= n, "r_1 ({r_1}) + r_2 ({r_2}) + r_3 ({r_3}) > n ({n})");
assert!(p_1 + p_2 + p_3 <= 1.0, "p_1 ({p_1}) + p_2 ({p_2}) + p_3 ({p_3}) > 1.0");
let p_4 = 1.0 - p_1 - p_2 - p_3;
let r_4 = n - r_1 - r_2 - r_3;
(factorial.get(n) / factorial.get(r_1) / factorial.get(r_2) / factorial.get(r_3) / factorial.get(r_4)) as f64 * p_1.powi(r_1 as i32) * p_2.powi(r_2 as i32) * p_3.powi(r_3 as i32) * p_4.powi(r_4 as i32)
}
pub fn bivariate_binomial(n: u8, r_1: u8, r_2: u8, p_1: f64, p_2: f64, p_3: f64, factorial: &impl Factorial) -> f64 {
assert!(r_1 <= n, "r_1 ({}) > n ({})", r_1, n);
assert!(r_2 <= n, "r_2 ({}) > n ({})", r_2, n);
assert!(p_1 + p_2 + p_3 <= 1.0, "p_1 ({p_1}) + p_2 ({p_2}) + p_3 ({p_3}) > 1.0");
let backtrack = u8::min(r_1, r_2);
let zeros = if r_1 + r_2 > n { r_1 + r_2 - n } else { 0 };
(zeros..=backtrack).map(|i| {
let (k_1, k_2) = (r_1 - i, r_2 - i);
quadranomial(n, k_1, k_2, i, p_1, p_2, p_3, factorial)
}).sum()
}
#[cfg(test)]
mod tests {
use assert_float_eq::*;
use crate::factorial::Calculator;
use super::*;
#[test]
fn test_combinations() {
assert_eq!(5, combinations(5, 1, &Calculator));
assert_eq!(5, combinations(5, 4, &Calculator));
assert_eq!(1, combinations(5, 5, &Calculator));
assert_eq!(10, combinations(5, 3, &Calculator));
assert_eq!(120, combinations(10, 3, &Calculator));
}
#[test]
fn test_binomial() {
assert_eq!(0.25, binomial(4, 1, 0.5, &Calculator));
assert_eq!(0.421875, binomial(4, 1, 0.25, &Calculator));
assert_eq!(0.375, binomial(4, 2, 0.5, &Calculator));
assert_eq!(0.2109375, binomial(4, 2, 0.25, &Calculator));
assert_eq!(0.046875, binomial(4, 3, 0.25, &Calculator));
assert_eq!(0.421875, binomial(4, 3, 0.75, &Calculator));
}
#[test]
fn bivariate_binomial_independent() {
fn test(n: u8, r_1: u8, r_2: u8, i_1: f64) {
let i_2 = 1.0 - i_1;
let p_1 = i_1 * (1.0 - i_2);
let p_2 = i_2 *(1.0 - i_1);
let p_3 = i_1 * i_2;
println!("testing with n={n}, r_1={r_1}, r_2={r_2}, i_1={i_1}, p_1={p_1}, p_2={p_2}, p_3={p_3}");
let independent_prob = binomial(n, r_1, i_1, &Calculator) * binomial(n, r_2, i_2, &Calculator);
assert_eq!(independent_prob, bivariate_binomial(n, r_1, r_2, p_1, p_2, p_3, &Calculator));
}
for n in 0..=6 {
for r_1 in 0..=n {
for r_2 in 0..=n {
test(n, r_1, r_2, 0.25);
}
}
}
}
#[test]
fn bivariate_binomial_dependent() {
assert_float_absolute_eq!(0.04, bivariate_binomial(2, 0, 0, 0.25, 0.25, 0.3, &Calculator));
assert_eq!(0.10, bivariate_binomial(2, 1, 0, 0.25, 0.25, 0.3, &Calculator));
assert_eq!(0.10, bivariate_binomial(2, 0, 1, 0.25, 0.25, 0.3, &Calculator));
assert_eq!(0.245, bivariate_binomial(2, 1, 1, 0.25, 0.25, 0.3, &Calculator));
}
#[test]
fn test_trinomial() {
assert_eq!(0.25, trinomial(4, 1, 3, 0.5, 0.5, &Calculator));
assert_eq!(0.421875, trinomial(4, 1, 3, 0.25, 0.75, &Calculator));
assert_eq!(0.375, trinomial(4, 2, 2, 0.5, 0.5, &Calculator));
assert_eq!(0.2109375, trinomial(4, 2, 2, 0.25, 0.75, &Calculator));
assert_eq!(0.046875, trinomial(4, 3, 1, 0.25, 0.75, &Calculator));
assert_eq!(0.421875, trinomial(4, 3, 1, 0.75, 0.25, &Calculator));
assert_eq!(0.25, trinomial(2, 1, 0, 0.25, 0.25, &Calculator));
assert_eq!(0.015625, trinomial(3, 3, 0, 0.25, 0.5, &Calculator));
assert_eq!(0.125, trinomial(3, 0, 3, 0.25, 0.5, &Calculator));
assert_eq!(0.015625, trinomial(3, 0, 0, 0.25, 0.5, &Calculator));
assert_eq!(0.09375, trinomial(3, 2, 1, 0.25, 0.5, &Calculator));
assert_eq!(0.0, trinomial(4, 0, 2, 0.5, 0.5, &Calculator));
}
#[test]
fn test_quadranomial() {
assert_eq!(0.25, quadranomial(4, 1, 3, 0, 0.5, 0.5, 0.0,&Calculator));
assert_eq!(0.421875, quadranomial(4, 1, 3, 0, 0.25, 0.75, 0.0, &Calculator));
assert_eq!(0.25, quadranomial(2, 1, 0, 1, 0.25, 0.25, 0.5, &Calculator));
assert_eq!(0.015625, quadranomial(3, 3, 0, 0, 0.25, 0.5, 0.25, &Calculator));
assert_eq!(0.125, quadranomial(3, 0, 3, 0, 0.25, 0.5, 0.0, &Calculator));
assert_eq!(0.015625, quadranomial(3, 0, 0, 3, 0.25, 0.5, 0.25, &Calculator));
assert_eq!(0.09375, quadranomial(3, 2, 1, 0, 0.25, 0.5, 0.25, &Calculator));
assert_eq!(0.0, quadranomial(4, 0, 2, 2,0.5, 0.5, 0.0, &Calculator));
}
}