use crate::Error;
use std::f64::consts::PI;
const LOG_PI: f64 = 1.144729885849400174143427351353058711647;
const LANCZOS_N: usize = 13;
const LANCZOS_G: f64 = 6.024680040776729583740234375;
const LANCZOS_G_MINUS_HALF: f64 = 5.524680040776729583740234375;
const LANCZOS_NUM_COEFFS: [f64; LANCZOS_N] = [
23531376880.410759688572007674451636754734846804940,
42919803642.649098768957899047001988850926355848959,
35711959237.355668049440185451547166705960488635843,
17921034426.037209699919755754458931112671403265390,
6039542586.3520280050642916443072979210699388420708,
1439720407.3117216736632230727949123939715485786772,
248874557.86205415651146038641322942321632125127801,
31426415.585400194380614231628318205362874684987640,
2876370.6289353724412254090516208496135991145378768,
186056.26539522349504029498971604569928220784236328,
8071.6720023658162106380029022722506138218516325024,
210.82427775157934587250973392071336271166969580291,
2.5066282746310002701649081771338373386264310793408,
];
const LANCZOS_DEN_COEFFS: [f64; LANCZOS_N] = [
0.0,
39916800.0,
120543840.0,
150917976.0,
105258076.0,
45995730.0,
13339535.0,
2637558.0,
357423.0,
32670.0,
1925.0,
66.0,
1.0,
];
fn mul_add(a: f64, b: f64, c: f64) -> f64 {
if cfg!(feature = "mul_add") {
a.mul_add(b, c)
} else {
a * b + c
}
}
fn lanczos_sum(x: f64) -> f64 {
let mut num = 0.0;
let mut den = 0.0;
if x < 5.0 {
for i in (0..LANCZOS_N).rev() {
num = mul_add(num, x, LANCZOS_NUM_COEFFS[i]);
den = mul_add(den, x, LANCZOS_DEN_COEFFS[i]);
}
} else {
for i in 0..LANCZOS_N {
num = num / x + LANCZOS_NUM_COEFFS[i];
den = den / x + LANCZOS_DEN_COEFFS[i];
}
}
num / den
}
fn m_sinpi(x: f64) -> f64 {
debug_assert!(x.is_finite());
let y = x.abs() % 2.0;
let n = (2.0 * y).round() as i32;
let r = match n {
0 => (PI * y).sin(),
1 => (PI * (y - 0.5)).cos(),
2 => {
(PI * (1.0 - y)).sin()
}
3 => -(PI * (y - 1.5)).cos(),
4 => (PI * (y - 2.0)).sin(),
_ => unreachable!(),
};
(1.0f64).copysign(x) * r
}
const NGAMMA_INTEGRAL: usize = 23;
const GAMMA_INTEGRAL: [f64; NGAMMA_INTEGRAL] = [
1.0,
1.0,
2.0,
6.0,
24.0,
120.0,
720.0,
5040.0,
40320.0,
362880.0,
3628800.0,
39916800.0,
479001600.0,
6227020800.0,
87178291200.0,
1307674368000.0,
20922789888000.0,
355687428096000.0,
6402373705728000.0,
121645100408832000.0,
2432902008176640000.0,
51090942171709440000.0,
1124000727777607680000.0,
];
pub fn gamma(x: f64) -> crate::Result<f64> {
if !x.is_finite() {
if x.is_nan() || x > 0.0 {
return Ok(x);
} else {
return Err((f64::NAN, Error::EDOM).1);
}
}
if x == 0.0 {
let v = if x.is_sign_positive() {
f64::INFINITY
} else {
f64::NEG_INFINITY
};
return Err((v, Error::EDOM).1);
}
if x == x.floor() {
if x < 0.0 {
return Err((f64::NAN, Error::EDOM).1);
}
if x < NGAMMA_INTEGRAL as f64 {
return Ok(GAMMA_INTEGRAL[x as usize - 1]);
}
}
let absx = x.abs();
if absx < 1e-20 {
let r = 1.0 / x;
if r.is_infinite() {
return Err((f64::INFINITY, Error::ERANGE).1);
} else {
return Ok(r);
}
}
if absx > 200.0 {
if x < 0.0 {
return Ok(0.0 / m_sinpi(x));
} else {
return Err((f64::INFINITY, Error::ERANGE).1);
}
}
let y = absx + LANCZOS_G_MINUS_HALF;
let z = if absx > LANCZOS_G_MINUS_HALF {
let q = y - absx;
q - LANCZOS_G_MINUS_HALF
} else {
let q = y - LANCZOS_G_MINUS_HALF;
q - absx
};
let z = z * LANCZOS_G / y;
let r = if x < 0.0 {
let mut r = -PI / m_sinpi(absx) / absx * y.exp() / lanczos_sum(absx);
r -= z * r;
if absx < 140.0 {
r /= y.powf(absx - 0.5);
} else {
let sqrtpow = y.powf(absx / 2.0 - 0.25);
r /= sqrtpow;
r /= sqrtpow;
}
r
} else {
let mut r = lanczos_sum(absx) / y.exp();
r += z * r;
if absx < 140.0 {
r *= y.powf(absx - 0.5);
} else {
let sqrtpow = y.powf(absx / 2.0 - 0.25);
r *= sqrtpow;
r *= sqrtpow;
}
r
};
if r.is_infinite() {
return Err((f64::INFINITY, Error::ERANGE).1);
} else {
return Ok(r);
}
}
pub fn lgamma(x: f64) -> crate::Result<f64> {
if !x.is_finite() {
if x.is_nan() {
return Ok(x); } else {
return Ok(f64::INFINITY); }
}
if x == x.floor() && x <= 2.0 {
if x <= 0.0 {
return Err(Error::EDOM);
} else {
return Ok(0.0);
}
}
let absx = x.abs();
if absx < 1e-20 {
return Ok(-absx.ln());
}
let mut r = lanczos_sum(absx).ln() - LANCZOS_G;
let t = absx - 0.5;
r = mul_add(t, (absx + LANCZOS_G - 0.5).ln() - 1.0, r);
if x < 0.0 {
r = LOG_PI - m_sinpi(absx).abs().ln() - absx.ln() - r;
}
if r.is_infinite() {
return Err(Error::ERANGE);
}
Ok(r)
}
super::pyo3_proptest!(gamma(Result<_>), test_gamma, proptest_gamma, fulltest_gamma);
super::pyo3_proptest!(
lgamma(Result<_>),
test_lgamma,
proptest_lgamma,
fulltest_lgamma
);