use core::arch::aarch64::*;
use super::super::ops::{complex_mul, complex_mul_i, load_neg_imag_mask};
use crate::fft::Complex32;
#[target_feature(enable = "neon")]
pub(super) unsafe fn butterfly_radix4_stride1_neon(
src: &[Complex32],
dst: &mut [Complex32],
stage_twiddles: &[Complex32],
) {
let samples = src.len();
let quarter_samples = samples >> 2;
let simd_iters = (quarter_samples >> 1) << 1;
unsafe {
let neg_imag = load_neg_imag_mask();
for i in (0..simd_iters).step_by(2) {
let z0_ptr = src.as_ptr().add(i) as *const f32;
let z0 = vld1q_f32(z0_ptr);
let z1_ptr = src.as_ptr().add(i + quarter_samples) as *const f32;
let z1 = vld1q_f32(z1_ptr);
let z2_ptr = src.as_ptr().add(i + quarter_samples * 2) as *const f32;
let z2 = vld1q_f32(z2_ptr);
let z3_ptr = src.as_ptr().add(i + quarter_samples * 3) as *const f32;
let z3 = vld1q_f32(z3_ptr);
let t1 = z1;
let t2 = z2;
let t3 = z3;
let a0 = vaddq_f32(z0, t2);
let a1 = vsubq_f32(z0, t2);
let a2 = vaddq_f32(t1, t3);
let t1_sub_t3 = vsubq_f32(t1, t3);
let a3 = complex_mul_i(t1_sub_t3, neg_imag);
let out0 = vaddq_f32(a0, a2);
let out2 = vsubq_f32(a0, a2);
let out1 = vaddq_f32(a1, a3);
let out3 = vsubq_f32(a1, a3);
let j = 4 * i;
let out0_f64 = vreinterpretq_f64_f32(out0);
let out1_f64 = vreinterpretq_f64_f32(out1);
let out2_f64 = vreinterpretq_f64_f32(out2);
let out3_f64 = vreinterpretq_f64_f32(out3);
let pair01_lo = vzip1q_f64(out0_f64, out1_f64); let pair01_hi = vzip2q_f64(out0_f64, out1_f64); let pair23_lo = vzip1q_f64(out2_f64, out3_f64); let pair23_hi = vzip2q_f64(out2_f64, out3_f64);
let pair01_lo_f32 = vreinterpretq_f32_f64(pair01_lo); let pair23_lo_f32 = vreinterpretq_f32_f64(pair23_lo); let pair01_hi_f32 = vreinterpretq_f32_f64(pair01_hi); let pair23_hi_f32 = vreinterpretq_f32_f64(pair23_hi);
let dst_ptr = dst.as_mut_ptr().add(j) as *mut f32;
vst1q_f32(dst_ptr, pair01_lo_f32); vst1q_f32(dst_ptr.add(4), pair23_lo_f32); vst1q_f32(dst_ptr.add(8), pair01_hi_f32); vst1q_f32(dst_ptr.add(12), pair23_hi_f32); }
}
super::butterfly_radix4_scalar::<2>(src, dst, stage_twiddles, 1, simd_iters);
}
#[target_feature(enable = "neon")]
pub(super) unsafe fn butterfly_radix4_generic_neon(
src: &[Complex32],
dst: &mut [Complex32],
stage_twiddles: &[Complex32],
stride: usize,
) {
if stride == 0 {
return;
}
let samples = src.len();
let quarter_samples = samples >> 2;
let simd_iters = (quarter_samples >> 1) << 1;
unsafe {
let neg_imag = load_neg_imag_mask();
for i in (0..simd_iters).step_by(2) {
let k = i % stride;
let k0 = k;
let k1 = k + 1 - ((k + 1 >= stride) as usize) * stride;
let z0_ptr = src.as_ptr().add(i) as *const f32;
let z0 = vld1q_f32(z0_ptr);
let z1_ptr = src.as_ptr().add(i + quarter_samples) as *const f32;
let z1 = vld1q_f32(z1_ptr);
let z2_ptr = src.as_ptr().add(i + quarter_samples * 2) as *const f32;
let z2 = vld1q_f32(z2_ptr);
let z3_ptr = src.as_ptr().add(i + quarter_samples * 3) as *const f32;
let z3 = vld1q_f32(z3_ptr);
let tw_ptr = stage_twiddles.as_ptr().add(i * 3) as *const f32;
let w1 = vld1q_f32(tw_ptr); let w2 = vld1q_f32(tw_ptr.add(4)); let w3 = vld1q_f32(tw_ptr.add(8));
let t1 = complex_mul(z1, w1);
let t2 = complex_mul(z2, w2);
let t3 = complex_mul(z3, w3);
let a0 = vaddq_f32(z0, t2);
let a1 = vsubq_f32(z0, t2);
let a2 = vaddq_f32(t1, t3);
let t1_sub_t3 = vsubq_f32(t1, t3);
let a3 = complex_mul_i(t1_sub_t3, neg_imag);
let out0 = vaddq_f32(a0, a2);
let out2 = vsubq_f32(a0, a2);
let out1 = vaddq_f32(a1, a3);
let out3 = vsubq_f32(a1, a3);
let j0 = 4 * i - 3 * k0;
let j1 = 4 * (i + 1) - 3 * k1;
let dst_ptr = dst.as_mut_ptr() as *mut f32;
let out0_0 = vget_low_f32(out0);
let out1_0 = vget_low_f32(out1);
let out2_0 = vget_low_f32(out2);
let out3_0 = vget_low_f32(out3);
vst1_f32(dst_ptr.add(j0 << 1), out0_0);
vst1_f32(dst_ptr.add((j0 + stride) << 1), out1_0);
vst1_f32(dst_ptr.add((j0 + stride * 2) << 1), out2_0);
vst1_f32(dst_ptr.add((j0 + stride * 3) << 1), out3_0);
let out0_1 = vget_high_f32(out0);
let out1_1 = vget_high_f32(out1);
let out2_1 = vget_high_f32(out2);
let out3_1 = vget_high_f32(out3);
vst1_f32(dst_ptr.add(j1 << 1), out0_1);
vst1_f32(dst_ptr.add((j1 + stride) << 1), out1_1);
vst1_f32(dst_ptr.add((j1 + stride * 2) << 1), out2_1);
vst1_f32(dst_ptr.add((j1 + stride * 3) << 1), out3_1);
}
}
super::butterfly_radix4_scalar::<2>(src, dst, stage_twiddles, stride, simd_iters);
}