Skip to main content

oxifft_codegen_impl/gen_simd/
mod.rs

1//! SIMD codelet generation.
2//!
3//! Generates architecture-aware SIMD FFT codelets with multi-architecture dispatch.
4//! At compile time, generates:
5//! - AVX-512F variant (512-bit, `8×f64` / `16×f32`) for `x86_64`
6//! - AVX2+FMA variant (256-bit, `4×f64`) for `x86_64`
7//! - Pure-AVX variant (256-bit, `4×f64`, no FMA, no AVX2) for `x86_64`
8//! - SSE2 variant (128-bit, `2×f64`) for `x86_64`
9//! - NEON variant (128-bit, `2×f64`) for `aarch64`
10//! - Scalar fallback for all architectures
11//!
12//! The dispatcher function selects the best SIMD path at runtime using
13//! `is_x86_feature_detected!` (`x86_64`) or compile-time cfg (`aarch64`).
14//!
15//! Probe order for `x86_64`: AVX-512F > AVX2+FMA > AVX > SSE2 > scalar.
16//! AVX-512F is probed first (when the host supports it) to enable
17//! `_mm512_fmadd_pd`/`_mm512_fmsub_pd` based butterfly arithmetic.
18
19use proc_macro2::TokenStream;
20use quote::{format_ident, quote};
21use syn::LitInt;
22
23mod avx;
24mod avx2;
25mod avx512;
26mod neon;
27mod scalar;
28mod sse2;
29
30pub mod multi_transform;
31pub mod runtime_dispatch;
32
33/// Generate a SIMD-optimized codelet for the given FFT size.
34///
35/// Supports sizes 2, 4, 8 for all ISAs and size 16 (f32 only, AVX-512F or scalar).
36/// The macro generates:
37/// - A public dispatcher that picks the best SIMD path at runtime
38/// - Architecture-specific inner functions with `#[target_feature]`
39/// - A generic scalar fallback
40///
41/// AVX-512F is probed before AVX2+FMA for sizes that have AVX-512 emitters.
42///
43/// # Errors
44/// Returns a `syn::Error` when the input does not parse as a valid size literal,
45/// or when the size is not in the supported set {2, 4, 8, 16}.
46pub fn generate(input: TokenStream) -> Result<TokenStream, syn::Error> {
47    let size: LitInt = syn::parse2(input)?;
48    let n: usize = size.base10_parse().map_err(|_| {
49        syn::Error::new(
50            size.span(),
51            "gen_simd_codelet: expected an integer size literal",
52        )
53    })?;
54
55    match n {
56        2 => Ok(gen_simd_size_2()),
57        4 => Ok(gen_simd_size_4()),
58        8 => Ok(gen_simd_size_8()),
59        16 => Ok(gen_simd_size_16()),
60        _ => Err(syn::Error::new(
61            size.span(),
62            format!("gen_simd_codelet: unsupported size {n} (expected one of 2, 4, 8, 16)"),
63        )),
64    }
65}
66
67/// Generate size-2 SIMD butterfly codelet.
68///
69/// Size-2 butterfly: out[0] = a + b, out[1] = a - b
70///
71/// SIMD strategy:
72/// - AVX-512F (`x86_64`): 256-bit YMM under avx512f feature umbrella, f64 only
73/// - SSE2/AVX2 (`x86_64`): `__m128d` / `__m128` for f64/f32, vector add/sub
74/// - NEON (`aarch64`): `float64x2_t` / `float32x2_t` for f64/f32, vector add/sub
75fn gen_simd_size_2() -> TokenStream {
76    let dispatcher = gen_dispatcher(2);
77    let scalar = scalar::gen_scalar_size_2();
78    let sse2_f64 = sse2::gen_sse2_size_2();
79    let sse2_f32 = sse2::gen_sse2_size_2_f32();
80    let plain_avx_f64 = avx::gen_avx_size_2_f64();
81    let avx2_f64 = avx2::gen_avx2_size_2();
82    let avx2_f32 = avx2::gen_avx2_size_2_f32();
83    let avx512_f64 = avx512::gen_avx512_size_2_f64();
84    let avx512_f32 = avx512::gen_avx512_size_2_f32();
85    let neon_f64 = neon::gen_neon_size_2();
86    let neon_f32 = neon::gen_neon_size_2_f32();
87
88    quote! {
89        #dispatcher
90        #scalar
91        #sse2_f64
92        #sse2_f32
93        #plain_avx_f64
94        #avx2_f64
95        #avx2_f32
96        #avx512_f64
97        #avx512_f32
98        #neon_f64
99        #neon_f32
100    }
101}
102
103/// Generate size-4 SIMD radix-4 codelet.
104///
105/// Size-4 FFT: radix-4 butterfly with sign-dependent ±i rotation.
106///
107/// SIMD strategy:
108/// - AVX-512F (`x86_64`): 256-bit f64 + f32 butterfly under avx512f feature
109/// - SSE2/AVX2 (`x86_64`): `__m128d` / `__m128` for f64/f32, shuffle-based rotation
110/// - NEON (`aarch64`): `float64x2_t` / `float32x2_t` for f64/f32, ext-based rotation
111fn gen_simd_size_4() -> TokenStream {
112    let dispatcher = gen_dispatcher(4);
113    let scalar = scalar::gen_scalar_size_4();
114    let sse2_f64 = sse2::gen_sse2_size_4();
115    let sse2_f32 = sse2::gen_sse2_size_4_f32();
116    let plain_avx_f64 = avx::gen_avx_size_4_f64();
117    let avx2_f64 = avx2::gen_avx2_size_4();
118    let avx2_f32 = avx2::gen_avx2_size_4_f32();
119    let avx512_f64 = avx512::gen_avx512_size_4_f64();
120    let avx512_f32 = avx512::gen_avx512_size_4_f32();
121    let neon_f64 = neon::gen_neon_size_4();
122    let neon_f32 = neon::gen_neon_size_4_f32();
123
124    quote! {
125        #dispatcher
126        #scalar
127        #sse2_f64
128        #sse2_f32
129        #plain_avx_f64
130        #avx2_f64
131        #avx2_f32
132        #avx512_f64
133        #avx512_f32
134        #neon_f64
135        #neon_f32
136    }
137}
138
139/// Generate size-8 SIMD radix-8 codelet.
140///
141/// Size-8 FFT: radix-2 DIT with 3 butterfly stages.
142///
143/// SIMD strategy:
144/// - AVX-512F (`x86_64`): 256-bit + FMA via ZMM promotion for twiddles
145/// - SSE2/AVX2 (`x86_64`): `__m128d` / `__m128` for f64/f32, FMA twiddles
146/// - NEON (`aarch64`): `float64x2_t` / `float32x2_t` for f64/f32, FMA twiddles
147fn gen_simd_size_8() -> TokenStream {
148    let dispatcher = gen_dispatcher(8);
149    let scalar = scalar::gen_scalar_size_8();
150    let sse2_f64 = sse2::gen_sse2_size_8();
151    let sse2_f32 = sse2::gen_sse2_size_8_f32();
152    let plain_avx_f64 = avx::gen_avx_size_8_f64();
153    let avx2_f64 = avx2::gen_avx2_size_8();
154    let avx2_f32 = avx2::gen_avx2_size_8_f32();
155    let avx512_f64 = avx512::gen_avx512_size_8_f64();
156    let avx512_f32 = avx512::gen_avx512_size_8_f32();
157    let neon_f64 = neon::gen_neon_size_8();
158    let neon_f32 = neon::gen_neon_size_8_f32();
159
160    quote! {
161        #dispatcher
162        #scalar
163        #sse2_f64
164        #sse2_f32
165        #plain_avx_f64
166        #avx2_f64
167        #avx2_f32
168        #avx512_f64
169        #avx512_f32
170        #neon_f64
171        #neon_f32
172    }
173}
174
175/// Generate size-16 SIMD radix-2 DIT codelet (f32 only via AVX-512F).
176///
177/// Size-16 FFT: radix-2 DIT with 4 butterfly stages.
178///
179/// SIMD strategy:
180/// - AVX-512F (`x86_64`): full `__m512` 16-lane f32 butterfly with FMA W16 twiddles
181/// - Scalar fallback for all other architectures (AVX2/SSE2/NEON lack this size)
182fn gen_simd_size_16() -> TokenStream {
183    let dispatcher = gen_dispatcher_16();
184    let scalar = scalar::gen_scalar_size_16();
185    let avx512_f32 = avx512::gen_avx512_size_16_f32();
186
187    quote! {
188        #dispatcher
189        #scalar
190        #avx512_f32
191    }
192}
193
194// ---------------------------------------------------------------------------
195// Dispatcher generation
196// ---------------------------------------------------------------------------
197
198/// Generate the public dispatcher function for sizes 2, 4, 8 (all ISAs).
199///
200/// Priority on `x86_64` (f64): AVX-512F > AVX2+FMA > AVX > SSE2 > scalar.
201/// Priority on `x86_64` (f32): AVX-512F > AVX2+FMA > SSE2 > scalar (no pure-AVX f32 path).
202/// Priority on `aarch64`: NEON > scalar.
203///
204/// The dispatcher:
205/// 1. Checks `T` via `core::any::TypeId` (since `Float: 'static`):
206///    - `f64`: uses f64 SIMD path
207///    - `f32`: uses f32 SIMD path
208///    - other: scalar fallback
209/// 2. On `x86_64`: probes AVX-512F first, then AVX2+FMA, then SSE2
210/// 3. On `aarch64`: uses NEON unconditionally
211/// 4. Falls back to the generic scalar implementation
212fn gen_dispatcher(n: usize) -> proc_macro2::TokenStream {
213    let fn_name = format_ident!("codelet_simd_{}", n);
214    let scalar_name = format_ident!("codelet_simd_{}_scalar", n);
215    let avx512_f64_name = format_ident!("codelet_simd_{}_avx512_f64", n);
216    let avx2_f64_name = format_ident!("codelet_simd_{}_avx2_f64", n);
217    let plain_avx_fn_name = format_ident!("codelet_simd_{}_avx_f64", n);
218    let sse2_f64_name = format_ident!("codelet_simd_{}_sse2_f64", n);
219    let neon_f64_name = format_ident!("codelet_simd_{}_neon_f64", n);
220    let avx512_f32_name = format_ident!("codelet_simd_{}_avx512_f32", n);
221    let avx2_f32_name = format_ident!("codelet_simd_{}_avx2_f32", n);
222    let sse2_f32_name = format_ident!("codelet_simd_{}_sse2_f32", n);
223    let neon_f32_name = format_ident!("codelet_simd_{}_neon_f32", n);
224    let n_lit = n;
225    let doc = format!(
226        "Size-{n} SIMD-optimized FFT codelet with architecture dispatch.\n\n\
227         Automatically selects the best SIMD path at runtime:\n\
228         - x86_64: AVX-512F > AVX2+FMA > AVX > SSE2 > scalar  (f64; f32 uses AVX2/SSE2/512)\n\
229         - aarch64: NEON > scalar                               (both f64 and f32)\n\
230         - other: scalar fallback"
231    );
232
233    quote! {
234        #[doc = #doc]
235        #[inline]
236        pub fn #fn_name<T: crate::kernel::Float>(
237            data: &mut [crate::kernel::Complex<T>],
238            sign: i32,
239        ) {
240            debug_assert!(
241                data.len() >= #n_lit,
242                "codelet_simd_{}: need >= {} elements, got {}",
243                #n_lit,
244                #n_lit,
245                data.len(),
246            );
247
248            // Fast path: f64 SIMD
249            if core::any::TypeId::of::<T>() == core::any::TypeId::of::<f64>() {
250                // Safety: Complex<T> is #[repr(C)] with (re, im) fields.
251                // When T == f64, &mut [Complex<f64>] has the same layout as
252                // &mut [f64] with twice the length: [re0, im0, re1, im1, ...].
253                let len = data.len() * 2;
254                let ptr = data.as_mut_ptr().cast::<f64>();
255                let f64_data = unsafe { core::slice::from_raw_parts_mut(ptr, len) };
256
257                #[cfg(target_arch = "x86_64")]
258                {
259                    if is_x86_feature_detected!("avx512f") {
260                        // Safety: AVX-512F detected, pointer valid for len f64s
261                        unsafe { #avx512_f64_name(f64_data, sign); }
262                        return;
263                    }
264                    if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
265                        // Safety: AVX2+FMA detected, pointer valid for len f64s
266                        unsafe { #avx2_f64_name(f64_data, sign); }
267                        return;
268                    }
269                    // Pure AVX (no FMA, no AVX2) — probe after AVX2+FMA, before SSE2
270                    if is_x86_feature_detected!("avx") {
271                        // Safety: AVX detected (superset of SSE2), pointer valid
272                        unsafe { #plain_avx_fn_name(f64_data, sign); }
273                        return;
274                    }
275                    if is_x86_feature_detected!("sse2") {
276                        // Safety: SSE2 detected (guaranteed on x86_64), pointer valid
277                        unsafe { #sse2_f64_name(f64_data, sign); }
278                        return;
279                    }
280                }
281
282                #[cfg(target_arch = "aarch64")]
283                {
284                    // NEON is mandatory on aarch64
285                    unsafe { #neon_f64_name(f64_data, sign); }
286                    return;
287                }
288            }
289
290            // Fast path: f32 SIMD
291            if core::any::TypeId::of::<T>() == core::any::TypeId::of::<f32>() {
292                // Safety: Complex<T> is #[repr(C)] with (re, im) fields.
293                // When T == f32, &mut [Complex<f32>] has the same layout as
294                // &mut [f32] with twice the length: [re0, im0, re1, im1, ...].
295                let len = data.len() * 2;
296                let ptr = data.as_mut_ptr().cast::<f32>();
297                let f32_data = unsafe { core::slice::from_raw_parts_mut(ptr, len) };
298
299                #[cfg(target_arch = "x86_64")]
300                {
301                    if is_x86_feature_detected!("avx512f") {
302                        // Safety: AVX-512F detected
303                        unsafe { #avx512_f32_name(f32_data, sign); }
304                        return;
305                    }
306                    if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
307                        unsafe { #avx2_f32_name(f32_data, sign); }
308                        return;
309                    }
310                    // f32 has no dedicated AVX (non-AVX2) path; fall through to SSE2
311                    if is_x86_feature_detected!("sse2") {
312                        unsafe { #sse2_f32_name(f32_data, sign); }
313                        return;
314                    }
315                }
316
317                #[cfg(target_arch = "aarch64")]
318                {
319                    // NEON is mandatory on aarch64
320                    unsafe { #neon_f32_name(f32_data, sign); }
321                    return;
322                }
323            }
324
325            // Scalar fallback for other float types or unsupported architectures
326            #scalar_name(data, sign);
327        }
328    }
329}
330
331/// Generate the dispatcher for size-16 (f32 only via AVX-512F; scalar fallback).
332///
333/// Size-16 is only available as f32 via AVX-512F. All other paths fall through
334/// to the scalar implementation, including f64 (no size-16 f64 SIMD emitter).
335fn gen_dispatcher_16() -> proc_macro2::TokenStream {
336    let avx512_f32_name = format_ident!("codelet_simd_16_avx512_f32");
337    let scalar_name = format_ident!("codelet_simd_16_scalar");
338
339    quote! {
340        /// Size-16 SIMD-optimized FFT codelet.
341        ///
342        /// Selects AVX-512F f32 path when available; otherwise falls back to scalar.
343        /// No f64 SIMD path at size 16 (scalar is used instead).
344        ///
345        /// - x86_64 + avx512f: `__m512` 16-lane f32 butterfly with FMA twiddles
346        /// - all other: scalar fallback
347        #[inline]
348        pub fn codelet_simd_16<T: crate::kernel::Float>(
349            data: &mut [crate::kernel::Complex<T>],
350            sign: i32,
351        ) {
352            debug_assert!(
353                data.len() >= 16_usize,
354                "codelet_simd_16: need >= 16 elements, got {}",
355                data.len(),
356            );
357
358            // AVX-512F f32 path only
359            if core::any::TypeId::of::<T>() == core::any::TypeId::of::<f32>() {
360                let len = data.len() * 2;
361                let ptr = data.as_mut_ptr().cast::<f32>();
362                let f32_data = unsafe { core::slice::from_raw_parts_mut(ptr, len) };
363
364                #[cfg(target_arch = "x86_64")]
365                {
366                    if is_x86_feature_detected!("avx512f") {
367                        // Safety: AVX-512F detected, pointer valid for len f32s
368                        unsafe { #avx512_f32_name(f32_data, sign); }
369                        return;
370                    }
371                }
372            }
373
374            // Scalar fallback for f64, other types, or no AVX-512F
375            #scalar_name(data, sign);
376        }
377    }
378}