use crate::error::{SpecialError, SpecialResult};
use std::collections::HashMap;
fn log_factorial(n: usize) -> f64 {
if n <= 1 {
return 0.0;
}
let mut acc = 0.0;
for i in 2..=n {
acc += (i as f64).ln();
}
acc
}
pub fn stirling_first(n: usize, k: usize) -> i64 {
if n == 0 && k == 0 {
return 1;
}
if n == 0 || k == 0 || k > n {
return 0;
}
let mut dp = vec![vec![0i64; k + 1]; n + 1];
dp[0][0] = 1;
for i in 1..=n {
for j in 1..=i.min(k) {
let prev_same = dp[i - 1][j];
let prev_less = dp[i - 1][j - 1];
dp[i][j] = prev_less.saturating_sub((i as i64 - 1).saturating_mul(prev_same));
}
}
dp[n][k]
}
pub fn stirling_first_unsigned(n: usize, k: usize) -> u64 {
if n == 0 && k == 0 {
return 1;
}
if n == 0 || k == 0 || k > n {
return 0;
}
let mut dp = vec![vec![0u64; k + 1]; n + 1];
dp[0][0] = 1;
for i in 1..=n {
for j in 1..=i.min(k) {
let a = (i as u64 - 1).saturating_mul(dp[i - 1][j]);
let b = dp[i - 1][j - 1];
dp[i][j] = a.saturating_add(b);
}
}
dp[n][k]
}
pub fn stirling_second(n: usize, k: usize) -> u64 {
if n == 0 && k == 0 {
return 1;
}
if n == 0 || k == 0 || k > n {
return 0;
}
let mut dp = vec![vec![0u64; k + 1]; n + 1];
dp[0][0] = 1;
for i in 1..=n {
for j in 1..=i.min(k) {
let a = (j as u64).saturating_mul(dp[i - 1][j]);
let b = dp[i - 1][j - 1];
dp[i][j] = a.saturating_add(b);
}
}
dp[n][k]
}
pub fn bell_number(n: usize) -> u64 {
if n == 0 {
return 1;
}
let mut row = vec![1u64; n + 1];
for i in 1..=n {
let mut new_row = vec![0u64; n + 1];
new_row[0] = row[i - 1]; for j in 1..=i {
new_row[j] = new_row[j - 1].saturating_add(row[j - 1]);
}
row = new_row;
}
row[0]
}
pub fn bernoulli(n: usize) -> f64 {
match n {
0 => return 1.0,
1 => return -0.5,
_ if n % 2 == 1 => return 0.0,
2 => return 1.0 / 6.0,
4 => return -1.0 / 30.0,
6 => return 1.0 / 42.0,
8 => return -1.0 / 30.0,
10 => return 5.0 / 66.0,
12 => return -691.0 / 2730.0,
_ => {}
}
let mut b = vec![0.0f64; n + 1];
b[0] = 1.0;
b[1] = -0.5;
for m in 2..=n {
if m % 2 == 1 {
b[m] = 0.0;
continue;
}
let mut sum = 0.0;
for k in 0..m {
sum += binom_f64(m + 1, k) * b[k];
}
b[m] = -sum / (m + 1) as f64;
}
b[n]
}
fn binom_f64(n: usize, k: usize) -> f64 {
if k > n {
return 0.0;
}
let k = k.min(n - k);
let mut result = 1.0f64;
for i in 0..k {
result = result * (n - i) as f64 / (i + 1) as f64;
}
result
}
pub fn euler_number(n: usize) -> f64 {
if n % 2 == 1 {
return 0.0;
}
match n {
0 => return 1.0,
2 => return -1.0,
4 => return 5.0,
6 => return -61.0,
8 => return 1385.0,
10 => return -50521.0,
_ => {}
}
let mut e = vec![0.0f64; n + 1];
e[0] = 1.0;
for m in (2..=n).step_by(2) {
let mut sum = 0.0;
for k in (0..m).step_by(2) {
sum += binom_f64(m, k) * e[k];
}
e[m] = -sum;
}
e[n]
}
pub fn catalan(n: usize) -> u64 {
if n == 0 {
return 1;
}
let mut c = 1u64;
for k in 1..=n {
c = c.saturating_mul(2 * (2 * k as u64 - 1)) / (k as u64 + 1);
}
c
}
pub fn rising_factorial(x: f64, n: usize) -> f64 {
if n == 0 {
return 1.0;
}
let mut result = 1.0;
for i in 0..n {
result *= x + i as f64;
}
result
}
pub fn falling_factorial(x: f64, n: usize) -> f64 {
if n == 0 {
return 1.0;
}
let mut result = 1.0;
for i in 0..n {
result *= x - i as f64;
}
result
}
pub fn multinomial(n: usize, ks: &[usize]) -> SpecialResult<f64> {
let sum: usize = ks.iter().sum();
if sum != n {
return Err(SpecialError::ValueError(format!(
"multinomial: k values sum to {sum} but n = {n}"
)));
}
if ks.is_empty() {
return Ok(1.0);
}
let ln_result = log_factorial(n) - ks.iter().map(|&k| log_factorial(k)).sum::<f64>();
Ok(ln_result.exp())
}
pub fn nth_prime(n: usize) -> SpecialResult<u64> {
if n == 0 {
return Err(SpecialError::ValueError(
"nth_prime: n must be at least 1".to_string(),
));
}
let upper = if n < 6 {
20usize
} else {
let f = n as f64;
let ln_n = f.ln();
let ln_ln_n = ln_n.ln().max(1.0);
(f * (ln_n + ln_ln_n) * 1.3 + 3.0) as usize
};
let sieve = sieve_of_eratosthenes(upper);
let primes: Vec<u64> = sieve
.into_iter()
.enumerate()
.filter_map(|(i, is_prime)| if is_prime && i >= 2 { Some(i as u64) } else { None })
.collect();
if primes.len() >= n {
Ok(primes[n - 1])
} else {
let larger = upper * 2;
let sieve2 = sieve_of_eratosthenes(larger);
let primes2: Vec<u64> = sieve2
.into_iter()
.enumerate()
.filter_map(|(i, is_prime)| {
if is_prime && i >= 2 {
Some(i as u64)
} else {
None
}
})
.collect();
primes2
.get(n - 1)
.copied()
.ok_or_else(|| SpecialError::ComputationError("nth_prime: sieve too small".to_string()))
}
}
fn sieve_of_eratosthenes(limit: usize) -> Vec<bool> {
let mut is_prime = vec![true; limit + 1];
if limit >= 1 {
is_prime[0] = false;
is_prime[1] = false;
}
let mut i = 2;
while i * i <= limit {
if is_prime[i] {
let mut j = i * i;
while j <= limit {
is_prime[j] = false;
j += i;
}
}
i += 1;
}
is_prime
}
pub fn jordan_totient(n: u64, k: u32) -> f64 {
if n == 0 {
return 0.0;
}
let mut result = (n as f64).powi(k as i32);
let mut x = n;
let mut p = 2u64;
while p * p <= x {
if x % p == 0 {
while x % p == 0 {
x /= p;
}
result *= 1.0 - (p as f64).powi(-(k as i32));
}
p += 1;
}
if x > 1 {
result *= 1.0 - (x as f64).powi(-(k as i32));
}
result
}
pub fn partition(n: usize) -> u64 {
let mut p = vec![0u64; n + 1];
p[0] = 1;
for m in 1..=n {
let mut sign = 1i64;
let mut k = 1isize;
loop {
let pos_pent = (k * (3 * k - 1) / 2) as usize;
let neg_pent = (k * (3 * k + 1) / 2) as usize;
if pos_pent > m && neg_pent > m {
break;
}
if pos_pent <= m {
if sign > 0 {
p[m] = p[m].saturating_add(p[m - pos_pent]);
} else {
p[m] = p[m].saturating_sub(p[m - pos_pent].min(p[m]));
}
}
if neg_pent <= m && neg_pent != pos_pent {
if sign > 0 {
p[m] = p[m].saturating_add(p[m - neg_pent]);
} else {
p[m] = p[m].saturating_sub(p[m - neg_pent].min(p[m]));
}
}
sign = -sign;
k += 1;
}
}
p[n]
}
pub fn bell_numbers_table(n: usize) -> Vec<u64> {
let mut table = Vec::with_capacity(n + 1);
for i in 0..=n {
table.push(bell_number(i));
}
table
}
pub fn catalan_table(n: usize) -> Vec<u64> {
let mut table = Vec::with_capacity(n + 1);
for i in 0..=n {
table.push(catalan(i));
}
table
}
pub fn prime_factors_flat(mut n: u64) -> Vec<u64> {
let mut factors = Vec::new();
if n <= 1 {
return factors;
}
let mut p = 2u64;
while p * p <= n {
while n % p == 0 {
factors.push(p);
n /= p;
}
p += 1;
}
if n > 1 {
factors.push(n);
}
factors
}
#[allow(dead_code)]
fn partition_memoized(n: usize, cache: &mut HashMap<usize, u64>) -> u64 {
if let Some(&v) = cache.get(&n) {
return v;
}
let v = partition(n);
cache.insert(n, v);
v
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_stirling_first_signed() {
assert_eq!(stirling_first(0, 0), 1);
assert_eq!(stirling_first(1, 1), 1);
assert_eq!(stirling_first(4, 2), 11);
assert_eq!(stirling_first(4, 0), 0);
assert_eq!(stirling_first(0, 4), 0);
}
#[test]
fn test_stirling_first_unsigned() {
assert_eq!(stirling_first_unsigned(0, 0), 1);
assert_eq!(stirling_first_unsigned(4, 2), 11);
assert_eq!(stirling_first_unsigned(5, 3), 35);
assert_eq!(stirling_first_unsigned(3, 0), 0);
}
#[test]
fn test_stirling_second() {
assert_eq!(stirling_second(0, 0), 1);
assert_eq!(stirling_second(4, 2), 7);
assert_eq!(stirling_second(5, 3), 25);
assert_eq!(stirling_second(3, 0), 0);
assert_eq!(stirling_second(0, 3), 0);
}
#[test]
fn test_bell_number() {
assert_eq!(bell_number(0), 1);
assert_eq!(bell_number(1), 1);
assert_eq!(bell_number(2), 2);
assert_eq!(bell_number(3), 5);
assert_eq!(bell_number(4), 15);
assert_eq!(bell_number(5), 52);
assert_eq!(bell_number(10), 115975);
}
#[test]
fn test_bernoulli() {
assert_eq!(bernoulli(0), 1.0);
assert_relative_eq!(bernoulli(1), -0.5, epsilon = 1e-14);
assert_relative_eq!(bernoulli(2), 1.0 / 6.0, epsilon = 1e-12);
assert_eq!(bernoulli(3), 0.0);
assert_relative_eq!(bernoulli(4), -1.0 / 30.0, epsilon = 1e-12);
assert_eq!(bernoulli(5), 0.0);
}
#[test]
fn test_euler_number() {
assert_eq!(euler_number(0), 1.0);
assert_eq!(euler_number(1), 0.0);
assert_eq!(euler_number(2), -1.0);
assert_eq!(euler_number(4), 5.0);
assert_eq!(euler_number(6), -61.0);
}
#[test]
fn test_catalan() {
assert_eq!(catalan(0), 1);
assert_eq!(catalan(1), 1);
assert_eq!(catalan(2), 2);
assert_eq!(catalan(3), 5);
assert_eq!(catalan(4), 14);
assert_eq!(catalan(5), 42);
assert_eq!(catalan(10), 16796);
}
#[test]
fn test_rising_factorial() {
assert_eq!(rising_factorial(1.0, 0), 1.0);
assert_relative_eq!(rising_factorial(1.0, 4), 24.0, epsilon = 1e-12);
assert_relative_eq!(rising_factorial(3.0, 4), 360.0, epsilon = 1e-12);
assert_relative_eq!(rising_factorial(0.5, 2), 0.75, epsilon = 1e-12);
}
#[test]
fn test_falling_factorial() {
assert_eq!(falling_factorial(1.0, 0), 1.0);
assert_relative_eq!(falling_factorial(5.0, 3), 60.0, epsilon = 1e-12);
assert_relative_eq!(falling_factorial(4.0, 2), 12.0, epsilon = 1e-12);
}
#[test]
fn test_multinomial() {
assert_relative_eq!(
multinomial(4, &[2, 1, 1]).expect("should succeed"),
12.0,
epsilon = 1e-8
);
assert_relative_eq!(
multinomial(6, &[3, 2, 1]).expect("should succeed"),
60.0,
epsilon = 1e-8
);
assert!(multinomial(5, &[2, 1, 1]).is_err());
}
#[test]
fn test_nth_prime() {
assert_eq!(nth_prime(1).expect("ok"), 2);
assert_eq!(nth_prime(2).expect("ok"), 3);
assert_eq!(nth_prime(10).expect("ok"), 29);
assert_eq!(nth_prime(100).expect("ok"), 541);
}
#[test]
fn test_jordan_totient() {
assert_relative_eq!(jordan_totient(1, 1), 1.0, epsilon = 1e-10);
assert_relative_eq!(jordan_totient(6, 1), 2.0, epsilon = 1e-10);
assert_relative_eq!(jordan_totient(4, 2), 12.0, epsilon = 1e-10);
}
#[test]
fn test_partition() {
assert_eq!(partition(0), 1);
assert_eq!(partition(1), 1);
assert_eq!(partition(4), 5);
assert_eq!(partition(5), 7);
assert_eq!(partition(10), 42);
assert_eq!(partition(20), 627);
}
#[test]
fn test_prime_factors_flat() {
assert_eq!(prime_factors_flat(1), Vec::<u64>::new());
assert_eq!(prime_factors_flat(12), vec![2, 2, 3]);
assert_eq!(prime_factors_flat(30), vec![2, 3, 5]);
assert_eq!(prime_factors_flat(97), vec![97]);
}
}