use core::arch::aarch64::*;
use super::super::ops::complex_mul;
use crate::fft::Complex32;
#[target_feature(enable = "neon")]
pub(super) unsafe fn butterfly_radix2_stride1_neon(
src: &[Complex32],
dst: &mut [Complex32],
stage_twiddles: &[Complex32],
) {
let samples = src.len();
let half_samples = samples >> 1;
let simd_iters = (half_samples >> 1) << 1;
unsafe {
for i in (0..simd_iters).step_by(2) {
let a_ptr = src.as_ptr().add(i) as *const f32;
let a = vld1q_f32(a_ptr);
let b_ptr = src.as_ptr().add(i + half_samples) as *const f32;
let b = vld1q_f32(b_ptr);
let out_top = vaddq_f32(a, b); let out_bot = vsubq_f32(a, b);
let out_top_f64 = vreinterpretq_f64_f32(out_top);
let out_bot_f64 = vreinterpretq_f64_f32(out_bot);
let result_lo = vzip1q_f64(out_top_f64, out_bot_f64); let result_hi = vzip2q_f64(out_top_f64, out_bot_f64);
let result_lo_f32 = vreinterpretq_f32_f64(result_lo);
let result_hi_f32 = vreinterpretq_f32_f64(result_hi);
let j = i << 1;
let dst_ptr = dst.as_mut_ptr().add(j) as *mut f32;
vst1q_f32(dst_ptr, result_lo_f32);
vst1q_f32(dst_ptr.add(4), result_hi_f32);
}
}
super::butterfly_radix2_scalar::<2>(src, dst, stage_twiddles, 1, simd_iters);
}
#[target_feature(enable = "neon")]
pub(super) unsafe fn butterfly_radix2_generic_neon(
src: &[Complex32],
dst: &mut [Complex32],
stage_twiddles: &[Complex32],
stride: usize,
) {
if stride == 0 {
return;
}
let samples = src.len();
let half_samples = samples >> 1;
let simd_iters = (half_samples >> 1) << 1;
unsafe {
for i in (0..simd_iters).step_by(2) {
let k = i % stride;
let k0 = k;
let k1 = k + 1 - ((k + 1 >= stride) as usize) * stride;
let a_ptr = src.as_ptr().add(i) as *const f32;
let a = vld1q_f32(a_ptr);
let b_ptr = src.as_ptr().add(i + half_samples) as *const f32;
let b = vld1q_f32(b_ptr);
let tw_ptr = stage_twiddles.as_ptr().add(i) as *const f32;
let tw = vld1q_f32(tw_ptr);
let t = complex_mul(b, tw);
let out_top = vaddq_f32(a, t);
let out_bot = vsubq_f32(a, t);
let j0 = (i << 1) - k0;
let j1 = ((i + 1) << 1) - k1;
let dst_ptr = dst.as_mut_ptr() as *mut f32;
let top0 = vget_low_f32(out_top);
let bot0 = vget_low_f32(out_bot);
vst1_f32(dst_ptr.add(j0 << 1), top0);
vst1_f32(dst_ptr.add((j0 + stride) << 1), bot0);
let top1 = vget_high_f32(out_top);
let bot1 = vget_high_f32(out_bot);
vst1_f32(dst_ptr.add(j1 << 1), top1);
vst1_f32(dst_ptr.add((j1 + stride) << 1), bot1);
}
}
super::butterfly_radix2_scalar::<2>(src, dst, stage_twiddles, stride, simd_iters);
}