use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{
parse::{Parse, ParseStream},
LitInt, Token,
};
mod scalar;
mod simd_avx2_f32;
mod simd_sse2_f32;
#[cfg(test)]
mod tests;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimdIsa {
Sse2,
Avx2,
Scalar,
}
impl SimdIsa {
#[must_use]
pub const fn lanes_f32(self) -> usize {
match self {
Self::Sse2 => 4,
Self::Avx2 => 8,
Self::Scalar => 1,
}
}
#[must_use]
pub const fn lanes_f64(self) -> usize {
match self {
Self::Sse2 => 2,
Self::Avx2 => 4,
Self::Scalar => 1,
}
}
#[must_use]
pub const fn ident_str(self) -> &'static str {
match self {
Self::Sse2 => "sse2",
Self::Avx2 => "avx2",
Self::Scalar => "scalar",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Precision {
F32,
F64,
}
impl Precision {
#[must_use]
pub const fn type_str(self) -> &'static str {
match self {
Self::F32 => "f32",
Self::F64 => "f64",
}
}
}
#[derive(Debug, Clone)]
pub struct MultiTransformConfig {
pub size: usize,
pub v: usize,
pub isa: SimdIsa,
pub precision: Precision,
}
const fn has_simd_impl(isa: SimdIsa, precision: Precision, size: usize) -> bool {
matches!(
(isa, precision, size),
(SimdIsa::Sse2, Precision::F32, 2 | 4) | (SimdIsa::Avx2, Precision::F32, 2 | 4 | 8)
)
}
fn gen_simd_inner(config: &MultiTransformConfig) -> Option<TokenStream> {
match (config.isa, config.precision, config.size) {
(SimdIsa::Sse2, Precision::F32, 2) => Some(simd_sse2_f32::gen_sse2_f32_v4_size2_soa()),
(SimdIsa::Sse2, Precision::F32, 4) => Some(simd_sse2_f32::gen_sse2_f32_v4_size4_soa()),
(SimdIsa::Avx2, Precision::F32, 2) => Some(simd_avx2_f32::gen_avx2_f32_v8_size2_soa()),
(SimdIsa::Avx2, Precision::F32, 4) => Some(simd_avx2_f32::gen_avx2_f32_v8_size4_soa()),
(SimdIsa::Avx2, Precision::F32, 8) => Some(simd_avx2_f32::gen_avx2_f32_v8_size8_soa()),
_ => None,
}
}
fn gen_outer_body(config: &MultiTransformConfig, size: usize, v: usize) -> TokenStream {
let butterfly_body = scalar::gen_scalar_butterfly(size, config.precision);
let v_lit = v;
let size_lit = size;
quote! {
let batches = count / #v_lit;
let remainder = count % #v_lit;
for b in 0..batches {
for t in 0..#v_lit {
let base_in = (b * #v_lit + t) * 2;
let base_out = (b * #v_lit + t) * 2;
#butterfly_body
}
}
for t in 0..remainder {
let base_in = (batches * #v_lit + t) * 2;
let base_out = (batches * #v_lit + t) * 2;
#butterfly_body
}
let _ = #size_lit;
}
}
pub fn generate_multi_transform(config: &MultiTransformConfig) -> Result<TokenStream, syn::Error> {
if !matches!(config.size, 2 | 4 | 8) {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
format!(
"multi_transform: unsupported size {} (expected 2, 4, or 8)",
config.size
),
));
}
if config.v == 0 {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"multi_transform: v must be >= 1",
));
}
let fn_name = format_ident!(
"notw_{}_v{}_{}_{}",
config.size,
config.v,
config.isa.ident_str(),
config.precision.type_str()
);
let size = config.size;
let v = config.v;
let ty_str = config.precision.type_str();
let ty_tokens: TokenStream = ty_str.parse().expect("valid type token");
let use_simd = has_simd_impl(config.isa, config.precision, size);
let simd_inner = gen_simd_inner(config);
let outer_body = gen_outer_body(config, size, v);
let stride = v * 2;
let simd_note = if use_simd {
format!(
"True SIMD available via `notw_{size}_v{v}_{isa}_{ty}_soa` (`SoA` layout).",
isa = config.isa.ident_str(),
ty = ty_str,
)
} else {
"Sequential scalar fallback (no SIMD for this `ISA`+precision+size combination).".into()
};
let fn_doc = format!(
"Process `count` transforms of size {size} in batches of {v} (v={v}) using {isa} ISA.\n\n\
# Data layout (`AoS`)\n\
Interleaved with stride {v}: `data[element * {stride} + transform * 2 + c]`\n\
where `c` is 0 for real, 1 for imaginary.\n\n\
# SIMD acceleration\n\
{simd_note}\n\n\
# Safety\n\
- `input` must be valid for `count * {size} * 2 * {v}` reads of `{ty_str}`.\n\
- `output` must be valid for `count * {size} * 2 * {v}` writes of `{ty_str}`.\n\
- `istride` / `ostride` must be `2 * {v}` for the canonical `AoS` layout.\n\
- No alignment requirement; uses unaligned loads.",
size = size,
v = v,
isa = config.isa.ident_str(),
stride = stride,
ty_str = ty_str,
simd_note = simd_note,
);
let outer_fn = quote! {
#[doc = #fn_doc]
pub unsafe fn #fn_name(
input: *const #ty_tokens,
output: *mut #ty_tokens,
istride: usize,
ostride: usize,
count: usize,
) {
#outer_body
}
};
Ok(if let Some(inner) = simd_inner {
quote! {
#inner
#outer_fn
}
} else {
outer_fn
})
}
struct MacroArgs {
size: usize,
v: usize,
isa: SimdIsa,
precision: Precision,
}
impl Parse for MacroArgs {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let mut size: Option<usize> = None;
let mut v: Option<usize> = None;
let mut isa: Option<SimdIsa> = None;
let mut precision: Option<Precision> = None;
while !input.is_empty() {
let key: syn::Ident = input.parse()?;
let _eq: Token![=] = input.parse()?;
match key.to_string().as_str() {
"size" => {
let lit: LitInt = input.parse()?;
size = Some(lit.base10_parse::<usize>().map_err(|_| {
syn::Error::new(lit.span(), "expected an integer literal for `size`")
})?);
}
"v" => {
let lit: LitInt = input.parse()?;
v = Some(lit.base10_parse::<usize>().map_err(|_| {
syn::Error::new(lit.span(), "expected an integer literal for `v`")
})?);
}
"isa" => {
let ident: syn::Ident = input.parse()?;
isa = Some(match ident.to_string().as_str() {
"sse2" => SimdIsa::Sse2,
"avx2" => SimdIsa::Avx2,
"scalar" => SimdIsa::Scalar,
other => {
return Err(syn::Error::new(
ident.span(),
format!(
"unknown isa `{other}`, expected one of: sse2, avx2, scalar"
),
));
}
});
}
"ty" => {
let ident: syn::Ident = input.parse()?;
precision = Some(match ident.to_string().as_str() {
"f32" => Precision::F32,
"f64" => Precision::F64,
other => {
return Err(syn::Error::new(
ident.span(),
format!("unknown ty `{other}`, expected f32 or f64"),
));
}
});
}
other => {
return Err(syn::Error::new(
key.span(),
format!("unknown key `{other}`, expected one of: size, v, isa, ty"),
));
}
}
if input.peek(Token![,]) {
let _: Token![,] = input.parse()?;
}
}
let size = size.ok_or_else(|| {
syn::Error::new(proc_macro2::Span::call_site(), "missing `size` argument")
})?;
let v = v.ok_or_else(|| {
syn::Error::new(proc_macro2::Span::call_site(), "missing `v` argument")
})?;
let isa = isa.ok_or_else(|| {
syn::Error::new(proc_macro2::Span::call_site(), "missing `isa` argument")
})?;
let precision = precision.ok_or_else(|| {
syn::Error::new(proc_macro2::Span::call_site(), "missing `ty` argument")
})?;
Ok(Self {
size,
v,
isa,
precision,
})
}
}
pub fn generate_from_macro(input: TokenStream) -> Result<TokenStream, syn::Error> {
let args: MacroArgs = syn::parse2(input)?;
let config = MultiTransformConfig {
size: args.size,
v: args.v,
isa: args.isa,
precision: args.precision,
};
generate_multi_transform(&config)
}