use crate::error::PramanaError;
use std::f64::consts::PI;
#[must_use = "returns the factorial"]
pub fn factorial(n: u64) -> Result<u128, PramanaError> {
if n > 34 {
return Err(PramanaError::ComputationError(format!(
"factorial({n}) overflows u128"
)));
}
let mut result: u128 = 1;
for i in 2..=n as u128 {
result = result
.checked_mul(i)
.ok_or_else(|| PramanaError::ComputationError(format!("factorial({n}) overflow")))?;
}
Ok(result)
}
#[must_use = "returns the number of permutations"]
pub fn permutations(n: u64, r: u64) -> Result<u128, PramanaError> {
if r > n {
return Err(PramanaError::InvalidParameter(format!(
"r ({r}) must be <= n ({n})"
)));
}
let mut result: u128 = 1;
for i in (n - r + 1)..=n {
result = result.checked_mul(i as u128).ok_or_else(|| {
PramanaError::ComputationError(format!("permutations({n}, {r}) overflow"))
})?;
}
Ok(result)
}
#[must_use = "returns the number of combinations"]
pub fn combinations(n: u64, r: u64) -> Result<u128, PramanaError> {
if r > n {
return Err(PramanaError::InvalidParameter(format!(
"r ({r}) must be <= n ({n})"
)));
}
let r = r.min(n - r);
let mut result: u128 = 1;
for i in 0..r {
result = result.checked_mul((n - i) as u128).ok_or_else(|| {
PramanaError::ComputationError(format!("combinations({n}, {r}) overflow"))
})?;
result /= (i + 1) as u128;
}
Ok(result)
}
#[must_use = "returns Stirling's approximation"]
#[inline]
pub fn stirling_approx(n: u64) -> f64 {
if n == 0 {
return 1.0;
}
let n_f = n as f64;
(2.0 * PI * n_f).sqrt() * (n_f / std::f64::consts::E).powf(n_f)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_factorial() {
assert_eq!(factorial(0).unwrap(), 1);
assert_eq!(factorial(1).unwrap(), 1);
assert_eq!(factorial(5).unwrap(), 120);
assert_eq!(factorial(10).unwrap(), 3_628_800);
assert_eq!(factorial(20).unwrap(), 2_432_902_008_176_640_000);
}
#[test]
fn test_factorial_overflow() {
assert!(factorial(35).is_err());
}
#[test]
fn test_permutations() {
assert_eq!(permutations(5, 3).unwrap(), 60);
assert_eq!(permutations(10, 2).unwrap(), 90);
assert_eq!(permutations(5, 0).unwrap(), 1);
assert_eq!(permutations(5, 5).unwrap(), 120);
}
#[test]
fn test_combinations() {
assert_eq!(combinations(10, 3).unwrap(), 120);
assert_eq!(combinations(5, 2).unwrap(), 10);
assert_eq!(combinations(5, 0).unwrap(), 1);
assert_eq!(combinations(5, 5).unwrap(), 1);
assert_eq!(combinations(20, 10).unwrap(), 184_756);
}
#[test]
fn test_combinations_r_gt_n() {
assert!(combinations(3, 5).is_err());
}
#[test]
fn test_stirling() {
let approx = stirling_approx(10);
let exact = 3_628_800.0;
let relative_error = (approx - exact).abs() / exact;
assert!(
relative_error < 0.01,
"relative error {relative_error} too large"
);
}
}