Skip to main content

oxifft_codegen_impl/gen_simd/
mod.rs

1//! SIMD codelet generation.
2//!
3//! Generates architecture-aware SIMD FFT codelets with multi-architecture dispatch.
4//! At compile time, generates:
5//! - AVX-512F variant (512-bit, `8×f64` / `16×f32`) for `x86_64`
6//! - AVX2+FMA variant (256-bit, `4×f64`) for `x86_64`
7//! - Pure-AVX variant (256-bit, `4×f64`, no FMA, no AVX2) for `x86_64`
8//! - SSE2 variant (128-bit, `2×f64`) for `x86_64`
9//! - NEON variant (128-bit, `2×f64`) for `aarch64`
10//! - Scalar fallback for all architectures
11//!
12//! The dispatcher function selects the best SIMD path at runtime using
13//! `is_x86_feature_detected!` (`x86_64`) or compile-time cfg (`aarch64`).
14//!
15//! Probe order for `x86_64`: AVX-512F > AVX2+FMA > AVX > SSE2 > scalar.
16//! AVX-512F is probed first (when the host supports it) to enable
17//! `_mm512_fmadd_pd`/`_mm512_fmsub_pd` based butterfly arithmetic.
18
19use proc_macro2::TokenStream;
20use quote::{format_ident, quote};
21use syn::LitInt;
22
23mod avx;
24mod avx2;
25mod avx512;
26mod neon;
27mod scalar;
28mod sse2;
29
30/// Generate a SIMD-optimized codelet for the given FFT size.
31///
32/// Supports sizes 2, 4, 8 for all ISAs and size 16 (f32 only, AVX-512F or scalar).
33/// The macro generates:
34/// - A public dispatcher that picks the best SIMD path at runtime
35/// - Architecture-specific inner functions with `#[target_feature]`
36/// - A generic scalar fallback
37///
38/// AVX-512F is probed before AVX2+FMA for sizes that have AVX-512 emitters.
39///
40/// # Errors
41/// Returns a `syn::Error` when the input does not parse as a valid size literal,
42/// or when the size is not in the supported set {2, 4, 8, 16}.
43pub fn generate(input: TokenStream) -> Result<TokenStream, syn::Error> {
44    let size: LitInt = syn::parse2(input)?;
45    let n: usize = size.base10_parse().map_err(|_| {
46        syn::Error::new(
47            size.span(),
48            "gen_simd_codelet: expected an integer size literal",
49        )
50    })?;
51
52    match n {
53        2 => Ok(gen_simd_size_2()),
54        4 => Ok(gen_simd_size_4()),
55        8 => Ok(gen_simd_size_8()),
56        16 => Ok(gen_simd_size_16()),
57        _ => Err(syn::Error::new(
58            size.span(),
59            format!("gen_simd_codelet: unsupported size {n} (expected one of 2, 4, 8, 16)"),
60        )),
61    }
62}
63
64/// Generate size-2 SIMD butterfly codelet.
65///
66/// Size-2 butterfly: out[0] = a + b, out[1] = a - b
67///
68/// SIMD strategy:
69/// - AVX-512F (`x86_64`): 256-bit YMM under avx512f feature umbrella, f64 only
70/// - SSE2/AVX2 (`x86_64`): `__m128d` / `__m128` for f64/f32, vector add/sub
71/// - NEON (`aarch64`): `float64x2_t` / `float32x2_t` for f64/f32, vector add/sub
72fn gen_simd_size_2() -> TokenStream {
73    let dispatcher = gen_dispatcher(2);
74    let scalar = scalar::gen_scalar_size_2();
75    let sse2_f64 = sse2::gen_sse2_size_2();
76    let sse2_f32 = sse2::gen_sse2_size_2_f32();
77    let plain_avx_f64 = avx::gen_avx_size_2_f64();
78    let avx2_f64 = avx2::gen_avx2_size_2();
79    let avx2_f32 = avx2::gen_avx2_size_2_f32();
80    let avx512_f64 = avx512::gen_avx512_size_2_f64();
81    let avx512_f32 = avx512::gen_avx512_size_2_f32();
82    let neon_f64 = neon::gen_neon_size_2();
83    let neon_f32 = neon::gen_neon_size_2_f32();
84
85    quote! {
86        #dispatcher
87        #scalar
88        #sse2_f64
89        #sse2_f32
90        #plain_avx_f64
91        #avx2_f64
92        #avx2_f32
93        #avx512_f64
94        #avx512_f32
95        #neon_f64
96        #neon_f32
97    }
98}
99
100/// Generate size-4 SIMD radix-4 codelet.
101///
102/// Size-4 FFT: radix-4 butterfly with sign-dependent ±i rotation.
103///
104/// SIMD strategy:
105/// - AVX-512F (`x86_64`): 256-bit f64 + f32 butterfly under avx512f feature
106/// - SSE2/AVX2 (`x86_64`): `__m128d` / `__m128` for f64/f32, shuffle-based rotation
107/// - NEON (`aarch64`): `float64x2_t` / `float32x2_t` for f64/f32, ext-based rotation
108fn gen_simd_size_4() -> TokenStream {
109    let dispatcher = gen_dispatcher(4);
110    let scalar = scalar::gen_scalar_size_4();
111    let sse2_f64 = sse2::gen_sse2_size_4();
112    let sse2_f32 = sse2::gen_sse2_size_4_f32();
113    let plain_avx_f64 = avx::gen_avx_size_4_f64();
114    let avx2_f64 = avx2::gen_avx2_size_4();
115    let avx2_f32 = avx2::gen_avx2_size_4_f32();
116    let avx512_f64 = avx512::gen_avx512_size_4_f64();
117    let avx512_f32 = avx512::gen_avx512_size_4_f32();
118    let neon_f64 = neon::gen_neon_size_4();
119    let neon_f32 = neon::gen_neon_size_4_f32();
120
121    quote! {
122        #dispatcher
123        #scalar
124        #sse2_f64
125        #sse2_f32
126        #plain_avx_f64
127        #avx2_f64
128        #avx2_f32
129        #avx512_f64
130        #avx512_f32
131        #neon_f64
132        #neon_f32
133    }
134}
135
136/// Generate size-8 SIMD radix-8 codelet.
137///
138/// Size-8 FFT: radix-2 DIT with 3 butterfly stages.
139///
140/// SIMD strategy:
141/// - AVX-512F (`x86_64`): 256-bit + FMA via ZMM promotion for twiddles
142/// - SSE2/AVX2 (`x86_64`): `__m128d` / `__m128` for f64/f32, FMA twiddles
143/// - NEON (`aarch64`): `float64x2_t` / `float32x2_t` for f64/f32, FMA twiddles
144fn gen_simd_size_8() -> TokenStream {
145    let dispatcher = gen_dispatcher(8);
146    let scalar = scalar::gen_scalar_size_8();
147    let sse2_f64 = sse2::gen_sse2_size_8();
148    let sse2_f32 = sse2::gen_sse2_size_8_f32();
149    let plain_avx_f64 = avx::gen_avx_size_8_f64();
150    let avx2_f64 = avx2::gen_avx2_size_8();
151    let avx2_f32 = avx2::gen_avx2_size_8_f32();
152    let avx512_f64 = avx512::gen_avx512_size_8_f64();
153    let avx512_f32 = avx512::gen_avx512_size_8_f32();
154    let neon_f64 = neon::gen_neon_size_8();
155    let neon_f32 = neon::gen_neon_size_8_f32();
156
157    quote! {
158        #dispatcher
159        #scalar
160        #sse2_f64
161        #sse2_f32
162        #plain_avx_f64
163        #avx2_f64
164        #avx2_f32
165        #avx512_f64
166        #avx512_f32
167        #neon_f64
168        #neon_f32
169    }
170}
171
172/// Generate size-16 SIMD radix-2 DIT codelet (f32 only via AVX-512F).
173///
174/// Size-16 FFT: radix-2 DIT with 4 butterfly stages.
175///
176/// SIMD strategy:
177/// - AVX-512F (`x86_64`): full `__m512` 16-lane f32 butterfly with FMA W16 twiddles
178/// - Scalar fallback for all other architectures (AVX2/SSE2/NEON lack this size)
179fn gen_simd_size_16() -> TokenStream {
180    let dispatcher = gen_dispatcher_16();
181    let scalar = scalar::gen_scalar_size_16();
182    let avx512_f32 = avx512::gen_avx512_size_16_f32();
183
184    quote! {
185        #dispatcher
186        #scalar
187        #avx512_f32
188    }
189}
190
191// ---------------------------------------------------------------------------
192// Dispatcher generation
193// ---------------------------------------------------------------------------
194
195/// Generate the public dispatcher function for sizes 2, 4, 8 (all ISAs).
196///
197/// Priority on `x86_64` (f64): AVX-512F > AVX2+FMA > AVX > SSE2 > scalar.
198/// Priority on `x86_64` (f32): AVX-512F > AVX2+FMA > SSE2 > scalar (no pure-AVX f32 path).
199/// Priority on `aarch64`: NEON > scalar.
200///
201/// The dispatcher:
202/// 1. Checks `T` via `core::any::TypeId` (since `Float: 'static`):
203///    - `f64`: uses f64 SIMD path
204///    - `f32`: uses f32 SIMD path
205///    - other: scalar fallback
206/// 2. On `x86_64`: probes AVX-512F first, then AVX2+FMA, then SSE2
207/// 3. On `aarch64`: uses NEON unconditionally
208/// 4. Falls back to the generic scalar implementation
209fn gen_dispatcher(n: usize) -> proc_macro2::TokenStream {
210    let fn_name = format_ident!("codelet_simd_{}", n);
211    let scalar_name = format_ident!("codelet_simd_{}_scalar", n);
212    let avx512_f64_name = format_ident!("codelet_simd_{}_avx512_f64", n);
213    let avx2_f64_name = format_ident!("codelet_simd_{}_avx2_f64", n);
214    let plain_avx_fn_name = format_ident!("codelet_simd_{}_avx_f64", n);
215    let sse2_f64_name = format_ident!("codelet_simd_{}_sse2_f64", n);
216    let neon_f64_name = format_ident!("codelet_simd_{}_neon_f64", n);
217    let avx512_f32_name = format_ident!("codelet_simd_{}_avx512_f32", n);
218    let avx2_f32_name = format_ident!("codelet_simd_{}_avx2_f32", n);
219    let sse2_f32_name = format_ident!("codelet_simd_{}_sse2_f32", n);
220    let neon_f32_name = format_ident!("codelet_simd_{}_neon_f32", n);
221    let n_lit = n;
222    let doc = format!(
223        "Size-{n} SIMD-optimized FFT codelet with architecture dispatch.\n\n\
224         Automatically selects the best SIMD path at runtime:\n\
225         - x86_64: AVX-512F > AVX2+FMA > AVX > SSE2 > scalar  (f64; f32 uses AVX2/SSE2/512)\n\
226         - aarch64: NEON > scalar                               (both f64 and f32)\n\
227         - other: scalar fallback"
228    );
229
230    quote! {
231        #[doc = #doc]
232        #[inline]
233        pub fn #fn_name<T: crate::kernel::Float>(
234            data: &mut [crate::kernel::Complex<T>],
235            sign: i32,
236        ) {
237            debug_assert!(
238                data.len() >= #n_lit,
239                "codelet_simd_{}: need >= {} elements, got {}",
240                #n_lit,
241                #n_lit,
242                data.len(),
243            );
244
245            // Fast path: f64 SIMD
246            if core::any::TypeId::of::<T>() == core::any::TypeId::of::<f64>() {
247                // Safety: Complex<T> is #[repr(C)] with (re, im) fields.
248                // When T == f64, &mut [Complex<f64>] has the same layout as
249                // &mut [f64] with twice the length: [re0, im0, re1, im1, ...].
250                let len = data.len() * 2;
251                let ptr = data.as_mut_ptr().cast::<f64>();
252                let f64_data = unsafe { core::slice::from_raw_parts_mut(ptr, len) };
253
254                #[cfg(target_arch = "x86_64")]
255                {
256                    if is_x86_feature_detected!("avx512f") {
257                        // Safety: AVX-512F detected, pointer valid for len f64s
258                        unsafe { #avx512_f64_name(f64_data, sign); }
259                        return;
260                    }
261                    if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
262                        // Safety: AVX2+FMA detected, pointer valid for len f64s
263                        unsafe { #avx2_f64_name(f64_data, sign); }
264                        return;
265                    }
266                    // Pure AVX (no FMA, no AVX2) — probe after AVX2+FMA, before SSE2
267                    if is_x86_feature_detected!("avx") {
268                        // Safety: AVX detected (superset of SSE2), pointer valid
269                        unsafe { #plain_avx_fn_name(f64_data, sign); }
270                        return;
271                    }
272                    if is_x86_feature_detected!("sse2") {
273                        // Safety: SSE2 detected (guaranteed on x86_64), pointer valid
274                        unsafe { #sse2_f64_name(f64_data, sign); }
275                        return;
276                    }
277                }
278
279                #[cfg(target_arch = "aarch64")]
280                {
281                    // NEON is mandatory on aarch64
282                    unsafe { #neon_f64_name(f64_data, sign); }
283                    return;
284                }
285            }
286
287            // Fast path: f32 SIMD
288            if core::any::TypeId::of::<T>() == core::any::TypeId::of::<f32>() {
289                // Safety: Complex<T> is #[repr(C)] with (re, im) fields.
290                // When T == f32, &mut [Complex<f32>] has the same layout as
291                // &mut [f32] with twice the length: [re0, im0, re1, im1, ...].
292                let len = data.len() * 2;
293                let ptr = data.as_mut_ptr().cast::<f32>();
294                let f32_data = unsafe { core::slice::from_raw_parts_mut(ptr, len) };
295
296                #[cfg(target_arch = "x86_64")]
297                {
298                    if is_x86_feature_detected!("avx512f") {
299                        // Safety: AVX-512F detected
300                        unsafe { #avx512_f32_name(f32_data, sign); }
301                        return;
302                    }
303                    if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
304                        unsafe { #avx2_f32_name(f32_data, sign); }
305                        return;
306                    }
307                    // f32 has no dedicated AVX (non-AVX2) path; fall through to SSE2
308                    if is_x86_feature_detected!("sse2") {
309                        unsafe { #sse2_f32_name(f32_data, sign); }
310                        return;
311                    }
312                }
313
314                #[cfg(target_arch = "aarch64")]
315                {
316                    // NEON is mandatory on aarch64
317                    unsafe { #neon_f32_name(f32_data, sign); }
318                    return;
319                }
320            }
321
322            // Scalar fallback for other float types or unsupported architectures
323            #scalar_name(data, sign);
324        }
325    }
326}
327
328/// Generate the dispatcher for size-16 (f32 only via AVX-512F; scalar fallback).
329///
330/// Size-16 is only available as f32 via AVX-512F. All other paths fall through
331/// to the scalar implementation, including f64 (no size-16 f64 SIMD emitter).
332fn gen_dispatcher_16() -> proc_macro2::TokenStream {
333    let avx512_f32_name = format_ident!("codelet_simd_16_avx512_f32");
334    let scalar_name = format_ident!("codelet_simd_16_scalar");
335
336    quote! {
337        /// Size-16 SIMD-optimized FFT codelet.
338        ///
339        /// Selects AVX-512F f32 path when available; otherwise falls back to scalar.
340        /// No f64 SIMD path at size 16 (scalar is used instead).
341        ///
342        /// - x86_64 + avx512f: `__m512` 16-lane f32 butterfly with FMA twiddles
343        /// - all other: scalar fallback
344        #[inline]
345        pub fn codelet_simd_16<T: crate::kernel::Float>(
346            data: &mut [crate::kernel::Complex<T>],
347            sign: i32,
348        ) {
349            debug_assert!(
350                data.len() >= 16_usize,
351                "codelet_simd_16: need >= 16 elements, got {}",
352                data.len(),
353            );
354
355            // AVX-512F f32 path only
356            if core::any::TypeId::of::<T>() == core::any::TypeId::of::<f32>() {
357                let len = data.len() * 2;
358                let ptr = data.as_mut_ptr().cast::<f32>();
359                let f32_data = unsafe { core::slice::from_raw_parts_mut(ptr, len) };
360
361                #[cfg(target_arch = "x86_64")]
362                {
363                    if is_x86_feature_detected!("avx512f") {
364                        // Safety: AVX-512F detected, pointer valid for len f32s
365                        unsafe { #avx512_f32_name(f32_data, sign); }
366                        return;
367                    }
368                }
369            }
370
371            // Scalar fallback for f64, other types, or no AVX-512F
372            #scalar_name(data, sign);
373        }
374    }
375}