use proc_macro2::TokenStream;
use quote::quote;
pub(super) fn gen_avx_size_2_f64() -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx")]
unsafe fn codelet_simd_2_avx_f64(data: &mut [f64], _sign: i32) {
use core::arch::x86_64::*;
let ptr = data.as_mut_ptr();
let v = _mm256_loadu_pd(ptr);
let a = _mm256_castpd256_pd128(v); let b = _mm256_extractf128_pd(v, 1);
let sum = _mm_add_pd(a, b);
let diff = _mm_sub_pd(a, b);
let sum_wide = _mm256_castpd128_pd256(sum); let diff_wide = _mm256_castpd128_pd256(diff); let result = _mm256_permute2f128_pd(sum_wide, diff_wide, 0x20); _mm256_storeu_pd(ptr, result);
}
}
}
pub(super) fn gen_avx_size_4_f64() -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx")]
unsafe fn codelet_simd_4_avx_f64(data: &mut [f64], sign: i32) {
use core::arch::x86_64::*;
let ptr = data.as_mut_ptr();
let v_01 = _mm256_loadu_pd(ptr); let v_23 = _mm256_loadu_pd(ptr.add(4));
let sum = _mm256_add_pd(v_01, v_23); let diff = _mm256_sub_pd(v_01, v_23);
let t0 = _mm256_castpd256_pd128(sum); let t2 = _mm256_extractf128_pd(sum, 1); let t1 = _mm256_castpd256_pd128(diff); let t3 = _mm256_extractf128_pd(diff, 1);
let t3_swapped = _mm_shuffle_pd(t3, t3, 0b01); let t3_rot = if sign < 0 {
_mm_xor_pd(t3_swapped, _mm_set_pd(-0.0, 0.0))
} else {
_mm_xor_pd(t3_swapped, _mm_set_pd(0.0, -0.0))
};
let out0 = _mm_add_pd(t0, t2); let out1 = _mm_add_pd(t1, t3_rot); let out2 = _mm_sub_pd(t0, t2); let out3 = _mm_sub_pd(t1, t3_rot);
let v_out_01 = _mm256_permute2f128_pd(
_mm256_castpd128_pd256(out0),
_mm256_castpd128_pd256(out1),
0x20, );
let v_out_23 = _mm256_permute2f128_pd(
_mm256_castpd128_pd256(out2),
_mm256_castpd128_pd256(out3),
0x20, );
_mm256_storeu_pd(ptr, v_out_01);
_mm256_storeu_pd(ptr.add(4), v_out_23);
}
}
}
#[allow(clippy::too_many_lines)]
pub(super) fn gen_avx_size_8_f64() -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx")]
#[allow(clippy::too_many_lines)]
unsafe fn codelet_simd_8_avx_f64(data: &mut [f64], sign: i32) {
use core::arch::x86_64::*;
let ptr = data.as_mut_ptr();
let inv_sqrt2 = _mm_set1_pd(0.707_106_781_186_547_6_f64);
let rotate_pm_i = |v: __m128d, fwd: bool| -> __m128d {
let swapped = _mm_shuffle_pd(v, v, 0b01);
if fwd {
_mm_xor_pd(swapped, _mm_set_pd(-0.0, 0.0))
} else {
_mm_xor_pd(swapped, _mm_set_pd(0.0, -0.0))
}
};
let cmul_no_fma = |v: __m128d, twd: __m128d| -> __m128d {
let c = _mm_shuffle_pd(twd, twd, 0b00); let d = _mm_shuffle_pd(twd, twd, 0b11); let v_re = _mm_shuffle_pd(v, v, 0b00); let v_im = _mm_shuffle_pd(v, v, 0b11); let real = _mm_sub_pd(_mm_mul_pd(v_re, c), _mm_mul_pd(v_im, d));
let imag = _mm_add_pd(_mm_mul_pd(v_re, d), _mm_mul_pd(v_im, c));
_mm_shuffle_pd(real, imag, 0b00) };
let fwd = sign < 0;
let mut a = [_mm_setzero_pd(); 8];
a[0] = _mm_loadu_pd(ptr); a[1] = _mm_loadu_pd(ptr.add(8)); a[2] = _mm_loadu_pd(ptr.add(4)); a[3] = _mm_loadu_pd(ptr.add(12)); a[4] = _mm_loadu_pd(ptr.add(2)); a[5] = _mm_loadu_pd(ptr.add(10)); a[6] = _mm_loadu_pd(ptr.add(6)); a[7] = _mm_loadu_pd(ptr.add(14));
for i in (0..8usize).step_by(2) {
let t = a[i + 1];
a[i + 1] = _mm_sub_pd(a[i], t);
a[i] = _mm_add_pd(a[i], t);
}
for group in (0..8usize).step_by(4) {
let t = a[group + 2];
a[group + 2] = _mm_sub_pd(a[group], t);
a[group] = _mm_add_pd(a[group], t);
let t = a[group + 3];
let t_tw = rotate_pm_i(t, fwd);
a[group + 3] = _mm_sub_pd(a[group + 1], t_tw);
a[group + 1] = _mm_add_pd(a[group + 1], t_tw);
}
{
let t = a[4];
a[4] = _mm_sub_pd(a[0], t);
a[0] = _mm_add_pd(a[0], t);
}
{
let v = a[5];
let swapped = _mm_shuffle_pd(v, v, 0b01);
let t_tw = if fwd {
let sum = _mm_add_pd(v, swapped); let diff = _mm_sub_pd(swapped, v); let combined = _mm_shuffle_pd(sum, diff, 0b00);
_mm_mul_pd(combined, inv_sqrt2)
} else {
let diff = _mm_sub_pd(v, swapped); let sum = _mm_add_pd(v, swapped); let combined = _mm_shuffle_pd(diff, sum, 0b10);
_mm_mul_pd(combined, inv_sqrt2)
};
a[5] = _mm_sub_pd(a[1], t_tw);
a[1] = _mm_add_pd(a[1], t_tw);
}
{
let t = a[6];
let t_tw = rotate_pm_i(t, fwd);
a[6] = _mm_sub_pd(a[2], t_tw);
a[2] = _mm_add_pd(a[2], t_tw);
}
{
let v = a[7];
let swapped = _mm_shuffle_pd(v, v, 0b01);
let t_tw = if fwd {
let t = _mm_sub_pd(swapped, v); let neg_sum = _mm_xor_pd(
_mm_add_pd(v, swapped),
_mm_set1_pd(-0.0),
); let combined = _mm_shuffle_pd(t, neg_sum, 0b00);
_mm_mul_pd(combined, inv_sqrt2)
} else {
let neg_sum = _mm_xor_pd(
_mm_add_pd(v, swapped),
_mm_set1_pd(-0.0),
); let diff = _mm_sub_pd(swapped, v); let combined = _mm_shuffle_pd(neg_sum, diff, 0b10);
_mm_mul_pd(combined, inv_sqrt2)
};
a[7] = _mm_sub_pd(a[3], t_tw);
a[3] = _mm_add_pd(a[3], t_tw);
}
let _ = cmul_no_fma; for i in (0..8usize).step_by(2) {
let packed = _mm256_permute2f128_pd(
_mm256_castpd128_pd256(a[i]),
_mm256_castpd128_pd256(a[i + 1]),
0x20, );
_mm256_storeu_pd(ptr.add(i * 2), packed);
}
}
}
}