use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::LitInt;
mod avx;
mod avx2;
mod avx512;
mod neon;
mod scalar;
mod sse2;
pub fn generate(input: TokenStream) -> Result<TokenStream, syn::Error> {
let size: LitInt = syn::parse2(input)?;
let n: usize = size.base10_parse().map_err(|_| {
syn::Error::new(
size.span(),
"gen_simd_codelet: expected an integer size literal",
)
})?;
match n {
2 => Ok(gen_simd_size_2()),
4 => Ok(gen_simd_size_4()),
8 => Ok(gen_simd_size_8()),
16 => Ok(gen_simd_size_16()),
_ => Err(syn::Error::new(
size.span(),
format!("gen_simd_codelet: unsupported size {n} (expected one of 2, 4, 8, 16)"),
)),
}
}
fn gen_simd_size_2() -> TokenStream {
let dispatcher = gen_dispatcher(2);
let scalar = scalar::gen_scalar_size_2();
let sse2_f64 = sse2::gen_sse2_size_2();
let sse2_f32 = sse2::gen_sse2_size_2_f32();
let plain_avx_f64 = avx::gen_avx_size_2_f64();
let avx2_f64 = avx2::gen_avx2_size_2();
let avx2_f32 = avx2::gen_avx2_size_2_f32();
let avx512_f64 = avx512::gen_avx512_size_2_f64();
let avx512_f32 = avx512::gen_avx512_size_2_f32();
let neon_f64 = neon::gen_neon_size_2();
let neon_f32 = neon::gen_neon_size_2_f32();
quote! {
#dispatcher
#scalar
#sse2_f64
#sse2_f32
#plain_avx_f64
#avx2_f64
#avx2_f32
#avx512_f64
#avx512_f32
#neon_f64
#neon_f32
}
}
fn gen_simd_size_4() -> TokenStream {
let dispatcher = gen_dispatcher(4);
let scalar = scalar::gen_scalar_size_4();
let sse2_f64 = sse2::gen_sse2_size_4();
let sse2_f32 = sse2::gen_sse2_size_4_f32();
let plain_avx_f64 = avx::gen_avx_size_4_f64();
let avx2_f64 = avx2::gen_avx2_size_4();
let avx2_f32 = avx2::gen_avx2_size_4_f32();
let avx512_f64 = avx512::gen_avx512_size_4_f64();
let avx512_f32 = avx512::gen_avx512_size_4_f32();
let neon_f64 = neon::gen_neon_size_4();
let neon_f32 = neon::gen_neon_size_4_f32();
quote! {
#dispatcher
#scalar
#sse2_f64
#sse2_f32
#plain_avx_f64
#avx2_f64
#avx2_f32
#avx512_f64
#avx512_f32
#neon_f64
#neon_f32
}
}
fn gen_simd_size_8() -> TokenStream {
let dispatcher = gen_dispatcher(8);
let scalar = scalar::gen_scalar_size_8();
let sse2_f64 = sse2::gen_sse2_size_8();
let sse2_f32 = sse2::gen_sse2_size_8_f32();
let plain_avx_f64 = avx::gen_avx_size_8_f64();
let avx2_f64 = avx2::gen_avx2_size_8();
let avx2_f32 = avx2::gen_avx2_size_8_f32();
let avx512_f64 = avx512::gen_avx512_size_8_f64();
let avx512_f32 = avx512::gen_avx512_size_8_f32();
let neon_f64 = neon::gen_neon_size_8();
let neon_f32 = neon::gen_neon_size_8_f32();
quote! {
#dispatcher
#scalar
#sse2_f64
#sse2_f32
#plain_avx_f64
#avx2_f64
#avx2_f32
#avx512_f64
#avx512_f32
#neon_f64
#neon_f32
}
}
fn gen_simd_size_16() -> TokenStream {
let dispatcher = gen_dispatcher_16();
let scalar = scalar::gen_scalar_size_16();
let avx512_f32 = avx512::gen_avx512_size_16_f32();
quote! {
#dispatcher
#scalar
#avx512_f32
}
}
fn gen_dispatcher(n: usize) -> proc_macro2::TokenStream {
let fn_name = format_ident!("codelet_simd_{}", n);
let scalar_name = format_ident!("codelet_simd_{}_scalar", n);
let avx512_f64_name = format_ident!("codelet_simd_{}_avx512_f64", n);
let avx2_f64_name = format_ident!("codelet_simd_{}_avx2_f64", n);
let plain_avx_fn_name = format_ident!("codelet_simd_{}_avx_f64", n);
let sse2_f64_name = format_ident!("codelet_simd_{}_sse2_f64", n);
let neon_f64_name = format_ident!("codelet_simd_{}_neon_f64", n);
let avx512_f32_name = format_ident!("codelet_simd_{}_avx512_f32", n);
let avx2_f32_name = format_ident!("codelet_simd_{}_avx2_f32", n);
let sse2_f32_name = format_ident!("codelet_simd_{}_sse2_f32", n);
let neon_f32_name = format_ident!("codelet_simd_{}_neon_f32", n);
let n_lit = n;
let doc = format!(
"Size-{n} SIMD-optimized FFT codelet with architecture dispatch.\n\n\
Automatically selects the best SIMD path at runtime:\n\
- x86_64: AVX-512F > AVX2+FMA > AVX > SSE2 > scalar (f64; f32 uses AVX2/SSE2/512)\n\
- aarch64: NEON > scalar (both f64 and f32)\n\
- other: scalar fallback"
);
quote! {
#[doc = #doc]
#[inline]
pub fn #fn_name<T: crate::kernel::Float>(
data: &mut [crate::kernel::Complex<T>],
sign: i32,
) {
debug_assert!(
data.len() >= #n_lit,
"codelet_simd_{}: need >= {} elements, got {}",
#n_lit,
#n_lit,
data.len(),
);
if core::any::TypeId::of::<T>() == core::any::TypeId::of::<f64>() {
let len = data.len() * 2;
let ptr = data.as_mut_ptr().cast::<f64>();
let f64_data = unsafe { core::slice::from_raw_parts_mut(ptr, len) };
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
unsafe { #avx512_f64_name(f64_data, sign); }
return;
}
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { #avx2_f64_name(f64_data, sign); }
return;
}
if is_x86_feature_detected!("avx") {
unsafe { #plain_avx_fn_name(f64_data, sign); }
return;
}
if is_x86_feature_detected!("sse2") {
unsafe { #sse2_f64_name(f64_data, sign); }
return;
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { #neon_f64_name(f64_data, sign); }
return;
}
}
if core::any::TypeId::of::<T>() == core::any::TypeId::of::<f32>() {
let len = data.len() * 2;
let ptr = data.as_mut_ptr().cast::<f32>();
let f32_data = unsafe { core::slice::from_raw_parts_mut(ptr, len) };
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
unsafe { #avx512_f32_name(f32_data, sign); }
return;
}
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { #avx2_f32_name(f32_data, sign); }
return;
}
if is_x86_feature_detected!("sse2") {
unsafe { #sse2_f32_name(f32_data, sign); }
return;
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { #neon_f32_name(f32_data, sign); }
return;
}
}
#scalar_name(data, sign);
}
}
}
fn gen_dispatcher_16() -> proc_macro2::TokenStream {
let avx512_f32_name = format_ident!("codelet_simd_16_avx512_f32");
let scalar_name = format_ident!("codelet_simd_16_scalar");
quote! {
#[inline]
pub fn codelet_simd_16<T: crate::kernel::Float>(
data: &mut [crate::kernel::Complex<T>],
sign: i32,
) {
debug_assert!(
data.len() >= 16_usize,
"codelet_simd_16: need >= 16 elements, got {}",
data.len(),
);
if core::any::TypeId::of::<T>() == core::any::TypeId::of::<f32>() {
let len = data.len() * 2;
let ptr = data.as_mut_ptr().cast::<f32>();
let f32_data = unsafe { core::slice::from_raw_parts_mut(ptr, len) };
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
unsafe { #avx512_f32_name(f32_data, sign); }
return;
}
}
}
#scalar_name(data, sign);
}
}
}