use core::arch::x86_64::*;
use crate::fft::{
Complex32,
butterflies::ops::{complex_mul_avx, complex_mul_i_avx, load_neg_imag_mask_avx},
};
#[target_feature(enable = "avx,fma")]
pub(super) unsafe fn butterfly_radix4_stride1_avx_fma(
src: &[Complex32],
dst: &mut [Complex32],
stage_twiddles: &[Complex32],
) {
let samples = src.len();
let quarter_samples = samples >> 2;
let simd_iters = (quarter_samples >> 2) << 2;
unsafe {
let neg_imag_mask = load_neg_imag_mask_avx();
for i in (0..simd_iters).step_by(4) {
let z0_ptr = src.as_ptr().add(i) as *const f32;
let z0 = _mm256_loadu_ps(z0_ptr);
let z1_ptr = src.as_ptr().add(i + quarter_samples) as *const f32;
let z1 = _mm256_loadu_ps(z1_ptr);
let z2_ptr = src.as_ptr().add(i + quarter_samples * 2) as *const f32;
let z2 = _mm256_loadu_ps(z2_ptr);
let z3_ptr = src.as_ptr().add(i + quarter_samples * 3) as *const f32;
let z3 = _mm256_loadu_ps(z3_ptr);
let t1 = z1;
let t2 = z2;
let t3 = z3;
let a0 = _mm256_add_ps(z0, t2);
let a1 = _mm256_sub_ps(z0, t2);
let a2 = _mm256_add_ps(t1, t3);
let t1_sub_t3 = _mm256_sub_ps(t1, t3);
let a3 = complex_mul_i_avx(t1_sub_t3, neg_imag_mask);
let out0 = _mm256_add_ps(a0, a2);
let out2 = _mm256_sub_ps(a0, a2);
let out1 = _mm256_add_ps(a1, a3);
let out3 = _mm256_sub_ps(a1, a3);
let out0_pd = _mm256_castps_pd(out0);
let out1_pd = _mm256_castps_pd(out1);
let out2_pd = _mm256_castps_pd(out2);
let out3_pd = _mm256_castps_pd(out3);
let out01_lo = _mm256_castpd_ps(_mm256_unpacklo_pd(out0_pd, out1_pd));
let out01_hi = _mm256_castpd_ps(_mm256_unpackhi_pd(out0_pd, out1_pd));
let out23_lo = _mm256_castpd_ps(_mm256_unpacklo_pd(out2_pd, out3_pd));
let out23_hi = _mm256_castpd_ps(_mm256_unpackhi_pd(out2_pd, out3_pd));
let out0123_0 = _mm256_permute2f128_ps(out01_lo, out23_lo, 0x20);
let out0123_1 = _mm256_permute2f128_ps(out01_hi, out23_hi, 0x20);
let out0123_2 = _mm256_permute2f128_ps(out01_lo, out23_lo, 0x31);
let out0123_3 = _mm256_permute2f128_ps(out01_hi, out23_hi, 0x31);
let j = 4 * i;
let dst_ptr = dst.as_mut_ptr().add(j) as *mut f32;
_mm256_storeu_ps(dst_ptr, out0123_0);
_mm256_storeu_ps(dst_ptr.add(8), out0123_1);
_mm256_storeu_ps(dst_ptr.add(16), out0123_2);
_mm256_storeu_ps(dst_ptr.add(24), out0123_3);
}
}
super::butterfly_radix4_scalar::<4>(src, dst, stage_twiddles, 1, simd_iters);
}
#[target_feature(enable = "avx,fma")]
pub(super) unsafe fn butterfly_radix4_generic_avx_fma(
src: &[Complex32],
dst: &mut [Complex32],
stage_twiddles: &[Complex32],
stride: usize,
) {
if stride == 0 {
return;
}
let samples = src.len();
let quarter_samples = samples >> 2;
let simd_iters = (quarter_samples >> 2) << 2;
unsafe {
let neg_imag_mask = load_neg_imag_mask_avx();
for i in (0..simd_iters).step_by(4) {
let k = i % stride;
let k0 = k;
let k1 = k + 1 - ((k + 1 >= stride) as usize) * stride;
let k2 = k + 2 - ((k + 2 >= stride) as usize) * stride;
let k3 = k + 3 - ((k + 3 >= stride) as usize) * stride;
let z0_ptr = src.as_ptr().add(i) as *const f32;
let z0 = _mm256_loadu_ps(z0_ptr);
let z1_ptr = src.as_ptr().add(i + quarter_samples) as *const f32;
let z1 = _mm256_loadu_ps(z1_ptr);
let z2_ptr = src.as_ptr().add(i + quarter_samples * 2) as *const f32;
let z2 = _mm256_loadu_ps(z2_ptr);
let z3_ptr = src.as_ptr().add(i + quarter_samples * 3) as *const f32;
let z3 = _mm256_loadu_ps(z3_ptr);
let tw_ptr = stage_twiddles.as_ptr().add(i * 3) as *const f32;
let w1 = _mm256_loadu_ps(tw_ptr); let w2 = _mm256_loadu_ps(tw_ptr.add(8)); let w3 = _mm256_loadu_ps(tw_ptr.add(16));
let t1 = complex_mul_avx(w1, z1);
let t2 = complex_mul_avx(w2, z2);
let t3 = complex_mul_avx(w3, z3);
let a0 = _mm256_add_ps(z0, t2);
let a1 = _mm256_sub_ps(z0, t2);
let a2 = _mm256_add_ps(t1, t3);
let t1_sub_t3 = _mm256_sub_ps(t1, t3);
let a3 = complex_mul_i_avx(t1_sub_t3, neg_imag_mask);
let out0 = _mm256_add_ps(a0, a2);
let out2 = _mm256_sub_ps(a0, a2);
let out1 = _mm256_add_ps(a1, a3);
let out3 = _mm256_sub_ps(a1, a3);
let j0 = 4 * i - 3 * k0;
let j1 = 4 * (i + 1) - 3 * k1;
let j2 = 4 * (i + 2) - 3 * k2;
let j3 = 4 * (i + 3) - 3 * k3;
let out0_pd = _mm256_castps_pd(out0);
let out1_pd = _mm256_castps_pd(out1);
let out2_pd = _mm256_castps_pd(out2);
let out3_pd = _mm256_castps_pd(out3);
let dst_ptr = dst.as_mut_ptr() as *mut f64;
let out0_lo = _mm256_castpd256_pd128(out0_pd);
let out0_hi = _mm256_extractf128_pd(out0_pd, 1);
let out1_lo = _mm256_castpd256_pd128(out1_pd);
let out1_hi = _mm256_extractf128_pd(out1_pd, 1);
let out2_lo = _mm256_castpd256_pd128(out2_pd);
let out2_hi = _mm256_extractf128_pd(out2_pd, 1);
let out3_lo = _mm256_castpd256_pd128(out3_pd);
let out3_hi = _mm256_extractf128_pd(out3_pd, 1);
_mm_storel_pd(dst_ptr.add(j0), out0_lo);
_mm_storel_pd(dst_ptr.add(j0 + stride), out1_lo);
_mm_storel_pd(dst_ptr.add(j0 + stride * 2), out2_lo);
_mm_storel_pd(dst_ptr.add(j0 + stride * 3), out3_lo);
_mm_storeh_pd(dst_ptr.add(j1), out0_lo);
_mm_storeh_pd(dst_ptr.add(j1 + stride), out1_lo);
_mm_storeh_pd(dst_ptr.add(j1 + stride * 2), out2_lo);
_mm_storeh_pd(dst_ptr.add(j1 + stride * 3), out3_lo);
_mm_storel_pd(dst_ptr.add(j2), out0_hi);
_mm_storel_pd(dst_ptr.add(j2 + stride), out1_hi);
_mm_storel_pd(dst_ptr.add(j2 + stride * 2), out2_hi);
_mm_storel_pd(dst_ptr.add(j2 + stride * 3), out3_hi);
_mm_storeh_pd(dst_ptr.add(j3), out0_hi);
_mm_storeh_pd(dst_ptr.add(j3 + stride), out1_hi);
_mm_storeh_pd(dst_ptr.add(j3 + stride * 2), out2_hi);
_mm_storeh_pd(dst_ptr.add(j3 + stride * 3), out3_hi);
}
}
super::butterfly_radix4_scalar::<4>(src, dst, stage_twiddles, stride, simd_iters);
}