use crate::constants;
use crate::gamma::gamma;
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
#[inline(always)]
fn const_f64<F: Float + FromPrimitive>(value: f64) -> F {
F::from(value).expect("Failed to convert constant to target float type")
}
#[allow(dead_code)]
pub fn i0<F: Float + FromPrimitive + Debug>(x: F) -> F {
if x == F::zero() {
return F::one();
}
let abs_x = x.abs();
if abs_x < const_f64::<F>(1e-6) {
let x2 = abs_x * abs_x;
let x4 = x2 * x2;
return F::one()
+ x2 / const_f64::<F>(4.0)
+ x4 / const_f64::<F>(64.0)
+ x4 * x2 / const_f64::<F>(2304.0);
}
if abs_x <= const_f64::<F>(3.75) {
let y = (abs_x / const_f64::<F>(3.75)).powi(2);
let p = [
const_f64::<F>(1.0),
const_f64::<F>(3.5156229),
const_f64::<F>(3.0899424),
const_f64::<F>(1.2067492),
const_f64::<F>(0.2659732),
const_f64::<F>(0.0360768),
const_f64::<F>(0.0045813),
];
let mut sum = F::zero();
for i in (0..p.len()).rev() {
sum = sum * y + p[i];
}
sum
} else {
let y = const_f64::<F>(3.75) / abs_x;
let p = [
const_f64::<F>(0.39894228),
const_f64::<F>(0.01328592),
const_f64::<F>(0.00225319),
const_f64::<F>(-0.00157565),
const_f64::<F>(0.00916281),
const_f64::<F>(-0.02057706),
const_f64::<F>(0.02635537),
const_f64::<F>(-0.01647633),
const_f64::<F>(0.00392377),
];
let mut sum = F::zero();
for i in (0..p.len()).rev() {
sum = sum * y + p[i];
}
let exp_term = abs_x.exp();
if !exp_term.is_infinite() {
sum * exp_term / abs_x.sqrt()
} else {
let log_result = abs_x - const_f64::<F>(0.5) * abs_x.ln() + sum.ln();
if log_result < F::from(constants::f64::LN_MAX).expect("Failed to convert to float") {
log_result.exp()
} else {
F::infinity()
}
}
}
}
#[allow(dead_code)]
pub fn i1<F: Float + FromPrimitive + Debug>(x: F) -> F {
if x == F::zero() {
return F::zero();
}
let abs_x = x.abs();
let sign = if x.is_sign_positive() {
F::one()
} else {
-F::one()
};
if abs_x < const_f64::<F>(1e-6) {
let x2 = abs_x * abs_x;
let x3 = abs_x * x2;
let x5 = x3 * x2;
return sign
* (abs_x / const_f64::<F>(2.0)
+ x3 / const_f64::<F>(16.0)
+ x5 / const_f64::<F>(384.0));
}
if abs_x <= const_f64::<F>(3.75) {
let y = (abs_x / const_f64::<F>(3.75)).powi(2);
let p = [
const_f64::<F>(0.5),
const_f64::<F>(0.87890594),
const_f64::<F>(0.51498869),
const_f64::<F>(0.15084934),
const_f64::<F>(0.02658733),
const_f64::<F>(0.00301532),
const_f64::<F>(0.00032411),
];
let mut sum = F::zero();
for i in (0..p.len()).rev() {
sum = sum * y + p[i];
}
sign * sum * abs_x
} else {
let y = const_f64::<F>(3.75) / abs_x;
let p = [
const_f64::<F>(0.39894228),
const_f64::<F>(-0.03988024),
const_f64::<F>(-0.00362018),
const_f64::<F>(0.00163801),
const_f64::<F>(-0.01031555),
const_f64::<F>(0.02282967),
const_f64::<F>(-0.02895312),
const_f64::<F>(0.01787654),
const_f64::<F>(-0.00420059),
];
let mut sum = F::zero();
for i in (0..p.len()).rev() {
sum = sum * y + p[i];
}
let exp_term = abs_x.exp();
if !exp_term.is_infinite() {
sign * sum * exp_term / abs_x.sqrt()
} else {
let log_result = abs_x - const_f64::<F>(0.5) * abs_x.ln() + sum.ln();
if log_result < F::from(constants::f64::LN_MAX).expect("Failed to convert to float") {
sign * log_result.exp()
} else {
sign * F::infinity()
}
}
}
}
#[allow(dead_code)]
pub fn iv<F: Float + FromPrimitive + Debug + std::ops::AddAssign>(v: F, x: F) -> F {
if x == F::zero() {
if v == F::zero() {
return F::one();
} else {
return F::zero();
}
}
let abs_x = x.abs();
let v_f64 = v.to_f64().expect("Test/example failed");
if v_f64.fract() == 0.0 && (0.0..=100.0).contains(&v_f64) {
let n = v_f64 as i32;
if n == 0 {
return i0(x);
} else if n == 1 {
return i1(x);
} else if n > 1 {
let mut i_vminus_1 = i0(abs_x);
let mut i_v = i1(abs_x);
for k in 1..n {
let k_f = F::from(k).expect("Failed to convert to float");
let i_v_plus_1 = (k_f + k_f) / abs_x * i_v + i_vminus_1;
i_vminus_1 = i_v;
i_v = i_v_plus_1;
}
if x.is_sign_negative() && n % 2 != 0 {
return -i_v;
}
return i_v;
}
}
let half_x = abs_x / const_f64::<F>(2.0);
let log_term = v * half_x.ln() - gamma(v + F::one()).ln();
if log_term < F::from(constants::f64::LN_MAX).expect("Failed to convert to float")
&& log_term > F::from(constants::f64::LN_MIN).expect("Failed to convert to float")
{
let prefactor = log_term.exp();
let mut sum = F::one();
let mut term = F::one();
let x2 = half_x * half_x;
for k in 1..=100 {
let k_f = F::from(k).expect("Failed to convert to float");
term = term * x2 / (k_f * (v + k_f));
sum += term;
if term.abs() < const_f64::<F>(1e-15) * sum.abs() {
break;
}
}
let result = prefactor * sum;
if x.is_sign_negative() {
if v_f64.fract() == 0.0 {
if (v_f64 as i32) % 2 != 0 {
return -result;
}
return result;
} else {
let v_floor = v_f64.floor() as i32;
if v_floor % 2 != 0 {
return -result;
}
return result;
}
}
return result;
}
if abs_x > F::from(max(20.0, v_f64 * 1.5)).expect("Operation failed") {
let one_over_sqrt_2pi_x = F::from(constants::f64::ONE_OVER_SQRT_2PI)
.expect("Failed to convert to float")
/ abs_x.sqrt();
let log_result = abs_x + one_over_sqrt_2pi_x.ln();
if log_result < F::from(constants::f64::LN_MAX).expect("Failed to convert to float") {
let mu = const_f64::<F>(4.0) * v * v; let muminus_1 = mu - F::one();
let correction = F::one() - muminus_1 / (const_f64::<F>(8.0) * abs_x)
+ muminus_1 * (muminus_1 + const_f64::<F>(2.0))
/ (const_f64::<F>(128.0) * abs_x * abs_x);
let result = log_result.exp() * correction;
if x.is_sign_negative() && v_f64.fract() == 0.0 && (v_f64 as i32) % 2 != 0 {
return -result;
}
return result;
} else {
return F::infinity();
}
}
let exp_term = (abs_x * const_f64::<F>(0.5)).exp();
let result = exp_term * (abs_x / (const_f64::<F>(2.0) * (v + F::one()))).powf(v);
if x.is_sign_negative() && v_f64.fract() == 0.0 && (v_f64 as i32) % 2 != 0 {
return -result;
}
result
}
#[allow(dead_code)]
pub fn k0<F: Float + FromPrimitive + Debug>(x: F) -> F {
if x <= F::zero() {
return F::infinity();
}
if x < const_f64::<F>(1e-8) {
let gamma = F::from(constants::f64::EULER_MASCHERONI).expect("Failed to convert to float");
return -(x / const_f64::<F>(2.0)).ln() - gamma;
}
let pi_over_2 = F::from(constants::f64::PI_2).expect("Failed to convert to float");
(pi_over_2 / x).sqrt() * (-x).exp()
}
#[allow(dead_code)]
pub fn k1<F: Float + FromPrimitive + Debug>(x: F) -> F {
if x <= F::zero() {
return F::infinity();
}
if x < const_f64::<F>(1e-8) {
return F::one() / x;
}
let pi_over_2 = F::from(constants::f64::PI_2).expect("Failed to convert to float");
(pi_over_2 / x).sqrt() * (-x).exp() * (F::one() + F::one() / (const_f64::<F>(8.0) * x))
}
#[allow(dead_code)]
pub fn kv<F: Float + FromPrimitive + Debug + std::ops::AddAssign>(v: F, x: F) -> F {
if x <= F::zero() {
return F::infinity();
}
let abs_v = v.abs();
let v_f64 = abs_v.to_f64().expect("Test/example failed");
if v_f64.fract() == 0.0 {
let n = v_f64 as i32;
if n == 0 {
return k0(x);
} else if n == 1 {
return k1(x);
}
}
let pi_over_2 = F::from(constants::f64::PI_2).expect("Failed to convert to float");
(pi_over_2 / x).sqrt()
* (-x).exp()
* (F::one() + (const_f64::<F>(4.0) * v * v - F::one()) / (const_f64::<F>(8.0) * x))
}
#[allow(dead_code)]
pub fn i0e<F: Float + FromPrimitive + Debug>(x: F) -> F {
let abs_x = x.abs();
i0(x) * (-abs_x).exp()
}
#[allow(dead_code)]
pub fn i1e<F: Float + FromPrimitive + Debug>(x: F) -> F {
let abs_x = x.abs();
let sign = if x.is_sign_positive() {
F::one()
} else {
-F::one()
};
sign * i1(abs_x) * (-abs_x).exp()
}
#[allow(dead_code)]
pub fn ive<F: Float + FromPrimitive + Debug + std::ops::AddAssign>(v: F, x: F) -> F {
let abs_x = x.abs();
iv(v, x) * (-abs_x).exp()
}
#[allow(dead_code)]
pub fn k0e<F: Float + FromPrimitive + Debug>(x: F) -> F {
if x <= F::zero() {
return F::infinity();
}
k0(x) * x.exp()
}
#[allow(dead_code)]
pub fn k1e<F: Float + FromPrimitive + Debug>(x: F) -> F {
if x <= F::zero() {
return F::infinity();
}
k1(x) * x.exp()
}
#[allow(dead_code)]
pub fn kve<F: Float + FromPrimitive + Debug + std::ops::AddAssign>(v: F, x: F) -> F {
if x <= F::zero() {
return F::infinity();
}
kv(v, x) * x.exp()
}
#[allow(dead_code)]
fn max<T: PartialOrd>(a: T, b: T) -> T {
if a > b {
a
} else {
b
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_i0_special_cases() {
assert_relative_eq!(i0(0.0), 1.0, epsilon = 1e-10);
let i0_small = i0(1e-10);
assert_relative_eq!(i0_small, 1.0, epsilon = 1e-10);
}
#[test]
fn test_i0_moderate_values() {
assert_relative_eq!(i0(0.5), 1.0634833439946074, epsilon = 1e-8);
assert_relative_eq!(i0(1.0), 1.2660658480342601, epsilon = 1e-8);
}
#[test]
fn test_i1_special_cases() {
assert_relative_eq!(i1(0.0), 0.0, epsilon = 1e-10);
let i1_small = i1(1e-10);
assert_relative_eq!(i1_small, 5e-11, epsilon = 1e-12);
}
#[test]
fn test_iv_integer_orders() {
let x = 2.0;
assert_relative_eq!(iv(0.0, x), i0(x), epsilon = 1e-8);
assert_relative_eq!(iv(1.0, x), i1(x), epsilon = 1e-8);
}
}