use numra_core::Scalar;
pub fn hyp1f1<S: Scalar>(a: S, b: S, z: S) -> S {
let af = a.to_f64();
let bf = b.to_f64();
let zf = z.to_f64();
if bf <= 0.0 && bf == bf.floor() {
return S::NAN;
}
S::from_f64(hyp1f1_f64(af, bf, zf))
}
pub fn hyp2f1<S: Scalar>(a: S, b: S, c: S, z: S) -> S {
let af = a.to_f64();
let bf = b.to_f64();
let cf = c.to_f64();
let zf = z.to_f64();
S::from_f64(hyp2f1_f64(af, bf, cf, zf))
}
fn hyp1f1_f64(a: f64, b: f64, z: f64) -> f64 {
if z < -10.0 {
return z.exp() * hyp1f1_series(b - a, b, -z);
}
hyp1f1_series(a, b, z)
}
fn hyp1f1_series(a: f64, b: f64, z: f64) -> f64 {
let max_terms = 300;
let eps = 1e-15;
let mut sum = 1.0;
let mut term = 1.0;
for k in 0..max_terms {
let kf = k as f64;
term *= (a + kf) * z / ((b + kf) * (kf + 1.0));
sum += term;
if term.abs() < eps * sum.abs() {
break;
}
}
sum
}
fn hyp2f1_f64(a: f64, b: f64, c: f64, z: f64) -> f64 {
if (a <= 0.0 && a == a.floor()) || (b <= 0.0 && b == b.floor()) {
return hyp2f1_series(a, b, c, z);
}
if z.abs() < 0.5 {
hyp2f1_series(a, b, c, z)
} else if z.abs() < 1.0 {
let z_new = z / (z - 1.0);
if z_new.abs() < 0.5 {
(1.0 - z).powf(-a) * hyp2f1_series(a, c - b, c, z_new)
} else {
(1.0 - z).powf(c - a - b) * hyp2f1_series(c - a, c - b, c, z)
}
} else if z == 1.0 {
if c > a + b {
let num = libm::lgamma(c) + libm::lgamma(c - a - b);
let den = libm::lgamma(c - a) + libm::lgamma(c - b);
(num - den).exp()
} else {
f64::INFINITY
}
} else {
if z > 1.0 {
f64::NAN
} else {
let z_new = z / (z - 1.0);
(1.0 - z).powf(-a) * hyp2f1_f64(a, c - b, c, z_new)
}
}
}
fn hyp2f1_series(a: f64, b: f64, c: f64, z: f64) -> f64 {
let max_terms = 500;
let eps = 1e-14;
let mut sum = 1.0;
let mut term = 1.0;
for k in 0..max_terms {
let kf = k as f64;
term *= (a + kf) * (b + kf) * z / ((c + kf) * (kf + 1.0));
sum += term;
if term.abs() < eps * sum.abs() {
break;
}
if term == 0.0 {
break;
}
}
sum
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_hyp1f1_zero_z() {
assert_relative_eq!(hyp1f1(2.0_f64, 3.0_f64, 0.0_f64), 1.0, epsilon = 1e-14);
}
#[test]
fn test_hyp1f1_exp() {
let z = 2.0_f64;
assert_relative_eq!(hyp1f1(3.0_f64, 3.0_f64, z), z.exp(), epsilon = 1e-10);
}
#[test]
fn test_hyp1f1_known() {
let z = 1.0_f64;
let expected = (z.exp() - 1.0) / z;
assert_relative_eq!(hyp1f1(1.0_f64, 2.0_f64, z), expected, epsilon = 1e-10);
}
#[test]
fn test_hyp1f1_negative_z() {
assert_relative_eq!(
hyp1f1(1.0_f64, 1.0_f64, -5.0_f64),
(-5.0_f64).exp(),
epsilon = 1e-10
);
}
#[test]
fn test_hyp2f1_zero() {
assert_relative_eq!(
hyp2f1(1.0_f64, 2.0_f64, 3.0_f64, 0.0_f64),
1.0,
epsilon = 1e-14
);
}
#[test]
fn test_hyp2f1_gauss_sum() {
assert_relative_eq!(
hyp2f1(1.0_f64, 1.0_f64, 3.0_f64, 1.0_f64),
2.0,
epsilon = 1e-10
);
}
#[test]
fn test_hyp2f1_geometric() {
let z = 0.3_f64;
assert_relative_eq!(
hyp2f1(1.0_f64, 1.0_f64, 1.0_f64, z),
1.0 / (1.0 - z),
epsilon = 1e-10
);
}
#[test]
fn test_hyp2f1_negative_z() {
assert_relative_eq!(
hyp2f1(1.0_f64, 1.0_f64, 2.0_f64, -1.0_f64),
2.0_f64.ln(),
epsilon = 1e-10
);
}
#[test]
fn test_hypergeometric_f32() {
let r = hyp1f1(1.0_f32, 1.0_f32, 1.0_f32);
assert!((r.to_f64() - 1.0_f64.exp()).abs() < 1e-4);
}
}