use core::arch::aarch64::*;
use super::{super::ops::complex_mul, SQRT3_2};
use crate::fft::Complex32;
#[target_feature(enable = "neon")]
pub(super) unsafe fn butterfly_radix3_stride1_neon(
src: &[Complex32],
dst: &mut [Complex32],
stage_twiddles: &[Complex32],
) {
let samples = src.len();
let third_samples = samples / 3;
let simd_iters = (third_samples >> 1) << 1;
#[repr(align(16))]
struct AlignedPattern([f32; 4]);
const SQRT3_PATTERN: AlignedPattern = AlignedPattern([SQRT3_2, -SQRT3_2, SQRT3_2, -SQRT3_2]);
unsafe {
let half_vec = vdupq_n_f32(0.5);
let sqrt3_signs = vld1q_f32(SQRT3_PATTERN.0.as_ptr());
for i in (0..simd_iters).step_by(2) {
let z0_ptr = src.as_ptr().add(i) as *const f32;
let z0 = vld1q_f32(z0_ptr);
let z1_ptr = src.as_ptr().add(i + third_samples) as *const f32;
let z1 = vld1q_f32(z1_ptr);
let z2_ptr = src.as_ptr().add(i + third_samples * 2) as *const f32;
let z2 = vld1q_f32(z2_ptr);
let t1 = z1;
let t2 = z2;
let sum_t = vaddq_f32(t1, t2);
let diff_t = vsubq_f32(t1, t2);
let out0 = vaddq_f32(z0, sum_t);
let half_sum_t = vmulq_f32(sum_t, half_vec);
let re_im_part = vsubq_f32(z0, half_sum_t);
let diff_t_swap = vrev64q_f32(diff_t);
let sqrt3_diff = vmulq_f32(diff_t_swap, sqrt3_signs);
let out1 = vaddq_f32(re_im_part, sqrt3_diff);
let out2 = vsubq_f32(re_im_part, sqrt3_diff);
let j = 3 * i;
let dst_ptr = dst.as_mut_ptr() as *mut f32;
let out0_0 = vget_low_f32(out0);
let out1_0 = vget_low_f32(out1);
let out2_0 = vget_low_f32(out2);
vst1_f32(dst_ptr.add(j * 2), out0_0);
vst1_f32(dst_ptr.add((j + 1) * 2), out1_0);
vst1_f32(dst_ptr.add((j + 2) * 2), out2_0);
let out0_1 = vget_high_f32(out0);
let out1_1 = vget_high_f32(out1);
let out2_1 = vget_high_f32(out2);
let j1 = 3 * (i + 1);
vst1_f32(dst_ptr.add(j1 * 2), out0_1);
vst1_f32(dst_ptr.add((j1 + 1) * 2), out1_1);
vst1_f32(dst_ptr.add((j1 + 2) * 2), out2_1);
}
}
super::butterfly_radix3_scalar::<2>(src, dst, stage_twiddles, 1, simd_iters);
}
#[target_feature(enable = "neon")]
pub(super) unsafe fn butterfly_radix3_generic_neon(
src: &[Complex32],
dst: &mut [Complex32],
stage_twiddles: &[Complex32],
stride: usize,
) {
if stride == 0 {
return;
}
let samples = src.len();
let third_samples = samples / 3;
let simd_iters = (third_samples >> 1) << 1;
#[repr(align(16))]
struct AlignedPattern([f32; 4]);
const SQRT3_PATTERN: AlignedPattern = AlignedPattern([SQRT3_2, -SQRT3_2, SQRT3_2, -SQRT3_2]);
unsafe {
let half_vec = vdupq_n_f32(0.5);
let sqrt3_signs = vld1q_f32(SQRT3_PATTERN.0.as_ptr());
for i in (0..simd_iters).step_by(2) {
let k = i % stride;
let k0 = k;
let k1 = k + 1 - ((k + 1 >= stride) as usize) * stride;
let z0_ptr = src.as_ptr().add(i) as *const f32;
let z0 = vld1q_f32(z0_ptr);
let z1_ptr = src.as_ptr().add(i + third_samples) as *const f32;
let z1 = vld1q_f32(z1_ptr);
let z2_ptr = src.as_ptr().add(i + third_samples * 2) as *const f32;
let z2 = vld1q_f32(z2_ptr);
let tw_ptr = stage_twiddles.as_ptr().add(i * 2) as *const f32;
let w1 = vld1q_f32(tw_ptr); let w2 = vld1q_f32(tw_ptr.add(4));
let t1 = complex_mul(z1, w1);
let t2 = complex_mul(z2, w2);
let sum_t = vaddq_f32(t1, t2);
let diff_t = vsubq_f32(t1, t2);
let out0 = vaddq_f32(z0, sum_t);
let half_sum_t = vmulq_f32(sum_t, half_vec);
let re_im_part = vsubq_f32(z0, half_sum_t);
let diff_t_swap = vrev64q_f32(diff_t);
let sqrt3_diff = vmulq_f32(diff_t_swap, sqrt3_signs);
let out1 = vaddq_f32(re_im_part, sqrt3_diff);
let out2 = vsubq_f32(re_im_part, sqrt3_diff);
let j0 = 3 * i - 2 * k0;
let j1 = 3 * (i + 1) - 2 * k1;
let dst_ptr = dst.as_mut_ptr() as *mut f32;
let out0_0 = vget_low_f32(out0);
let out1_0 = vget_low_f32(out1);
let out2_0 = vget_low_f32(out2);
vst1_f32(dst_ptr.add(j0 << 1), out0_0);
vst1_f32(dst_ptr.add((j0 + stride) << 1), out1_0);
vst1_f32(dst_ptr.add((j0 + stride * 2) << 1), out2_0);
let out0_1 = vget_high_f32(out0);
let out1_1 = vget_high_f32(out1);
let out2_1 = vget_high_f32(out2);
vst1_f32(dst_ptr.add(j1 << 1), out0_1);
vst1_f32(dst_ptr.add((j1 + stride) << 1), out1_1);
vst1_f32(dst_ptr.add((j1 + stride * 2) << 1), out2_1);
}
}
super::butterfly_radix3_scalar::<2>(src, dst, stage_twiddles, stride, simd_iters);
}