use proc_macro2::TokenStream;
use quote::quote;
pub(super) fn gen_avx2_size_2() -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn codelet_simd_2_avx2_f64(data: &mut [f64], _sign: i32) {
use core::arch::x86_64::*;
let ptr = data.as_mut_ptr();
let v = _mm256_loadu_pd(ptr);
let a = _mm256_castpd256_pd128(v); let b = _mm256_extractf128_pd(v, 1);
let sum = _mm_add_pd(a, b); let diff = _mm_sub_pd(a, b);
let result = _mm256_permute2f128_pd(
_mm256_castpd128_pd256(sum),
_mm256_castpd128_pd256(diff),
0x20, );
_mm256_storeu_pd(ptr, result);
}
}
}
pub(super) fn gen_avx2_size_4() -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn codelet_simd_4_avx2_f64(data: &mut [f64], sign: i32) {
use core::arch::x86_64::*;
let ptr = data.as_mut_ptr();
let v_01 = _mm256_loadu_pd(ptr); let v_23 = _mm256_loadu_pd(ptr.add(4));
let sum = _mm256_add_pd(v_01, v_23);
let diff = _mm256_sub_pd(v_01, v_23);
let t0 = _mm256_castpd256_pd128(sum); let t2 = _mm256_extractf128_pd(sum, 1);
let t1 = _mm256_castpd256_pd128(diff); let t3 = _mm256_extractf128_pd(diff, 1);
let t3_swapped = _mm_shuffle_pd(t3, t3, 0b01); let t3_rot = if sign < 0 {
_mm_xor_pd(t3_swapped, _mm_set_pd(-0.0, 0.0))
} else {
_mm_xor_pd(t3_swapped, _mm_set_pd(0.0, -0.0))
};
let out0 = _mm_add_pd(t0, t2); let out1 = _mm_add_pd(t1, t3_rot); let out2 = _mm_sub_pd(t0, t2); let out3 = _mm_sub_pd(t1, t3_rot);
let v_out_01 = _mm256_permute2f128_pd(
_mm256_castpd128_pd256(out0),
_mm256_castpd128_pd256(out1),
0x20, );
let v_out_23 = _mm256_permute2f128_pd(
_mm256_castpd128_pd256(out2),
_mm256_castpd128_pd256(out3),
0x20, );
_mm256_storeu_pd(ptr, v_out_01);
_mm256_storeu_pd(ptr.add(4), v_out_23);
}
}
}
pub(super) fn gen_avx2_size_8() -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[allow(clippy::too_many_lines)]
unsafe fn codelet_simd_8_avx2_f64(data: &mut [f64], sign: i32) {
use core::arch::x86_64::*;
let ptr = data.as_mut_ptr();
let inv_sqrt2 = _mm_set1_pd(0.707_106_781_186_547_6_f64);
let rotate_pm_i_sse = |v: __m128d, fwd: bool| -> __m128d {
let swapped = _mm_shuffle_pd(v, v, 0b01);
if fwd {
_mm_xor_pd(swapped, _mm_set_pd(-0.0, 0.0))
} else {
_mm_xor_pd(swapped, _mm_set_pd(0.0, -0.0))
}
};
let fwd = sign < 0;
let mut a = [_mm_setzero_pd(); 8];
a[0] = _mm_loadu_pd(ptr); a[1] = _mm_loadu_pd(ptr.add(8)); a[2] = _mm_loadu_pd(ptr.add(4)); a[3] = _mm_loadu_pd(ptr.add(12)); a[4] = _mm_loadu_pd(ptr.add(2)); a[5] = _mm_loadu_pd(ptr.add(10)); a[6] = _mm_loadu_pd(ptr.add(6)); a[7] = _mm_loadu_pd(ptr.add(14));
for i in (0..8usize).step_by(2) {
let t = a[i + 1];
a[i + 1] = _mm_sub_pd(a[i], t);
a[i] = _mm_add_pd(a[i], t);
}
for group in (0..8usize).step_by(4) {
let t = a[group + 2];
a[group + 2] = _mm_sub_pd(a[group], t);
a[group] = _mm_add_pd(a[group], t);
let t = a[group + 3];
let t_tw = rotate_pm_i_sse(t, fwd);
a[group + 3] = _mm_sub_pd(a[group + 1], t_tw);
a[group + 1] = _mm_add_pd(a[group + 1], t_tw);
}
let t = a[4];
a[4] = _mm_sub_pd(a[0], t);
a[0] = _mm_add_pd(a[0], t);
{
let v = a[5];
let swapped = _mm_shuffle_pd(v, v, 0b01);
let t_tw = if fwd {
let sum = _mm_add_pd(v, swapped);
let diff = _mm_sub_pd(swapped, v);
let combined = _mm_shuffle_pd(sum, diff, 0b00);
_mm_mul_pd(combined, inv_sqrt2)
} else {
let diff = _mm_sub_pd(v, swapped);
let sum = _mm_add_pd(v, swapped);
let combined = _mm_shuffle_pd(diff, sum, 0b10);
_mm_mul_pd(combined, inv_sqrt2)
};
a[5] = _mm_sub_pd(a[1], t_tw);
a[1] = _mm_add_pd(a[1], t_tw);
}
{
let t = a[6];
let t_tw = rotate_pm_i_sse(t, fwd);
a[6] = _mm_sub_pd(a[2], t_tw);
a[2] = _mm_add_pd(a[2], t_tw);
}
{
let v = a[7];
let swapped = _mm_shuffle_pd(v, v, 0b01);
let t_tw = if fwd {
let t = _mm_sub_pd(swapped, v);
let neg_sum = _mm_xor_pd(
_mm_add_pd(v, swapped),
_mm_set1_pd(-0.0),
);
let combined = _mm_shuffle_pd(t, neg_sum, 0b00);
_mm_mul_pd(combined, inv_sqrt2)
} else {
let neg_sum = _mm_xor_pd(
_mm_add_pd(v, swapped),
_mm_set1_pd(-0.0),
);
let diff = _mm_sub_pd(swapped, v);
let combined = _mm_shuffle_pd(neg_sum, diff, 0b10);
_mm_mul_pd(combined, inv_sqrt2)
};
a[7] = _mm_sub_pd(a[3], t_tw);
a[3] = _mm_add_pd(a[3], t_tw);
}
for i in (0..8usize).step_by(2) {
let packed = _mm256_permute2f128_pd(
_mm256_castpd128_pd256(a[i]),
_mm256_castpd128_pd256(a[i + 1]),
0x20, );
_mm256_storeu_pd(ptr.add(i * 2), packed);
}
}
}
}
pub(super) fn gen_avx2_size_2_f32() -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn codelet_simd_2_avx2_f32(data: &mut [f32], _sign: i32) {
use core::arch::x86_64::*;
let ptr = data.as_mut_ptr();
let v = _mm_loadu_ps(ptr);
let a = _mm_shuffle_ps(v, v, 0b01_00_01_00);
let b = _mm_shuffle_ps(v, v, 0b11_10_11_10);
let sum = _mm_add_ps(a, b);
let diff = _mm_sub_ps(a, b);
let out = _mm_shuffle_ps(sum, diff, 0b01_00_01_00);
_mm_storeu_ps(ptr, out);
}
}
}
pub(super) fn gen_avx2_size_4_f32() -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn codelet_simd_4_avx2_f32(data: &mut [f32], sign: i32) {
use core::arch::x86_64::*;
let ptr = data.as_mut_ptr();
let all = _mm256_loadu_ps(ptr);
let v_01 = _mm256_castps256_ps128(all);
let v_23 = _mm256_extractf128_ps(all, 1);
let sum = _mm_add_ps(v_01, v_23); let diff = _mm_sub_ps(v_01, v_23);
let t0 = _mm_shuffle_ps(sum, sum, 0b01_00_01_00);
let t2 = _mm_shuffle_ps(sum, sum, 0b11_10_11_10);
let t1 = _mm_shuffle_ps(diff, diff, 0b01_00_01_00);
let t3 = _mm_shuffle_ps(diff, diff, 0b11_10_11_10);
let t3_swapped = _mm_shuffle_ps(t3, t3, 0b00_01_00_01);
let t3_rot = if sign < 0 {
let mask = _mm_set_ps(-0.0, 0.0, -0.0, 0.0);
_mm_xor_ps(t3_swapped, mask)
} else {
let mask = _mm_set_ps(0.0, -0.0, 0.0, -0.0);
_mm_xor_ps(t3_swapped, mask)
};
let out0 = _mm_add_ps(t0, t2);
let out1 = _mm_add_ps(t1, t3_rot);
let out2 = _mm_sub_ps(t0, t2);
let out3 = _mm_sub_ps(t1, t3_rot);
let packed_01 = _mm_shuffle_ps(out0, out1, 0b01_00_01_00);
let packed_23 = _mm_shuffle_ps(out2, out3, 0b01_00_01_00);
let result = _mm256_insertf128_ps(
_mm256_castps128_ps256(packed_01),
packed_23,
1,
);
_mm256_storeu_ps(ptr, result);
}
}
}
#[allow(clippy::too_many_lines)]
pub(super) fn gen_avx2_size_8_f32() -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[allow(clippy::too_many_lines)]
unsafe fn codelet_simd_8_avx2_f32(data: &mut [f32], sign: i32) {
use core::arch::x86_64::*;
let ptr = data.as_mut_ptr();
let inv_sqrt2 = _mm_set1_ps(0.707_106_8_f32);
let load_cx = |base: *const f32| -> __m128 {
let v = _mm_castsi128_ps(_mm_loadl_epi64(base.cast::<__m128i>()));
_mm_shuffle_ps(v, v, 0b01_00_01_00)
};
let store_cx = |base: *mut f32, v: __m128| {
_mm_storel_epi64(base.cast::<__m128i>(), _mm_castps_si128(v));
};
let rotate_pm_i_ps = |v: __m128, fwd: bool| -> __m128 {
let sw = _mm_shuffle_ps(v, v, 0b00_01_00_01);
if fwd {
let mask = _mm_set_ps(-0.0, 0.0, -0.0, 0.0);
_mm_xor_ps(sw, mask)
} else {
let mask = _mm_set_ps(0.0, -0.0, 0.0, -0.0);
_mm_xor_ps(sw, mask)
}
};
let fwd = sign < 0;
let mut a = [_mm_setzero_ps(); 8];
a[0] = load_cx(ptr); a[1] = load_cx(ptr.add(8)); a[2] = load_cx(ptr.add(4)); a[3] = load_cx(ptr.add(12)); a[4] = load_cx(ptr.add(2)); a[5] = load_cx(ptr.add(10)); a[6] = load_cx(ptr.add(6)); a[7] = load_cx(ptr.add(14));
for i in (0..8usize).step_by(2) {
let t = a[i + 1];
a[i + 1] = _mm_sub_ps(a[i], t);
a[i] = _mm_add_ps(a[i], t);
}
for group in (0..8usize).step_by(4) {
let t = a[group + 2];
a[group + 2] = _mm_sub_ps(a[group], t);
a[group] = _mm_add_ps(a[group], t);
let t = a[group + 3];
let t_tw = rotate_pm_i_ps(t, fwd);
a[group + 3] = _mm_sub_ps(a[group + 1], t_tw);
a[group + 1] = _mm_add_ps(a[group + 1], t_tw);
}
let t = a[4];
a[4] = _mm_sub_ps(a[0], t);
a[0] = _mm_add_ps(a[0], t);
{
let v = a[5];
let sw = _mm_shuffle_ps(v, v, 0b00_01_00_01);
let t_tw = if fwd {
let sum = _mm_add_ps(v, sw); let diff = _mm_sub_ps(sw, v); let combined = _mm_unpacklo_ps(sum, diff);
_mm_mul_ps(_mm_shuffle_ps(combined, combined, 0b01_00_01_00), inv_sqrt2)
} else {
let diff = _mm_sub_ps(v, sw); let sum = _mm_add_ps(v, sw); let sum_sw = _mm_shuffle_ps(sum, sum, 0b00_01_00_01);
let combined = _mm_unpacklo_ps(diff, sum_sw);
_mm_mul_ps(_mm_shuffle_ps(combined, combined, 0b01_00_01_00), inv_sqrt2)
};
a[5] = _mm_sub_ps(a[1], t_tw);
a[1] = _mm_add_ps(a[1], t_tw);
}
{
let t = a[6];
let t_tw = rotate_pm_i_ps(t, fwd);
a[6] = _mm_sub_ps(a[2], t_tw);
a[2] = _mm_add_ps(a[2], t_tw);
}
{
let v = a[7];
let sw = _mm_shuffle_ps(v, v, 0b00_01_00_01);
let t_tw = if fwd {
let diff = _mm_sub_ps(sw, v); let neg_sum = _mm_xor_ps(_mm_add_ps(v, sw), _mm_set1_ps(-0.0));
let combined = _mm_unpacklo_ps(diff, neg_sum);
_mm_mul_ps(_mm_shuffle_ps(combined, combined, 0b01_00_01_00), inv_sqrt2)
} else {
let neg_sum = _mm_xor_ps(_mm_add_ps(v, sw), _mm_set1_ps(-0.0));
let diff = _mm_sub_ps(sw, v); let diff_sw = _mm_shuffle_ps(diff, diff, 0b00_01_00_01);
let combined = _mm_unpacklo_ps(neg_sum, diff_sw);
_mm_mul_ps(_mm_shuffle_ps(combined, combined, 0b01_00_01_00), inv_sqrt2)
};
a[7] = _mm_sub_ps(a[3], t_tw);
a[3] = _mm_add_ps(a[3], t_tw);
}
for i in (0..8usize).step_by(2) {
store_cx(ptr.add(i * 2), a[i]);
store_cx(ptr.add(i * 2 + 2), a[i + 1]);
}
}
}
}