Skip to main content

oxifft_codegen_impl/gen_simd/multi_transform/
mod.rs

1//! Build-time codegen for SIMD vrank multi-transform codelets.
2//!
3//! A multi-transform codelet processes `V` DFTs of size `N` simultaneously.
4//!
5//! # Implementations
6//!
7//! - **SSE2 f32 (V=4)**: true SIMD for sizes 2 and 4 via `notw_{size}_v4_sse2_f32_soa`.
8//! - **AVX2 f32 (V=8)**: true SIMD for sizes 2, 4, and 8 via `notw_{size}_v8_avx2_f32_soa`.
9//! - **All other combos**: sequential scalar fallback over `AoS` layout.
10//!
11//! # Data layouts
12//!
13//! ## `AoS` (Array-of-Structs) — outer function signature
14//!
15//! For `V` transforms of size `N`:
16//! ```text
17//! data[element_idx * v * 2 + transform_idx * 2 + 0]  = re of x[element_idx] for transform transform_idx
18//! data[element_idx * v * 2 + transform_idx * 2 + 1]  = im of x[element_idx] for transform transform_idx
19//! ```
20//!
21//! ## `SoA` (Struct-of-Arrays) — inner SIMD function signature
22//!
23//! For `V` transforms of size `N` (only used internally by SIMD paths):
24//! ```text
25//! re_in[element_idx * v + transform_idx] = real  part of x[element_idx] for transform transform_idx
26//! im_in[element_idx * v + transform_idx] = imag  part of x[element_idx] for transform transform_idx
27//! ```
28//!
29//! The SIMD functions operate natively in `SoA`. The outer `AoS` function optionally
30//! calls the inner `SoA` function (when `ISA` + precision match a SIMD path), otherwise
31//! falls back to the sequential scalar loop.
32//!
33//! # Generated function signatures
34//!
35//! Outer (`AoS`, called by users):
36//! ```rust,ignore
37//! pub unsafe fn notw_4_v8_avx2_f32(
38//!     input: *const f32, output: *mut f32,
39//!     istride: usize, ostride: usize, count: usize,
40//! )
41//! ```
42//!
43//! Inner `SoA` SIMD helpers (emitted alongside, for direct use or testing):
44//! ```rust,ignore
45//! pub unsafe fn notw_4_v8_avx2_f32_soa(
46//!     re_in: *const f32, im_in: *const f32,
47//!     re_out: *mut f32, im_out: *mut f32,
48//! )
49//! ```
50
51use proc_macro2::TokenStream;
52use quote::{format_ident, quote};
53use syn::{
54    parse::{Parse, ParseStream},
55    LitInt, Token,
56};
57
58mod scalar;
59mod simd_avx2_f32;
60mod simd_sse2_f32;
61
62#[cfg(test)]
63mod tests;
64
65// ============================================================================
66// Public types
67// ============================================================================
68
69/// Target ISA for a multi-transform codelet.
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum SimdIsa {
72    /// SSE2 (128-bit, 4 f32 or 2 f64 lanes).
73    Sse2,
74    /// AVX2+FMA (256-bit, 8 f32 or 4 f64 lanes).
75    Avx2,
76    /// Scalar fallback (no SIMD).
77    Scalar,
78}
79
80impl SimdIsa {
81    /// Number of scalar lanes for `f32`.
82    #[must_use]
83    pub const fn lanes_f32(self) -> usize {
84        match self {
85            Self::Sse2 => 4,
86            Self::Avx2 => 8,
87            Self::Scalar => 1,
88        }
89    }
90
91    /// Number of scalar lanes for `f64`.
92    #[must_use]
93    pub const fn lanes_f64(self) -> usize {
94        match self {
95            Self::Sse2 => 2,
96            Self::Avx2 => 4,
97            Self::Scalar => 1,
98        }
99    }
100
101    /// Lowercase name used in generated identifiers.
102    #[must_use]
103    pub const fn ident_str(self) -> &'static str {
104        match self {
105            Self::Sse2 => "sse2",
106            Self::Avx2 => "avx2",
107            Self::Scalar => "scalar",
108        }
109    }
110}
111
112/// Floating-point precision for a multi-transform codelet.
113#[derive(Debug, Clone, Copy, PartialEq, Eq)]
114pub enum Precision {
115    /// 32-bit single precision.
116    F32,
117    /// 64-bit double precision.
118    F64,
119}
120
121impl Precision {
122    /// Lowercase type name used in generated identifiers and code.
123    #[must_use]
124    pub const fn type_str(self) -> &'static str {
125        match self {
126            Self::F32 => "f32",
127            Self::F64 => "f64",
128        }
129    }
130}
131
132/// Configuration for a vectorized multi-transform codelet.
133///
134/// Describes a (DFT size, `ISA`, V, precision) tuple used to emit a
135/// batch-of-V-transforms function at build time.
136#[derive(Debug, Clone)]
137pub struct MultiTransformConfig {
138    /// DFT size — must be 2, 4, or 8.
139    pub size: usize,
140    /// Number of simultaneous transforms (lane count: 4 for SSE2 f32, 8 for AVX2 f32, etc.).
141    pub v: usize,
142    /// Target ISA.
143    pub isa: SimdIsa,
144    /// `f32` or `f64`.
145    pub precision: Precision,
146}
147
148// ============================================================================
149// SIMD dispatch logic
150// ============================================================================
151
152/// Returns `true` when the (`ISA`, precision, size) combination has a true SIMD
153/// multi-transform implementation (`SoA` inner function).
154///
155/// - SSE2 f32: sizes 2 and 4
156/// - AVX2 f32: sizes 2, 4, and 8
157/// - All f64 combos: scalar fallback only
158const fn has_simd_impl(isa: SimdIsa, precision: Precision, size: usize) -> bool {
159    matches!(
160        (isa, precision, size),
161        (SimdIsa::Sse2, Precision::F32, 2 | 4) | (SimdIsa::Avx2, Precision::F32, 2 | 4 | 8)
162    )
163}
164
165/// Emit the inner `SoA` SIMD function `TokenStream` for the given config.
166///
167/// Returns `None` if the config has no SIMD implementation.
168fn gen_simd_inner(config: &MultiTransformConfig) -> Option<TokenStream> {
169    match (config.isa, config.precision, config.size) {
170        (SimdIsa::Sse2, Precision::F32, 2) => Some(simd_sse2_f32::gen_sse2_f32_v4_size2_soa()),
171        (SimdIsa::Sse2, Precision::F32, 4) => Some(simd_sse2_f32::gen_sse2_f32_v4_size4_soa()),
172        (SimdIsa::Avx2, Precision::F32, 2) => Some(simd_avx2_f32::gen_avx2_f32_v8_size2_soa()),
173        (SimdIsa::Avx2, Precision::F32, 4) => Some(simd_avx2_f32::gen_avx2_f32_v8_size4_soa()),
174        (SimdIsa::Avx2, Precision::F32, 8) => Some(simd_avx2_f32::gen_avx2_f32_v8_size8_soa()),
175        _ => None,
176    }
177}
178
179// ============================================================================
180// Code generation
181// ============================================================================
182
183/// Build the outer `AoS` function body for any config (scalar loop over all transforms).
184///
185/// The outer function always processes transforms sequentially (scalar `AoS` loop),
186/// regardless of whether a companion `SoA` SIMD function is also emitted.
187/// Callers that want true SIMD throughput should use the `_soa` companion directly.
188///
189/// # Panics
190///
191/// Panics only if internal constant string literals fail to parse — impossible
192/// in practice.
193fn gen_outer_body(config: &MultiTransformConfig, size: usize, v: usize) -> TokenStream {
194    let butterfly_body = scalar::gen_scalar_butterfly(size, config.precision);
195    let v_lit = v;
196    let size_lit = size;
197    quote! {
198        let batches = count / #v_lit;
199        let remainder = count % #v_lit;
200
201        for b in 0..batches {
202            for t in 0..#v_lit {
203                let base_in  = (b * #v_lit + t) * 2;
204                let base_out = (b * #v_lit + t) * 2;
205                #butterfly_body
206            }
207        }
208        for t in 0..remainder {
209            let base_in  = (batches * #v_lit + t) * 2;
210            let base_out = (batches * #v_lit + t) * 2;
211            #butterfly_body
212        }
213        let _ = #size_lit;
214    }
215}
216
217/// Generate a multi-transform codelet `TokenStream`.
218///
219/// # Output
220///
221/// Always emits a public outer function `notw_{size}_v{v}_{isa}_{ty}` with
222/// `AoS` signature `(input, output, istride, ostride, count)`.
223///
224/// For supported (`ISA`, precision, size) combinations (SSE2 f32 sizes 2/4,
225/// AVX2 f32 sizes 2/4/8), also emits a companion inner function
226/// `notw_{size}_v{v}_{isa}_{ty}_soa` with `SoA` signature
227/// `(re_in, im_in, re_out, im_out)` that is the **true SIMD implementation**.
228///
229/// # Errors
230///
231/// Returns a [`syn::Error`] when:
232/// - `config.size` is not one of 2, 4, or 8.
233/// - `config.v` is 0.
234///
235/// # Panics
236///
237/// Panics only if internal constant string literals that are guaranteed to be
238/// valid fail to parse as token streams — this cannot occur in practice.
239pub fn generate_multi_transform(config: &MultiTransformConfig) -> Result<TokenStream, syn::Error> {
240    if !matches!(config.size, 2 | 4 | 8) {
241        return Err(syn::Error::new(
242            proc_macro2::Span::call_site(),
243            format!(
244                "multi_transform: unsupported size {} (expected 2, 4, or 8)",
245                config.size
246            ),
247        ));
248    }
249    if config.v == 0 {
250        return Err(syn::Error::new(
251            proc_macro2::Span::call_site(),
252            "multi_transform: v must be >= 1",
253        ));
254    }
255
256    let fn_name = format_ident!(
257        "notw_{}_v{}_{}_{}",
258        config.size,
259        config.v,
260        config.isa.ident_str(),
261        config.precision.type_str()
262    );
263    let size = config.size;
264    let v = config.v;
265    let ty_str = config.precision.type_str();
266    let ty_tokens: TokenStream = ty_str.parse().expect("valid type token");
267
268    let use_simd = has_simd_impl(config.isa, config.precision, size);
269    let simd_inner = gen_simd_inner(config);
270    let outer_body = gen_outer_body(config, size, v);
271
272    let stride = v * 2;
273    let simd_note = if use_simd {
274        format!(
275            "True SIMD available via `notw_{size}_v{v}_{isa}_{ty}_soa` (`SoA` layout).",
276            isa = config.isa.ident_str(),
277            ty = ty_str,
278        )
279    } else {
280        "Sequential scalar fallback (no SIMD for this `ISA`+precision+size combination).".into()
281    };
282
283    let fn_doc = format!(
284        "Process `count` transforms of size {size} in batches of {v} (v={v}) using {isa} ISA.\n\n\
285         # Data layout (`AoS`)\n\
286         Interleaved with stride {v}: `data[element * {stride} + transform * 2 + c]`\n\
287         where `c` is 0 for real, 1 for imaginary.\n\n\
288         # SIMD acceleration\n\
289         {simd_note}\n\n\
290         # Safety\n\
291         - `input` must be valid for `count * {size} * 2 * {v}` reads of `{ty_str}`.\n\
292         - `output` must be valid for `count * {size} * 2 * {v}` writes of `{ty_str}`.\n\
293         - `istride` / `ostride` must be `2 * {v}` for the canonical `AoS` layout.\n\
294         - No alignment requirement; uses unaligned loads.",
295        size = size,
296        v = v,
297        isa = config.isa.ident_str(),
298        stride = stride,
299        ty_str = ty_str,
300        simd_note = simd_note,
301    );
302
303    let outer_fn = quote! {
304        #[doc = #fn_doc]
305        pub unsafe fn #fn_name(
306            input:   *const #ty_tokens,
307            output:  *mut   #ty_tokens,
308            istride: usize,
309            ostride: usize,
310            count:   usize,
311        ) {
312            #outer_body
313        }
314    };
315
316    Ok(if let Some(inner) = simd_inner {
317        quote! {
318            #inner
319            #outer_fn
320        }
321    } else {
322        outer_fn
323    })
324}
325
326// ============================================================================
327// Proc-macro entry point
328// ============================================================================
329
330/// Parsed arguments from `gen_multi_transform_codelet!(size=4, v=8, isa=avx2, ty=f32)`.
331struct MacroArgs {
332    size: usize,
333    v: usize,
334    isa: SimdIsa,
335    precision: Precision,
336}
337
338impl Parse for MacroArgs {
339    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
340        let mut size: Option<usize> = None;
341        let mut v: Option<usize> = None;
342        let mut isa: Option<SimdIsa> = None;
343        let mut precision: Option<Precision> = None;
344
345        while !input.is_empty() {
346            let key: syn::Ident = input.parse()?;
347            let _eq: Token![=] = input.parse()?;
348            match key.to_string().as_str() {
349                "size" => {
350                    let lit: LitInt = input.parse()?;
351                    size = Some(lit.base10_parse::<usize>().map_err(|_| {
352                        syn::Error::new(lit.span(), "expected an integer literal for `size`")
353                    })?);
354                }
355                "v" => {
356                    let lit: LitInt = input.parse()?;
357                    v = Some(lit.base10_parse::<usize>().map_err(|_| {
358                        syn::Error::new(lit.span(), "expected an integer literal for `v`")
359                    })?);
360                }
361                "isa" => {
362                    let ident: syn::Ident = input.parse()?;
363                    isa = Some(match ident.to_string().as_str() {
364                        "sse2" => SimdIsa::Sse2,
365                        "avx2" => SimdIsa::Avx2,
366                        "scalar" => SimdIsa::Scalar,
367                        other => {
368                            return Err(syn::Error::new(
369                                ident.span(),
370                                format!(
371                                    "unknown isa `{other}`, expected one of: sse2, avx2, scalar"
372                                ),
373                            ));
374                        }
375                    });
376                }
377                "ty" => {
378                    let ident: syn::Ident = input.parse()?;
379                    precision = Some(match ident.to_string().as_str() {
380                        "f32" => Precision::F32,
381                        "f64" => Precision::F64,
382                        other => {
383                            return Err(syn::Error::new(
384                                ident.span(),
385                                format!("unknown ty `{other}`, expected f32 or f64"),
386                            ));
387                        }
388                    });
389                }
390                other => {
391                    return Err(syn::Error::new(
392                        key.span(),
393                        format!("unknown key `{other}`, expected one of: size, v, isa, ty"),
394                    ));
395                }
396            }
397            if input.peek(Token![,]) {
398                let _: Token![,] = input.parse()?;
399            }
400        }
401
402        let size = size.ok_or_else(|| {
403            syn::Error::new(proc_macro2::Span::call_site(), "missing `size` argument")
404        })?;
405        let v = v.ok_or_else(|| {
406            syn::Error::new(proc_macro2::Span::call_site(), "missing `v` argument")
407        })?;
408        let isa = isa.ok_or_else(|| {
409            syn::Error::new(proc_macro2::Span::call_site(), "missing `isa` argument")
410        })?;
411        let precision = precision.ok_or_else(|| {
412            syn::Error::new(proc_macro2::Span::call_site(), "missing `ty` argument")
413        })?;
414
415        Ok(Self {
416            size,
417            v,
418            isa,
419            precision,
420        })
421    }
422}
423
424/// Entry point for the `gen_multi_transform_codelet!` proc-macro.
425///
426/// Parses `size=N, v=V, isa=ISA, ty=TY` from the token stream and calls
427/// [`generate_multi_transform`].
428///
429/// # Example
430/// ```ignore
431/// gen_multi_transform_codelet!(size = 4, v = 8, isa = avx2, ty = f32);
432/// ```
433///
434/// # Errors
435///
436/// Returns a [`syn::Error`] when the input does not parse as valid key-value
437/// pairs, a required key is missing, or `size` / `isa` / `ty` have unsupported
438/// values.
439pub fn generate_from_macro(input: TokenStream) -> Result<TokenStream, syn::Error> {
440    let args: MacroArgs = syn::parse2(input)?;
441    let config = MultiTransformConfig {
442        size: args.size,
443        v: args.v,
444        isa: args.isa,
445        precision: args.precision,
446    };
447    generate_multi_transform(&config)
448}