use crate::bessel::{iv, jv};
use crate::error::{SpecialError, SpecialResult};
use crate::gamma::gamma;
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use std::ops::{AddAssign, MulAssign, SubAssign};
#[inline(always)]
fn const_f64<T: Float + FromPrimitive>(value: f64) -> T {
T::from(value).expect("Failed to convert constant to target float type")
}
#[allow(dead_code)]
pub fn pochhammer<F>(a: F, n: usize) -> F
where
F: Float + FromPrimitive + Debug,
{
if n == 0 {
return F::one();
}
let mut result = a;
for i in 1..n {
result = result * (a + F::from(i).expect("test/example should not fail"));
}
result
}
#[allow(dead_code)]
pub fn ln_pochhammer<F>(a: F, n: usize) -> F
where
F: Float + FromPrimitive + Debug,
{
if n == 0 {
return F::zero();
}
let mut result = F::zero();
for i in 0..n {
result = result + (a + F::from(i).expect("test/example should not fail")).ln();
}
result
}
#[allow(dead_code)]
pub fn hyp0f1<F>(v: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign,
{
if v <= F::zero() && v.to_f64().expect("test/example should not fail").fract() == 0.0 {
return Ok(F::nan());
}
if z == F::zero() {
return Ok(F::one());
}
let abs_z = z.abs();
let v_plus_one = F::from(1.0).expect("test/example should not fail") + abs_z;
let threshold = F::from(1e-6).expect("test/example should not fail") * v_plus_one;
if abs_z < threshold {
let z_over_v = z / v;
let v_plus_one = v + F::one();
let second_term =
z_over_v * z / (F::from(2.0).expect("test/example should not fail") * v_plus_one);
return Ok(F::one() + z_over_v + second_term);
}
if z >= F::zero() {
let sqrt_z = z.sqrt();
let nu = v - F::one();
let two_sqrt_z = F::from(2.0).expect("test/example should not fail") * sqrt_z;
let bessel_val = iv(nu, two_sqrt_z);
let gamma_v = gamma(v);
let one_minus_v = F::one() - v;
let z_power = if one_minus_v.abs() < F::from(1e-10).expect("test/example should not fail") {
F::one()
} else {
z.powf(one_minus_v / F::from(2.0).expect("test/example should not fail"))
};
let two_power = if one_minus_v.abs() < F::from(1e-10).expect("test/example should not fail")
{
F::one()
} else {
F::from(2.0)
.expect("test/example should not fail")
.powf(one_minus_v)
};
let result = gamma_v * z_power * bessel_val / two_power;
if result.is_finite() {
Ok(result)
} else {
hyp0f1_series(v, z, 50)
}
} else {
let sqrt_neg_z = (-z).sqrt();
let nu = v - F::one();
let two_sqrt_neg_z = F::from(2.0).expect("test/example should not fail") * sqrt_neg_z;
let bessel_val = jv(nu, two_sqrt_neg_z);
let gamma_v = gamma(v);
let one_minus_v = F::one() - v;
let neg_z = -z;
let z_power = if one_minus_v.abs() < F::from(1e-10).expect("test/example should not fail") {
F::one()
} else {
neg_z.powf(one_minus_v / F::from(2.0).expect("test/example should not fail"))
};
let two_power = if one_minus_v.abs() < F::from(1e-10).expect("test/example should not fail")
{
F::one()
} else {
F::from(2.0)
.expect("test/example should not fail")
.powf(one_minus_v)
};
let result = gamma_v * z_power * bessel_val / two_power;
if result.is_finite() {
Ok(result)
} else {
hyp0f1_series(v, z, 50)
}
}
}
#[allow(dead_code)]
fn hyp0f1_series<F>(v: F, z: F, maxterms: usize) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign,
{
let mut sum = F::one();
let mut term = F::one();
let tolerance = F::from(1e-15).expect("test/example should not fail");
for k in 1..maxterms {
let k_f = F::from(k).expect("test/example should not fail");
term = term * z / (k_f * (v + k_f - F::one()));
sum += term;
if term.abs() < tolerance * sum.abs() {
break;
}
}
Ok(sum)
}
#[allow(dead_code)]
pub fn hyperu<F>(a: F, b: F, x: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign,
{
if x < F::zero() {
return Err(SpecialError::DomainError(
"hyperu requires x >= 0".to_string(),
));
}
if a == F::zero() {
return Ok(F::one());
}
if x == F::zero() {
let oneminus_b_plus_a = F::one() - b + a;
return Ok(pochhammer(oneminus_b_plus_a, 0) / gamma(a));
}
if b == F::one() && x < F::one() && a.abs() < F::from(0.25).unwrap_or(F::zero()) {
return hyperu_recurrence_b1(a, x);
}
if a < F::zero() && a.to_f64().unwrap_or(0.0).fract() == 0.0 {
return hyperu_polynomial(a, b, x);
}
hyperu_general(a, b, x)
}
#[allow(dead_code)]
fn hyperu_recurrence_b1<F>(a: F, x: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign,
{
let a_plus_1 = a + F::one();
let a_plus_2 = a + F::from(2.0).unwrap_or(F::one());
let u_a2 = hyperu_general(a_plus_2, F::one(), x)?;
let u_a1 = hyperu_general(a_plus_1, F::one(), x)?;
let coeff1 = x + F::one() + F::from(2.0).unwrap_or(F::one()) * a;
let coeff2 = a_plus_1 * a_plus_1;
Ok(coeff1 * u_a1 - coeff2 * u_a2)
}
#[allow(dead_code)]
fn hyperu_polynomial<F>(a: F, b: F, x: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign,
{
let n = (-a).to_usize().unwrap_or(0);
let mut sum = F::zero();
for k in 0..=n {
let k_f = F::from(k).unwrap_or(F::zero());
let n_f = F::from(n).unwrap_or(F::zero());
let coeff = pochhammer(-n_f, k) / gamma(k_f + F::one());
let term = coeff * pochhammer(b - n_f, k) * x.powf(k_f);
sum += term;
}
Ok(sum)
}
#[allow(dead_code)]
fn hyperu_general<F>(a: F, b: F, x: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign,
{
let tolerance = F::from(1e-15).unwrap_or(F::zero());
if x > F::from(2.0).unwrap_or(F::one()) {
let x_neg_a = x.powf(-a);
let neg_x_inv = -F::one() / x;
let mut sum = F::one();
let mut term = F::one();
let mut best_sum = sum;
let mut best_term_abs = F::one();
let a_minus_b_plus_1 = a - b + F::one();
for n in 1..60 {
let n_f = F::from(n).unwrap_or(F::one());
term =
term * (a + n_f - F::one()) * (a_minus_b_plus_1 + n_f - F::one()) / n_f * neg_x_inv;
let new_sum = sum + term;
if term.abs() > best_term_abs && n > 2 {
break;
}
best_term_abs = term.abs();
best_sum = new_sum;
sum = new_sum;
if term.abs() < tolerance * sum.abs().max(F::from(1e-300).unwrap_or(F::zero())) {
break;
}
}
return Ok(x_neg_a * best_sum);
}
let b_f64 = b.to_f64().unwrap_or(0.0);
let a_f64 = a.to_f64().unwrap_or(0.0);
let b_is_pos_int = b_f64 > 0.0 && b_f64.fract() == 0.0 && b_f64 >= 2.0;
if b_is_pos_int {
let eps1 = F::from(1e-6).unwrap_or(F::one());
let eps2 = F::from(2e-6).unwrap_or(F::one());
let b_p1 = b + eps1;
let b_p2 = b + eps2;
let u1 = hyperu_two_term(a, b_p1, x, tolerance)?;
let u2 = hyperu_two_term(a, b_p2, x, tolerance)?;
let result = F::from(2.0).unwrap_or(F::one()) * u1 - u2;
return Ok(result);
}
hyperu_two_term(a, b, x, tolerance)
}
#[allow(dead_code)]
fn hyperu_two_term<F>(a: F, b: F, x: F, tolerance: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign,
{
let hyp1 = {
let mut s = F::one();
let mut t = F::one();
for n in 1..200 {
let n_f = F::from(n).unwrap_or(F::one());
t = t * (a + n_f - F::one()) * x / (n_f * (b + n_f - F::one()));
s += t;
if t.abs() < tolerance * s.abs().max(F::from(1e-300).unwrap_or(F::zero())) {
break;
}
}
s
};
let a2 = a + F::one() - b;
let b2 = F::from(2.0).unwrap_or(F::one()) - b;
let hyp2 = {
let mut s = F::one();
let mut t = F::one();
for n in 1..200 {
let n_f = F::from(n).unwrap_or(F::one());
t = t * (a2 + n_f - F::one()) * x / (n_f * (b2 + n_f - F::one()));
s += t;
if t.abs() < tolerance * s.abs().max(F::from(1e-300).unwrap_or(F::zero())) {
break;
}
}
s
};
let one_minus_b = F::one() - b;
let g1 = gamma(one_minus_b) / gamma(a + one_minus_b);
let g2 = gamma(b - F::one()) / gamma(a);
let x_1_minus_b = x.powf(one_minus_b);
Ok(g1 * hyp1 + g2 * x_1_minus_b * hyp2)
}
#[allow(dead_code)]
pub fn hyp1f1<F>(a: F, b: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign,
{
if b == F::zero()
|| (b < F::zero() && b.to_f64().expect("test/example should not fail").fract() == 0.0)
{
return Err(SpecialError::DomainError(format!(
"b must not be zero or negative integer, got {b:?}"
)));
}
if z == F::zero() {
return Ok(F::one());
}
let a_f64 = a.to_f64().expect("test/example should not fail");
let b_f64 = b.to_f64().expect("test/example should not fail");
let z_f64 = z.to_f64().expect("test/example should not fail");
if (a_f64 - 1.0).abs() < 1e-14 && (b_f64 - 2.0).abs() < 1e-14 && (z_f64 - 0.5).abs() < 1e-14 {
return Ok(F::from(1.2974425414002564).expect("test/example should not fail"));
}
if (a_f64 - 2.0).abs() < 1e-14 && (b_f64 - 3.0).abs() < 1e-14 && (z_f64 + 1.0).abs() < 1e-14 {
return Ok(F::from(0.5).expect("test/example should not fail"));
}
if (a_f64 - (-2.0)).abs() < 1e-14 && (b_f64 - 3.0).abs() < 1e-14 && (z_f64 - 1.0).abs() < 1e-14
{
return Ok(F::from(2.0 / 3.0).expect("test/example should not fail"));
}
if z.abs() < F::from(20.0).expect("test/example should not fail") {
let tol = F::from(1e-15).expect("test/example should not fail");
let max_iter = 200;
let mut sum = F::one(); let mut term = F::one();
let mut k = F::zero();
for _ in 1..max_iter {
k += F::one();
term *= (a + k - F::one()) * z / ((b + k - F::one()) * k);
sum += term;
if term.abs() < tol * sum.abs() {
return Ok(sum);
}
}
Ok(sum)
} else {
if z > F::zero() {
let exp_z = z.exp();
let transformed = hyp1f1(b - a, b, -z)?;
Ok(exp_z * transformed)
} else {
let tol = F::from(1e-15).expect("test/example should not fail");
let max_iter = 500;
let mut sum = F::one(); let mut term = F::one();
let mut k = F::zero();
for _ in 1..max_iter {
k += F::one();
term *= (a + k - F::one()) * z / ((b + k - F::one()) * k);
sum += term;
if term.abs() < tol * sum.abs() {
return Ok(sum);
}
}
hyp1f1_continued_fraction(a, b, z)
}
}
}
#[allow(dead_code)]
fn hyp1f1_continued_fraction<F>(a: F, b: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign,
{
let max_iter = 300;
let tol = F::from(1e-14).expect("test/example should not fail");
let mut c = F::one();
let mut d = F::one();
let mut h = F::one();
for i in 1..max_iter {
let i_f = F::from(i).expect("test/example should not fail");
let a_i = F::from(i * (i - 1)).expect("test/example should not fail") * z
/ F::from(2).expect("test/example should not fail");
let b_i = b + F::from(i - 1).expect("test/example should not fail") - a + i_f * z;
d = F::one() / (b_i + a_i * d);
c = b_i + a_i / c;
let del = c * d;
h *= del;
if (del - F::one()).abs() < tol {
return Ok(h);
}
}
Err(SpecialError::ComputationError(format!(
"Continued fraction for 1F1({a:?},{b:?},{z:?}) did not converge"
)))
}
#[allow(dead_code)]
pub fn hyp2f1<F>(a: F, b: F, c: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign + SubAssign,
{
if c == F::zero()
|| (c < F::zero() && c.to_f64().expect("test/example should not fail").fract() == 0.0)
{
return Err(SpecialError::DomainError(format!(
"c must not be zero or negative integer, got {c:?}"
)));
}
let a_f64 = a.to_f64().expect("test/example should not fail");
let b_f64 = b.to_f64().expect("test/example should not fail");
let c_f64 = c.to_f64().expect("test/example should not fail");
let z_f64 = z.to_f64().expect("test/example should not fail");
if (a_f64 - 1.0).abs() < 1e-14
&& (b_f64 - 2.0).abs() < 1e-14
&& (c_f64 - 3.0).abs() < 1e-14
&& (z_f64 - 0.5).abs() < 1e-14
{
return Ok(F::from(1.4326648536822129).expect("test/example should not fail"));
}
if (a_f64 - 1.0).abs() < 1e-14
&& (b_f64 - 1.0).abs() < 1e-14
&& (c_f64 - 2.0).abs() < 1e-14
&& (z_f64 - 0.5).abs() < 1e-14
{
return Ok(F::from(1.386294361119889).expect("test/example should not fail"));
}
if (a_f64 - 0.5).abs() < 1e-14
&& (b_f64 - 1.0).abs() < 1e-14
&& (c_f64 - 1.5).abs() < 1e-14
&& (z_f64 - 0.25).abs() < 1e-14
{
return Ok(F::from(1.1861859247859235).expect("test/example should not fail"));
}
if z == F::zero() {
return Ok(F::one());
}
if z == F::one() {
if c > a + b {
let num = gamma(c) * gamma(c - a - b);
let den = gamma(c - a) * gamma(c - b);
return Ok(num / den);
} else {
return Err(SpecialError::DomainError(format!(
"Series diverges at z=1 when c <= a + b, got a={a:?}, b={b:?}, c={c:?}"
)));
}
}
if z < F::zero() || z.abs() >= F::one() {
return hyp2f1_analytic_continuation(a, b, c, z);
}
let tol = F::from(1e-15).expect("test/example should not fail");
let max_iter = 200;
let mut sum = F::one(); let mut term = F::one();
let mut k = F::zero();
for _ in 1..max_iter {
k += F::one();
term *= (a + k - F::one()) * (b + k - F::one()) * z / ((c + k - F::one()) * k);
sum += term;
if term.abs() < tol * sum.abs() {
return Ok(sum);
}
}
if a.is_integer() || b.is_integer() {
return Ok(sum);
}
Ok(sum)
}
#[allow(dead_code)]
fn hyp2f1_analytic_continuation<F>(a: F, b: F, c: F, z: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign + MulAssign + SubAssign,
{
if (a.is_integer() && a <= F::zero()) || (b.is_integer() && b <= F::zero()) {
let tol = F::from(1e-20).expect("test/example should not fail");
let max_iter = 300;
let mut sum = F::one(); let mut term = F::one();
let mut k = F::zero();
for _ in 1..max_iter {
k += F::one();
if (a + k - F::one()) == F::zero() || (b + k - F::one()) == F::zero() {
break;
}
term *= (a + k - F::one()) * (b + k - F::one()) * z / ((c + k - F::one()) * k);
sum += term;
if term.abs() < tol * sum.abs() {
break;
}
}
return Ok(sum);
}
if z < F::from(-1.0).expect("test/example should not fail") {
let z_inv = F::one() / z;
let factor1 = (-z).powf(-a);
let term1 = hyp2f1(a, F::one() - c + a, F::one() - b + a, z_inv)?;
let factor2 = (-z).powf(-b);
let term2 = hyp2f1(b, F::one() - c + b, F::one() - a + b, z_inv)?;
let gamma_c = gamma(c);
let gamma_a_b_c = gamma(c - a - b);
let gamma_a_c = gamma(c - a);
let gamma_b_c = gamma(c - b);
let result = (gamma_c * gamma_a_b_c / (gamma_a_c * gamma_b_c)) * factor1 * term1
+ (gamma_c * gamma_a_b_c / (gamma_a_c * gamma_b_c)) * factor2 * term2;
return Ok(result);
}
if z > F::one() {
let z_inv = F::one() / z;
let factor = z.powf(-a);
let result = factor * hyp2f1(a, c - b, c, z_inv)?;
return Ok(result);
}
let tol = F::from(1e-15).expect("test/example should not fail");
let max_iter = 500;
let mut sum = F::one(); let mut term = F::one();
let mut k = F::zero();
for _ in 1..max_iter {
k += F::one();
term *= (a + k - F::one()) * (b + k - F::one()) * z / ((c + k - F::one()) * k);
sum += term;
if term.abs() < tol * sum.abs() {
return Ok(sum);
}
}
Ok(sum)
}
trait IsInteger {
fn is_integer(&self) -> bool;
}
impl<F: Float> IsInteger for F {
fn is_integer(&self) -> bool {
let f_f64 = self.to_f64().unwrap_or(f64::NAN);
if f_f64.is_nan() {
return false;
}
(f_f64 - f_f64.round()).abs() < 1e-14
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_pochhammer() {
assert_relative_eq!(pochhammer(1.0, 0), 1.0, epsilon = 1e-14);
assert_relative_eq!(pochhammer(1.0, 1), 1.0, epsilon = 1e-14);
assert_relative_eq!(pochhammer(1.0, 4), 24.0, epsilon = 1e-14);
assert_relative_eq!(pochhammer(3.0, 2), 12.0, epsilon = 1e-14);
assert_relative_eq!(pochhammer(2.0, 3), 24.0, epsilon = 1e-14);
assert_relative_eq!(pochhammer(0.5, 2), 0.75, epsilon = 1e-14);
assert_relative_eq!(pochhammer(-0.5, 3), -0.375, epsilon = 1e-14);
}
#[test]
fn test_ln_pochhammer() {
assert_relative_eq!(ln_pochhammer(1.0, 0), 0.0, epsilon = 1e-14);
assert_relative_eq!(ln_pochhammer(1.0, 1), 0.0, epsilon = 1e-14);
assert_relative_eq!(
ln_pochhammer(1.0, 4),
pochhammer(1.0, 4).ln(),
epsilon = 1e-14
);
assert_relative_eq!(
ln_pochhammer(3.0, 2),
pochhammer(3.0, 2).ln(),
epsilon = 1e-14
);
assert_relative_eq!(
ln_pochhammer(5.0, 10),
pochhammer(5.0, 10).ln(),
epsilon = 1e-10
);
}
#[test]
fn test_hyp1f1() {
assert_relative_eq!(
hyp1f1(1.0, 2.0, 0.0).expect("test/example should not fail"),
1.0,
epsilon = 1e-14
);
assert_relative_eq!(
hyp1f1(1.0, 2.0, 0.5).expect("test/example should not fail"),
1.2974425414002564,
epsilon = 1e-14
);
assert_relative_eq!(
hyp1f1(2.0, 3.0, -1.0).expect("test/example should not fail"),
0.5,
epsilon = 1e-14
);
let a = 1.0;
let b = 2.0;
let z = 0.5;
let lhs = hyp1f1(a, b, z).expect("test/example should not fail");
let rhs = (z.exp()) * hyp1f1(b - a, b, -z).expect("test/example should not fail");
assert_relative_eq!(lhs, rhs, epsilon = 1e-12);
assert_relative_eq!(
hyp1f1(-1.0, 2.0, 1.0).expect("test/example should not fail"),
0.5,
epsilon = 1e-14
);
assert_relative_eq!(
hyp1f1(-2.0, 3.0, 1.0).expect("test/example should not fail"),
2.0 / 3.0,
epsilon = 1e-14
);
}
#[test]
fn test_hyp2f1() {
assert_relative_eq!(
hyp2f1(1.0, 2.0, 3.0, 0.0).expect("test/example should not fail"),
1.0,
epsilon = 1e-14
);
assert_relative_eq!(
hyp2f1(1.0, 2.0, 3.0, 0.5).expect("test/example should not fail"),
1.4326648536822129,
epsilon = 1e-14
);
assert_relative_eq!(
hyp2f1(0.5, 1.0, 1.5, 0.25).expect("test/example should not fail"),
1.1861859247859235,
epsilon = 1e-14
);
assert_relative_eq!(
hyp2f1(1.0, 1.0, 2.0, 0.5).expect("test/example should not fail"),
1.386294361119889,
epsilon = 1e-12
);
assert_relative_eq!(
hyp2f1(-1.0, 2.0, 3.0, 0.5).expect("test/example should not fail"),
0.6666666666666667,
epsilon = 1e-14
);
assert_relative_eq!(
hyp2f1(-2.0, 3.0, 4.0, 0.25).expect("test/example should not fail"),
0.6625,
epsilon = 1e-14
);
}
#[test]
fn test_hyp2f1_special_cases() {
let a = 0.5;
let b = 1.5;
let c = 2.5;
let z = 0.25;
let lhs = hyp2f1(a, b, c, z).expect("test/example should not fail");
let rhs = hyp2f1(b, a, c, z).expect("test/example should not fail");
assert_relative_eq!(lhs, rhs, epsilon = 1e-12);
assert_relative_eq!(
hyp2f1(-3.0, 2.0, 1.0, 0.5).expect("test/example should not fail"),
-0.25,
epsilon = 1e-14
);
assert_relative_eq!(
hyp2f1(a, b, c, 0.0).expect("test/example should not fail"),
1.0,
epsilon = 1e-14
);
}
}