use super::SQRT3_2;
use crate::fft::{
Complex32,
butterflies::ops::{complex_mul_sqrt3_i_sse4_2, complex_mul_sse4_2},
};
#[target_feature(enable = "sse4.2")]
pub(super) unsafe fn butterfly_radix3_stride1_sse4_2(
src: &[Complex32],
dst: &mut [Complex32],
stage_twiddles: &[Complex32],
) {
use core::arch::x86_64::*;
let samples = src.len();
let third_samples = samples / 3;
let simd_iters = (third_samples >> 1) << 1;
unsafe {
let half_vec = _mm_set1_ps(0.5);
for i in (0..simd_iters).step_by(2) {
let z0_ptr = src.as_ptr().add(i) as *const f32;
let z0 = _mm_loadu_ps(z0_ptr);
let z1_ptr = src.as_ptr().add(i + third_samples) as *const f32;
let z1 = _mm_loadu_ps(z1_ptr);
let z2_ptr = src.as_ptr().add(i + third_samples * 2) as *const f32;
let z2 = _mm_loadu_ps(z2_ptr);
let t1 = z1;
let t2 = z2;
let sum_t = _mm_add_ps(t1, t2);
let diff_t = _mm_sub_ps(t1, t2);
let out0 = _mm_add_ps(z0, sum_t);
let half_sum_t = _mm_mul_ps(sum_t, half_vec);
let re_im_part = _mm_sub_ps(z0, half_sum_t);
let sqrt3_diff = complex_mul_sqrt3_i_sse4_2(diff_t, SQRT3_2);
let out1 = _mm_add_ps(re_im_part, sqrt3_diff);
let out2 = _mm_sub_ps(re_im_part, sqrt3_diff);
let j = 6 * i; let dst_ptr = dst.as_mut_ptr() as *mut f32;
let out0_pd = _mm_castps_pd(out0);
let out1_pd = _mm_castps_pd(out1);
let out2_pd = _mm_castps_pd(out2);
let out01_lo = _mm_castpd_ps(_mm_unpacklo_pd(out0_pd, out1_pd)); let out20_cross = _mm_castpd_ps(_mm_shuffle_pd(out2_pd, out0_pd, 0b10)); let out12_hi = _mm_castpd_ps(_mm_unpackhi_pd(out1_pd, out2_pd));
_mm_storeu_ps(dst_ptr.add(j), out01_lo);
_mm_storeu_ps(dst_ptr.add(j + 4), out20_cross);
_mm_storeu_ps(dst_ptr.add(j + 8), out12_hi);
}
}
super::butterfly_radix3_scalar::<2>(src, dst, stage_twiddles, 1, simd_iters);
}
#[target_feature(enable = "sse4.2")]
pub(super) unsafe fn butterfly_radix3_generic_sse4_2(
src: &[Complex32],
dst: &mut [Complex32],
stage_twiddles: &[Complex32],
stride: usize,
) {
use core::arch::x86_64::*;
let samples = src.len();
let third_samples = samples / 3;
let simd_iters = (third_samples >> 1) << 1;
unsafe {
let half_vec = _mm_set1_ps(0.5);
for i in (0..simd_iters).step_by(2) {
let k = i % stride;
let k0 = k;
let k1 = k + 1 - ((k + 1 >= stride) as usize) * stride;
let z0_ptr = src.as_ptr().add(i) as *const f32;
let z0 = _mm_loadu_ps(z0_ptr);
let z1_ptr = src.as_ptr().add(i + third_samples) as *const f32;
let z1 = _mm_loadu_ps(z1_ptr);
let z2_ptr = src.as_ptr().add(i + third_samples * 2) as *const f32;
let z2 = _mm_loadu_ps(z2_ptr);
let tw_ptr = stage_twiddles.as_ptr().add(i * 2) as *const f32;
let w1 = _mm_loadu_ps(tw_ptr); let w2 = _mm_loadu_ps(tw_ptr.add(4));
let t1 = complex_mul_sse4_2(w1, z1);
let t2 = complex_mul_sse4_2(w2, z2);
let sum_t = _mm_add_ps(t1, t2);
let diff_t = _mm_sub_ps(t1, t2);
let out0 = _mm_add_ps(z0, sum_t);
let half_sum_t = _mm_mul_ps(sum_t, half_vec);
let re_im_part = _mm_sub_ps(z0, half_sum_t);
let sqrt3_diff = complex_mul_sqrt3_i_sse4_2(diff_t, SQRT3_2);
let out1 = _mm_add_ps(re_im_part, sqrt3_diff);
let out2 = _mm_sub_ps(re_im_part, sqrt3_diff);
let j0 = 3 * i - 2 * k0;
let j1 = 3 * (i + 1) - 2 * k1;
let out0_pd = _mm_castps_pd(out0);
let out1_pd = _mm_castps_pd(out1);
let out2_pd = _mm_castps_pd(out2);
let dst_ptr = dst.as_mut_ptr() as *mut f64;
_mm_storel_pd(dst_ptr.add(j0), out0_pd);
_mm_storel_pd(dst_ptr.add(j0 + stride), out1_pd);
_mm_storel_pd(dst_ptr.add(j0 + stride * 2), out2_pd);
_mm_storeh_pd(dst_ptr.add(j1), out0_pd);
_mm_storeh_pd(dst_ptr.add(j1 + stride), out1_pd);
_mm_storeh_pd(dst_ptr.add(j1 + stride * 2), out2_pd);
}
}
super::butterfly_radix3_scalar::<2>(src, dst, stage_twiddles, stride, simd_iters);
}