use crate::error::{StatsError, StatsResult};
pub fn factorial(n: u64) -> StatsResult<u64> {
match n {
0 | 1 => Ok(1),
_ => {
let mut result: u64 = 1;
for i in 2..=n {
result = result.checked_mul(i).ok_or_else(|| {
StatsError::invalid_input(format!(
"factorial({}) overflows u64 (max supported: factorial(20))",
n
))
})?;
}
Ok(result)
}
}
}
pub fn permutation(n: u64, k: u64) -> StatsResult<u64> {
if k > n {
return Err(StatsError::invalid_input(format!(
"k ({}) cannot be greater than n ({})",
k, n
)));
}
Ok(((n - k + 1)..=n).product::<u64>())
}
pub fn combination(n: u64, k: u64) -> StatsResult<u64> {
if k > n {
return Err(StatsError::invalid_input(format!(
"k ({}) cannot be greater than n ({})",
k, n
)));
}
let k = if k > n - k { n - k } else { k };
Ok((1..=k).fold(1, |acc, x| acc * (n - x + 1) / x))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_factorial() {
assert_eq!(factorial(0).unwrap(), 1);
assert_eq!(factorial(1).unwrap(), 1);
assert_eq!(factorial(5).unwrap(), 120);
assert_eq!(factorial(10).unwrap(), 3628800);
assert_eq!(factorial(20).unwrap(), 2_432_902_008_176_640_000);
}
#[test]
fn test_factorial_overflow() {
assert!(factorial(21).is_err());
assert!(matches!(
factorial(21).unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_permutation_valid() {
assert_eq!(permutation(5, 3).unwrap(), 60);
assert_eq!(permutation(5, 5).unwrap(), 120);
assert_eq!(permutation(5, 0).unwrap(), 1);
assert_eq!(permutation(10, 3).unwrap(), 720);
}
#[test]
fn test_permutation_invalid() {
assert!(permutation(5, 10).is_err());
assert!(matches!(
permutation(5, 10).unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_combination_valid() {
assert_eq!(combination(5, 3).unwrap(), 10);
assert_eq!(combination(5, 5).unwrap(), 1);
assert_eq!(combination(5, 0).unwrap(), 1);
assert_eq!(combination(10, 3).unwrap(), 120);
}
#[test]
fn test_combination_invalid() {
assert!(combination(5, 10).is_err());
assert!(matches!(
combination(5, 10).unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_combination_symmetry() {
assert_eq!(combination(10, 3).unwrap(), combination(10, 7).unwrap());
assert_eq!(combination(20, 5).unwrap(), combination(20, 15).unwrap());
}
#[test]
fn test_combination_k_greater_than_n_minus_k() {
let n = 10u64;
let k = 8u64;
let result1 = combination(n, k).unwrap();
let result2 = combination(n, n - k).unwrap();
assert_eq!(
result1, result2,
"C(n, k) should equal C(n, n-k) when k > n-k"
);
assert_eq!(result1, 45u64, "C(10, 8) should equal C(10, 2) = 45");
}
}