use scirs2_core::error::{check_domain, CoreError, CoreResult};
use scirs2_core::numeric::{Float, FloatConst, FromPrimitive};
type SciRS2Result<T> = CoreResult<T>;
type SciRS2Error = CoreError;
#[inline(always)]
fn const_f64<T: Float + FromPrimitive>(value: f64) -> T {
T::from(value).unwrap_or_else(|| T::nan())
}
const LANCZOS_G: f64 = 7.0;
#[allow(clippy::excessive_precision, clippy::inconsistent_digit_grouping)]
const LANCZOS_C: [f64; 9] = [
0.999_999_999_999_809_93,
676.520_368_121_885_1,
-1_259.139_216_722_402_8,
771.323_428_777_653_13,
-176.615_029_162_140_59,
12.507_343_278_686_905,
-0.138_571_095_265_720_12,
9.984_369_578_019_571_6e-6,
1.505_632_735_149_311_6e-7,
];
#[inline]
fn lanczos_gamma_pos(z: f64) -> f64 {
let z = z - 1.0; let mut ag = LANCZOS_C[0];
for (k, &ck) in LANCZOS_C[1..].iter().enumerate() {
ag += ck / (z + (k + 1) as f64);
}
let t = z + LANCZOS_G + 0.5;
(2.0 * std::f64::consts::PI).sqrt() * t.powf(z + 0.5) * (-t).exp() * ag
}
#[inline]
fn lanczos_lgamma_pos(z: f64) -> f64 {
let z = z - 1.0;
let mut ag = LANCZOS_C[0];
for (k, &ck) in LANCZOS_C[1..].iter().enumerate() {
ag += ck / (z + (k + 1) as f64);
}
let t = z + LANCZOS_G + 0.5;
0.5 * (2.0 * std::f64::consts::PI).ln() + (z + 0.5) * t.ln() - t + ag.abs().ln()
}
#[allow(dead_code)]
pub fn gamma<T: Float + FromPrimitive + FloatConst>(x: T) -> T {
let xf = match x.to_f64() {
Some(v) => v,
None => return T::nan(),
};
if xf.is_nan() {
return T::nan();
}
if (xf - 1.0).abs() < 1e-15 || (xf - 2.0).abs() < 1e-15 {
return T::one();
}
if xf <= 0.0 {
let nearest_int = xf.round();
if (xf - nearest_int).abs() < 1e-14 {
return T::infinity();
}
let pi = std::f64::consts::PI;
let sin_pi_x = (pi * xf).sin();
if sin_pi_x.abs() < f64::EPSILON {
return T::infinity();
}
let gamma_one_minus_x = lanczos_gamma_pos(1.0 - xf);
let val = pi / (sin_pi_x * gamma_one_minus_x);
return T::from_f64(val).unwrap_or(T::nan());
}
let val = if xf < 0.5 {
let gx1 = lanczos_gamma_pos(xf + 1.0);
gx1 / xf
} else {
lanczos_gamma_pos(xf)
};
T::from_f64(val).unwrap_or(T::nan())
}
#[allow(dead_code)]
pub fn lgamma<T: Float + FromPrimitive + FloatConst>(x: T) -> T {
let xf = match x.to_f64() {
Some(v) => v,
None => return T::nan(),
};
if xf.is_nan() {
return T::nan();
}
if xf <= 0.0 {
let nearest_int = xf.round();
if (xf - nearest_int).abs() < 1e-14 {
return T::infinity();
}
let pi = std::f64::consts::PI;
let sin_pi_x = (pi * xf).sin().abs();
if sin_pi_x < f64::EPSILON {
return T::infinity();
}
let lgamma_one_minus_x = lanczos_lgamma_pos(1.0 - xf);
let val = pi.ln() - sin_pi_x.ln() - lgamma_one_minus_x;
return T::from_f64(val).unwrap_or(T::nan());
}
if (xf - 1.0).abs() < 1e-15 || (xf - 2.0).abs() < 1e-15 {
return T::zero();
}
let val = if xf < 0.5 {
let lgx1 = lanczos_lgamma_pos(xf + 1.0);
lgx1 - xf.ln()
} else {
lanczos_lgamma_pos(xf)
};
T::from_f64(val).unwrap_or(T::nan())
}
#[allow(dead_code)]
pub fn beta<T: Float + FromPrimitive + FloatConst>(a: T, b: T) -> SciRS2Result<T> {
check_domain(
a > T::zero() && b > T::zero(),
"Beta function parameters must be positive",
)?;
let a_int = a.round().to_usize().unwrap_or(0);
let b_int = b.round().to_usize().unwrap_or(0);
if (a - T::from(a_int).expect("Failed to convert to float")).abs() < T::epsilon()
&& (b - T::from(b_int).expect("Failed to convert to float")).abs() < T::epsilon()
{
let num1 = gamma(a);
let num2 = gamma(b);
let denom = gamma(a + b);
return Ok(num1 * num2 / denom);
}
Ok(T::nan())
}
#[allow(dead_code)]
pub fn erf<T: Float + FromPrimitive>(x: T) -> T {
if x == T::zero() {
return T::zero();
}
let x_abs = x.abs();
let sign = if x < T::zero() { -T::one() } else { T::one() };
let t = T::one() / (T::one() + const_f64::<T>(0.47047) * x_abs);
let polynomial = t
* (const_f64::<T>(0.3480242)
- t * (const_f64::<T>(0.0958798) - t * const_f64::<T>(0.7478556)));
sign * (T::one() - polynomial * (-x_abs * x_abs).exp())
}
#[allow(dead_code)]
pub fn erfc<T: Float + FromPrimitive>(x: T) -> T {
T::one() - erf(x)
}
#[allow(dead_code)]
pub fn i0<T: Float + FromPrimitive>(x: T) -> T {
let abs_x = x.abs();
if abs_x < T::from_f64(3.75).expect("Operation failed") {
let y = abs_x / T::from_f64(3.75).expect("Test/example failed");
let y2 = y * y;
let c1 = T::from_f64(3.5156229).unwrap_or(T::nan());
let c2 = T::from_f64(3.0899424).unwrap_or(T::nan());
let c3 = T::from_f64(1.2067492).unwrap_or(T::nan());
let c4 = T::from_f64(0.2659732).unwrap_or(T::nan());
let c5 = T::from_f64(0.0360768).unwrap_or(T::nan());
let c6 = T::from_f64(0.0045813).unwrap_or(T::nan());
T::one()
+ c1 * y2
+ c2 * y2 * y2
+ c3 * y2 * y2 * y2
+ c4 * y2 * y2 * y2 * y2
+ c5 * y2 * y2 * y2 * y2 * y2
+ c6 * y2 * y2 * y2 * y2 * y2 * y2
} else {
let z = T::from_f64(3.75).unwrap_or(T::nan()) / abs_x;
let z2 = z * z;
let z3 = z2 * z;
let z4 = z3 * z;
let z5 = z4 * z;
let z6 = z5 * z;
let z7 = z6 * z;
let z8 = z7 * z;
let p = T::from_f64(0.39894228).unwrap_or(T::nan())
+ T::from_f64(0.01328592).unwrap_or(T::nan()) * z
+ T::from_f64(0.00225319).unwrap_or(T::nan()) * z2
- T::from_f64(0.00157565).unwrap_or(T::nan()) * z3
+ T::from_f64(0.00916281).unwrap_or(T::nan()) * z4
- T::from_f64(0.02057706).unwrap_or(T::nan()) * z5
+ T::from_f64(0.02635537).unwrap_or(T::nan()) * z6
- T::from_f64(0.01647633).unwrap_or(T::nan()) * z7
+ T::from_f64(0.00392377).unwrap_or(T::nan()) * z8;
let exp_term = abs_x.exp();
let sqrt_term = abs_x.sqrt();
(exp_term / sqrt_term) * p
}
}
#[allow(dead_code)]
pub fn sinc<T: Float>(x: T) -> T {
if x.abs() < T::epsilon() {
T::one()
} else {
x.sin() / x
}
}
#[allow(dead_code)]
pub fn jn<T: Float + FromPrimitive>(n: i32, x: T) -> T {
if n < 0 {
let result = jn(-n, x);
if n % 2 == 0 {
result
} else {
-result
}
} else if x < T::zero() {
let result = jn(n, -x);
if n % 2 == 0 {
result
} else {
-result
}
} else if x == T::zero() {
if n == 0 {
T::one()
} else {
T::zero()
}
} else if n == 0 {
bessel_j0(x)
} else if n == 1 {
bessel_j1(x)
} else {
bessel_jn_recurrence(n, x)
}
}
#[allow(dead_code)]
fn bessel_j0<T: Float + FromPrimitive>(x: T) -> T {
let abs_x = x.abs();
if abs_x < T::from_f64(8.0).expect("Operation failed") {
let y = x * x;
let y2 = y * y;
let y3 = y2 * y;
let y4 = y3 * y;
let y5 = y4 * y;
T::one() - y / T::from_f64(4.0).unwrap_or(T::nan())
+ y2 / T::from_f64(64.0).unwrap_or(T::nan())
- y3 / T::from_f64(2304.0).unwrap_or(T::nan())
+ y4 / T::from_f64(147456.0).unwrap_or(T::nan())
- y5 / T::from_f64(14745600.0).unwrap_or(T::nan())
} else {
let z = T::from_f64(8.0).unwrap_or(T::nan()) / abs_x;
let z2 = z * z;
let z3 = z2 * z;
let z4 = z3 * z;
let pi_val = T::from_f64(std::f64::consts::PI).unwrap_or(T::nan());
let p = T::one() - T::from_f64(0.1098628627).unwrap_or(T::nan()) * z2
+ T::from_f64(0.0143125463).unwrap_or(T::nan()) * z4
- T::from_f64(0.0045681716).unwrap_or(T::nan()) * z4 * z2;
let q = z * T::from_f64(0.125).unwrap_or(T::nan())
- z * z2 * T::from_f64(0.0732421875).unwrap_or(T::nan())
+ z * z4 * T::from_f64(0.0227108002).unwrap_or(T::nan());
let sqrt_term = (T::from_f64(2.0).unwrap_or(T::nan()) / (pi_val * abs_x)).sqrt();
let angle = abs_x - pi_val / T::from_f64(4.0).unwrap_or(T::nan());
sqrt_term * (p * angle.cos() - q * angle.sin())
}
}
#[allow(dead_code)]
fn bessel_j1<T: Float + FromPrimitive>(x: T) -> T {
let abs_x = x.abs();
if abs_x < T::from_f64(8.0).unwrap_or(T::nan()) {
let y = x * x;
let y2 = y * y;
let y3 = y2 * y;
let y4 = y3 * y;
x / T::from_f64(2.0).unwrap_or(T::nan()) - x * y / T::from_f64(16.0).unwrap_or(T::nan())
+ x * y2 / T::from_f64(384.0).unwrap_or(T::nan())
- x * y3 / T::from_f64(18432.0).unwrap_or(T::nan())
+ x * y4 / T::from_f64(1474560.0).unwrap_or(T::nan())
} else {
let z = T::from_f64(8.0).unwrap_or(T::nan()) / abs_x;
let z2 = z * z;
let z3 = z2 * z;
let z4 = z3 * z;
let z5 = z4 * z;
let pi_val = T::from_f64(std::f64::consts::PI).unwrap_or(T::nan());
let p = T::one() + T::from_f64(0.183105e-2).unwrap_or(T::nan()) * z2
- T::from_f64(0.3516396496).unwrap_or(T::nan()) * z4
+ T::from_f64(0.2457520174e-1).unwrap_or(T::nan()) * z4 * z2;
let q = -(z * T::from_f64(0.375).unwrap_or(T::nan()))
+ z * z2 * T::from_f64(0.2109375).unwrap_or(T::nan())
- z * z4 * T::from_f64(0.1025390625).unwrap_or(T::nan());
let _ = z5; let sqrt_term = (T::from_f64(2.0).unwrap_or(T::nan()) / (pi_val * abs_x)).sqrt();
let angle = abs_x - T::from_f64(3.0 * std::f64::consts::PI / 4.0).unwrap_or(T::nan());
let result = sqrt_term * (p * angle.cos() - q * angle.sin());
if x < T::zero() {
-result
} else {
result
}
}
}
#[allow(dead_code)]
fn bessel_jn_recurrence<T: Float + FromPrimitive>(n: i32, x: T) -> T {
if n == 0 {
return bessel_j0(x);
}
if n == 1 {
return bessel_j1(x);
}
let mut j_n_minus_2 = bessel_j0(x);
let mut j_n_minus_1 = bessel_j1(x);
let mut j_n = T::zero();
for i in 2..=n {
let two_i_minus_1 = T::from_i32(2 * i - 1).expect("Test/example failed");
j_n = (two_i_minus_1 / x) * j_n_minus_1 - j_n_minus_2;
j_n_minus_2 = j_n_minus_1;
j_n_minus_1 = j_n;
}
j_n
}
#[allow(dead_code)]
pub fn yn<T: Float + FromPrimitive>(n: i32, x: T) -> T {
T::nan()
}
#[allow(dead_code)]
pub fn ellipk<T: Float + FromPrimitive>(m: T) -> SciRS2Result<T> {
check_domain(m < T::one(), "Parameter m must be less than 1")?;
Ok(T::nan())
}
#[allow(dead_code)]
pub fn ellipe<T: Float + FromPrimitive>(m: T) -> SciRS2Result<T> {
check_domain(m < T::one(), "Parameter m must be less than 1")?;
Ok(T::nan())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gamma_integers() {
assert!((gamma(1.0_f64) - 1.0).abs() < 1e-10);
assert!((gamma(2.0_f64) - 1.0).abs() < 1e-10);
assert!((gamma(3.0_f64) - 2.0).abs() < 1e-10);
assert!((gamma(4.0_f64) - 6.0).abs() < 1e-10);
assert!((gamma(5.0_f64) - 24.0).abs() < 1e-10);
}
#[test]
fn test_gamma_half_integers() {
let sqrt_pi = std::f64::consts::PI.sqrt();
assert!((gamma(0.5_f64) - sqrt_pi).abs() < 1e-10);
assert!((gamma(1.5_f64) - 0.5 * sqrt_pi).abs() < 1e-10);
assert!((gamma(2.5_f64) - 0.75 * sqrt_pi).abs() < 1e-10);
assert!((gamma(3.5_f64) - 1.875 * sqrt_pi).abs() < 1e-10);
}
#[test]
fn test_gamma_general() {
assert!((gamma(4.5_f64) - 11.631728396567448).abs() < 1e-8);
assert!((gamma(0.25_f64) - 3.625609908221908).abs() < 1e-8);
}
#[test]
fn test_lgamma() {
assert!((lgamma(1.0_f64)).abs() < 1e-12);
assert!((lgamma(2.0_f64)).abs() < 1e-12);
assert!((lgamma(5.0_f64) - (24.0_f64).ln()).abs() < 1e-10);
let ln_sqrt_pi = 0.5 * std::f64::consts::PI.ln();
assert!((lgamma(0.5_f64) - ln_sqrt_pi).abs() < 1e-10);
}
#[test]
fn test_beta() {
assert!((beta(1.0, 1.0).expect("Operation failed") - 1.0).abs() < 1e-10);
assert!((beta(2.0, 3.0).expect("Operation failed") - 1.0 / 12.0).abs() < 1e-10);
assert!((beta(3.0, 2.0).expect("Operation failed") - 1.0 / 12.0).abs() < 1e-10);
}
#[test]
fn test_erf() {
assert!((erf(0.0) - 0.0).abs() < 1e-10);
assert!((erf(1.0) - 0.8427).abs() < 1e-3);
assert!((erf(-1.0) + 0.8427).abs() < 1e-3);
}
#[test]
fn test_sinc() {
assert!((sinc(0.0) - 1.0).abs() < 1e-10);
let x = std::f64::consts::PI;
assert!((sinc(x) - 0.0).abs() < 1e-10);
}
}