use super::gamma_functions::lgamma_scalar;
const MAX_SERIES_TERMS: usize = 500;
const EPSILON: f64 = 1e-15;
const SMALL_Z: f64 = 0.5;
pub fn hyp2f1_scalar(a: f64, b: f64, c: f64, z: f64) -> f64 {
if z.is_nan() || a.is_nan() || b.is_nan() || c.is_nan() {
return f64::NAN;
}
if c <= 0.0 && c == c.floor() {
if !((a < 0.0 && a == a.floor() && a >= c) || (b < 0.0 && b == b.floor() && b >= c)) {
return f64::NAN;
}
}
if z == 0.0 {
return 1.0;
}
if a == 0.0 || (a < 0.0 && a == a.floor()) {
return hyp2f1_polynomial(a, b, c, z);
}
if b == 0.0 || (b < 0.0 && b == b.floor()) {
return hyp2f1_polynomial(b, a, c, z);
}
if z.abs() >= 1.0 {
if z == 1.0 {
let cab = c - a - b;
if cab > 0.0 {
return (lgamma_scalar(c) + lgamma_scalar(cab)
- lgamma_scalar(c - a)
- lgamma_scalar(c - b))
.exp();
}
}
return hyp2f1_transform(a, b, c, z);
}
if z.abs() < SMALL_Z {
return hyp2f1_series(a, b, c, z);
}
hyp2f1_transform(a, b, c, z)
}
fn hyp2f1_series(a: f64, b: f64, c: f64, z: f64) -> f64 {
let mut sum = 1.0;
let mut term = 1.0;
for n in 1..MAX_SERIES_TERMS {
let n_f = n as f64;
term *= (a + n_f - 1.0) * (b + n_f - 1.0) / ((c + n_f - 1.0) * n_f) * z;
sum += term;
if term.abs() < EPSILON * sum.abs() {
break;
}
}
sum
}
fn hyp2f1_polynomial(a: f64, b: f64, c: f64, z: f64) -> f64 {
let n = (-a) as i32;
if n < 0 {
return 1.0; }
let mut sum = 1.0;
let mut term = 1.0;
for k in 1..=n {
let k_f = k as f64;
term *= (a + k_f - 1.0) * (b + k_f - 1.0) / ((c + k_f - 1.0) * k_f) * z;
sum += term;
}
sum
}
fn hyp2f1_transform(a: f64, b: f64, c: f64, z: f64) -> f64 {
if z > 0.0 && z < 1.0 && z >= SMALL_Z {
let z_new = z / (z - 1.0);
let prefactor = (1.0 - z).powf(-a);
if z_new.abs() < SMALL_Z {
return prefactor * hyp2f1_series(a, c - b, c, z_new);
}
let euler_prefactor = (1.0 - z).powf(c - a - b);
return euler_prefactor * hyp2f1_series(c - a, c - b, c, z);
}
if z < 0.0 {
let z_new = z / (z - 1.0);
let prefactor = (1.0 - z).powf(-a);
return prefactor * hyp2f1_series(a, c - b, c, z_new);
}
if z == 1.0 {
let cab = c - a - b;
if cab > 0.0 {
return (lgamma_scalar(c) + lgamma_scalar(cab)
- lgamma_scalar(c - a)
- lgamma_scalar(c - b))
.exp();
}
return f64::INFINITY;
}
let z_inv = 1.0 / z;
let prefactor = (-z).powf(-a);
prefactor * hyp2f1_series(a, a - c + 1.0, a - b + 1.0, z_inv)
}
pub fn hyp1f1_scalar(a: f64, b: f64, z: f64) -> f64 {
if z.is_nan() || a.is_nan() || b.is_nan() {
return f64::NAN;
}
if b <= 0.0 && b == b.floor() {
if !(a < 0.0 && a == a.floor() && a >= b) {
return f64::NAN;
}
}
if z == 0.0 {
return 1.0;
}
if a == 0.0 {
return 1.0;
}
if a < 0.0 && a == a.floor() {
return hyp1f1_polynomial(a, b, z);
}
if z.abs() > 50.0 {
return hyp1f1_asymp(a, b, z);
}
hyp1f1_series(a, b, z)
}
fn hyp1f1_series(a: f64, b: f64, z: f64) -> f64 {
let mut sum = 1.0;
let mut term = 1.0;
for n in 1..MAX_SERIES_TERMS {
let n_f = n as f64;
term *= (a + n_f - 1.0) / ((b + n_f - 1.0) * n_f) * z;
sum += term;
if term.abs() < EPSILON * sum.abs() {
break;
}
}
sum
}
fn hyp1f1_polynomial(a: f64, b: f64, z: f64) -> f64 {
let n = (-a) as i32;
if n < 0 {
return 1.0;
}
let mut sum = 1.0;
let mut term = 1.0;
for k in 1..=n {
let k_f = k as f64;
term *= (a + k_f - 1.0) / ((b + k_f - 1.0) * k_f) * z;
sum += term;
}
sum
}
fn hyp1f1_asymp(a: f64, b: f64, z: f64) -> f64 {
if z > 0.0 {
let exp_z = z.exp();
let series = hyp1f1_series_large(a, b, z);
if series.is_finite() {
return series;
}
exp_z * hyp1f1_series(b - a, b, -z)
} else {
let exp_z = z.exp();
exp_z * hyp1f1_series(b - a, b, -z)
}
}
fn hyp1f1_series_large(a: f64, b: f64, z: f64) -> f64 {
let mut sum: f64 = 1.0;
let mut term: f64 = 1.0;
let mut max_term: f64 = 1.0;
for n in 1..MAX_SERIES_TERMS {
let n_f = n as f64;
term *= (a + n_f - 1.0) / ((b + n_f - 1.0) * n_f) * z;
max_term = max_term.max(term.abs());
sum += term;
if term.abs() < EPSILON * max_term && term.abs() < EPSILON * sum.abs() {
break;
}
}
if max_term > 1e10 * sum.abs() {
return f64::NAN;
}
sum
}
#[cfg(test)]
mod tests {
use super::*;
const TOL: f64 = 1e-8;
const TOL_LOOSE: f64 = 1e-5;
fn assert_close(a: f64, b: f64, tol: f64, msg: &str) {
let diff = (a - b).abs();
assert!(
diff < tol || (a.is_nan() && b.is_nan()) || (a.is_infinite() && b.is_infinite()),
"{}: expected {}, got {}, diff {}",
msg,
b,
a,
diff
);
}
#[test]
fn test_hyp2f1_special_values() {
assert_close(hyp2f1_scalar(1.0, 2.0, 3.0, 0.0), 1.0, TOL, "2F1(1,2;3;0)");
let z: f64 = 0.3;
let expected = -(1.0 - z).ln() / z;
assert_close(
hyp2f1_scalar(1.0, 1.0, 2.0, z),
expected,
TOL,
"2F1(1,1;2;0.3)",
);
assert_close(
hyp2f1_scalar(2.0, 3.0, 3.0, 0.3),
(1.0_f64 - 0.3).powf(-2.0),
TOL,
"2F1(2,3;3;0.3)",
);
}
#[test]
fn test_hyp2f1_polynomial() {
assert_close(
hyp2f1_scalar(-1.0, 2.0, 3.0, 0.5),
1.0 - 2.0 * 0.5 / 3.0,
TOL,
"2F1(-1,2;3;0.5)",
);
let z = 0.4;
assert_close(
hyp2f1_scalar(-2.0, 1.0, 1.0, z),
(1.0 - z).powi(2),
TOL,
"2F1(-2,1;1;0.4)",
);
}
#[test]
fn test_hyp2f1_at_one() {
assert_close(hyp2f1_scalar(1.0, 1.0, 3.0, 1.0), 2.0, TOL, "2F1(1,1;3;1)");
}
#[test]
fn test_hyp2f1_negative_z() {
assert_close(
hyp2f1_scalar(1.0, 1.0, 2.0, -1.0),
2.0_f64.ln(),
TOL,
"2F1(1,1;2;-1)",
);
}
#[test]
fn test_hyp1f1_special_values() {
assert_close(hyp1f1_scalar(1.0, 2.0, 0.0), 1.0, TOL, "1F1(1;2;0)");
assert_close(hyp1f1_scalar(0.0, 2.0, 5.0), 1.0, TOL, "1F1(0;2;5)");
let z = 1.5;
assert_close(hyp1f1_scalar(1.0, 1.0, z), z.exp(), TOL, "1F1(1;1;1.5)");
assert_close(hyp1f1_scalar(3.0, 3.0, z), z.exp(), TOL, "1F1(3;3;1.5)");
}
#[test]
fn test_hyp1f1_polynomial() {
assert_close(
hyp1f1_scalar(-1.0, 2.0, 3.0),
1.0 - 3.0 / 2.0,
TOL,
"1F1(-1;2;3)",
);
let z: f64 = 0.5;
assert_close(
hyp1f1_scalar(-2.0, 1.0, z),
1.0 - 2.0 * z + z * z / 2.0,
TOL,
"1F1(-2;1;0.5)",
);
}
#[test]
fn test_hyp1f1_kummer_relation() {
let a = 1.5;
let b = 2.5;
let z = 2.0;
let lhs = hyp1f1_scalar(a, b, z);
let rhs = z.exp() * hyp1f1_scalar(b - a, b, -z);
assert_close(lhs, rhs, TOL_LOOSE, "Kummer transformation");
}
#[test]
fn test_hyp_nan_handling() {
assert!(hyp2f1_scalar(1.0, 2.0, 3.0, f64::NAN).is_nan());
assert!(hyp1f1_scalar(1.0, 2.0, f64::NAN).is_nan());
assert!(hyp2f1_scalar(1.0, 2.0, 0.0, 0.5).is_nan());
assert!(hyp1f1_scalar(1.0, 0.0, 0.5).is_nan());
}
#[test]
fn test_hyp1f1_negative_z() {
let result = hyp1f1_scalar(0.5, 1.5, -2.0);
assert!(result.is_finite(), "1F1(0.5;1.5;-2) should be finite");
assert!(result > 0.0 && result < 1.0, "1F1(0.5;1.5;-2) in (0,1)");
}
#[test]
fn test_hyp2f1_symmetry() {
let a = 1.5;
let b = 2.5;
let c = 3.5;
let z = 0.3;
let f1 = hyp2f1_scalar(a, b, c, z);
let f2 = hyp2f1_scalar(b, a, c, z);
assert_close(f1, f2, TOL, "2F1 symmetry");
}
}