use core::arch::aarch64::*;
use super::{
super::ops::{complex_mul, complex_mul_i, load_neg_imag_mask},
COS_2PI_5, COS_4PI_5, SIN_2PI_5, SIN_4PI_5,
};
use crate::fft::Complex32;
#[target_feature(enable = "neon")]
pub(super) unsafe fn butterfly_radix5_stride1_neon(
src: &[Complex32],
dst: &mut [Complex32],
stage_twiddles: &[Complex32],
) {
let samples = src.len();
let fifth_samples = samples / 5;
let simd_iters = (fifth_samples >> 1) << 1;
unsafe {
let cos_2pi_5_vec = vdupq_n_f32(COS_2PI_5);
let sin_2pi_5_vec = vdupq_n_f32(SIN_2PI_5);
let cos_4pi_5_vec = vdupq_n_f32(COS_4PI_5);
let sin_4pi_5_vec = vdupq_n_f32(SIN_4PI_5);
let neg_imag = load_neg_imag_mask();
for i in (0..simd_iters).step_by(2) {
let z0_ptr = src.as_ptr().add(i) as *const f32;
let z0 = vld1q_f32(z0_ptr);
let z1_ptr = src.as_ptr().add(i + fifth_samples) as *const f32;
let z1 = vld1q_f32(z1_ptr);
let z2_ptr = src.as_ptr().add(i + fifth_samples * 2) as *const f32;
let z2 = vld1q_f32(z2_ptr);
let z3_ptr = src.as_ptr().add(i + fifth_samples * 3) as *const f32;
let z3 = vld1q_f32(z3_ptr);
let z4_ptr = src.as_ptr().add(i + fifth_samples * 4) as *const f32;
let z4 = vld1q_f32(z4_ptr);
let t1 = z1;
let t2 = z2;
let t3 = z3;
let t4 = z4;
let sum_all = vaddq_f32(vaddq_f32(vaddq_f32(t1, t2), t3), t4);
let a1 = vaddq_f32(t1, t4);
let a2 = vaddq_f32(t2, t3);
let t1_sub_t4 = vsubq_f32(t1, t4);
let t2_sub_t3 = vsubq_f32(t2, t3);
let b1 = complex_mul_i(t1_sub_t4, neg_imag);
let b2 = complex_mul_i(t2_sub_t3, neg_imag);
let c1 = vaddq_f32(
z0,
vaddq_f32(vmulq_f32(cos_2pi_5_vec, a1), vmulq_f32(cos_4pi_5_vec, a2)),
);
let c2 = vaddq_f32(
z0,
vaddq_f32(vmulq_f32(cos_4pi_5_vec, a1), vmulq_f32(cos_2pi_5_vec, a2)),
);
let d1 = vaddq_f32(vmulq_f32(sin_2pi_5_vec, b1), vmulq_f32(sin_4pi_5_vec, b2));
let d2 = vsubq_f32(vmulq_f32(sin_4pi_5_vec, b1), vmulq_f32(sin_2pi_5_vec, b2));
let out0 = vaddq_f32(z0, sum_all);
let out1 = vaddq_f32(c1, d1);
let out4 = vsubq_f32(c1, d1);
let out2 = vaddq_f32(c2, d2);
let out3 = vsubq_f32(c2, d2);
let j = 5 * i;
let dst_ptr = dst.as_mut_ptr() as *mut f32;
let out0_0 = vget_low_f32(out0);
let out1_0 = vget_low_f32(out1);
let out2_0 = vget_low_f32(out2);
let out3_0 = vget_low_f32(out3);
let out4_0 = vget_low_f32(out4);
vst1_f32(dst_ptr.add(j * 2), out0_0);
vst1_f32(dst_ptr.add((j + 1) * 2), out1_0);
vst1_f32(dst_ptr.add((j + 2) * 2), out2_0);
vst1_f32(dst_ptr.add((j + 3) * 2), out3_0);
vst1_f32(dst_ptr.add((j + 4) * 2), out4_0);
let out0_1 = vget_high_f32(out0);
let out1_1 = vget_high_f32(out1);
let out2_1 = vget_high_f32(out2);
let out3_1 = vget_high_f32(out3);
let out4_1 = vget_high_f32(out4);
let j1 = 5 * (i + 1);
vst1_f32(dst_ptr.add(j1 * 2), out0_1);
vst1_f32(dst_ptr.add((j1 + 1) * 2), out1_1);
vst1_f32(dst_ptr.add((j1 + 2) * 2), out2_1);
vst1_f32(dst_ptr.add((j1 + 3) * 2), out3_1);
vst1_f32(dst_ptr.add((j1 + 4) * 2), out4_1);
}
}
super::butterfly_radix5_scalar::<2>(src, dst, stage_twiddles, 1, simd_iters);
}
#[target_feature(enable = "neon")]
pub(super) unsafe fn butterfly_radix5_generic_neon(
src: &[Complex32],
dst: &mut [Complex32],
stage_twiddles: &[Complex32],
stride: usize,
) {
if stride == 0 {
return;
}
let samples = src.len();
let fifth_samples = samples / 5;
let simd_iters = (fifth_samples >> 1) << 1;
unsafe {
let cos_2pi_5_vec = vdupq_n_f32(COS_2PI_5);
let sin_2pi_5_vec = vdupq_n_f32(SIN_2PI_5);
let cos_4pi_5_vec = vdupq_n_f32(COS_4PI_5);
let sin_4pi_5_vec = vdupq_n_f32(SIN_4PI_5);
let neg_imag = load_neg_imag_mask();
for i in (0..simd_iters).step_by(2) {
let k = i % stride;
let k0 = k;
let k1 = k + 1 - ((k + 1 >= stride) as usize) * stride;
let z0_ptr = src.as_ptr().add(i) as *const f32;
let z0 = vld1q_f32(z0_ptr);
let z1_ptr = src.as_ptr().add(i + fifth_samples) as *const f32;
let z1 = vld1q_f32(z1_ptr);
let z2_ptr = src.as_ptr().add(i + fifth_samples * 2) as *const f32;
let z2 = vld1q_f32(z2_ptr);
let z3_ptr = src.as_ptr().add(i + fifth_samples * 3) as *const f32;
let z3 = vld1q_f32(z3_ptr);
let z4_ptr = src.as_ptr().add(i + fifth_samples * 4) as *const f32;
let z4 = vld1q_f32(z4_ptr);
let tw_ptr = stage_twiddles.as_ptr().add(i * 4) as *const f32;
let w1 = vld1q_f32(tw_ptr); let w2 = vld1q_f32(tw_ptr.add(4)); let w3 = vld1q_f32(tw_ptr.add(8)); let w4 = vld1q_f32(tw_ptr.add(12));
let t1 = complex_mul(z1, w1);
let t2 = complex_mul(z2, w2);
let t3 = complex_mul(z3, w3);
let t4 = complex_mul(z4, w4);
let sum_all = vaddq_f32(vaddq_f32(vaddq_f32(t1, t2), t3), t4);
let a1 = vaddq_f32(t1, t4);
let a2 = vaddq_f32(t2, t3);
let t1_sub_t4 = vsubq_f32(t1, t4);
let t2_sub_t3 = vsubq_f32(t2, t3);
let b1 = complex_mul_i(t1_sub_t4, neg_imag);
let b2 = complex_mul_i(t2_sub_t3, neg_imag);
let c1 = vaddq_f32(
z0,
vaddq_f32(vmulq_f32(cos_2pi_5_vec, a1), vmulq_f32(cos_4pi_5_vec, a2)),
);
let c2 = vaddq_f32(
z0,
vaddq_f32(vmulq_f32(cos_4pi_5_vec, a1), vmulq_f32(cos_2pi_5_vec, a2)),
);
let d1 = vaddq_f32(vmulq_f32(sin_2pi_5_vec, b1), vmulq_f32(sin_4pi_5_vec, b2));
let d2 = vsubq_f32(vmulq_f32(sin_4pi_5_vec, b1), vmulq_f32(sin_2pi_5_vec, b2));
let out0 = vaddq_f32(z0, sum_all);
let out1 = vaddq_f32(c1, d1);
let out4 = vsubq_f32(c1, d1);
let out2 = vaddq_f32(c2, d2);
let out3 = vsubq_f32(c2, d2);
let j0 = 5 * i - 4 * k0;
let j1 = 5 * (i + 1) - 4 * k1;
let dst_ptr = dst.as_mut_ptr() as *mut f32;
let out0_0 = vget_low_f32(out0);
let out1_0 = vget_low_f32(out1);
let out2_0 = vget_low_f32(out2);
let out3_0 = vget_low_f32(out3);
let out4_0 = vget_low_f32(out4);
vst1_f32(dst_ptr.add(j0 << 1), out0_0);
vst1_f32(dst_ptr.add((j0 + stride) << 1), out1_0);
vst1_f32(dst_ptr.add((j0 + stride * 2) << 1), out2_0);
vst1_f32(dst_ptr.add((j0 + stride * 3) << 1), out3_0);
vst1_f32(dst_ptr.add((j0 + stride * 4) << 1), out4_0);
let out0_1 = vget_high_f32(out0);
let out1_1 = vget_high_f32(out1);
let out2_1 = vget_high_f32(out2);
let out3_1 = vget_high_f32(out3);
let out4_1 = vget_high_f32(out4);
vst1_f32(dst_ptr.add(j1 << 1), out0_1);
vst1_f32(dst_ptr.add((j1 + stride) << 1), out1_1);
vst1_f32(dst_ptr.add((j1 + stride * 2) << 1), out2_1);
vst1_f32(dst_ptr.add((j1 + stride * 3) << 1), out3_1);
vst1_f32(dst_ptr.add((j1 + stride * 4) << 1), out4_1);
}
}
super::butterfly_radix5_scalar::<2>(src, dst, stage_twiddles, stride, simd_iters);
}