use crate::error::{SpecialError, SpecialResult};
fn pochhammer(a: f64, n: usize) -> f64 {
if n == 0 {
return 1.0;
}
let mut result = 1.0f64;
for k in 0..n {
result *= a + k as f64;
if !result.is_finite() {
return result;
}
}
result
}
fn falling_factorial(x: f64, n: usize) -> f64 {
if n == 0 {
return 1.0;
}
let mut result = 1.0f64;
for k in 0..n {
result *= x - k as f64;
if !result.is_finite() {
return result;
}
}
result
}
fn factorial_f64(n: usize) -> f64 {
let mut result = 1.0f64;
for i in 2..=n {
result *= i as f64;
if !result.is_finite() {
return f64::INFINITY;
}
}
result
}
fn binomial_usize(n: usize, k: usize) -> f64 {
if k > n {
return 0.0;
}
if k == 0 || k == n {
return 1.0;
}
let k = k.min(n - k); let mut result = 1.0f64;
for i in 0..k {
result *= (n - i) as f64;
result /= (i + 1) as f64;
}
result
}
fn hypergeometric_terminated(
num_params: &[f64],
den_params: &[f64],
z: f64,
n_terms: usize,
) -> SpecialResult<f64> {
let mut total = 0.0f64;
let mut term = 1.0f64;
for k in 0..n_terms {
if !term.is_finite() {
break;
}
total += term;
let kf = k as f64;
let mut ratio = z / (kf + 1.0);
for &ai in num_params {
ratio *= ai + kf;
}
for &bj in den_params {
let denom = bj + kf;
if denom.abs() < f64::MIN_POSITIVE {
return Err(SpecialError::DomainError(format!(
"hypergeometric: denominator parameter hits non-positive integer at k={k}"
)));
}
ratio /= denom;
}
let next_term = term * ratio;
if next_term == 0.0 || (next_term.abs() < f64::EPSILON * total.abs() * 1e-6) {
break;
}
term = next_term;
}
Ok(total)
}
pub fn wilson_polynomial(
n: usize,
a: f64,
b: f64,
c: f64,
d: f64,
x: f64,
) -> SpecialResult<f64> {
if a <= 0.0 || b <= 0.0 || c <= 0.0 || d <= 0.0 {
return Err(SpecialError::DomainError(
"wilson_polynomial: parameters a,b,c,d must all be positive".to_string(),
));
}
if n == 0 {
return Ok(1.0);
}
let s = a + b + c + d - 1.0;
let prefactor = pochhammer(a + b, n) * pochhammer(a + c, n) * pochhammer(a + d, n);
if !prefactor.is_finite() {
return Err(SpecialError::OverflowError(
"wilson_polynomial: prefactor overflow".to_string(),
));
}
let mut total = 0.0f64;
let mut term = 1.0f64;
for k in 0..=n {
if !term.is_finite() {
break;
}
total += term;
if k == n {
break;
}
let kf = k as f64;
let neg_n_k = -(n as f64) + kf; let s_k = s + kf;
let re_complex_next = (a + kf) * (a + kf) + x * x;
let num_factor = neg_n_k * s_k * re_complex_next;
let den1 = (a + b + kf) * (a + c + kf) * (a + d + kf) * (kf + 1.0);
if den1.abs() < f64::MIN_POSITIVE {
break;
}
term *= num_factor / den1;
}
Ok(prefactor * total)
}
pub fn racah_polynomial(
n: usize,
alpha: f64,
beta: f64,
gamma: f64,
delta: f64,
x: f64,
) -> SpecialResult<f64> {
if n == 0 {
return Ok(1.0);
}
let s = n as f64 + alpha + beta + 1.0;
let xp = x + gamma + delta + 1.0;
let mut total = 0.0f64;
let mut term = 1.0f64;
for k in 0..=n {
if !term.is_finite() {
break;
}
total += term;
if k == n {
break;
}
let kf = k as f64;
let num = (-(n as f64) + kf) * (s + kf) * (-x + kf) * (xp + kf);
let d1 = alpha + 1.0 + kf;
let d2 = beta + delta + 1.0 + kf;
let d3 = gamma + 1.0 + kf;
let d4 = kf + 1.0;
let den = d1 * d2 * d3 * d4;
if den.abs() < f64::MIN_POSITIVE {
return Err(SpecialError::DomainError(format!(
"racah_polynomial: denominator zero at k={k} (check parameter constraints)"
)));
}
term *= num / den;
}
Ok(total)
}
pub fn hahn_polynomial(
n: usize,
alpha: f64,
beta: f64,
n_max: usize,
x: f64,
) -> SpecialResult<f64> {
if alpha <= -1.0 {
return Err(SpecialError::DomainError(format!(
"hahn_polynomial: alpha={alpha} must be > -1"
)));
}
if beta <= -1.0 {
return Err(SpecialError::DomainError(format!(
"hahn_polynomial: beta={beta} must be > -1"
)));
}
if n > n_max {
return Err(SpecialError::DomainError(format!(
"hahn_polynomial: n={n} must be ≤ N={n_max}"
)));
}
if n == 0 {
return Ok(1.0);
}
let s = n as f64 + alpha + beta + 1.0;
let big_n_f = -(n_max as f64);
let mut total = 0.0f64;
let mut term = 1.0f64;
for k in 0..=n {
if !term.is_finite() {
break;
}
total += term;
if k == n {
break;
}
let kf = k as f64;
let num = (-(n as f64) + kf) * (s + kf) * (-x + kf);
let d1 = alpha + 1.0 + kf;
let d2 = big_n_f + kf;
let d3 = kf + 1.0;
let den = d1 * d2 * d3;
if den.abs() < f64::MIN_POSITIVE {
return Err(SpecialError::DomainError(format!(
"hahn_polynomial: denominator zero at k={k}"
)));
}
term *= num / den;
}
Ok(total)
}
pub fn krawtchouk_polynomial(
n: usize,
p: f64,
n_max: usize,
x: f64,
) -> SpecialResult<f64> {
if p <= 0.0 || p >= 1.0 {
return Err(SpecialError::DomainError(format!(
"krawtchouk_polynomial: p={p} must be in (0,1)"
)));
}
if n_max == 0 {
return Err(SpecialError::DomainError(
"krawtchouk_polynomial: N must be positive".to_string(),
));
}
if n > n_max {
return Err(SpecialError::DomainError(format!(
"krawtchouk_polynomial: n={n} must be ≤ N={n_max}"
)));
}
if n == 0 {
return Ok(1.0);
}
let z = 1.0 / p;
let big_n_f = -(n_max as f64);
let mut total = 0.0f64;
let mut term = 1.0f64;
for k in 0..=n {
if !term.is_finite() {
break;
}
total += term;
if k == n {
break;
}
let kf = k as f64;
let num = (-(n as f64) + kf) * (-x + kf);
let d1 = big_n_f + kf;
let d2 = kf + 1.0;
let den = d1 * d2;
if den.abs() < f64::MIN_POSITIVE {
return Err(SpecialError::DomainError(format!(
"krawtchouk_polynomial: denominator zero at k={k}"
)));
}
term *= num * z / den;
}
Ok(total)
}
pub fn dual_hahn_polynomial(
n: usize,
gamma: f64,
delta: f64,
n_max: usize,
x: f64,
) -> SpecialResult<f64> {
if gamma <= -1.0 {
return Err(SpecialError::DomainError(format!(
"dual_hahn_polynomial: gamma={gamma} must be > -1"
)));
}
if n > n_max {
return Err(SpecialError::DomainError(format!(
"dual_hahn_polynomial: n={n} must be ≤ N={n_max}"
)));
}
if n == 0 {
return Ok(1.0);
}
let xp = x + gamma + delta + 1.0;
let big_n_f = -(n_max as f64);
let mut total = 0.0f64;
let mut term = 1.0f64;
for k in 0..=n {
if !term.is_finite() {
break;
}
total += term;
if k == n {
break;
}
let kf = k as f64;
let num = (-(n as f64) + kf) * (-x + kf) * (xp + kf);
let d1 = gamma + 1.0 + kf;
let d2 = big_n_f + kf;
let d3 = kf + 1.0;
let den = d1 * d2 * d3;
if den.abs() < f64::MIN_POSITIVE {
return Err(SpecialError::DomainError(format!(
"dual_hahn_polynomial: denominator zero at k={k}"
)));
}
term *= num / den;
}
Ok(total)
}
pub fn krawtchouk_inner_product(
n1: usize,
n2: usize,
p: f64,
n_max: usize,
) -> SpecialResult<f64> {
let mut sum = 0.0f64;
for x in 0..=n_max {
let xf = x as f64;
let w = binomial_usize(n_max, x) * p.powi(x as i32) * (1.0 - p).powi((n_max - x) as i32);
let k1 = krawtchouk_polynomial(n1, p, n_max, xf)?;
let k2 = krawtchouk_polynomial(n2, p, n_max, xf)?;
sum += w * k1 * k2;
}
Ok(sum)
}
pub fn hahn_inner_product(
n1: usize,
n2: usize,
alpha: f64,
beta: f64,
n_max: usize,
) -> SpecialResult<f64> {
let mut sum = 0.0f64;
for x in 0..=n_max {
let xf = x as f64;
let w_top = pochhammer(alpha + 1.0, x) / factorial_f64(x);
let w_bot = pochhammer(beta + 1.0, n_max - x) / factorial_f64(n_max - x);
let w = w_top * w_bot;
let h1 = hahn_polynomial(n1, alpha, beta, n_max, xf)?;
let h2 = hahn_polynomial(n2, alpha, beta, n_max, xf)?;
sum += w * h1 * h2;
}
Ok(sum)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wilson_n0() {
let v = wilson_polynomial(0, 1.0, 1.0, 1.0, 1.0, 0.5).expect("wilson n=0");
assert!((v - 1.0).abs() < 1e-14, "W_0 = 1: {v}");
}
#[test]
fn test_wilson_n1_at_zero() {
let v = wilson_polynomial(1, 1.0, 1.0, 1.0, 1.0, 0.0).expect("wilson n=1 x=0");
assert!(v.is_finite(), "W_1(0) finite: {v}");
assert!((v - 4.0).abs() < 1e-10, "W_1(0;1,1,1,1) = {v}, expected 4");
}
#[test]
fn test_wilson_positive_params() {
let v = wilson_polynomial(2, 0.5, 0.5, 0.5, 0.5, 1.0).expect("wilson n=2");
assert!(v.is_finite(), "W_2 finite: {v}");
}
#[test]
fn test_wilson_invalid_params() {
let r = wilson_polynomial(1, -1.0, 1.0, 1.0, 1.0, 0.5);
assert!(r.is_err(), "negative a should fail");
}
#[test]
fn test_racah_n0() {
let v = racah_polynomial(0, 0.5, 0.5, 0.5, 0.5, 1.0).expect("racah n=0");
assert!((v - 1.0).abs() < 1e-14, "R_0 = 1: {v}");
}
#[test]
fn test_racah_n1() {
let (alpha, beta, gamma, delta, x) = (1.0, 1.0, 1.0, 1.0, 1.0);
let v = racah_polynomial(1, alpha, beta, gamma, delta, x).expect("racah n=1");
let s = alpha + beta + 2.0;
let d = (alpha + 1.0) * (beta + delta + 1.0) * (gamma + 1.0);
let expected = 1.0 + (-1.0) * s * (-x) * (x + gamma + delta + 1.0) / d;
assert!(
(v - expected).abs() < 1e-10,
"R_1: {v} vs {expected}"
);
}
#[test]
fn test_racah_finite() {
let v = racah_polynomial(3, 2.0, 1.5, 0.5, 1.0, 2.0).expect("racah n=3");
assert!(v.is_finite(), "R_3 finite: {v}");
}
#[test]
fn test_hahn_n0() {
let v = hahn_polynomial(0, 1.0, 1.0, 5, 2.0).expect("hahn n=0");
assert!((v - 1.0).abs() < 1e-14, "Q_0 = 1: {v}");
}
#[test]
fn test_hahn_n1() {
let (alpha, beta, n_max, x) = (1.0, 1.0, 4usize, 2.0);
let v = hahn_polynomial(1, alpha, beta, n_max, x).expect("hahn n=1");
let s = alpha + beta + 2.0;
let expected = 1.0 - s * x / ((alpha + 1.0) * n_max as f64);
assert!((v - expected).abs() < 1e-10, "Q_1: {v} vs {expected}");
}
#[test]
fn test_hahn_n_gt_n_max() {
let r = hahn_polynomial(6, 1.0, 1.0, 5, 2.0);
assert!(r.is_err(), "n > N should fail");
}
#[test]
fn test_hahn_alpha_negative() {
let r = hahn_polynomial(1, -1.5, 1.0, 5, 2.0);
assert!(r.is_err(), "alpha <= -1 should fail");
}
#[test]
fn test_hahn_orthogonality() {
let inner = hahn_inner_product(0, 1, 1.0, 1.0, 5).expect("hahn inner product");
assert!(
inner.abs() < 1e-10,
"Hahn Q_0 ⊥ Q_1: inner product = {inner}"
);
}
#[test]
fn test_krawtchouk_n0() {
let v = krawtchouk_polynomial(0, 0.5, 5, 2.0).expect("krawtchouk n=0");
assert!((v - 1.0).abs() < 1e-14, "K_0 = 1: {v}");
}
#[test]
fn test_krawtchouk_n1() {
let (n, p, n_max, x) = (1usize, 0.5, 5usize, 2.0);
let v = krawtchouk_polynomial(n, p, n_max, x).expect("krawtchouk n=1");
let expected = 1.0 - x / (n_max as f64 * p);
assert!((v - expected).abs() < 1e-10, "K_1: {v} vs {expected}");
}
#[test]
fn test_krawtchouk_orthogonality() {
let inner = krawtchouk_inner_product(1, 2, 0.5, 5).expect("krawtchouk inner product");
assert!(
inner.abs() < 1e-9,
"Krawtchouk K_1 ⊥ K_2: inner product = {inner}"
);
}
#[test]
fn test_krawtchouk_invalid_p() {
assert!(krawtchouk_polynomial(1, 0.0, 5, 2.0).is_err());
assert!(krawtchouk_polynomial(1, 1.0, 5, 2.0).is_err());
}
#[test]
fn test_krawtchouk_n_gt_n_max() {
let r = krawtchouk_polynomial(6, 0.5, 5, 2.0);
assert!(r.is_err(), "n > N should fail");
}
#[test]
fn test_krawtchouk_symmetry() {
let n = 2;
let p = 0.3;
let n_max = 5;
for x in 0..=n_max {
let v = krawtchouk_polynomial(n, p, n_max, x as f64).expect("krawtchouk symmetry");
assert!(v.is_finite(), "K_{n}({x}) = {v}");
}
}
#[test]
fn test_dual_hahn_n0() {
let v = dual_hahn_polynomial(0, 1.0, 1.0, 5, 2.0).expect("dual_hahn n=0");
assert!((v - 1.0).abs() < 1e-14, "R_0 = 1: {v}");
}
#[test]
fn test_dual_hahn_finite() {
let v = dual_hahn_polynomial(2, 0.5, 0.5, 5, 1.0).expect("dual_hahn n=2");
assert!(v.is_finite(), "R_2 finite: {v}");
}
}