use crate::FloatScalar;
use super::{LANCZOS_G, lanczos_sum};
const FACTORIAL: [f64; 21] = [
1.0,
1.0,
2.0,
6.0,
24.0,
120.0,
720.0,
5040.0,
40320.0,
362880.0,
3628800.0,
39916800.0,
479001600.0,
6227020800.0,
87178291200.0,
1307674368000.0,
20922789888000.0,
355687428096000.0,
6402373705728000.0,
121645100408832000.0,
2432902008176640000.0,
];
pub fn gamma<T: FloatScalar>(x: T) -> T {
let zero = T::zero();
let one = T::one();
let half = T::from(0.5).unwrap();
if x.is_nan() {
return x;
}
if x > zero && x == x.floor() {
if let Some(n) = num_traits::cast::<T, u64>(x) {
if n >= 1 && n <= 21 {
return T::from(FACTORIAL[(n - 1) as usize]).unwrap();
}
}
}
if x <= zero && x == x.floor() {
return T::infinity();
}
if x < half {
let pi = T::from(core::f64::consts::PI).unwrap();
let sin_pi_x = (pi * x).sin();
if sin_pi_x == zero {
return T::infinity();
}
return pi / (sin_pi_x * gamma(one - x));
}
let z = x - one;
let g = T::from(LANCZOS_G).unwrap();
let t = z + g + half;
let sqrt_2pi = T::from(core::f64::consts::TAU.sqrt()).unwrap();
sqrt_2pi * t.powf(z + half) * (-t).exp() * lanczos_sum(z)
}
pub fn lgamma<T: FloatScalar>(x: T) -> T {
let zero = T::zero();
let one = T::one();
let half = T::from(0.5).unwrap();
if x.is_nan() {
return x;
}
if x <= zero && x == x.floor() {
return T::infinity();
}
if x < half {
let pi = T::from(core::f64::consts::PI).unwrap();
let sin_pi_x = (pi * x).sin().abs();
if sin_pi_x == zero {
return T::infinity();
}
return pi.ln() - sin_pi_x.ln() - lgamma(one - x);
}
let large_threshold = T::from(1e6).unwrap();
if x > large_threshold {
let ln_sqrt_2pi = T::from(0.5 * core::f64::consts::TAU.ln()).unwrap();
let inv_x = one / x;
let inv_x2 = inv_x * inv_x;
let c1 = T::from(1.0 / 12.0).unwrap();
let c2 = T::from(1.0 / 360.0).unwrap();
let c3 = T::from(1.0 / 1260.0).unwrap();
let c4 = T::from(1.0 / 1680.0).unwrap();
let series = inv_x * (c1 + inv_x2 * (T::zero() - c2 + inv_x2 * (c3 - c4 * inv_x2)));
return (x - half) * x.ln() - x + ln_sqrt_2pi + series;
}
let z = x - one;
let g = T::from(LANCZOS_G).unwrap();
let t = z + g + half;
let ln_sqrt_2pi = T::from(0.5 * core::f64::consts::TAU.ln()).unwrap();
ln_sqrt_2pi + (z + half) * t.ln() - t + lanczos_sum(z).ln()
}