use proc_macro2::TokenStream;
use quote::quote;
pub(super) fn gen_neon_size_2() -> TokenStream {
quote! {
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn codelet_simd_2_neon_f64(data: &mut [f64], _sign: i32) {
use core::arch::aarch64::*;
let ptr = data.as_mut_ptr();
let a = vld1q_f64(ptr);
let b = vld1q_f64(ptr.add(2));
let sum = vaddq_f64(a, b);
let diff = vsubq_f64(a, b);
vst1q_f64(ptr, sum);
vst1q_f64(ptr.add(2), diff);
}
}
}
pub(super) fn gen_neon_size_4() -> TokenStream {
quote! {
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn codelet_simd_4_neon_f64(data: &mut [f64], sign: i32) {
use core::arch::aarch64::*;
let ptr = data.as_mut_ptr();
let x0 = vld1q_f64(ptr);
let x1 = vld1q_f64(ptr.add(2));
let x2 = vld1q_f64(ptr.add(4));
let x3 = vld1q_f64(ptr.add(6));
let t0 = vaddq_f64(x0, x2);
let t1 = vsubq_f64(x0, x2);
let t2 = vaddq_f64(x1, x3);
let t3 = vsubq_f64(x1, x3);
let t3_swapped = vextq_f64(t3, t3, 1); let t3_rot = if sign < 0 {
let neg_mask = vld1q_f64([1.0_f64, -1.0_f64].as_ptr());
vmulq_f64(t3_swapped, neg_mask)
} else {
let neg_mask = vld1q_f64([-1.0_f64, 1.0_f64].as_ptr());
vmulq_f64(t3_swapped, neg_mask)
};
let out0 = vaddq_f64(t0, t2);
let out1 = vaddq_f64(t1, t3_rot);
let out2 = vsubq_f64(t0, t2);
let out3 = vsubq_f64(t1, t3_rot);
vst1q_f64(ptr, out0);
vst1q_f64(ptr.add(2), out1);
vst1q_f64(ptr.add(4), out2);
vst1q_f64(ptr.add(6), out3);
}
}
}
pub(super) fn gen_neon_size_8() -> TokenStream {
quote! {
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[allow(clippy::too_many_lines)]
unsafe fn codelet_simd_8_neon_f64(data: &mut [f64], sign: i32) {
use core::arch::aarch64::*;
let ptr = data.as_mut_ptr();
let inv_sqrt2 = vdupq_n_f64(0.707_106_781_186_547_6_f64);
let fwd = sign < 0;
let rotate_pm_i = |v: float64x2_t, is_fwd: bool| -> float64x2_t {
let swapped = vextq_f64(v, v, 1);
if is_fwd {
let mask = vld1q_f64([1.0_f64, -1.0_f64].as_ptr());
vmulq_f64(swapped, mask)
} else {
let mask = vld1q_f64([-1.0_f64, 1.0_f64].as_ptr());
vmulq_f64(swapped, mask)
}
};
let mut a = [vdupq_n_f64(0.0); 8];
a[0] = vld1q_f64(ptr); a[1] = vld1q_f64(ptr.add(8)); a[2] = vld1q_f64(ptr.add(4)); a[3] = vld1q_f64(ptr.add(12)); a[4] = vld1q_f64(ptr.add(2)); a[5] = vld1q_f64(ptr.add(10)); a[6] = vld1q_f64(ptr.add(6)); a[7] = vld1q_f64(ptr.add(14));
for i in (0..8usize).step_by(2) {
let t = a[i + 1];
a[i + 1] = vsubq_f64(a[i], t);
a[i] = vaddq_f64(a[i], t);
}
for group in (0..8usize).step_by(4) {
let t = a[group + 2];
a[group + 2] = vsubq_f64(a[group], t);
a[group] = vaddq_f64(a[group], t);
let t = a[group + 3];
let t_tw = rotate_pm_i(t, fwd);
a[group + 3] = vsubq_f64(a[group + 1], t_tw);
a[group + 1] = vaddq_f64(a[group + 1], t_tw);
}
let t = a[4];
a[4] = vsubq_f64(a[0], t);
a[0] = vaddq_f64(a[0], t);
{
let v = a[5];
let swapped = vextq_f64(v, v, 1);
let t_tw = if fwd {
let sum = vaddq_f64(v, swapped);
let diff_sr = vsubq_f64(swapped, v);
let combined = vzip1q_f64(sum, diff_sr);
vmulq_f64(combined, inv_sqrt2)
} else {
let diff = vsubq_f64(v, swapped);
let sum = vaddq_f64(v, swapped);
let combined = vzip1q_f64(diff, sum);
let combined = vzip1q_f64(diff, vextq_f64(sum, sum, 1));
vmulq_f64(combined, inv_sqrt2)
};
a[5] = vsubq_f64(a[1], t_tw);
a[1] = vaddq_f64(a[1], t_tw);
}
{
let t = a[6];
let t_tw = rotate_pm_i(t, fwd);
a[6] = vsubq_f64(a[2], t_tw);
a[2] = vaddq_f64(a[2], t_tw);
}
{
let v = a[7];
let swapped = vextq_f64(v, v, 1);
let t_tw = if fwd {
let diff = vsubq_f64(swapped, v); let neg_sum = vnegq_f64(vaddq_f64(v, swapped)); let combined = vzip1q_f64(diff, neg_sum);
vmulq_f64(combined, inv_sqrt2)
} else {
let neg_sum = vnegq_f64(vaddq_f64(v, swapped));
let diff = vsubq_f64(swapped, v);
let combined = vzip1q_f64(neg_sum, vextq_f64(diff, diff, 1));
vmulq_f64(combined, inv_sqrt2)
};
a[7] = vsubq_f64(a[3], t_tw);
a[3] = vaddq_f64(a[3], t_tw);
}
for i in 0..8usize {
vst1q_f64(ptr.add(i * 2), a[i]);
}
}
}
}
pub(super) fn gen_neon_size_2_f32() -> TokenStream {
quote! {
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn codelet_simd_2_neon_f32(data: &mut [f32], _sign: i32) {
use core::arch::aarch64::*;
let ptr = data.as_mut_ptr();
let a = vld1_f32(ptr);
let b = vld1_f32(ptr.add(2));
let sum = vadd_f32(a, b);
let diff = vsub_f32(a, b);
vst1_f32(ptr, sum);
vst1_f32(ptr.add(2), diff);
}
}
}
pub(super) fn gen_neon_size_4_f32() -> TokenStream {
quote! {
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn codelet_simd_4_neon_f32(data: &mut [f32], sign: i32) {
use core::arch::aarch64::*;
let ptr = data.as_mut_ptr();
let x0 = vld1_f32(ptr);
let x1 = vld1_f32(ptr.add(2));
let x2 = vld1_f32(ptr.add(4));
let x3 = vld1_f32(ptr.add(6));
let t0 = vadd_f32(x0, x2);
let t1 = vsub_f32(x0, x2);
let t2 = vadd_f32(x1, x3);
let t3 = vsub_f32(x1, x3);
let t3_swapped = vext_f32(t3, t3, 1); let t3_rot = if sign < 0 {
let neg_mask = vld1_f32([1.0_f32, -1.0_f32].as_ptr());
vmul_f32(t3_swapped, neg_mask)
} else {
let neg_mask = vld1_f32([-1.0_f32, 1.0_f32].as_ptr());
vmul_f32(t3_swapped, neg_mask)
};
let out0 = vadd_f32(t0, t2);
let out1 = vadd_f32(t1, t3_rot);
let out2 = vsub_f32(t0, t2);
let out3 = vsub_f32(t1, t3_rot);
vst1_f32(ptr, out0);
vst1_f32(ptr.add(2), out1);
vst1_f32(ptr.add(4), out2);
vst1_f32(ptr.add(6), out3);
}
}
}
pub(super) fn gen_neon_size_8_f32() -> TokenStream {
quote! {
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[allow(clippy::too_many_lines)]
unsafe fn codelet_simd_8_neon_f32(data: &mut [f32], sign: i32) {
use core::arch::aarch64::*;
let ptr = data.as_mut_ptr();
let inv_sqrt2 = vdup_n_f32(0.707_106_8_f32);
let fwd = sign < 0;
let rotate_pm_i = |v: float32x2_t, is_fwd: bool| -> float32x2_t {
let swapped = vext_f32(v, v, 1); if is_fwd {
let mask = vld1_f32([1.0_f32, -1.0_f32].as_ptr());
vmul_f32(swapped, mask)
} else {
let mask = vld1_f32([-1.0_f32, 1.0_f32].as_ptr());
vmul_f32(swapped, mask)
}
};
let mut a = [vdup_n_f32(0.0); 8];
a[0] = vld1_f32(ptr); a[1] = vld1_f32(ptr.add(8)); a[2] = vld1_f32(ptr.add(4)); a[3] = vld1_f32(ptr.add(12)); a[4] = vld1_f32(ptr.add(2)); a[5] = vld1_f32(ptr.add(10)); a[6] = vld1_f32(ptr.add(6)); a[7] = vld1_f32(ptr.add(14));
for i in (0..8usize).step_by(2) {
let t = a[i + 1];
a[i + 1] = vsub_f32(a[i], t);
a[i] = vadd_f32(a[i], t);
}
for group in (0..8usize).step_by(4) {
let t = a[group + 2];
a[group + 2] = vsub_f32(a[group], t);
a[group] = vadd_f32(a[group], t);
let t = a[group + 3];
let t_tw = rotate_pm_i(t, fwd);
a[group + 3] = vsub_f32(a[group + 1], t_tw);
a[group + 1] = vadd_f32(a[group + 1], t_tw);
}
let t = a[4];
a[4] = vsub_f32(a[0], t);
a[0] = vadd_f32(a[0], t);
{
let v = a[5];
let swapped = vext_f32(v, v, 1); let t_tw = if fwd {
let sum = vadd_f32(v, swapped); let diff = vsub_f32(swapped, v); let combined = vzip1_f32(sum, diff);
vmul_f32(combined, inv_sqrt2)
} else {
let diff = vsub_f32(v, swapped); let sum = vadd_f32(v, swapped); let combined = vzip1_f32(diff, vext_f32(sum, sum, 1));
vmul_f32(combined, inv_sqrt2)
};
a[5] = vsub_f32(a[1], t_tw);
a[1] = vadd_f32(a[1], t_tw);
}
{
let t = a[6];
let t_tw = rotate_pm_i(t, fwd);
a[6] = vsub_f32(a[2], t_tw);
a[2] = vadd_f32(a[2], t_tw);
}
{
let v = a[7];
let swapped = vext_f32(v, v, 1); let t_tw = if fwd {
let diff = vsub_f32(swapped, v); let neg_sum = vneg_f32(vadd_f32(v, swapped)); let combined = vzip1_f32(diff, neg_sum);
vmul_f32(combined, inv_sqrt2)
} else {
let neg_sum = vneg_f32(vadd_f32(v, swapped)); let diff = vsub_f32(swapped, v); let combined = vzip1_f32(neg_sum, vext_f32(diff, diff, 1));
vmul_f32(combined, inv_sqrt2)
};
a[7] = vsub_f32(a[3], t_tw);
a[3] = vadd_f32(a[3], t_tw);
}
for i in 0..8usize {
vst1_f32(ptr.add(i * 2), a[i]);
}
}
}
}