use crate::error::{SpecialError, SpecialResult};
use crate::gamma::gamma;
use std::f64::consts::PI;
#[allow(dead_code)]
pub fn factorial(n: u32) -> SpecialResult<f64> {
if n <= 20 {
let mut result = 1.0;
for i in 1..=n {
result *= i as f64;
}
Ok(result)
} else {
Ok(gamma((n + 1) as f64))
}
}
#[allow(dead_code)]
pub fn double_factorial(n: u32) -> SpecialResult<f64> {
if n == 0 {
return Ok(1.0);
}
let mut result = 1.0;
let mut i = n;
while i > 0 {
result *= i as f64;
i = i.saturating_sub(2);
}
Ok(result)
}
#[allow(dead_code)]
pub fn factorial2(n: u32) -> SpecialResult<f64> {
double_factorial(n)
}
#[allow(dead_code)]
pub fn factorialk(n: u32, k: u32) -> SpecialResult<f64> {
if k == 0 {
return Err(crate::SpecialError::ValueError(
"k must be positive".to_string(),
));
}
if n == 0 {
return Ok(1.0);
}
let mut result = 1.0;
let mut i = n;
while i > 0 {
result *= i as f64;
i = i.saturating_sub(k);
}
Ok(result)
}
#[allow(dead_code)]
pub fn binomial(n: u32, k: u32) -> SpecialResult<f64> {
if k > n {
return Ok(0.0);
}
if k == 0 || k == n {
return Ok(1.0);
}
let k = if k > n - k { n - k } else { k };
if n <= 30 {
let mut result = 1.0;
for i in 0..k {
result = result * (n - i) as f64 / (i + 1) as f64;
}
Ok(result)
} else {
let n_fact = gamma((n + 1) as f64);
let k_fact = gamma((k + 1) as f64);
let nk_fact = gamma((n - k + 1) as f64);
Ok(n_fact / (k_fact * nk_fact))
}
}
#[allow(dead_code)]
pub fn permutations(n: u32, k: u32) -> SpecialResult<f64> {
if k > n {
return Ok(0.0);
}
if k == 0 {
return Ok(1.0);
}
if n <= 30 {
let mut result = 1.0;
for i in 0..k {
result *= (n - i) as f64;
}
Ok(result)
} else {
let n_fact = gamma((n + 1) as f64);
let nk_fact = gamma((n - k + 1) as f64);
Ok(n_fact / nk_fact)
}
}
#[allow(dead_code)]
pub fn perm(n: u32, k: u32) -> SpecialResult<f64> {
permutations(n, k)
}
#[allow(dead_code)]
pub fn stirling_first(n: u32, k: u32) -> SpecialResult<f64> {
if n == 0 && k == 0 {
return Ok(1.0);
}
if n == 0 || k == 0 || k > n {
return Ok(0.0);
}
let mut dp = vec![vec![0.0; (k + 1) as usize]; (n + 1) as usize];
dp[0][0] = 1.0;
for i in 1..=n as usize {
for j in 1..=std::cmp::min(i, k as usize) {
dp[i][j] = (i - 1) as f64 * dp[i - 1][j] + dp[i - 1][j - 1];
}
}
Ok(dp[n as usize][k as usize])
}
#[allow(dead_code)]
pub fn stirling_second(n: u32, k: u32) -> SpecialResult<f64> {
if n == 0 && k == 0 {
return Ok(1.0);
}
if n == 0 || k == 0 || k > n {
return Ok(0.0);
}
let mut dp = vec![vec![0.0; (k + 1) as usize]; (n + 1) as usize];
dp[0][0] = 1.0;
for i in 1..=n as usize {
for j in 1..=std::cmp::min(i, k as usize) {
dp[i][j] = j as f64 * dp[i - 1][j] + dp[i - 1][j - 1];
}
}
Ok(dp[n as usize][k as usize])
}
#[allow(dead_code)]
pub fn stirling2(n: u32, k: u32) -> SpecialResult<f64> {
stirling_second(n, k)
}
#[allow(dead_code)]
pub fn bell_number(n: u32) -> SpecialResult<f64> {
if n == 0 {
return Ok(1.0);
}
let mut result = 0.0;
for k in 0..=n {
result += stirling_second(n, k)?;
}
Ok(result)
}
#[allow(dead_code)]
pub fn bernoulli_number(n: u32) -> SpecialResult<f64> {
if n == 0 {
return Ok(1.0);
}
if n == 1 {
return Ok(-0.5);
}
if n > 1 && n % 2 == 1 {
return Ok(0.0); }
match n {
2 => return Ok(1.0 / 6.0),
4 => return Ok(-1.0 / 30.0),
6 => return Ok(1.0 / 42.0),
8 => return Ok(-1.0 / 30.0),
10 => return Ok(5.0 / 66.0),
12 => return Ok(-691.0 / 2730.0),
_ => {} }
let mut bernoulli = vec![0.0; (n + 1) as usize];
bernoulli[0] = 1.0;
if n >= 1 {
bernoulli[1] = -0.5;
}
for m in 2..=(n as usize) {
if m % 2 == 1 {
bernoulli[m] = 0.0;
continue;
}
let mut sum = 0.0;
for (k, &bernoulli_k) in bernoulli.iter().enumerate().take(m) {
let binom_coeff = binomial((m + 1) as u32, k as u32)?;
sum += binom_coeff * bernoulli_k;
}
bernoulli[m] = -sum / (m + 1) as f64;
}
Ok(bernoulli[n as usize])
}
#[allow(dead_code)]
pub fn euler_number(n: u32) -> SpecialResult<f64> {
if n % 2 == 1 {
return Ok(0.0); }
match n {
0 => return Ok(1.0),
2 => return Ok(-1.0),
4 => return Ok(5.0),
6 => return Ok(-61.0),
8 => return Ok(1385.0),
10 => return Ok(-50521.0),
_ => {} }
if n > 100 {
return euler_number_asymptotic(n as i32);
} else if n > 20 {
return euler_number_improved_recurrence(n as i32);
}
euler_number_standard_recurrence(n as i32)
}
#[allow(dead_code)]
fn euler_number_standard_recurrence(n: i32) -> SpecialResult<f64> {
let mut euler = vec![0.0; (n + 1) as usize];
euler[0] = 1.0;
for m in (2..=(n as usize)).step_by(2) {
let mut sum = 0.0;
for k in (0..m).step_by(2) {
let binom_coeff = binomial(m as u32, k as u32)?;
sum += binom_coeff * euler[k];
}
euler[m] = -sum;
}
Ok(euler[n as usize])
}
#[allow(dead_code)]
fn euler_number_improved_recurrence(n: i32) -> SpecialResult<f64> {
if n % 2 == 1 {
return Ok(0.0); }
let mut _prev_eulr = [0.0; 2]; _prev_eulr[0] = 1.0;
if n == 0 {
return Ok(1.0);
}
let mut euler_cache = vec![0.0; (n / 2 + 1) as usize];
euler_cache[0] = 1.0;
for m in (2..=n).step_by(2) {
let m_idx = (m / 2) as usize;
let mut sum = 0.0;
for k in (0..m).step_by(2) {
let k_idx = (k / 2) as usize;
let binom_coeff = efficient_binomial(m as u32, k as u32)?;
sum += binom_coeff * euler_cache[k_idx];
}
euler_cache[m_idx] = -sum;
}
Ok(euler_cache[(n / 2) as usize])
}
#[allow(dead_code)]
fn euler_number_asymptotic(n: i32) -> SpecialResult<f64> {
if n % 2 == 1 {
return Ok(0.0); }
let n_f = n as f64;
let ln_n_factorial = n_f * n_f.ln() - n_f + 0.5 * (2.0 * PI * n_f).ln();
let ln_8 = 8.0_f64.ln();
let ln_sqrt_2_over_pi = 0.5 * (2.0 / PI).ln();
let power_of_2_term = (n_f + 2.0) * 2.0_f64.ln();
let pi_power_term = -(n_f + 1.0) * PI.ln();
let ln_magnitude = ln_8 + ln_sqrt_2_over_pi + power_of_2_term + ln_n_factorial + pi_power_term;
if ln_magnitude > 700.0 {
return Err(SpecialError::OverflowError(
"Euler number too large to represent as f64".to_string(),
));
}
let magnitude = ln_magnitude.exp();
let sign = if (n / 2) % 2 == 0 { 1.0 } else { -1.0 };
Ok(sign * magnitude)
}
#[allow(dead_code)]
fn efficient_binomial(n: u32, k: u32) -> SpecialResult<f64> {
if k > n {
return Ok(0.0);
}
if k == 0 || k == n {
return Ok(1.0);
}
let k_use = if k > n - k { n - k } else { k };
if n > 30 {
use crate::gamma::gamma;
let ln_result = (gamma((n + 1) as f64).ln())
- (gamma((k_use + 1) as f64).ln())
- (gamma((n - k_use + 1) as f64).ln());
if ln_result > 700.0 {
return Err(SpecialError::OverflowError(
"Binomial coefficient too large".to_string(),
));
}
Ok(ln_result.exp())
} else {
let mut result = 1.0;
for i in 0..k_use {
result *= (n - i) as f64;
result /= (i + 1) as f64;
}
Ok(result)
}
}
#[allow(dead_code)]
pub fn comb(n: u32, k: u32) -> SpecialResult<f64> {
binomial(n, k)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_factorial() {
assert_eq!(factorial(0).expect("test/example should not fail"), 1.0);
assert_eq!(factorial(1).expect("test/example should not fail"), 1.0);
assert_eq!(factorial(5).expect("test/example should not fail"), 120.0);
assert_eq!(
factorial(10).expect("test/example should not fail"),
3628800.0
);
assert_relative_eq!(
factorial(15).expect("test/example should not fail"),
1307674368000.0,
epsilon = 1.0
);
}
#[test]
fn test_double_factorial() {
assert_eq!(
double_factorial(0).expect("test/example should not fail"),
1.0
);
assert_eq!(
double_factorial(1).expect("test/example should not fail"),
1.0
);
assert_eq!(
double_factorial(2).expect("test/example should not fail"),
2.0
);
assert_eq!(
double_factorial(5).expect("test/example should not fail"),
15.0
); assert_eq!(
double_factorial(6).expect("test/example should not fail"),
48.0
); assert_eq!(
double_factorial(8).expect("test/example should not fail"),
384.0
); }
#[test]
fn test_factorial2() {
assert_eq!(factorial2(0).expect("test/example should not fail"), 1.0);
assert_eq!(factorial2(5).expect("test/example should not fail"), 15.0); assert_eq!(factorial2(6).expect("test/example should not fail"), 48.0);
assert_eq!(
factorial2(8).expect("test/example should not fail"),
double_factorial(8).expect("test/example should not fail")
);
}
#[test]
fn test_factorialk() {
assert_eq!(factorialk(0, 1).expect("test/example should not fail"), 1.0);
assert_eq!(
factorialk(8, 3).expect("test/example should not fail"),
80.0
); assert_eq!(
factorialk(5, 2).expect("test/example should not fail"),
15.0
); assert_eq!(
factorialk(6, 2).expect("test/example should not fail"),
48.0
); assert_eq!(
factorialk(9, 4).expect("test/example should not fail"),
45.0
);
assert_eq!(
factorialk(5, 1).expect("test/example should not fail"),
factorial(5).expect("test/example should not fail")
);
assert!(factorialk(5, 0).is_err());
}
#[test]
fn test_binomial() {
assert_eq!(binomial(5, 2).expect("test/example should not fail"), 10.0);
assert_eq!(
binomial(10, 3).expect("test/example should not fail"),
120.0
);
assert_eq!(binomial(7, 0).expect("test/example should not fail"), 1.0);
assert_eq!(binomial(7, 7).expect("test/example should not fail"), 1.0);
assert_eq!(binomial(5, 10).expect("test/example should not fail"), 0.0);
assert_eq!(
binomial(10, 3).expect("test/example should not fail"),
binomial(10, 7).expect("test/example should not fail")
);
}
#[test]
fn test_permutations() {
assert_eq!(
permutations(5, 2).expect("test/example should not fail"),
20.0
);
assert_eq!(
permutations(10, 3).expect("test/example should not fail"),
720.0
);
assert_eq!(
permutations(7, 0).expect("test/example should not fail"),
1.0
);
assert_eq!(
permutations(5, 10).expect("test/example should not fail"),
0.0
); }
#[test]
fn test_perm() {
assert_eq!(perm(5, 2).expect("test/example should not fail"), 20.0);
assert_eq!(perm(10, 3).expect("test/example should not fail"), 720.0);
assert_eq!(
perm(7, 3).expect("test/example should not fail"),
permutations(7, 3).expect("test/example should not fail")
);
}
#[test]
fn test_stirling_first() {
assert_eq!(
stirling_first(0, 0).expect("test/example should not fail"),
1.0
);
assert_eq!(
stirling_first(4, 2).expect("test/example should not fail"),
11.0
);
assert_eq!(
stirling_first(5, 3).expect("test/example should not fail"),
35.0
);
assert_eq!(
stirling_first(3, 0).expect("test/example should not fail"),
0.0
);
assert_eq!(
stirling_first(0, 3).expect("test/example should not fail"),
0.0
);
}
#[test]
fn test_stirling_second() {
assert_eq!(
stirling_second(0, 0).expect("test/example should not fail"),
1.0
);
assert_eq!(
stirling_second(4, 2).expect("test/example should not fail"),
7.0
);
assert_eq!(
stirling_second(5, 3).expect("test/example should not fail"),
25.0
);
assert_eq!(
stirling_second(3, 0).expect("test/example should not fail"),
0.0
);
assert_eq!(
stirling_second(0, 3).expect("test/example should not fail"),
0.0
);
}
#[test]
fn test_stirling2() {
assert_eq!(stirling2(4, 2).expect("test/example should not fail"), 7.0);
assert_eq!(stirling2(5, 3).expect("test/example should not fail"), 25.0);
assert_eq!(
stirling2(6, 3).expect("test/example should not fail"),
stirling_second(6, 3).expect("test/example should not fail")
);
}
#[test]
fn test_bell_number() {
assert_eq!(bell_number(0).expect("test/example should not fail"), 1.0);
assert_eq!(bell_number(1).expect("test/example should not fail"), 1.0);
assert_eq!(bell_number(2).expect("test/example should not fail"), 2.0);
assert_eq!(bell_number(3).expect("test/example should not fail"), 5.0);
assert_eq!(bell_number(4).expect("test/example should not fail"), 15.0);
assert_eq!(bell_number(5).expect("test/example should not fail"), 52.0);
}
#[test]
fn test_bernoulli_number() {
assert_eq!(
bernoulli_number(0).expect("test/example should not fail"),
1.0
);
assert_relative_eq!(
bernoulli_number(1).expect("test/example should not fail"),
-0.5,
epsilon = 1e-10
);
assert_relative_eq!(
bernoulli_number(2).expect("test/example should not fail"),
1.0 / 6.0,
epsilon = 1e-10
);
assert_eq!(
bernoulli_number(3).expect("test/example should not fail"),
0.0
);
assert_relative_eq!(
bernoulli_number(4).expect("test/example should not fail"),
-1.0 / 30.0,
epsilon = 1e-10
);
assert_eq!(
bernoulli_number(5).expect("test/example should not fail"),
0.0
);
}
#[test]
fn test_euler_number() {
assert_eq!(euler_number(0).expect("test/example should not fail"), 1.0);
assert_eq!(euler_number(1).expect("test/example should not fail"), 0.0);
assert_eq!(euler_number(2).expect("test/example should not fail"), -1.0);
assert_eq!(euler_number(3).expect("test/example should not fail"), 0.0);
assert_eq!(euler_number(4).expect("test/example should not fail"), 5.0);
assert_eq!(
euler_number(6).expect("test/example should not fail"),
-61.0
);
}
}