use crate::constants;
use crate::gamma::gamma;
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
#[inline(always)]
fn const_f64<F: Float + FromPrimitive>(value: f64) -> F {
F::from(value).expect("Failed to convert constant to target float type")
}
#[allow(dead_code)]
pub fn i0<F: Float + FromPrimitive + Debug>(x: F) -> F {
if x == F::zero() {
return F::one();
}
let abs_x = x.abs();
if abs_x < const_f64::<F>(1e-6) {
let x2 = abs_x * abs_x;
let x4 = x2 * x2;
return F::one()
+ x2 / const_f64::<F>(4.0)
+ x4 / const_f64::<F>(64.0)
+ x4 * x2 / const_f64::<F>(2304.0);
}
if abs_x <= const_f64::<F>(3.75) {
let y = (abs_x / const_f64::<F>(3.75)).powi(2);
let p = [
const_f64::<F>(1.0),
const_f64::<F>(3.5156229),
const_f64::<F>(3.0899424),
const_f64::<F>(1.2067492),
const_f64::<F>(0.2659732),
const_f64::<F>(0.0360768),
const_f64::<F>(0.0045813),
];
let mut sum = F::zero();
for i in (0..p.len()).rev() {
sum = sum * y + p[i];
}
sum
} else {
let y = const_f64::<F>(3.75) / abs_x;
let p = [
const_f64::<F>(0.39894228),
const_f64::<F>(0.01328592),
const_f64::<F>(0.00225319),
const_f64::<F>(-0.00157565),
const_f64::<F>(0.00916281),
const_f64::<F>(-0.02057706),
const_f64::<F>(0.02635537),
const_f64::<F>(-0.01647633),
const_f64::<F>(0.00392377),
];
let mut sum = F::zero();
for i in (0..p.len()).rev() {
sum = sum * y + p[i];
}
let exp_term = abs_x.exp();
if !exp_term.is_infinite() {
sum * exp_term / abs_x.sqrt()
} else {
let log_result = abs_x - const_f64::<F>(0.5) * abs_x.ln() + sum.ln();
if log_result < F::from(constants::f64::LN_MAX).expect("Failed to convert to float") {
log_result.exp()
} else {
F::infinity()
}
}
}
}
#[allow(dead_code)]
pub fn i1<F: Float + FromPrimitive + Debug>(x: F) -> F {
if x == F::zero() {
return F::zero();
}
let abs_x = x.abs();
let sign = if x.is_sign_positive() {
F::one()
} else {
-F::one()
};
if abs_x < const_f64::<F>(1e-6) {
let x2 = abs_x * abs_x;
let x3 = abs_x * x2;
let x5 = x3 * x2;
return sign
* (abs_x / const_f64::<F>(2.0)
+ x3 / const_f64::<F>(16.0)
+ x5 / const_f64::<F>(384.0));
}
if abs_x <= const_f64::<F>(3.75) {
let y = (abs_x / const_f64::<F>(3.75)).powi(2);
let p = [
const_f64::<F>(0.5),
const_f64::<F>(0.87890594),
const_f64::<F>(0.51498869),
const_f64::<F>(0.15084934),
const_f64::<F>(0.02658733),
const_f64::<F>(0.00301532),
const_f64::<F>(0.00032411),
];
let mut sum = F::zero();
for i in (0..p.len()).rev() {
sum = sum * y + p[i];
}
sign * sum * abs_x
} else {
let y = const_f64::<F>(3.75) / abs_x;
let p = [
const_f64::<F>(0.39894228),
const_f64::<F>(-0.03988024),
const_f64::<F>(-0.00362018),
const_f64::<F>(0.00163801),
const_f64::<F>(-0.01031555),
const_f64::<F>(0.02282967),
const_f64::<F>(-0.02895312),
const_f64::<F>(0.01787654),
const_f64::<F>(-0.00420059),
];
let mut sum = F::zero();
for i in (0..p.len()).rev() {
sum = sum * y + p[i];
}
let exp_term = abs_x.exp();
if !exp_term.is_infinite() {
sign * sum * exp_term / abs_x.sqrt()
} else {
let log_result = abs_x - const_f64::<F>(0.5) * abs_x.ln() + sum.ln();
if log_result < F::from(constants::f64::LN_MAX).expect("Failed to convert to float") {
sign * log_result.exp()
} else {
sign * F::infinity()
}
}
}
}
#[allow(dead_code)]
pub fn iv<F: Float + FromPrimitive + Debug + std::ops::AddAssign>(v: F, x: F) -> F {
if x == F::zero() {
if v == F::zero() {
return F::one();
} else {
return F::zero();
}
}
let abs_x = x.abs();
let v_f64 = v.to_f64().expect("Test/example failed");
if v_f64.fract() == 0.0 && (0.0..=100.0).contains(&v_f64) {
let n = v_f64 as i32;
if n == 0 {
return i0(x);
} else if n == 1 {
return i1(x);
} else if n > 1 {
let mut i_vminus_1 = i0(abs_x);
let mut i_v = i1(abs_x);
for k in 1..n {
let k_f = F::from(k).expect("Failed to convert to float");
let i_v_plus_1 = i_vminus_1 - (k_f + k_f) / abs_x * i_v;
i_vminus_1 = i_v;
i_v = i_v_plus_1;
}
if x.is_sign_negative() && n % 2 != 0 {
return -i_v;
}
return i_v;
}
}
let half_x = abs_x / const_f64::<F>(2.0);
let gamma_v1 = gamma(v + F::one());
let gamma_v1_f64 = gamma_v1.to_f64().unwrap_or(f64::NAN);
if !gamma_v1_f64.is_finite() || gamma_v1_f64 == 0.0 {
return F::infinity();
}
let log_abs_gamma_v1 = gamma_v1_f64.abs().ln();
let gamma_sign = if gamma_v1_f64 < 0.0 {
-1.0_f64
} else {
1.0_f64
};
let half_x_f64 = half_x.to_f64().unwrap_or(1.0);
let log_term_f64 = v_f64 * half_x_f64.ln() - log_abs_gamma_v1;
let log_term = F::from(log_term_f64).unwrap_or(F::zero());
if log_term < F::from(constants::f64::LN_MAX).expect("Failed to convert to float")
&& log_term > F::from(constants::f64::LN_MIN).expect("Failed to convert to float")
{
let prefactor = log_term.exp() * F::from(gamma_sign).unwrap_or(F::one());
let mut sum = F::one();
let mut term = F::one();
let x2 = half_x * half_x;
for k in 1..=100 {
let k_f = F::from(k).expect("Failed to convert to float");
term = term * x2 / (k_f * (v + k_f));
sum += term;
if term.abs() < const_f64::<F>(1e-15) * sum.abs() {
break;
}
}
let result = prefactor * sum;
if x.is_sign_negative() {
if v_f64.fract() == 0.0 {
if (v_f64 as i32) % 2 != 0 {
return -result;
}
return result;
} else {
let v_floor = v_f64.floor() as i32;
if v_floor % 2 != 0 {
return -result;
}
return result;
}
}
return result;
}
if abs_x > F::from(max(20.0, v_f64 * 1.5)).expect("Operation failed") {
let one_over_sqrt_2pi_x = F::from(constants::f64::ONE_OVER_SQRT_2PI)
.expect("Failed to convert to float")
/ abs_x.sqrt();
let log_result = abs_x + one_over_sqrt_2pi_x.ln();
if log_result < F::from(constants::f64::LN_MAX).expect("Failed to convert to float") {
let mu = const_f64::<F>(4.0) * v * v; let muminus_1 = mu - F::one();
let correction = F::one() - muminus_1 / (const_f64::<F>(8.0) * abs_x)
+ muminus_1 * (muminus_1 + const_f64::<F>(2.0))
/ (const_f64::<F>(128.0) * abs_x * abs_x);
let result = log_result.exp() * correction;
if x.is_sign_negative() && v_f64.fract() == 0.0 && (v_f64 as i32) % 2 != 0 {
return -result;
}
return result;
} else {
return F::infinity();
}
}
let exp_term = (abs_x * const_f64::<F>(0.5)).exp();
let result = exp_term * (abs_x / (const_f64::<F>(2.0) * (v + F::one()))).powf(v);
if x.is_sign_negative() && v_f64.fract() == 0.0 && (v_f64 as i32) % 2 != 0 {
return -result;
}
result
}
#[allow(dead_code)]
pub fn k0<F: Float + FromPrimitive + Debug>(x: F) -> F {
if x <= F::zero() {
return F::infinity();
}
let x_f64 = x.to_f64().unwrap_or(1.0);
if x_f64 <= 8.0 {
const EULER_GAMMA: f64 = 0.5772156649015329;
let result = k0_series_small(x_f64, EULER_GAMMA);
F::from(result).unwrap_or(F::infinity())
} else {
let result = k_asymptotic(x_f64, 0.0_f64);
F::from(result).unwrap_or(F::zero())
}
}
fn k_asymptotic(x: f64, mu: f64) -> f64 {
let mut sum = 1.0_f64;
let mut term = 1.0_f64;
for k in 1_u32..=25 {
let k_f = k as f64;
let odd = 2.0 * k_f - 1.0; term *= (mu - odd * odd) / (8.0 * x * k_f);
if term.abs() > sum.abs() {
break;
}
sum += term;
if term.abs() < 1e-15 * sum.abs() {
break;
}
}
let prefactor = (std::f64::consts::PI / (2.0 * x)).sqrt() * (-x).exp();
prefactor * sum
}
#[allow(dead_code)]
pub fn k1<F: Float + FromPrimitive + Debug>(x: F) -> F {
if x <= F::zero() {
return F::infinity();
}
let x_f64 = x.to_f64().unwrap_or(1.0);
if x_f64 <= 8.0 {
const EULER_GAMMA: f64 = 0.5772156649015329;
let h = (x_f64 * 1e-4_f64).clamp(1e-7_f64, 0.1_f64);
let k0_m2 = k0_series_small(x_f64 - 2.0 * h, EULER_GAMMA);
let k0_m1 = k0_series_small(x_f64 - h, EULER_GAMMA);
let k0_p1 = k0_series_small(x_f64 + h, EULER_GAMMA);
let k0_p2 = k0_series_small(x_f64 + 2.0 * h, EULER_GAMMA);
let k1_approx = (-k0_m2 + 8.0 * k0_m1 - 8.0 * k0_p1 + k0_p2) / (12.0 * h);
F::from(k1_approx).unwrap_or(F::infinity())
} else {
let mu = 4.0_f64;
let result = k_asymptotic(x_f64, mu);
F::from(result).unwrap_or(F::zero())
}
}
fn k0_series_small(x: f64, euler_gamma: f64) -> f64 {
let half_x = x / 2.0;
let half_x2 = half_x * half_x;
let mut sum_i0 = 1.0_f64;
let mut sum_k0_extra = -euler_gamma;
let mut term = 1.0_f64;
let mut h_k = 0.0_f64;
for k in 1_u32..=80 {
let k_f = k as f64;
term *= half_x2 / (k_f * k_f);
sum_i0 += term;
h_k += 1.0 / k_f;
sum_k0_extra += (h_k - euler_gamma) * term;
if term.abs() < 1e-17 * sum_i0.abs() {
break;
}
}
-half_x.ln() * sum_i0 + sum_k0_extra
}
#[allow(dead_code)]
pub fn kv<F: Float + FromPrimitive + Debug + std::ops::AddAssign>(v: F, x: F) -> F {
if x <= F::zero() {
return F::infinity();
}
let abs_v = v.abs();
let v_f64 = abs_v.to_f64().unwrap_or(0.0);
if v_f64.fract() == 0.0 {
let n = v_f64 as i32;
if n == 0 {
return k0(x);
} else if n == 1 {
return k1(x);
} else {
let mut k_prev = k0(x);
let mut k_curr = k1(x);
for k in 1..n {
let k_f = F::from(k).unwrap_or(F::one());
let k_next = k_prev + (k_f + k_f) / x * k_curr;
k_prev = k_curr;
k_curr = k_next;
}
return k_curr;
}
}
let x_f64 = x.to_f64().unwrap_or(1.0);
if x_f64 >= 8.0 {
let mu = 4.0 * v_f64 * v_f64;
let result = k_asymptotic(x_f64, mu);
return F::from(result).unwrap_or(F::infinity());
}
let frac_part = v_f64.fract();
let near_int = !(0.02..=0.98).contains(&frac_part);
if near_int {
let mu = 4.0 * v_f64 * v_f64;
let result = k_asymptotic(x_f64, mu);
return F::from(result).unwrap_or(F::infinity());
}
let pi = std::f64::consts::PI;
let sin_pi_v = (pi * v_f64).sin();
if sin_pi_v.abs() < 1e-14 {
return F::infinity();
}
let i_v = iv(abs_v, x);
let neg_v = F::from(-v_f64).unwrap_or(-abs_v);
let i_neg_v = iv(neg_v, x);
let factor = F::from(pi / (2.0 * sin_pi_v)).unwrap_or(F::zero());
factor * (i_neg_v - i_v)
}
#[allow(dead_code)]
pub fn i0e<F: Float + FromPrimitive + Debug>(x: F) -> F {
let abs_x = x.abs();
i0(x) * (-abs_x).exp()
}
#[allow(dead_code)]
pub fn i1e<F: Float + FromPrimitive + Debug>(x: F) -> F {
let abs_x = x.abs();
let sign = if x.is_sign_positive() {
F::one()
} else {
-F::one()
};
sign * i1(abs_x) * (-abs_x).exp()
}
#[allow(dead_code)]
pub fn ive<F: Float + FromPrimitive + Debug + std::ops::AddAssign>(v: F, x: F) -> F {
let abs_x = x.abs();
iv(v, x) * (-abs_x).exp()
}
#[allow(dead_code)]
pub fn k0e<F: Float + FromPrimitive + Debug>(x: F) -> F {
if x <= F::zero() {
return F::infinity();
}
k0(x) * x.exp()
}
#[allow(dead_code)]
pub fn k1e<F: Float + FromPrimitive + Debug>(x: F) -> F {
if x <= F::zero() {
return F::infinity();
}
k1(x) * x.exp()
}
#[allow(dead_code)]
pub fn kve<F: Float + FromPrimitive + Debug + std::ops::AddAssign>(v: F, x: F) -> F {
if x <= F::zero() {
return F::infinity();
}
kv(v, x) * x.exp()
}
#[allow(dead_code)]
fn max<T: PartialOrd>(a: T, b: T) -> T {
if a > b {
a
} else {
b
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_i0_special_cases() {
assert_relative_eq!(i0(0.0), 1.0, epsilon = 1e-10);
let i0_small = i0(1e-10);
assert_relative_eq!(i0_small, 1.0, epsilon = 1e-10);
}
#[test]
fn test_i0_moderate_values() {
assert_relative_eq!(i0(0.5), 1.0634833439946074, epsilon = 1e-8);
assert_relative_eq!(i0(1.0), 1.2660658480342601, epsilon = 1e-8);
}
#[test]
fn test_i1_special_cases() {
assert_relative_eq!(i1(0.0), 0.0, epsilon = 1e-10);
let i1_small = i1(1e-10);
assert_relative_eq!(i1_small, 5e-11, epsilon = 1e-12);
}
#[test]
fn test_iv_integer_orders() {
let x = 2.0;
assert_relative_eq!(iv(0.0, x), i0(x), epsilon = 1e-8);
assert_relative_eq!(iv(1.0, x), i1(x), epsilon = 1e-8);
}
#[test]
fn test_iv_half_integer_order() {
let pi = std::f64::consts::PI;
for &x in &[0.5_f64, 1.0, 2.0, 3.0] {
let expected = (2.0 / (pi * x)).sqrt() * x.sinh();
let got = iv(0.5_f64, x);
let rel_err = (got - expected).abs() / expected.abs();
assert!(
rel_err < 1e-9,
"iv(0.5, {x}) = {got}, expected {expected}, rel_err = {rel_err}"
);
}
}
#[test]
fn test_k0_reference_values() {
let k0_1 = k0(1.0_f64);
assert_relative_eq!(k0_1, 0.4210244382407083, epsilon = 1e-9);
let k0_2 = k0(2.0_f64);
assert_relative_eq!(k0_2, 0.1138938727495334, epsilon = 1e-9);
let k0_5 = k0(5.0_f64);
assert_relative_eq!(k0_5, 0.003691098334043, epsilon = 1e-6);
let k0_05 = k0(0.5_f64);
assert_relative_eq!(k0_05, 0.9244190712276663, epsilon = 1e-9);
}
#[test]
fn test_k1_reference_values() {
let k1_1 = k1(1.0_f64);
assert_relative_eq!(k1_1, 0.6019072301972346, epsilon = 1e-8);
let k1_05 = k1(0.5_f64);
assert_relative_eq!(k1_05, 1.6564411200033016, epsilon = 1e-6);
let k1_2 = k1(2.0_f64);
assert_relative_eq!(k1_2, 0.13986588181652243, epsilon = 1e-8);
let k1_5 = k1(5.0_f64);
assert_relative_eq!(k1_5, 0.004044613445452, epsilon = 1e-6);
}
#[test]
fn test_kv_half_integer_order() {
let pi = std::f64::consts::PI;
for &x in &[0.5_f64, 1.0, 2.0, 3.0] {
let expected = (pi / (2.0 * x)).sqrt() * (-x).exp();
let got = kv(0.5_f64, x);
let rel_err = (got - expected).abs() / expected.abs();
assert!(
rel_err < 1e-6,
"kv(0.5, {x}) = {got}, expected {expected}, rel_err = {rel_err}"
);
}
}
#[test]
fn test_modified_bessel_wronskian_identity() {
for &v in &[0.0_f64, 0.5, 1.0, 1.5, 2.0] {
for &x in &[0.5_f64, 1.0, 2.0, 5.0] {
let i_v = iv(v, x);
let i_v1 = iv(v + 1.0, x);
let k_v = kv(v, x);
let k_v1 = kv(v + 1.0, x);
let lhs = i_v * k_v1 + i_v1 * k_v;
let expected = 1.0 / x;
let rel_err = (lhs - expected).abs() / expected.abs();
assert!(
rel_err < 1e-5,
"Wronskian: v={v}, x={x}: I_v*K_{{v+1}} + I_{{v+1}}*K_v = {lhs}, expected {expected}, rel_err={rel_err}"
);
}
}
}
#[test]
fn test_kv_debug_half() {
let test_cases = [
(0.5_f64, 0.5_f64),
(0.5, 1.0),
(0.5, 2.0),
(1.5, 0.5),
(1.5, 1.0),
(1.5, 2.0),
(0.3, 1.0),
(0.7, 1.0),
(1.3, 1.0),
];
for (v_val, x_val) in &test_cases {
let kv_val = kv(*v_val, *x_val);
let iv_val = iv(*v_val, *x_val);
assert!(
kv_val.is_finite() && kv_val > 0.0,
"kv({v_val}, {x_val}) = {kv_val}"
);
assert!(
iv_val.is_finite() && iv_val > 0.0,
"iv({v_val}, {x_val}) = {iv_val}"
);
let i_v1 = iv(*v_val + 1.0, *x_val);
let k_v1 = kv(*v_val + 1.0, *x_val);
let lhs = iv_val * k_v1 + i_v1 * kv_val;
let expected = 1.0 / x_val;
let rel_err = (lhs - expected).abs() / expected.abs();
assert!(rel_err < 1e-5,
"Wronskian fail: v={v_val}, x={x_val}, lhs={lhs}, expected={expected}, rel_err={rel_err}");
}
}
#[test]
fn test_kv_integer_orders() {
let x = 1.0_f64;
assert_relative_eq!(kv(0.0_f64, x), k0(x), epsilon = 1e-10);
assert_relative_eq!(kv(1.0_f64, x), k1(x), epsilon = 1e-10);
let k2_1 = kv(2.0_f64, x);
assert_relative_eq!(k2_1, 1.6248389218904965, epsilon = 1e-6);
}
}