use crate::error::{SpecialError, SpecialResult};
use crate::gamma::{gamma, gammaln};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use std::ops::{AddAssign, MulAssign, SubAssign};
#[inline(always)]
fn const_f64<F: Float + FromPrimitive>(value: f64) -> F {
F::from(value).unwrap_or_else(|| {
if value > 0.0 {
F::infinity()
} else if value < 0.0 {
F::neg_infinity()
} else {
F::zero()
}
})
}
const MAX_SERIES_TERMS: usize = 500;
const CONVERGENCE_TOL: f64 = 1e-15;
#[allow(dead_code)]
pub fn hyp2f1_enhanced<F>(a: F, b: F, c: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign + SubAssign,
{
let a_f64 = a
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert a to f64".to_string()))?;
let b_f64 = b
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert b to f64".to_string()))?;
let c_f64 = c
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert c to f64".to_string()))?;
let z_f64 = z
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert z to f64".to_string()))?;
if c_f64 <= 0.0 && c_f64.fract() == 0.0 {
return Err(SpecialError::DomainError(format!(
"c must not be 0 or negative integer, got {c_f64}"
)));
}
if z == F::zero() {
return Ok(F::one());
}
if (a_f64 <= 0.0 && a_f64.fract() == 0.0) || (b_f64 <= 0.0 && b_f64.fract() == 0.0) {
return hyp2f1_terminating(a, b, c, z);
}
if (z_f64 - 1.0).abs() < 1e-14 {
return hyp2f1_at_one(a, b, c);
}
let abs_z = z_f64.abs();
if abs_z <= 0.5 {
hyp2f1_series_accelerated(a, b, c, z)
} else if abs_z < 0.9 {
hyp2f1_pfaff_transform(a, b, c, z)
} else if abs_z < 1.0 {
hyp2f1_euler_transform(a, b, c, z)
} else if z_f64 > 1.0 {
hyp2f1_analytic_continuation_positive(a, b, c, z)
} else {
hyp2f1_analytic_continuation_negative(a, b, c, z)
}
}
#[allow(dead_code)]
fn hyp2f1_series_accelerated<F>(a: F, b: F, c: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign,
{
let mut terms = Vec::with_capacity(MAX_SERIES_TERMS);
let mut term = F::one();
let mut partial_sum = F::one();
terms.push(F::one());
for n in 1..MAX_SERIES_TERMS {
let n_f = const_f64::<F>(n as f64);
let n_minus_1 = const_f64::<F>((n - 1) as f64);
let numerator = (a + n_minus_1) * (b + n_minus_1);
let denominator = (c + n_minus_1) * n_f;
term = term * numerator * z / denominator;
partial_sum += term;
terms.push(partial_sum);
if term.abs() < const_f64::<F>(CONVERGENCE_TOL) * partial_sum.abs() {
return Ok(partial_sum);
}
}
if terms.len() > 10 {
return levin_u_transform(&terms);
}
Ok(partial_sum)
}
#[allow(dead_code)]
fn levin_u_transform<F>(partial_sums: &[F]) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug,
{
let n = partial_sums.len();
if n < 4 {
return Ok(*partial_sums.last().unwrap_or(&F::zero()));
}
let mut epsilon = vec![vec![F::zero(); n + 1]; n + 1];
for (i, &s) in partial_sums.iter().enumerate() {
epsilon[0][i] = F::zero();
epsilon[1][i] = s;
}
for k in 2..=n {
for i in 0..=(n - k) {
let diff = epsilon[k - 1][i + 1] - epsilon[k - 1][i];
if diff.abs() < const_f64::<F>(1e-100) {
epsilon[k][i] = epsilon[k - 2][i + 1];
} else {
epsilon[k][i] = epsilon[k - 2][i + 1] + F::one() / diff;
}
}
}
let best_col = if n.is_multiple_of(2) { n } else { n - 1 };
Ok(epsilon[best_col][0])
}
#[allow(dead_code)]
fn hyp2f1_terminating<F>(a: F, b: F, c: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign,
{
let a_f64 = a
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert a to f64".to_string()))?;
let b_f64 = b
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert b to f64".to_string()))?;
let n_terms = if a_f64 <= 0.0 && a_f64.fract() == 0.0 {
(-a_f64) as usize
} else if b_f64 <= 0.0 && b_f64.fract() == 0.0 {
(-b_f64) as usize
} else {
return Err(SpecialError::ValueError(
"Not a terminating series".to_string(),
));
};
let mut sum = F::one();
let mut term = F::one();
for n in 1..=n_terms {
let n_f = const_f64::<F>(n as f64);
let n_minus_1 = const_f64::<F>((n - 1) as f64);
let numerator = (a + n_minus_1) * (b + n_minus_1);
let denominator = (c + n_minus_1) * n_f;
term = term * numerator * z / denominator;
sum += term;
}
Ok(sum)
}
#[allow(dead_code)]
fn hyp2f1_at_one<F>(a: F, b: F, c: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign,
{
let c_minus_a_minus_b = c - a - b;
let cmab_f64 = c_minus_a_minus_b
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Conversion failed".to_string()))?;
if cmab_f64 <= 0.0 {
return Err(SpecialError::DomainError(
"₂F₁(a,b;c;1) diverges when c - a - b ≤ 0".to_string(),
));
}
let log_gamma_c = gammaln(c);
let log_gamma_cmab = gammaln(c_minus_a_minus_b);
let log_gamma_cma = gammaln(c - a);
let log_gamma_cmb = gammaln(c - b);
let log_result = log_gamma_c + log_gamma_cmab - log_gamma_cma - log_gamma_cmb;
Ok(log_result.exp())
}
#[allow(dead_code)]
fn hyp2f1_pfaff_transform<F>(a: F, b: F, c: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign + SubAssign,
{
let one_minus_z = F::one() - z;
let z_transformed = z / (z - F::one());
let factor = one_minus_z.powf(-a);
let transformed_result = hyp2f1_series_accelerated(a, c - b, c, z_transformed)?;
Ok(factor * transformed_result)
}
#[allow(dead_code)]
fn hyp2f1_euler_transform<F>(a: F, b: F, c: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign + SubAssign,
{
let one_minus_z = F::one() - z;
let exponent = c - a - b;
let factor = one_minus_z.powf(exponent);
let transformed_result = hyp2f1_series_accelerated(c - a, c - b, c, z)?;
Ok(factor * transformed_result)
}
#[allow(dead_code)]
fn hyp2f1_analytic_continuation_positive<F>(a: F, b: F, c: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign + SubAssign,
{
let z_inv = F::one() / z;
let neg_z = -z;
let term1_coeff = gamma(c) * gamma(b - a) / (gamma(b) * gamma(c - a));
let term1_power = neg_z.powf(-a);
let term1_hyp = hyp2f1_series_accelerated(a, a - c + F::one(), a - b + F::one(), z_inv)?;
let term2_coeff = gamma(c) * gamma(a - b) / (gamma(a) * gamma(c - b));
let term2_power = neg_z.powf(-b);
let term2_hyp = hyp2f1_series_accelerated(b, b - c + F::one(), b - a + F::one(), z_inv)?;
Ok(term1_coeff * term1_power * term1_hyp + term2_coeff * term2_power * term2_hyp)
}
#[allow(dead_code)]
fn hyp2f1_analytic_continuation_negative<F>(a: F, b: F, c: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign + SubAssign,
{
let z_inv = F::one() / z;
let neg_z = -z;
let term1_coeff = gamma(c) * gamma(b - a) / (gamma(b) * gamma(c - a));
let term1_power = neg_z.powf(-a);
let term1_hyp = hyp2f1_series_accelerated(a, a - c + F::one(), a - b + F::one(), z_inv)?;
let term2_coeff = gamma(c) * gamma(a - b) / (gamma(a) * gamma(c - b));
let term2_power = neg_z.powf(-b);
let term2_hyp = hyp2f1_series_accelerated(b, b - c + F::one(), b - a + F::one(), z_inv)?;
Ok(term1_coeff * term1_power * term1_hyp + term2_coeff * term2_power * term2_hyp)
}
#[allow(dead_code)]
pub fn hyp1f1_enhanced<F>(a: F, b: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign,
{
let a_f64 = a
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert a to f64".to_string()))?;
let b_f64 = b
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert b to f64".to_string()))?;
let z_f64 = z
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert z to f64".to_string()))?;
if b_f64 <= 0.0 && b_f64.fract() == 0.0 {
return Err(SpecialError::DomainError(format!(
"b must not be 0 or negative integer, got {b_f64}"
)));
}
if z == F::zero() {
return Ok(F::one());
}
if z_f64 < -20.0 {
let exp_z = z.exp();
let transformed = hyp1f1_series(b - a, b, -z)?;
return Ok(exp_z * transformed);
}
if z_f64 > 50.0 {
return hyp1f1_asymptotic(a, b, z);
}
hyp1f1_series(a, b, z)
}
#[allow(dead_code)]
fn hyp1f1_series<F>(a: F, b: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign,
{
let mut sum = F::one();
let mut term = F::one();
for n in 1..MAX_SERIES_TERMS {
let n_f = const_f64::<F>(n as f64);
let n_minus_1 = const_f64::<F>((n - 1) as f64);
term = term * (a + n_minus_1) * z / ((b + n_minus_1) * n_f);
sum += term;
if term.abs() < const_f64::<F>(CONVERGENCE_TOL) * sum.abs() {
return Ok(sum);
}
}
Ok(sum)
}
#[allow(dead_code)]
fn hyp1f1_asymptotic<F>(a: F, b: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign,
{
let gamma_b = gamma(b);
let gamma_a = gamma(a);
let exp_z = z.exp();
let z_power = z.powf(a - b);
let leading = gamma_b / gamma_a * exp_z * z_power;
let correction = (b - a) * (F::one() - a) / z;
Ok(leading * (F::one() + correction))
}
#[allow(dead_code)]
pub fn hyp0f1_enhanced<F>(a: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign,
{
let a_f64 = a
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert a to f64".to_string()))?;
if a_f64 <= 0.0 && a_f64.fract() == 0.0 {
return Err(SpecialError::DomainError(format!(
"a must not be 0 or negative integer, got {a_f64}"
)));
}
if z == F::zero() {
return Ok(F::one());
}
let mut sum = F::one();
let mut term = F::one();
for n in 1..MAX_SERIES_TERMS {
let n_f = const_f64::<F>(n as f64);
let n_minus_1 = const_f64::<F>((n - 1) as f64);
term = term * z / ((a + n_minus_1) * n_f);
sum += term;
if term.abs() < const_f64::<F>(CONVERGENCE_TOL) * sum.abs() {
return Ok(sum);
}
}
Ok(sum)
}
#[allow(dead_code)]
pub fn hyp2f1_regularized<F>(a: F, b: F, c: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign + SubAssign,
{
let c_f64 = c
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert c to f64".to_string()))?;
if c_f64 <= 0.0 && c_f64.fract().abs() < 1e-10 {
return hyp2f1_regularized_at_pole(a, b, c, z);
}
let gamma_c = gamma(c);
let hyp = hyp2f1_enhanced(a, b, c, z)?;
Ok(hyp / gamma_c)
}
#[allow(dead_code)]
fn hyp2f1_regularized_at_pole<F>(a: F, b: F, c: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign,
{
let c_f64 = c
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert c to f64".to_string()))?;
let m = (-c_f64).round() as usize;
let mut sum = F::zero();
for n in 0..MAX_SERIES_TERMS {
let n_f = const_f64::<F>(n as f64);
if n < m + 1 {
continue;
}
let a_n = pochhammer_n(a, n);
let b_n = pochhammer_n(b, n);
let n_factorial = factorial_n(n);
let c_n = pochhammer_n(c, n);
if c_n.abs() < const_f64::<F>(1e-100) {
continue;
}
let term = a_n * b_n * z.powi(n as i32) / (c_n * n_factorial);
sum += term;
if n > 10 && term.abs() < const_f64::<F>(CONVERGENCE_TOL) * sum.abs() {
break;
}
}
Ok(sum)
}
#[allow(dead_code)]
fn pochhammer_n<F>(a: F, n: usize) -> F
where
F: Float + FromPrimitive,
{
if n == 0 {
return F::one();
}
let mut result = a;
for i in 1..n {
result = result * (a + const_f64::<F>(i as f64));
}
result
}
#[allow(dead_code)]
fn factorial_n<F>(n: usize) -> F
where
F: Float + FromPrimitive,
{
if n <= 1 {
return F::one();
}
let mut result = F::one();
for i in 2..=n {
result = result * const_f64::<F>(i as f64);
}
result
}
#[allow(dead_code)]
pub fn hyp1f1_regularized<F>(a: F, b: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign + SubAssign,
{
let b_f64 = b
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert b to f64".to_string()))?;
if b_f64 <= 0.0 && b_f64.fract().abs() < 1e-10 {
return hyp1f1_regularized_at_pole(a, b, z);
}
let gamma_b = gamma(b);
if gamma_b.abs() < const_f64::<F>(1e-300) {
return Ok(F::zero());
}
let hyp = hyp1f1_enhanced(a, b, z)?;
Ok(hyp / gamma_b)
}
fn hyp1f1_regularized_at_pole<F>(a: F, b: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign,
{
let b_f64 = b
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert b to f64".to_string()))?;
let m = (-b_f64).round() as usize;
let mut sum = F::zero();
for n in (m + 1)..MAX_SERIES_TERMS {
let a_n = pochhammer_n(a, n);
let b_n = pochhammer_n(b, n);
let n_fact = factorial_n::<F>(n);
if b_n.abs() < const_f64::<F>(1e-300) {
continue;
}
let term = a_n * z.powi(n as i32) / (b_n * n_fact);
sum += term;
if n > m + 5 && term.abs() < const_f64::<F>(CONVERGENCE_TOL) * sum.abs() {
break;
}
}
Ok(sum)
}
#[allow(dead_code)]
pub fn whittaker_m<F>(kappa: F, mu: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign + SubAssign,
{
let z_f64 = z
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert z to f64".to_string()))?;
if z_f64 <= 0.0 {
return Err(SpecialError::DomainError(
"z must be > 0 for Whittaker M function".to_string(),
));
}
let half = const_f64::<F>(0.5);
let one = F::one();
let two = const_f64::<F>(2.0);
let a = mu - kappa + half;
let b = two * mu + one;
let exp_factor = (-z * half).exp();
let z_power = z.powf(mu + half);
let hyp = hyp1f1_enhanced(a, b, z)?;
Ok(exp_factor * z_power * hyp)
}
#[allow(dead_code)]
pub fn whittaker_w<F>(kappa: F, mu: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign + SubAssign,
{
let z_f64 = z
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert z to f64".to_string()))?;
if z_f64 <= 0.0 {
return Err(SpecialError::DomainError(
"z must be > 0 for Whittaker W function".to_string(),
));
}
let half = const_f64::<F>(0.5);
let one = F::one();
let two = const_f64::<F>(2.0);
let a = mu - kappa + half;
let b = two * mu + one;
let exp_factor = (-z * half).exp();
let z_power = z.powf(mu + half);
let u_val = crate::hypergeometric::hyperu(a, b, z)?;
Ok(exp_factor * z_power * u_val)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_hyp2f1_enhanced_zero() {
let result: f64 = hyp2f1_enhanced(1.0, 2.0, 3.0, 0.0).expect("test should succeed");
assert_relative_eq!(result, 1.0, epsilon = 1e-14);
}
#[test]
fn test_hyp2f1_enhanced_small_z() {
let result: f64 = hyp2f1_enhanced(1.0, 2.0, 3.0, 0.5).expect("test should succeed");
assert_relative_eq!(result, 1.545177444479562, epsilon = 1e-10);
}
#[test]
fn test_hyp2f1_enhanced_at_one() {
let result: f64 = hyp2f1_enhanced(0.5, 1.0, 3.0, 1.0).expect("test should succeed");
assert!(result.is_finite());
}
#[test]
fn test_hyp2f1_enhanced_terminating() {
let result: f64 = hyp2f1_enhanced(-2.0, 3.0, 4.0, 0.5).expect("test should succeed");
assert_relative_eq!(result, 0.4, epsilon = 1e-10);
}
#[test]
fn test_hyp1f1_enhanced_zero() {
let result: f64 = hyp1f1_enhanced(1.0, 2.0, 0.0).expect("test should succeed");
assert_relative_eq!(result, 1.0, epsilon = 1e-14);
}
#[test]
fn test_hyp1f1_enhanced_small_z() {
let result: f64 = hyp1f1_enhanced(1.0, 2.0, 0.5).expect("test should succeed");
assert!(result > 1.0 && result < 2.0);
}
#[test]
fn test_hyp0f1_enhanced() {
let result: f64 = hyp0f1_enhanced(1.0, 0.0).expect("test should succeed");
assert_relative_eq!(result, 1.0, epsilon = 1e-14);
let result2: f64 = hyp0f1_enhanced(1.0, 1.0).expect("test should succeed");
assert!(result2 > 1.0);
}
#[test]
fn test_levin_transform() {
let partial_sums: Vec<f64> = vec![1.0, 1.5, 1.833, 2.083, 2.283, 2.45, 2.593];
let result = levin_u_transform(&partial_sums).expect("test should succeed");
assert!(result.is_finite());
}
#[test]
fn test_hyp1f1_regularized_at_zero() {
let result: f64 = hyp1f1_regularized(1.0, 2.0, 0.0).expect("should succeed");
assert_relative_eq!(result, 1.0, epsilon = 1e-10);
}
#[test]
fn test_hyp1f1_regularized_known_value() {
let result: f64 = hyp1f1_regularized(1.0, 1.0, 0.5).expect("should succeed");
let direct: f64 = hyp1f1_enhanced(1.0, 1.0, 0.5).expect("should succeed");
assert_relative_eq!(result, direct, epsilon = 1e-10);
}
#[test]
fn test_hyp1f1_regularized_large_b() {
let result: f64 = hyp1f1_regularized(1.0, 10.0, 0.5).expect("should succeed");
assert!(
result.abs() < 1.0,
"regularized should be small for large b: {result}"
);
}
#[test]
fn test_hyp1f1_regularized_b_half() {
let result: f64 = hyp1f1_regularized(1.0, 0.5, 0.0).expect("should succeed");
let expected = 1.0 / std::f64::consts::PI.sqrt();
assert_relative_eq!(result, expected, epsilon = 1e-10);
}
#[test]
fn test_hyp1f1_regularized_finite() {
let result: f64 = hyp1f1_regularized(0.5, 3.0, 1.0).expect("should succeed");
assert!(result.is_finite(), "regularized hyp1f1 should be finite");
}
#[test]
fn test_whittaker_m_basic() {
let result = whittaker_m(0.5_f64, 0.5, 1.0).expect("should succeed");
assert!(
result.is_finite(),
"M_{{0.5,0.5}}(1) should be finite: {result}"
);
}
#[test]
fn test_whittaker_m_zero_kappa() {
let result = whittaker_m(0.0_f64, 0.5, 2.0).expect("should succeed");
assert!(result.is_finite());
}
#[test]
fn test_whittaker_m_large_z() {
let result = whittaker_m(0.5_f64, 0.5, 10.0).expect("should succeed");
assert!(result.is_finite());
}
#[test]
fn test_whittaker_m_negative_z_error() {
let result = whittaker_m(0.5_f64, 0.5, -1.0);
assert!(result.is_err(), "negative z should error");
}
#[test]
fn test_whittaker_m_positive() {
let result = whittaker_m(1.0_f64, 0.5, 0.1).expect("should succeed");
assert!(result.is_finite());
assert!(result > 0.0, "M should be positive for these params");
}
#[test]
fn test_whittaker_w_basic() {
let result = whittaker_w(0.5_f64, 0.5, 1.0).expect("should succeed");
assert!(
result.is_finite(),
"W_{{0.5,0.5}}(1) should be finite: {result}"
);
}
#[test]
fn test_whittaker_w_large_z() {
let result = whittaker_w(0.5_f64, 0.5, 10.0).expect("should succeed");
assert!(result.is_finite());
assert!(result.abs() < 1.0, "W should decay for large z: {result}");
}
#[test]
fn test_whittaker_w_negative_z_error() {
let result = whittaker_w(0.5_f64, 0.5, -1.0);
assert!(result.is_err(), "negative z should error");
}
#[test]
fn test_whittaker_w_moderate() {
let result = whittaker_w(1.0_f64, 0.5, 2.0).expect("should succeed");
assert!(result.is_finite());
}
#[test]
fn test_whittaker_w_small_z() {
let result = whittaker_w(0.5_f64, 0.5, 0.1).expect("should succeed");
assert!(result.is_finite());
}
}