use scirs2_core::numeric::{Float, FloatConst, FromPrimitive};
use crate::error::{SciRS2Error, SciRS2Result, check_domain};
#[inline(always)]
fn const_f64<T: Float + FromPrimitive>(value: f64) -> T {
T::from(value).expect("Failed to convert constant to target float type")
}
#[allow(dead_code)]
pub fn gamma<T: Float + FromPrimitive + FloatConst>(x: T) -> T {
if x <= T::zero() {
return T::nan();
}
let n = x.round();
if (x - n).abs() < T::epsilon() {
let n_int = n.to_usize().unwrap_or(0);
if n_int <= 1 {
return T::one();
}
let mut result = T::one();
for i in 1..n_int {
result = result * T::from(i).expect("Failed to convert to float");
}
return result;
}
if (x - const_f64::<T>(1.5)).abs() < T::epsilon() {
return T::from_f64(0.5 * PI.sqrt()).expect("Test/example failed");
}
if (x - const_f64::<T>(2.5)).abs() < T::epsilon() {
return T::from_f64(0.75 * PI.sqrt()).expect("Test/example failed");
}
T::nan()
}
#[allow(dead_code)]
pub fn lgamma<T: Float + FromPrimitive + FloatConst>(x: T) -> T {
if x <= T::zero() {
return T::nan();
}
let n = x.round();
if (x - n).abs() < T::epsilon() {
let n_int = n.to_usize().unwrap_or(0);
if n_int <= 1 {
return T::zero();
}
let mut result = T::zero();
for i in 1..n_int {
result = result + T::from(i).expect("Failed to convert to float").ln();
}
return result;
}
T::nan()
}
#[allow(dead_code)]
pub fn beta<T: Float + FromPrimitive + FloatConst>(a: T, b: T) -> SciRS2Result<T> {
check_domain(a > T::zero() && b > T::zero(),
"Beta function parameters must be positive")?;
let a_int = a.round().to_usize().unwrap_or(0);
let b_int = b.round().to_usize().unwrap_or(0);
if (a - T::from(a_int).expect("Failed to convert to float")).abs() < T::epsilon() &&
(b - T::from(b_int).expect("Failed to convert to float")).abs() < T::epsilon() {
let num1 = gamma(a);
let num2 = gamma(b);
let denom = gamma(a + b);
return Ok(num1 * num2 / denom);
}
Ok(T::nan())
}
#[allow(dead_code)]
pub fn erf<T: Float + FromPrimitive>(x: T) -> T {
if x == T::zero() {
return T::zero();
}
let x_abs = x.abs();
let sign = if x < T::zero() { -T::one() } else { T::one() };
let t = T::one() / (T::one() + const_f64::<T>(0.47047) * x_abs);
let polynomial = t * (const_f64::<T>(0.3480242) -
t * (const_f64::<T>(0.0958798) -
t * const_f64::<T>(0.7478556)));
sign * (T::one() - polynomial * (-x_abs * x_abs).exp())
}
#[allow(dead_code)]
pub fn erfc<T: Float + FromPrimitive>(x: T) -> T {
T::one() - erf(x)
}
#[allow(dead_code)]
pub fn i0<T: Float + FromPrimitive>(x: T) -> T {
let abs_x = x.abs();
if abs_x < T::from_f64(3.75).expect("Operation failed") {
let y = abs_x / T::from_f64(3.75).expect("Test/example failed");
let y2 = y * y;
let mut result = T::one();
result += T::from_f64(3.5156229).expect("Operation failed") * y2;
result += T::from_f64(3.0899424).expect("Operation failed") * y2 * y2;
result += T::from_f64(1.2067492).expect("Operation failed") * y2 * y2 * y2;
result += T::from_f64(0.2659732).expect("Operation failed") * y2 * y2 * y2 * y2;
result += T::from_f64(0.0360768).expect("Operation failed") * y2 * y2 * y2 * y2 * y2;
result += T::from_f64(0.0045813).expect("Operation failed") * y2 * y2 * y2 * y2 * y2 * y2;
result
} else {
let z = T::from_f64(3.75).expect("Operation failed") / abs_x;
let mut p = T::from_f64(0.39894228).expect("Test/example failed");
p += T::from_f64(0.01328592).expect("Operation failed") * z;
p += T::from_f64(0.00225319).expect("Operation failed") * z * z;
p -= T::from_f64(0.00157565).expect("Operation failed") * z * z * z;
p += T::from_f64(0.00916281).expect("Operation failed") * z * z * z * z;
p -= T::from_f64(0.02057706).expect("Operation failed") * z * z * z * z * z;
p += T::from_f64(0.02635537).expect("Operation failed") * z * z * z * z * z * z;
p -= T::from_f64(0.01647633).expect("Operation failed") * z * z * z * z * z * z * z;
p += T::from_f64(0.00392377).expect("Operation failed") * z * z * z * z * z * z * z * z;
let exp_term = abs_x.exp();
let sqrt_term = abs_x.sqrt();
(exp_term / sqrt_term) * p
}
}
#[allow(dead_code)]
pub fn sinc<T: Float>(x: T) -> T {
if x.abs() < T::epsilon() {
T::one()
} else {
x.sin() / x
}
}
#[allow(dead_code)]
pub fn jn<T: Float + FromPrimitive>(n: i32, x: T) -> T {
if n < 0 {
let result = jn(-n, x);
if n % 2 == 0 {
result
} else {
-result
}
} else if x < T::zero() {
let result = jn(n, -x);
if n % 2 == 0 {
result
} else {
-result
}
} else if x == T::zero() {
if n == 0 {
T::one()
} else {
T::zero()
}
} else if n == 0 {
bessel_j0(x)
} else if n == 1 {
bessel_j1(x)
} else {
bessel_jn_recurrence(n, x)
}
}
#[allow(dead_code)]
fn bessel_j0<T: Float + FromPrimitive>(x: T) -> T {
let abs_x = x.abs();
if abs_x < T::from_f64(8.0).expect("Operation failed") {
let y = x * x;
let mut result = T::one();
result -= y / T::from_f64(4.0).expect("Test/example failed");
result += y * y / T::from_f64(64.0).expect("Test/example failed");
result -= y * y * y / T::from_f64(2304.0).expect("Test/example failed");
result += y * y * y * y / T::from_f64(147456.0).expect("Test/example failed");
result -= y * y * y * y * y / T::from_f64(14745600.0).expect("Test/example failed");
result
} else {
let z = T::from_f64(8.0).expect("Operation failed") / abs_x;
let z2 = z * z;
let mut p = T::one();
p -= T::from_f64(0.1098628627).expect("Operation failed") * z2;
p += T::from_f64(0.0143125463).expect("Operation failed") * z2 * z2;
p -= T::from_f64(0.0045681716).expect("Operation failed") * z2 * z2 * z2;
let mut q = z * T::from_f64(0.125).expect("Test/example failed");
q -= z * z2 * T::from_f64(0.0732421875).expect("Test/example failed");
q += z * z2 * z2 * T::from_f64(0.0227108002).expect("Test/example failed");
let sqrt_term = (T::from_f64(2.0).expect("Operation failed") / (T::from_f64(std::f64::consts::PI).expect("Operation failed") * abs_x)).sqrt();
sqrt_term * (p * (abs_x - T::from_f64(std::f64::consts::PI / 4.0).expect("Operation failed")).cos() - q * (abs_x - T::from_f64(std::f64::consts::PI / 4.0).expect("Operation failed")).sin())
}
}
#[allow(dead_code)]
fn bessel_j1<T: Float + FromPrimitive>(x: T) -> T {
let abs_x = x.abs();
if abs_x < T::from_f64(8.0).expect("Operation failed") {
let y = x * x;
let mut result = x / T::from_f64(2.0).expect("Test/example failed");
result -= x * y / T::from_f64(16.0).expect("Test/example failed");
result += x * y * y / T::from_f64(384.0).expect("Test/example failed");
result -= x * y * y * y / T::from_f64(18432.0).expect("Test/example failed");
result += x * y * y * y * y / T::from_f64(1474560.0).expect("Test/example failed");
result
} else {
let z = T::from_f64(8.0).expect("Operation failed") / abs_x;
let z2 = z * z;
let mut p = T::one();
p += T::from_f64(0.183105e-2).expect("Operation failed") * z2;
p -= T::from_f64(0.3516396496).expect("Operation failed") * z2 * z2;
p += T::from_f64(0.2457520174e-1).expect("Operation failed") * z2 * z2 * z2;
let mut q = -z * T::from_f64(0.375).expect("Test/example failed");
q += z * z2 * T::from_f64(0.2109375).expect("Test/example failed");
q -= z * z2 * z2 * T::from_f64(0.1025390625).expect("Test/example failed");
let sqrt_term = (T::from_f64(2.0).expect("Operation failed") / (T::from_f64(std::f64::consts::PI).expect("Operation failed") * abs_x)).sqrt();
let result = sqrt_term * (p * (abs_x - T::from_f64(3.0 * std::f64::consts::PI / 4.0).expect("Operation failed")).cos() - q * (abs_x - T::from_f64(3.0 * std::f64::consts::PI / 4.0).expect("Operation failed")).sin());
if x < T::zero() {
-result
} else {
result
}
}
}
#[allow(dead_code)]
fn bessel_jn_recurrence<T: Float + FromPrimitive>(n: i32, x: T) -> T {
if n == 0 {
return bessel_j0(x);
}
if n == 1 {
return bessel_j1(x);
}
let mut j_n_minus_2 = bessel_j0(x);
let mut j_n_minus_1 = bessel_j1(x);
let mut j_n = T::zero();
for i in 2..=n {
let two_i_minus_1 = T::from_i32(2 * i - 1).expect("Test/example failed");
j_n = (two_i_minus_1 / x) * j_n_minus_1 - j_n_minus_2;
j_n_minus_2 = j_n_minus_1;
j_n_minus_1 = j_n;
}
j_n
}
#[allow(dead_code)]
pub fn yn<T: Float + FromPrimitive>(n: i32, x: T) -> T {
T::nan()
}
#[allow(dead_code)]
pub fn ellipk<T: Float + FromPrimitive>(m: T) -> SciRS2Result<T> {
check_domain(m < T::one(), "Parameter m must be less than 1")?;
Ok(T::nan())
}
#[allow(dead_code)]
pub fn ellipe<T: Float + FromPrimitive>(m: T) -> SciRS2Result<T> {
check_domain(m < T::one(), "Parameter m must be less than 1")?;
Ok(T::nan())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gamma() {
assert!((gamma(1.0) - 1.0).abs() < 1e-10);
assert!((gamma(2.0) - 1.0).abs() < 1e-10);
assert!((gamma(3.0) - 2.0).abs() < 1e-10);
assert!((gamma(4.0) - 6.0).abs() < 1e-10);
assert!((gamma(5.0) - 24.0).abs() < 1e-10);
}
#[test]
fn test_beta() {
assert!((beta(1.0, 1.0).expect("Operation failed") - 1.0).abs() < 1e-10);
assert!((beta(2.0, 3.0).expect("Operation failed") - 1.0/12.0).abs() < 1e-10);
assert!((beta(3.0, 2.0).expect("Operation failed") - 1.0/12.0).abs() < 1e-10);
}
#[test]
fn test_erf() {
assert!((erf(0.0) - 0.0).abs() < 1e-10);
assert!((erf(1.0) - 0.8427).abs() < 1e-3);
assert!((erf(-1.0) + 0.8427).abs() < 1e-3);
}
#[test]
fn test_sinc() {
assert!((sinc(0.0) - 1.0).abs() < 1e-10);
let x = std::f64::consts::PI;
assert!((sinc(x) - 0.0).abs() < 1e-10);
}
}