use ferray_core::Array;
use ferray_core::dimension::Dimension;
use ferray_core::dtype::Element;
use ferray_core::error::FerrayResult;
use num_traits::Float;
use crate::cr_math::CrMath;
use crate::helpers::unary_float_op;
pub fn sinc<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
where
T: Element + Float + CrMath,
D: Dimension,
{
let pi = T::from(std::f64::consts::PI).unwrap_or_else(|| <T as Element>::one());
unary_float_op(input, |x| {
if x == <T as Element>::zero() {
<T as Element>::one()
} else {
let px = pi * x;
px.cr_sin() / px
}
})
}
pub fn i0<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
where
T: Element + Float + CrMath,
D: Dimension,
{
unary_float_op(input, bessel_i0)
}
fn bessel_i0<T: Float + CrMath>(x: T) -> T {
let ax = x.abs();
let three_point_seven_five = T::from(3.75).unwrap();
if ax <= three_point_seven_five {
let t = (ax / three_point_seven_five).powi(2);
let c0 = T::from(1.0).unwrap();
let c1 = T::from(3.5156229).unwrap();
let c2 = T::from(3.0899424).unwrap();
let c3 = T::from(1.2067492).unwrap();
let c4 = T::from(0.2659732).unwrap();
let c5 = T::from(0.0360768).unwrap();
let c6 = T::from(0.0045813).unwrap();
c0 + t * (c1 + t * (c2 + t * (c3 + t * (c4 + t * (c5 + t * c6)))))
} else {
let t = three_point_seven_five / ax;
let c0 = T::from(0.39894228).unwrap();
let c1 = T::from(0.01328592).unwrap();
let c2 = T::from(0.00225319).unwrap();
let c3 = T::from(-0.00157565).unwrap();
let c4 = T::from(0.00916281).unwrap();
let c5 = T::from(-0.02057706).unwrap();
let c6 = T::from(0.02635537).unwrap();
let c7 = T::from(-0.01647633).unwrap();
let c8 = T::from(0.00392377).unwrap();
let poly = c0
+ t * (c1 + t * (c2 + t * (c3 + t * (c4 + t * (c5 + t * (c6 + t * (c7 + t * c8)))))));
poly * ax.cr_exp() / ax.sqrt()
}
}
#[cfg(feature = "f16")]
pub fn sinc_f16<D>(input: &Array<half::f16, D>) -> FerrayResult<Array<half::f16, D>>
where
D: Dimension,
{
crate::helpers::unary_f16_op(input, |x| {
if x == 0.0 {
1.0
} else {
let px = std::f32::consts::PI * x;
core_math::sinf(px) / px
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::Ix1;
fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
let n = data.len();
Array::from_vec(Ix1::new([n]), data).unwrap()
}
#[test]
fn test_sinc_zero_ac13() {
let a = arr1(vec![0.0]);
let r = sinc(&a).unwrap();
assert_eq!(r.as_slice().unwrap()[0], 1.0);
}
#[test]
fn test_sinc_nonzero() {
let a = arr1(vec![1.0, -1.0, 0.5]);
let r = sinc(&a).unwrap();
let s = r.as_slice().unwrap();
assert!(s[0].abs() < 1e-12);
assert!(s[1].abs() < 1e-12);
assert!((s[2] - 2.0 / std::f64::consts::PI).abs() < 1e-12);
}
#[test]
fn test_i0_zero_ac13() {
let a = arr1(vec![0.0]);
let r = i0(&a).unwrap();
assert!((r.as_slice().unwrap()[0] - 1.0).abs() < 1e-6);
}
#[test]
fn test_i0_known_values() {
let a = arr1(vec![1.0, 2.0]);
let r = i0(&a).unwrap();
let s = r.as_slice().unwrap();
assert!((s[0] - 1.2660658).abs() < 1e-4);
assert!((s[1] - 2.2795853).abs() < 1e-4);
}
#[test]
fn test_sinc_f32() {
let a = Array::<f32, Ix1>::from_vec(Ix1::new([2]), vec![0.0f32, 1.0]).unwrap();
let r = sinc(&a).unwrap();
let s = r.as_slice().unwrap();
assert!((s[0] - 1.0).abs() < 1e-6);
assert!(s[1].abs() < 1e-5);
}
}