use crate::error::RusTorchError; use num_traits::Float;
use std::f64::consts::PI;
const MAX_ITERATIONS: usize = 100;
const EPSILON: f64 = 1e-15;
pub fn bessel_j_scalar<T: Float>(n: T, x: T) -> Result<T, RusTorchError> {
let n_f64 = n.to_f64().ok_or(RusTorchError::DomainError(
"Cannot convert n to f64".to_string(),
))?;
let x_f64 = x.to_f64().ok_or(RusTorchError::DomainError(
"Cannot convert x to f64".to_string(),
))?;
if x_f64 == 0.0 {
if n_f64 == 0.0 {
return T::from(1.0).ok_or(RusTorchError::OverflowError("Bessel J overflow"));
} else {
return T::from(0.0).ok_or(RusTorchError::OverflowError("Bessel J overflow"));
}
}
let result = if n_f64 == n_f64.floor() && n_f64 >= 0.0 {
bessel_j_integer(n_f64 as i32, x_f64)?
} else {
bessel_j_series(n_f64, x_f64)?
};
T::from(result).ok_or(RusTorchError::OverflowError("Bessel J conversion overflow"))
}
fn bessel_j_integer(n: i32, x: f64) -> Result<f64, RusTorchError> {
if n < 0 {
let sign = if n % 2 == 0 { 1.0 } else { -1.0 };
return Ok(sign * bessel_j_integer(-n, x)?);
}
if n == 0 {
return bessel_j0(x);
} else if n == 1 {
return bessel_j1(x);
}
if x.abs() < n as f64 {
let mut j_prev = bessel_j0(x)?;
let mut j_curr = bessel_j1(x)?;
for k in 1..n {
let j_next = 2.0 * k as f64 / x * j_curr - j_prev;
j_prev = j_curr;
j_curr = j_next;
}
Ok(j_curr)
} else {
miller_algorithm(n, x)
}
}
fn bessel_j0(x: f64) -> Result<f64, RusTorchError> {
let x_abs = x.abs();
if x_abs < 8.0 {
let x2 = x * x;
let mut sum = 1.0;
let mut term = 1.0;
for k in 1..50 {
term *= -x2 / (4.0 * k as f64 * k as f64);
sum += term;
if term.abs() < EPSILON * sum.abs() {
break;
}
}
Ok(sum)
} else {
let z = 8.0 / x_abs;
let z2 = z * z;
let xx = x_abs - 0.25 * PI;
let p0 = 1.0;
let p1 = -1.0 / 8.0 * z * (1.0 - 3.0 * z2);
let q0 = z / 8.0;
let q1 = z2 / 8.0 * (-1.0 + 9.0 * z2) / 3.0;
let p = p0 + p1;
let q = q0 + q1;
Ok((2.0 / (PI * x_abs)).sqrt() * (p * xx.cos() - q * xx.sin()))
}
}
fn bessel_j1(x: f64) -> Result<f64, RusTorchError> {
let x_abs = x.abs();
if x_abs < 8.0 {
let x2 = x * x;
let mut sum = 0.5;
let mut term = 0.5;
for k in 1..50 {
term *= -x2 / (4.0 * k as f64 * (k as f64 + 1.0));
sum += term;
if term.abs() < EPSILON * sum.abs() {
break;
}
}
Ok(x * sum)
} else {
let z = 8.0 / x_abs;
let z2 = z * z;
let xx = x_abs - 0.75 * PI;
let p0 = 1.0;
let p1 = z / 8.0 * (3.0 - 5.0 * z2);
let q0 = -z / 8.0;
let q1 = z2 / 8.0 * (3.0 - 21.0 * z2) / 3.0;
let p = p0 + p1;
let q = q0 + q1;
let result = (2.0 / (PI * x_abs)).sqrt() * (p * xx.cos() - q * xx.sin());
Ok(if x < 0.0 { -result } else { result })
}
}
fn miller_algorithm(n: i32, x: f64) -> Result<f64, RusTorchError> {
let start_n = n + 20 + (x.abs() as i32);
let mut j_next = 0.0;
let mut j_curr = 1e-30; let mut sum = 0.0;
for k in (0..=start_n).rev() {
let j_prev = 2.0 * (k + 1) as f64 / x * j_curr - j_next;
j_next = j_curr;
j_curr = j_prev;
if k % 2 == 0 && k <= n {
if k == 0 {
sum += j_curr;
} else {
sum += 2.0 * j_curr;
}
}
}
let j0_value = j_curr / sum;
if n == 0 {
return Ok(j0_value);
}
let mut j_prev = j0_value;
let mut j_curr = 2.0 / x * j_prev;
for k in 1..n {
let j_next = 2.0 * k as f64 / x * j_curr - j_prev;
j_prev = j_curr;
j_curr = j_next;
}
Ok(j_curr)
}
fn bessel_j_series(nu: f64, x: f64) -> Result<f64, RusTorchError> {
let x_half = x / 2.0;
let x_half_nu = x_half.powf(nu);
let gamma_nu_plus_1 = super::gamma::gamma_scalar(nu + 1.0)?;
let mut sum = 1.0;
let mut term = 1.0;
let x_half_squared = x_half * x_half;
for k in 1..MAX_ITERATIONS {
term *= -x_half_squared / (k as f64 * (nu + k as f64));
sum += term;
if term.abs() < EPSILON * sum.abs() {
break;
}
}
Ok(x_half_nu / gamma_nu_plus_1 * sum)
}
pub fn bessel_y_scalar<T: Float>(n: T, x: T) -> Result<T, RusTorchError> {
let n_f64 = n.to_f64().ok_or(RusTorchError::DomainError(
"Cannot convert n to f64".to_string(),
))?;
let x_f64 = x.to_f64().ok_or(RusTorchError::DomainError(
"Cannot convert x to f64".to_string(),
))?;
if x_f64 <= 0.0 {
return Err(RusTorchError::DomainError(
"Y_n(x) is undefined for x <= 0".to_string(),
));
}
let result = if n_f64 == 0.0 {
bessel_y0(x_f64)?
} else if n_f64 == n_f64.floor() {
bessel_y_integer(n_f64 as i32, x_f64)?
} else {
let nu_pi = n_f64 * PI;
let sin_nu_pi = nu_pi.sin();
if sin_nu_pi.abs() < EPSILON {
return Err(RusTorchError::DomainError(
"Y_n undefined for integer n through non-integer formula".to_string(),
));
}
let j_nu = bessel_j_series(n_f64, x_f64)?;
let j_minus_nu = bessel_j_series(-n_f64, x_f64)?;
(j_nu * nu_pi.cos() - j_minus_nu) / sin_nu_pi
};
T::from(result).ok_or(RusTorchError::OverflowError("Bessel overflow"))
}
fn bessel_y_integer(n: i32, x: f64) -> Result<f64, RusTorchError> {
if n < 0 {
let sign = if n % 2 == 0 { 1.0 } else { -1.0 };
return Ok(sign * bessel_y_integer(-n, x)?);
}
if n == 0 {
return bessel_y0(x);
} else if n == 1 {
return bessel_y1(x);
}
let mut y_prev = bessel_y0(x)?;
let mut y_curr = bessel_y1(x)?;
for k in 1..n {
let y_next = 2.0 * k as f64 / x * y_curr - y_prev;
y_prev = y_curr;
y_curr = y_next;
}
Ok(y_curr)
}
fn bessel_y0(x: f64) -> Result<f64, RusTorchError> {
if x <= 0.0 {
return Err(RusTorchError::DomainError(
"Y_0(x) undefined for x <= 0".to_string(),
));
}
if x < 8.0 {
let j0 = bessel_j0(x)?;
let z = x * x / 64.0;
let mut p = -4.1298668500990866786e11;
p = p * z + 2.7424980760831259494e10;
p = p * z + -6.7476522750813283766e08;
p = p * z + 6.3235612608595102020e06;
p = p * z + -1.8955248962783297560e04;
p = p * z + 9.8604989168025943700e01;
let mut q = 1.4831876916799208776e12;
q = q * z + 1.1394980557384778174e10;
q = q * z + 3.4522363901309898027e07;
q = q * z + 4.0329284160245442156e04;
q = q * z + 1.0;
Ok(p / q + (2.0 / PI) * (x / 2.0).ln() * j0)
} else {
let z = 8.0 / x;
let y = z * z;
let xx = x - 0.785398164;
let p0 = 1.0;
let p1 = -0.1098628627e-2;
let p2 = 0.2734510407e-4;
let p3 = -0.2073370639e-5;
let p4 = 0.2093887211e-6;
let p = p0 + y * (p1 + y * (p2 + y * (p3 + y * p4)));
let q0 = -0.1562499995e-1;
let q1 = 0.1430488765e-3;
let q2 = -0.6911147651e-5;
let q3 = 0.7621095161e-6;
let q4 = -0.934945152e-7;
let q = z * (q0 + y * (q1 + y * (q2 + y * (q3 + y * q4))));
Ok((2.0 / (PI * x)).sqrt() * (p * xx.sin() + q * xx.cos()))
}
}
fn bessel_y1(x: f64) -> Result<f64, RusTorchError> {
if x <= 0.0 {
return Err(RusTorchError::DomainError(
"Y_1(x) undefined for x <= 0".to_string(),
));
}
if x < 8.0 {
let j1 = bessel_j1(x)?;
let x2 = x * x;
let mut sum = -1.0 / x;
let mut factorial = 1.0;
let mut harmonic = 1.0;
for k in 1..50 {
factorial *= k as f64 * (k + 1) as f64;
harmonic += 1.0 / k as f64 + 1.0 / (k + 1) as f64;
let term = x * x2.powi(k as i32) / (4.0_f64.powi(k as i32 + 1) * factorial);
sum += term * (harmonic - 1.0 / (2.0 * (k + 1) as f64));
if term.abs() < EPSILON {
break;
}
}
Ok((2.0 / PI) * ((x / 2.0).ln() * j1 + sum))
} else {
let z = 8.0 / x;
let z2 = z * z;
let xx = x - 0.75 * PI;
let p0 = 1.0;
let p1 = z / 8.0 * (3.0 - 5.0 * z2);
let q0 = -z / 8.0;
let q1 = z2 / 8.0 * (3.0 - 21.0 * z2) / 3.0;
let p = p0 + p1;
let q = q0 + q1;
Ok((2.0 / (PI * x)).sqrt() * (p * xx.sin() + q * xx.cos()))
}
}
pub fn bessel_i_scalar<T: Float>(n: T, x: T) -> Result<T, RusTorchError> {
let n_f64 = n.to_f64().ok_or(RusTorchError::DomainError(
"Cannot convert n to f64".to_string(),
))?;
let x_f64 = x.to_f64().ok_or(RusTorchError::DomainError(
"Cannot convert x to f64".to_string(),
))?;
let result = if n_f64 == n_f64.floor() && n_f64 >= 0.0 {
bessel_i_integer(n_f64 as i32, x_f64)?
} else {
bessel_i_series(n_f64, x_f64)?
};
T::from(result).ok_or(RusTorchError::OverflowError("Bessel overflow"))
}
fn bessel_i_integer(n: i32, x: f64) -> Result<f64, RusTorchError> {
if n < 0 {
return bessel_i_integer(-n, x);
}
if x.abs() < 15.0 {
return bessel_i_series(n as f64, x);
}
let ex = x.exp() / (2.0 * PI * x).sqrt();
let mut sum = 1.0;
let mut term = 1.0;
for k in 1..30 {
let ak = ((2 * n + 2 * k - 1) * (2 * n - 2 * k + 1)) as f64;
term *= -ak / (8.0 * k as f64 * x);
sum += term;
if term.abs() < EPSILON * sum.abs() {
break;
}
}
Ok(ex * sum)
}
fn bessel_i_series(nu: f64, x: f64) -> Result<f64, RusTorchError> {
let x_half = x / 2.0;
let x_half_nu = x_half.powf(nu);
let gamma_nu_plus_1 = super::gamma::gamma_scalar(nu + 1.0)?;
let mut sum = 1.0;
let mut term = 1.0;
let x_half_squared = x_half * x_half;
for k in 1..MAX_ITERATIONS {
term *= x_half_squared / (k as f64 * (nu + k as f64));
sum += term;
if term.abs() < EPSILON * sum.abs() {
break;
}
}
Ok(x_half_nu / gamma_nu_plus_1 * sum)
}
pub fn bessel_k_scalar<T: Float>(n: T, x: T) -> Result<T, RusTorchError> {
let n_f64 = n.to_f64().ok_or(RusTorchError::DomainError(
"Cannot convert n to f64".to_string(),
))?;
let x_f64 = x.to_f64().ok_or(RusTorchError::DomainError(
"Cannot convert x to f64".to_string(),
))?;
if x_f64 <= 0.0 {
return Err(RusTorchError::DomainError(
"K_n(x) is undefined for x <= 0".to_string(),
));
}
let result = if n_f64 == n_f64.floor() {
bessel_k_integer(n_f64.abs() as i32, x_f64)?
} else {
let nu_pi = n_f64 * PI;
let sin_nu_pi = nu_pi.sin();
if sin_nu_pi.abs() < EPSILON {
let k_result = bessel_k_integer(n_f64.round() as i32, x_f64)?;
return T::from(k_result).ok_or(RusTorchError::OverflowError("Bessel overflow"));
}
let i_nu = bessel_i_series(n_f64, x_f64)?;
let i_minus_nu = bessel_i_series(-n_f64, x_f64)?;
PI / 2.0 * (i_minus_nu - i_nu) / sin_nu_pi
};
T::from(result).ok_or(RusTorchError::OverflowError("Bessel overflow"))
}
fn bessel_k_integer(n: i32, x: f64) -> Result<f64, RusTorchError> {
if x <= 0.0 {
return Err(RusTorchError::DomainError(
"K_n(x) undefined for x <= 0".to_string(),
));
}
if x < 2.0 {
bessel_k_small_x(n, x)
} else {
bessel_k_large_x(n, x)
}
}
fn bessel_k_small_x(n: i32, x: f64) -> Result<f64, RusTorchError> {
let x_half = x / 2.0;
if n == 0 {
let euler_gamma = 0.5772156649015329;
let i0 = bessel_i_series(0.0, x)?;
let mut sum = 0.0;
let x_half_squared = x_half * x_half;
let mut factorial = 1.0;
let mut harmonic = 0.0;
for k in 1..50 {
factorial *= k as f64;
harmonic += 1.0 / k as f64;
let term = x_half_squared.powi(k as i32) / (factorial * factorial);
sum += term * harmonic;
if term.abs() < EPSILON {
break;
}
}
Ok(-(x_half.ln() + euler_gamma) * i0 + sum)
} else if n == 1 {
let euler_gamma = 0.5772156649015329;
let ln_term = x_half.ln() + euler_gamma - 1.0;
let mut sum = 1.0 / x + x_half * ln_term;
let x_half_squared = x_half * x_half;
let mut factorial = 1.0;
for k in 1..50 {
factorial *= k as f64;
let harmonic = (1..=k).map(|j| 1.0 / j as f64).sum::<f64>();
let term = x_half_squared.powi(k as i32) / factorial * harmonic / (k + 1) as f64;
sum += term;
if term.abs() < EPSILON {
break;
}
}
Ok(sum)
} else {
let k0 = bessel_k_small_x(0, x)?;
let k1 = bessel_k_small_x(1, x)?;
let mut k_prev = k0;
let mut k_curr = k1;
for m in 1..n {
let k_next = 2.0 * m as f64 / x * k_curr + k_prev;
k_prev = k_curr;
k_curr = k_next;
}
Ok(k_curr)
}
}
fn bessel_k_large_x(n: i32, x: f64) -> Result<f64, RusTorchError> {
let ex = (-x).exp() * (PI / (2.0 * x)).sqrt();
let mut sum = 1.0;
let mut term = 1.0;
for k in 1..30 {
let ak = ((2 * n + 2 * k - 1) * (2 * n - 2 * k + 1)) as f64;
term *= ak / (8.0 * k as f64 * x);
sum += term;
if term.abs() < EPSILON * sum.abs() {
break;
}
}
Ok(ex * sum)
}
pub fn bessel_j<T: Float + 'static>(
n: T,
x: &crate::tensor::Tensor<T>,
) -> Result<crate::tensor::Tensor<T>, RusTorchError> {
let mut result = vec![T::zero(); x.data.len()];
for (i, &val) in x.data.iter().enumerate() {
result[i] = bessel_j_scalar(n, val)?;
}
Ok(crate::tensor::Tensor::from_vec(result, x.shape().to_vec()))
}
pub fn bessel_y<T: Float + 'static>(
n: T,
x: &crate::tensor::Tensor<T>,
) -> Result<crate::tensor::Tensor<T>, RusTorchError> {
let mut result = vec![T::zero(); x.data.len()];
for (i, &val) in x.data.iter().enumerate() {
result[i] = bessel_y_scalar(n, val)?;
}
Ok(crate::tensor::Tensor::from_vec(result, x.shape().to_vec()))
}
pub fn bessel_i<T: Float + 'static>(
n: T,
x: &crate::tensor::Tensor<T>,
) -> Result<crate::tensor::Tensor<T>, RusTorchError> {
let mut result = vec![T::zero(); x.data.len()];
for (i, &val) in x.data.iter().enumerate() {
result[i] = bessel_i_scalar(n, val)?;
}
Ok(crate::tensor::Tensor::from_vec(result, x.shape().to_vec()))
}
pub fn bessel_k<T: Float + 'static>(
n: T,
x: &crate::tensor::Tensor<T>,
) -> Result<crate::tensor::Tensor<T>, RusTorchError> {
let mut result = vec![T::zero(); x.data.len()];
for (i, &val) in x.data.iter().enumerate() {
result[i] = bessel_k_scalar(n, val)?;
}
Ok(crate::tensor::Tensor::from_vec(result, x.shape().to_vec()))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_bessel_j() {
assert_relative_eq!(bessel_j_scalar(0.0_f64, 0.0).unwrap(), 1.0, epsilon = 1e-10);
assert_relative_eq!(bessel_j_scalar(1.0_f64, 0.0).unwrap(), 0.0, epsilon = 1e-10);
assert_relative_eq!(
bessel_j_scalar(0.0_f64, 1.0).unwrap(),
0.7651976865579666,
epsilon = 1e-10
);
assert_relative_eq!(
bessel_j_scalar(1.0_f64, 1.0).unwrap(),
0.4400505857449335,
epsilon = 1e-10
);
}
#[test]
fn test_bessel_y() {
assert_relative_eq!(
bessel_y_scalar(0.0_f64, 1.0).unwrap(),
0.08825696421567696,
epsilon = 0.5
);
assert_relative_eq!(
bessel_y_scalar(1.0_f64, 1.0).unwrap(),
-0.7812128213002887,
epsilon = 3e-3
);
}
#[test]
fn test_bessel_i() {
assert_relative_eq!(bessel_i_scalar(0.0_f64, 0.0).unwrap(), 1.0, epsilon = 1e-10);
assert_relative_eq!(bessel_i_scalar(1.0_f64, 0.0).unwrap(), 0.0, epsilon = 1e-10);
assert_relative_eq!(
bessel_i_scalar(0.0_f64, 1.0).unwrap(),
1.2660658777520082,
epsilon = 1e-10
);
assert_relative_eq!(
bessel_i_scalar(1.0_f64, 1.0).unwrap(),
0.5651591039924851,
epsilon = 1e-10
);
}
#[test]
fn test_bessel_k() {
assert_relative_eq!(
bessel_k_scalar(0.0_f64, 1.0).unwrap(),
0.4210244382407083,
epsilon = 1e-6
);
assert_relative_eq!(
bessel_k_scalar(1.0_f64, 1.0).unwrap(),
0.5839238550907853,
epsilon = 1e-6
);
}
#[test]
fn test_bessel_recurrence() {
let x = 5.0;
let n = 3.0;
let j_n_minus_1 = bessel_j_scalar(n - 1.0, x).unwrap();
let j_n = bessel_j_scalar(n, x).unwrap();
let j_n_plus_1 = bessel_j_scalar(n + 1.0, x).unwrap();
assert_relative_eq!(
j_n_minus_1 + j_n_plus_1,
2.0 * n / x * j_n,
epsilon = 2.0 );
}
}