#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
use crate::algorithm::special::scalar::{
bessel_i0_scalar, bessel_i1_scalar, bessel_j0_scalar, bessel_j1_scalar, erf_scalar,
};
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn erf_f32(input: *const f32, output: *mut f32, len: usize) {
let lanes = 4;
let chunks = len / lanes;
let a1 = vdupq_n_f32(0.254829592);
let a2 = vdupq_n_f32(-0.284496736);
let a3 = vdupq_n_f32(1.421413741);
let a4 = vdupq_n_f32(-1.453152027);
let a5 = vdupq_n_f32(1.061405429);
let p = vdupq_n_f32(0.3275911);
let one = vdupq_n_f32(1.0);
let neg_one = vdupq_n_f32(-1.0);
for i in 0..chunks {
let idx = i * lanes;
let x = vld1q_f32(input.add(idx));
let sign = vbslq_f32(vcltq_f32(x, vdupq_n_f32(0.0)), neg_one, one);
let absx = vabsq_f32(x);
let t = vdivq_f32(one, vaddq_f32(one, vmulq_f32(p, absx)));
let poly = vmulq_f32(
t,
vaddq_f32(
a1,
vmulq_f32(
t,
vaddq_f32(
a2,
vmulq_f32(
t,
vaddq_f32(a3, vmulq_f32(t, vaddq_f32(a4, vmulq_f32(t, a5)))),
),
),
),
),
);
let x2 = vmulq_f32(absx, absx);
let exp_arr = [
(-vgetq_lane_f32(x2, 0)).exp(),
(-vgetq_lane_f32(x2, 1)).exp(),
(-vgetq_lane_f32(x2, 2)).exp(),
(-vgetq_lane_f32(x2, 3)).exp(),
];
let exp_neg_x2 = vld1q_f32(exp_arr.as_ptr());
let result = vmulq_f32(sign, vsubq_f32(one, vmulq_f32(poly, exp_neg_x2)));
vst1q_f32(output.add(idx), result);
}
for i in (chunks * lanes)..len {
*output.add(i) = erf_scalar(*input.add(i) as f64) as f32;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn erf_f64(input: *const f64, output: *mut f64, len: usize) {
let lanes = 2;
let chunks = len / lanes;
let zero = vdupq_n_f64(0.0);
let one = vdupq_n_f64(1.0);
let neg_one = vdupq_n_f64(-1.0);
let three = vdupq_n_f64(3.0);
let six = vdupq_n_f64(6.0);
let two_over_sqrt_pi = vdupq_n_f64(1.1283791670955126);
let frac_1_sqrt_pi = vdupq_n_f64(0.5641895835477563);
for i in 0..chunks {
let idx = i * lanes;
let x = vld1q_f64(input.add(idx));
let sign = vbslq_f64(vcltq_f64(x, zero), neg_one, one);
let ax = vabsq_f64(x);
let x2 = vmulq_f64(ax, ax);
let neg_x2 = vnegq_f64(x2);
let mut term = ax;
let mut sum = ax;
for n in 1..30 {
let n_f = n as f64;
term = vmulq_f64(term, vdivq_f64(neg_x2, vdupq_n_f64(n_f)));
let contrib = vdivq_f64(term, vdupq_n_f64(2.0 * n_f + 1.0));
sum = vaddq_f64(sum, contrib);
}
let maclaurin_result = vmulq_f64(sum, two_over_sqrt_pi);
let mut f = zero;
for n in (1..=50_u32).rev() {
f = vdivq_f64(vdupq_n_f64(n as f64 * 0.5), vaddq_f64(ax, f));
}
let cf = vdivq_f64(one, vaddq_f64(ax, f));
let exp_arr = [
(-vgetq_lane_f64(x2, 0)).exp(),
(-vgetq_lane_f64(x2, 1)).exp(),
];
let exp_neg_x2 = vld1q_f64(exp_arr.as_ptr());
let erfc_val = vmulq_f64(vmulq_f64(exp_neg_x2, frac_1_sqrt_pi), cf);
let cf_result = vsubq_f64(one, erfc_val);
let mask_small = vcltq_f64(ax, three); let mask_large = vcgeq_f64(ax, six);
let mut result = vbslq_f64(mask_small, maclaurin_result, cf_result);
result = vbslq_f64(mask_large, one, result);
result = vmulq_f64(sign, result);
vst1q_f64(output.add(idx), result);
}
for i in (chunks * lanes)..len {
*output.add(i) = erf_scalar(*input.add(i));
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn erfc_f32(input: *const f32, output: *mut f32, len: usize) {
let lanes = 4;
let chunks = len / lanes;
let one = vdupq_n_f32(1.0);
erf_f32(input, output, len);
for i in 0..chunks {
let idx = i * lanes;
let erf_val = vld1q_f32(output.add(idx));
let result = vsubq_f32(one, erf_val);
vst1q_f32(output.add(idx), result);
}
for i in (chunks * lanes)..len {
*output.add(i) = 1.0 - *output.add(i);
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn erfc_f64(input: *const f64, output: *mut f64, len: usize) {
let lanes = 2;
let chunks = len / lanes;
let one = vdupq_n_f64(1.0);
erf_f64(input, output, len);
for i in 0..chunks {
let idx = i * lanes;
let erf_val = vld1q_f64(output.add(idx));
let result = vsubq_f64(one, erf_val);
vst1q_f64(output.add(idx), result);
}
for i in (chunks * lanes)..len {
*output.add(i) = 1.0 - *output.add(i);
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn bessel_j0_f32(input: *const f32, output: *mut f32, len: usize) {
for i in 0..len {
*output.add(i) = bessel_j0_scalar(*input.add(i) as f64) as f32;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn bessel_j0_f64(input: *const f64, output: *mut f64, len: usize) {
for i in 0..len {
*output.add(i) = bessel_j0_scalar(*input.add(i));
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn bessel_j1_f32(input: *const f32, output: *mut f32, len: usize) {
for i in 0..len {
*output.add(i) = bessel_j1_scalar(*input.add(i) as f64) as f32;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn bessel_j1_f64(input: *const f64, output: *mut f64, len: usize) {
for i in 0..len {
*output.add(i) = bessel_j1_scalar(*input.add(i));
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn bessel_i0_f32(input: *const f32, output: *mut f32, len: usize) {
for i in 0..len {
*output.add(i) = bessel_i0_scalar(*input.add(i) as f64) as f32;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn bessel_i0_f64(input: *const f64, output: *mut f64, len: usize) {
for i in 0..len {
*output.add(i) = bessel_i0_scalar(*input.add(i));
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn bessel_i1_f32(input: *const f32, output: *mut f32, len: usize) {
for i in 0..len {
*output.add(i) = bessel_i1_scalar(*input.add(i) as f64) as f32;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn bessel_i1_f64(input: *const f64, output: *mut f64, len: usize) {
for i in 0..len {
*output.add(i) = bessel_i1_scalar(*input.add(i));
}
}