use proc_macro2::TokenStream;
use quote::quote;
pub(super) fn gen_avx512_size_2_f64() -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn codelet_simd_2_avx512_f64(data: &mut [f64], _sign: i32) {
use core::arch::x86_64::*;
let ptr = data.as_mut_ptr();
let v = _mm256_loadu_pd(ptr);
let a = _mm256_castpd256_pd128(v); let b = _mm256_extractf128_pd(v, 1);
let sum = _mm_add_pd(a, b);
let diff = _mm_sub_pd(a, b);
let result = _mm256_insertf128_pd(
_mm256_castpd128_pd256(sum),
diff,
1,
);
_mm256_storeu_pd(ptr, result);
}
}
}
pub(super) fn gen_avx512_size_4_f64() -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn codelet_simd_4_avx512_f64(data: &mut [f64], sign: i32) {
use core::arch::x86_64::*;
let ptr = data.as_mut_ptr();
let v_01 = _mm256_loadu_pd(ptr); let v_23 = _mm256_loadu_pd(ptr.add(4));
let sum = _mm256_add_pd(v_01, v_23); let diff = _mm256_sub_pd(v_01, v_23);
let t0 = _mm256_castpd256_pd128(sum);
let t2 = _mm256_extractf128_pd(sum, 1);
let t1 = _mm256_castpd256_pd128(diff);
let t3 = _mm256_extractf128_pd(diff, 1);
let t3_swapped = _mm_shuffle_pd(t3, t3, 0b01); let t3_rot = if sign < 0 {
_mm_xor_pd(t3_swapped, _mm_set_pd(-0.0, 0.0))
} else {
_mm_xor_pd(t3_swapped, _mm_set_pd(0.0, -0.0))
};
let out0 = _mm_add_pd(t0, t2);
let out1 = _mm_add_pd(t1, t3_rot);
let out2 = _mm_sub_pd(t0, t2);
let out3 = _mm_sub_pd(t1, t3_rot);
let v_out_01 = _mm256_insertf128_pd(
_mm256_castpd128_pd256(out0), out1, 1
);
let v_out_23 = _mm256_insertf128_pd(
_mm256_castpd128_pd256(out2), out3, 1
);
_mm256_storeu_pd(ptr, v_out_01);
_mm256_storeu_pd(ptr.add(4), v_out_23);
}
}
}
#[allow(clippy::too_many_lines)]
pub(super) fn gen_avx512_size_8_f64() -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[allow(clippy::too_many_lines)]
unsafe fn codelet_simd_8_avx512_f64(data: &mut [f64], sign: i32) {
use core::arch::x86_64::*;
let ptr = data.as_mut_ptr();
let inv_sqrt2 = _mm512_set1_pd(core::f64::consts::FRAC_1_SQRT_2);
let rotate_i = |v: __m128d, fwd: bool| -> __m128d {
let sw = _mm_shuffle_pd(v, v, 0b01);
if fwd {
_mm_xor_pd(sw, _mm_set_pd(-0.0, 0.0))
} else {
_mm_xor_pd(sw, _mm_set_pd(0.0, -0.0))
}
};
let fwd = sign < 0;
let mut a = [_mm_setzero_pd(); 8];
a[0] = _mm_loadu_pd(ptr); a[1] = _mm_loadu_pd(ptr.add(8)); a[2] = _mm_loadu_pd(ptr.add(4)); a[3] = _mm_loadu_pd(ptr.add(12)); a[4] = _mm_loadu_pd(ptr.add(2)); a[5] = _mm_loadu_pd(ptr.add(10)); a[6] = _mm_loadu_pd(ptr.add(6)); a[7] = _mm_loadu_pd(ptr.add(14));
for i in (0..8_usize).step_by(2) {
let t = a[i + 1];
a[i + 1] = _mm_sub_pd(a[i], t);
a[i] = _mm_add_pd(a[i], t);
}
for group in (0..8_usize).step_by(4) {
let t = a[group + 2];
a[group + 2] = _mm_sub_pd(a[group], t);
a[group] = _mm_add_pd(a[group], t);
let t_tw = rotate_i(a[group + 3], fwd);
let t = a[group + 3];
let _ = t; a[group + 3] = _mm_sub_pd(a[group + 1], t_tw);
a[group + 1] = _mm_add_pd(a[group + 1], t_tw);
}
{
let t = a[4];
a[4] = _mm_sub_pd(a[0], t);
a[0] = _mm_add_pd(a[0], t);
}
{
let v = a[5];
let v_re = _mm_shuffle_pd(v, v, 0b00); let v_im = _mm_shuffle_pd(v, v, 0b11); let vr = _mm512_castpd128_pd512(v_re);
let vi = _mm512_castpd128_pd512(v_im);
let (c, d) = if fwd {
(inv_sqrt2, _mm512_xor_pd(inv_sqrt2, _mm512_set1_pd(-0.0)))
} else {
(inv_sqrt2, inv_sqrt2)
};
let re_out = _mm512_castpd512_pd128(
_mm512_fmsub_pd(vr, c, _mm512_mul_pd(vi, d))
);
let im_out = _mm512_castpd512_pd128(
_mm512_fmadd_pd(vr, d, _mm512_mul_pd(vi, c))
);
let t_tw = _mm_shuffle_pd(re_out, im_out, 0b00);
a[5] = _mm_sub_pd(a[1], t_tw);
a[1] = _mm_add_pd(a[1], t_tw);
}
{
let t_tw = rotate_i(a[6], fwd);
let t = a[6];
let _ = t;
a[6] = _mm_sub_pd(a[2], t_tw);
a[2] = _mm_add_pd(a[2], t_tw);
}
{
let v = a[7];
let v_re = _mm_shuffle_pd(v, v, 0b00);
let v_im = _mm_shuffle_pd(v, v, 0b11);
let vr = _mm512_castpd128_pd512(v_re);
let vi = _mm512_castpd128_pd512(v_im);
let neg_is2 = _mm512_xor_pd(inv_sqrt2, _mm512_set1_pd(-0.0));
let (c, d) = if fwd {
(neg_is2, neg_is2)
} else {
(neg_is2, inv_sqrt2)
};
let re_out = _mm512_castpd512_pd128(
_mm512_fmsub_pd(vr, c, _mm512_mul_pd(vi, d))
);
let im_out = _mm512_castpd512_pd128(
_mm512_fmadd_pd(vr, d, _mm512_mul_pd(vi, c))
);
let t_tw = _mm_shuffle_pd(re_out, im_out, 0b00);
a[7] = _mm_sub_pd(a[3], t_tw);
a[3] = _mm_add_pd(a[3], t_tw);
}
for i in (0..8_usize).step_by(2) {
let packed = _mm256_insertf128_pd(
_mm256_castpd128_pd256(a[i]),
a[i + 1],
1,
);
_mm256_storeu_pd(ptr.add(i * 2), packed);
}
}
}
}
pub(super) fn gen_avx512_size_2_f32() -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn codelet_simd_2_avx512_f32(data: &mut [f32], _sign: i32) {
use core::arch::x86_64::*;
let ptr = data.as_mut_ptr();
let v = _mm_loadu_ps(ptr);
let a = _mm_shuffle_ps(v, v, 0b01_00_01_00);
let b = _mm_shuffle_ps(v, v, 0b11_10_11_10);
let sum = _mm_add_ps(a, b);
let diff = _mm_sub_ps(a, b);
let out = _mm_shuffle_ps(sum, diff, 0b01_00_01_00);
_mm_storeu_ps(ptr, out);
}
}
}
pub(super) fn gen_avx512_size_4_f32() -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn codelet_simd_4_avx512_f32(data: &mut [f32], sign: i32) {
use core::arch::x86_64::*;
let ptr = data.as_mut_ptr();
let all = _mm256_loadu_ps(ptr);
let v_01 = _mm256_castps256_ps128(all);
let v_23 = _mm256_extractf128_ps(all, 1);
let sum = _mm_add_ps(v_01, v_23); let diff = _mm_sub_ps(v_01, v_23);
let t0 = _mm_shuffle_ps(sum, sum, 0b01_00_01_00);
let t2 = _mm_shuffle_ps(sum, sum, 0b11_10_11_10);
let t1 = _mm_shuffle_ps(diff, diff, 0b01_00_01_00);
let t3 = _mm_shuffle_ps(diff, diff, 0b11_10_11_10);
let t3_sw = _mm_shuffle_ps(t3, t3, 0b00_01_00_01);
let t3_rot = if sign < 0 {
_mm_xor_ps(t3_sw, _mm_set_ps(-0.0, 0.0, -0.0, 0.0))
} else {
_mm_xor_ps(t3_sw, _mm_set_ps(0.0, -0.0, 0.0, -0.0))
};
let out0 = _mm_add_ps(t0, t2);
let out1 = _mm_add_ps(t1, t3_rot);
let out2 = _mm_sub_ps(t0, t2);
let out3 = _mm_sub_ps(t1, t3_rot);
let packed_01 = _mm_shuffle_ps(out0, out1, 0b01_00_01_00);
let packed_23 = _mm_shuffle_ps(out2, out3, 0b01_00_01_00);
let result = _mm256_insertf128_ps(
_mm256_castps128_ps256(packed_01),
packed_23,
1,
);
_mm256_storeu_ps(ptr, result);
}
}
}
#[allow(clippy::too_many_lines)]
pub(super) fn gen_avx512_size_8_f32() -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[allow(clippy::too_many_lines)]
unsafe fn codelet_simd_8_avx512_f32(data: &mut [f32], sign: i32) {
use core::arch::x86_64::*;
let ptr = data.as_mut_ptr();
let inv_sqrt2 = _mm_set1_ps(core::f32::consts::FRAC_1_SQRT_2);
let load_cx = |base: *const f32| -> __m128 {
let v = _mm_castsi128_ps(_mm_loadl_epi64(base.cast::<__m128i>()));
_mm_shuffle_ps(v, v, 0b01_00_01_00)
};
let store_cx = |base: *mut f32, v: __m128| {
_mm_storel_epi64(base.cast::<__m128i>(), _mm_castps_si128(v));
};
let rotate_i = |v: __m128, fwd: bool| -> __m128 {
let sw = _mm_shuffle_ps(v, v, 0b00_01_00_01);
if fwd {
_mm_xor_ps(sw, _mm_set_ps(-0.0, 0.0, -0.0, 0.0))
} else {
_mm_xor_ps(sw, _mm_set_ps(0.0, -0.0, 0.0, -0.0))
}
};
let fwd = sign < 0;
let mut a = [_mm_setzero_ps(); 8];
a[0] = load_cx(ptr); a[1] = load_cx(ptr.add(8)); a[2] = load_cx(ptr.add(4)); a[3] = load_cx(ptr.add(12)); a[4] = load_cx(ptr.add(2)); a[5] = load_cx(ptr.add(10)); a[6] = load_cx(ptr.add(6)); a[7] = load_cx(ptr.add(14));
for i in (0..8_usize).step_by(2) {
let t = a[i + 1];
a[i + 1] = _mm_sub_ps(a[i], t);
a[i] = _mm_add_ps(a[i], t);
}
for group in (0..8_usize).step_by(4) {
let t = a[group + 2];
a[group + 2] = _mm_sub_ps(a[group], t);
a[group] = _mm_add_ps(a[group], t);
let t_tw = rotate_i(a[group + 3], fwd);
let t = a[group + 3];
let _ = t;
a[group + 3] = _mm_sub_ps(a[group + 1], t_tw);
a[group + 1] = _mm_add_ps(a[group + 1], t_tw);
}
{
let t = a[4];
a[4] = _mm_sub_ps(a[0], t);
a[0] = _mm_add_ps(a[0], t);
}
{
let v = a[5];
let v_re = _mm_shuffle_ps(v, v, 0b00_00_00_00); let v_im = _mm_shuffle_ps(v, v, 0b01_01_01_01); let vr = _mm512_castps128_ps512(v_re);
let vi = _mm512_castps128_ps512(v_im);
let is2 = _mm512_castps128_ps512(inv_sqrt2);
let (c, d) = if fwd {
let neg = _mm512_xor_ps(is2, _mm512_set1_ps(-0.0_f32));
(is2, neg)
} else {
(is2, is2)
};
let re_out = _mm512_castps512_ps128(
_mm512_fmsub_ps(vr, c, _mm512_mul_ps(vi, d))
);
let im_out = _mm512_castps512_ps128(
_mm512_fmadd_ps(vr, d, _mm512_mul_ps(vi, c))
);
let t_tw = _mm_unpacklo_ps(re_out, im_out);
a[5] = _mm_sub_ps(a[1], t_tw);
a[1] = _mm_add_ps(a[1], t_tw);
}
{
let t_tw = rotate_i(a[6], fwd);
let t = a[6];
let _ = t;
a[6] = _mm_sub_ps(a[2], t_tw);
a[2] = _mm_add_ps(a[2], t_tw);
}
{
let v = a[7];
let v_re = _mm_shuffle_ps(v, v, 0b00_00_00_00);
let v_im = _mm_shuffle_ps(v, v, 0b01_01_01_01);
let vr = _mm512_castps128_ps512(v_re);
let vi = _mm512_castps128_ps512(v_im);
let is2 = _mm512_castps128_ps512(inv_sqrt2);
let neg_is2 = _mm512_xor_ps(is2, _mm512_set1_ps(-0.0_f32));
let (c, d) = if fwd {
(neg_is2, neg_is2)
} else {
(neg_is2, is2)
};
let re_out = _mm512_castps512_ps128(
_mm512_fmsub_ps(vr, c, _mm512_mul_ps(vi, d))
);
let im_out = _mm512_castps512_ps128(
_mm512_fmadd_ps(vr, d, _mm512_mul_ps(vi, c))
);
let t_tw = _mm_unpacklo_ps(re_out, im_out);
a[7] = _mm_sub_ps(a[3], t_tw);
a[3] = _mm_add_ps(a[3], t_tw);
}
for i in 0..8_usize {
store_cx(ptr.add(i * 2), a[i]);
}
}
}
}
#[allow(clippy::too_many_lines)]
pub(super) fn gen_avx512_size_16_f32() -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[allow(clippy::too_many_lines)]
unsafe fn codelet_simd_16_avx512_f32(data: &mut [f32], sign: i32) {
use core::arch::x86_64::*;
let ptr = data.as_mut_ptr();
let inv_sqrt2 = _mm_set1_ps(core::f32::consts::FRAC_1_SQRT_2);
let load_cx = |base: *const f32| -> __m128 {
let v = _mm_castsi128_ps(_mm_loadl_epi64(base.cast::<__m128i>()));
_mm_shuffle_ps(v, v, 0b01_00_01_00)
};
let store_cx = |base: *mut f32, v: __m128| {
_mm_storel_epi64(base.cast::<__m128i>(), _mm_castps_si128(v));
};
let rotate_i = |v: __m128, fwd: bool| -> __m128 {
let sw = _mm_shuffle_ps(v, v, 0b00_01_00_01);
if fwd {
_mm_xor_ps(sw, _mm_set_ps(-0.0, 0.0, -0.0, 0.0))
} else {
_mm_xor_ps(sw, _mm_set_ps(0.0, -0.0, 0.0, -0.0))
}
};
let cmul_fma = |v: __m128, c: f32, d: f32| -> __m128 {
let v_re = _mm_shuffle_ps(v, v, 0b00_00_00_00);
let v_im = _mm_shuffle_ps(v, v, 0b01_01_01_01);
let vr = _mm512_castps128_ps512(v_re);
let vi = _mm512_castps128_ps512(v_im);
let vc = _mm512_set1_ps(c);
let vd = _mm512_set1_ps(d);
let re_out = _mm512_castps512_ps128(_mm512_fmsub_ps(vr, vc, _mm512_mul_ps(vi, vd)));
let im_out = _mm512_castps512_ps128(_mm512_fmadd_ps(vr, vd, _mm512_mul_ps(vi, vc)));
_mm_unpacklo_ps(re_out, im_out)
};
let fwd = sign < 0;
let sign_f = if fwd { -1.0_f32 } else { 1.0_f32 };
let w16: [(f32, f32); 8] = {
let mut arr = [(0.0_f32, 0.0_f32); 8];
for (k, item) in arr.iter_mut().enumerate() {
let angle = sign_f * 2.0 * core::f32::consts::PI * (k as f32) / 16.0;
*item = (angle.cos(), angle.sin());
}
arr
};
let mut a = [_mm_setzero_ps(); 16];
a[0] = load_cx(ptr); a[1] = load_cx(ptr.add(16)); a[2] = load_cx(ptr.add(8)); a[3] = load_cx(ptr.add(24)); a[4] = load_cx(ptr.add(4)); a[5] = load_cx(ptr.add(20)); a[6] = load_cx(ptr.add(12)); a[7] = load_cx(ptr.add(28)); a[8] = load_cx(ptr.add(2)); a[9] = load_cx(ptr.add(18)); a[10] = load_cx(ptr.add(10)); a[11] = load_cx(ptr.add(26)); a[12] = load_cx(ptr.add(6)); a[13] = load_cx(ptr.add(22)); a[14] = load_cx(ptr.add(14)); a[15] = load_cx(ptr.add(30));
for i in (0..16_usize).step_by(2) {
let t = a[i + 1];
a[i + 1] = _mm_sub_ps(a[i], t);
a[i] = _mm_add_ps(a[i], t);
}
for group in (0..16_usize).step_by(4) {
let t = a[group + 2];
a[group + 2] = _mm_sub_ps(a[group], t);
a[group] = _mm_add_ps(a[group], t);
let t_tw = rotate_i(a[group + 3], fwd);
let t = a[group + 3];
let _ = t;
a[group + 3] = _mm_sub_ps(a[group + 1], t_tw);
a[group + 1] = _mm_add_ps(a[group + 1], t_tw);
}
let w8: [(f32, f32); 4] = {
let mut arr = [(0.0_f32, 0.0_f32); 4];
for (k, item) in arr.iter_mut().enumerate() {
let angle = sign_f * 2.0 * core::f32::consts::PI * (k as f32) / 8.0;
*item = (angle.cos(), angle.sin());
}
arr
};
for group in (0..16_usize).step_by(8) {
for k in 0..4_usize {
let (c, d) = w8[k];
let t_tw = if k == 0 {
a[group + k + 4] } else {
cmul_fma(a[group + k + 4], c, d)
};
a[group + k + 4] = _mm_sub_ps(a[group + k], t_tw);
a[group + k] = _mm_add_ps(a[group + k], t_tw);
}
}
for k in 0..8_usize {
let (c, d) = w16[k];
let t_tw = if k == 0 {
a[k + 8]
} else {
cmul_fma(a[k + 8], c, d)
};
a[k + 8] = _mm_sub_ps(a[k], t_tw);
a[k] = _mm_add_ps(a[k], t_tw);
}
for i in 0..16_usize {
store_cx(ptr.add(i * 2), a[i]);
}
}
}
}