oxifft_codegen_impl/gen_simd/mod.rs
1//! SIMD codelet generation.
2//!
3//! Generates architecture-aware SIMD FFT codelets with multi-architecture dispatch.
4//! At compile time, generates:
5//! - AVX-512F variant (512-bit, `8×f64` / `16×f32`) for `x86_64`
6//! - AVX2+FMA variant (256-bit, `4×f64`) for `x86_64`
7//! - Pure-AVX variant (256-bit, `4×f64`, no FMA, no AVX2) for `x86_64`
8//! - SSE2 variant (128-bit, `2×f64`) for `x86_64`
9//! - NEON variant (128-bit, `2×f64`) for `aarch64`
10//! - Scalar fallback for all architectures
11//!
12//! The dispatcher function selects the best SIMD path at runtime using
13//! `is_x86_feature_detected!` (`x86_64`) or compile-time cfg (`aarch64`).
14//!
15//! Probe order for `x86_64`: AVX-512F > AVX2+FMA > AVX > SSE2 > scalar.
16//! AVX-512F is probed first (when the host supports it) to enable
17//! `_mm512_fmadd_pd`/`_mm512_fmsub_pd` based butterfly arithmetic.
18
19use proc_macro2::TokenStream;
20use quote::{format_ident, quote};
21use syn::LitInt;
22
23mod avx;
24mod avx2;
25mod avx512;
26mod neon;
27mod scalar;
28mod sse2;
29
30pub mod multi_transform;
31pub mod runtime_dispatch;
32
33/// Generate a SIMD-optimized codelet for the given FFT size.
34///
35/// Supports sizes 2, 4, 8 for all ISAs and size 16 (f32 only, AVX-512F or scalar).
36/// The macro generates:
37/// - A public dispatcher that picks the best SIMD path at runtime
38/// - Architecture-specific inner functions with `#[target_feature]`
39/// - A generic scalar fallback
40///
41/// AVX-512F is probed before AVX2+FMA for sizes that have AVX-512 emitters.
42///
43/// # Errors
44/// Returns a `syn::Error` when the input does not parse as a valid size literal,
45/// or when the size is not in the supported set {2, 4, 8, 16}.
46pub fn generate(input: TokenStream) -> Result<TokenStream, syn::Error> {
47 let size: LitInt = syn::parse2(input)?;
48 let n: usize = size.base10_parse().map_err(|_| {
49 syn::Error::new(
50 size.span(),
51 "gen_simd_codelet: expected an integer size literal",
52 )
53 })?;
54
55 match n {
56 2 => Ok(gen_simd_size_2()),
57 4 => Ok(gen_simd_size_4()),
58 8 => Ok(gen_simd_size_8()),
59 16 => Ok(gen_simd_size_16()),
60 _ => Err(syn::Error::new(
61 size.span(),
62 format!("gen_simd_codelet: unsupported size {n} (expected one of 2, 4, 8, 16)"),
63 )),
64 }
65}
66
67/// Generate size-2 SIMD butterfly codelet.
68///
69/// Size-2 butterfly: out[0] = a + b, out[1] = a - b
70///
71/// SIMD strategy:
72/// - AVX-512F (`x86_64`): 256-bit YMM under avx512f feature umbrella, f64 only
73/// - SSE2/AVX2 (`x86_64`): `__m128d` / `__m128` for f64/f32, vector add/sub
74/// - NEON (`aarch64`): `float64x2_t` / `float32x2_t` for f64/f32, vector add/sub
75fn gen_simd_size_2() -> TokenStream {
76 let dispatcher = gen_dispatcher(2);
77 let scalar = scalar::gen_scalar_size_2();
78 let sse2_f64 = sse2::gen_sse2_size_2();
79 let sse2_f32 = sse2::gen_sse2_size_2_f32();
80 let plain_avx_f64 = avx::gen_avx_size_2_f64();
81 let avx2_f64 = avx2::gen_avx2_size_2();
82 let avx2_f32 = avx2::gen_avx2_size_2_f32();
83 let avx512_f64 = avx512::gen_avx512_size_2_f64();
84 let avx512_f32 = avx512::gen_avx512_size_2_f32();
85 let neon_f64 = neon::gen_neon_size_2();
86 let neon_f32 = neon::gen_neon_size_2_f32();
87
88 quote! {
89 #dispatcher
90 #scalar
91 #sse2_f64
92 #sse2_f32
93 #plain_avx_f64
94 #avx2_f64
95 #avx2_f32
96 #avx512_f64
97 #avx512_f32
98 #neon_f64
99 #neon_f32
100 }
101}
102
103/// Generate size-4 SIMD radix-4 codelet.
104///
105/// Size-4 FFT: radix-4 butterfly with sign-dependent ±i rotation.
106///
107/// SIMD strategy:
108/// - AVX-512F (`x86_64`): 256-bit f64 + f32 butterfly under avx512f feature
109/// - SSE2/AVX2 (`x86_64`): `__m128d` / `__m128` for f64/f32, shuffle-based rotation
110/// - NEON (`aarch64`): `float64x2_t` / `float32x2_t` for f64/f32, ext-based rotation
111fn gen_simd_size_4() -> TokenStream {
112 let dispatcher = gen_dispatcher(4);
113 let scalar = scalar::gen_scalar_size_4();
114 let sse2_f64 = sse2::gen_sse2_size_4();
115 let sse2_f32 = sse2::gen_sse2_size_4_f32();
116 let plain_avx_f64 = avx::gen_avx_size_4_f64();
117 let avx2_f64 = avx2::gen_avx2_size_4();
118 let avx2_f32 = avx2::gen_avx2_size_4_f32();
119 let avx512_f64 = avx512::gen_avx512_size_4_f64();
120 let avx512_f32 = avx512::gen_avx512_size_4_f32();
121 let neon_f64 = neon::gen_neon_size_4();
122 let neon_f32 = neon::gen_neon_size_4_f32();
123
124 quote! {
125 #dispatcher
126 #scalar
127 #sse2_f64
128 #sse2_f32
129 #plain_avx_f64
130 #avx2_f64
131 #avx2_f32
132 #avx512_f64
133 #avx512_f32
134 #neon_f64
135 #neon_f32
136 }
137}
138
139/// Generate size-8 SIMD radix-8 codelet.
140///
141/// Size-8 FFT: radix-2 DIT with 3 butterfly stages.
142///
143/// SIMD strategy:
144/// - AVX-512F (`x86_64`): 256-bit + FMA via ZMM promotion for twiddles
145/// - SSE2/AVX2 (`x86_64`): `__m128d` / `__m128` for f64/f32, FMA twiddles
146/// - NEON (`aarch64`): `float64x2_t` / `float32x2_t` for f64/f32, FMA twiddles
147fn gen_simd_size_8() -> TokenStream {
148 let dispatcher = gen_dispatcher(8);
149 let scalar = scalar::gen_scalar_size_8();
150 let sse2_f64 = sse2::gen_sse2_size_8();
151 let sse2_f32 = sse2::gen_sse2_size_8_f32();
152 let plain_avx_f64 = avx::gen_avx_size_8_f64();
153 let avx2_f64 = avx2::gen_avx2_size_8();
154 let avx2_f32 = avx2::gen_avx2_size_8_f32();
155 let avx512_f64 = avx512::gen_avx512_size_8_f64();
156 let avx512_f32 = avx512::gen_avx512_size_8_f32();
157 let neon_f64 = neon::gen_neon_size_8();
158 let neon_f32 = neon::gen_neon_size_8_f32();
159
160 quote! {
161 #dispatcher
162 #scalar
163 #sse2_f64
164 #sse2_f32
165 #plain_avx_f64
166 #avx2_f64
167 #avx2_f32
168 #avx512_f64
169 #avx512_f32
170 #neon_f64
171 #neon_f32
172 }
173}
174
175/// Generate size-16 SIMD radix-2 DIT codelet (f32 only via AVX-512F).
176///
177/// Size-16 FFT: radix-2 DIT with 4 butterfly stages.
178///
179/// SIMD strategy:
180/// - AVX-512F (`x86_64`): full `__m512` 16-lane f32 butterfly with FMA W16 twiddles
181/// - Scalar fallback for all other architectures (AVX2/SSE2/NEON lack this size)
182fn gen_simd_size_16() -> TokenStream {
183 let dispatcher = gen_dispatcher_16();
184 let scalar = scalar::gen_scalar_size_16();
185 let avx512_f32 = avx512::gen_avx512_size_16_f32();
186
187 quote! {
188 #dispatcher
189 #scalar
190 #avx512_f32
191 }
192}
193
194// ---------------------------------------------------------------------------
195// Dispatcher generation
196// ---------------------------------------------------------------------------
197
198/// Generate the public dispatcher function for sizes 2, 4, 8 (all ISAs).
199///
200/// Priority on `x86_64` (f64): AVX-512F > AVX2+FMA > AVX > SSE2 > scalar.
201/// Priority on `x86_64` (f32): AVX-512F > AVX2+FMA > SSE2 > scalar (no pure-AVX f32 path).
202/// Priority on `aarch64`: NEON > scalar.
203///
204/// The dispatcher:
205/// 1. Checks `T` via `core::any::TypeId` (since `Float: 'static`):
206/// - `f64`: uses f64 SIMD path
207/// - `f32`: uses f32 SIMD path
208/// - other: scalar fallback
209/// 2. On `x86_64`: probes AVX-512F first, then AVX2+FMA, then SSE2
210/// 3. On `aarch64`: uses NEON unconditionally
211/// 4. Falls back to the generic scalar implementation
212fn gen_dispatcher(n: usize) -> proc_macro2::TokenStream {
213 let fn_name = format_ident!("codelet_simd_{}", n);
214 let scalar_name = format_ident!("codelet_simd_{}_scalar", n);
215 let avx512_f64_name = format_ident!("codelet_simd_{}_avx512_f64", n);
216 let avx2_f64_name = format_ident!("codelet_simd_{}_avx2_f64", n);
217 let plain_avx_fn_name = format_ident!("codelet_simd_{}_avx_f64", n);
218 let sse2_f64_name = format_ident!("codelet_simd_{}_sse2_f64", n);
219 let neon_f64_name = format_ident!("codelet_simd_{}_neon_f64", n);
220 let avx512_f32_name = format_ident!("codelet_simd_{}_avx512_f32", n);
221 let avx2_f32_name = format_ident!("codelet_simd_{}_avx2_f32", n);
222 let sse2_f32_name = format_ident!("codelet_simd_{}_sse2_f32", n);
223 let neon_f32_name = format_ident!("codelet_simd_{}_neon_f32", n);
224 let n_lit = n;
225 let doc = format!(
226 "Size-{n} SIMD-optimized FFT codelet with architecture dispatch.\n\n\
227 Automatically selects the best SIMD path at runtime:\n\
228 - x86_64: AVX-512F > AVX2+FMA > AVX > SSE2 > scalar (f64; f32 uses AVX2/SSE2/512)\n\
229 - aarch64: NEON > scalar (both f64 and f32)\n\
230 - other: scalar fallback"
231 );
232
233 quote! {
234 #[doc = #doc]
235 #[inline]
236 pub fn #fn_name<T: crate::kernel::Float>(
237 data: &mut [crate::kernel::Complex<T>],
238 sign: i32,
239 ) {
240 debug_assert!(
241 data.len() >= #n_lit,
242 "codelet_simd_{}: need >= {} elements, got {}",
243 #n_lit,
244 #n_lit,
245 data.len(),
246 );
247
248 // Fast path: f64 SIMD
249 if core::any::TypeId::of::<T>() == core::any::TypeId::of::<f64>() {
250 // Safety: Complex<T> is #[repr(C)] with (re, im) fields.
251 // When T == f64, &mut [Complex<f64>] has the same layout as
252 // &mut [f64] with twice the length: [re0, im0, re1, im1, ...].
253 let len = data.len() * 2;
254 let ptr = data.as_mut_ptr().cast::<f64>();
255 let f64_data = unsafe { core::slice::from_raw_parts_mut(ptr, len) };
256
257 #[cfg(target_arch = "x86_64")]
258 {
259 if is_x86_feature_detected!("avx512f") {
260 // Safety: AVX-512F detected, pointer valid for len f64s
261 unsafe { #avx512_f64_name(f64_data, sign); }
262 return;
263 }
264 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
265 // Safety: AVX2+FMA detected, pointer valid for len f64s
266 unsafe { #avx2_f64_name(f64_data, sign); }
267 return;
268 }
269 // Pure AVX (no FMA, no AVX2) — probe after AVX2+FMA, before SSE2
270 if is_x86_feature_detected!("avx") {
271 // Safety: AVX detected (superset of SSE2), pointer valid
272 unsafe { #plain_avx_fn_name(f64_data, sign); }
273 return;
274 }
275 if is_x86_feature_detected!("sse2") {
276 // Safety: SSE2 detected (guaranteed on x86_64), pointer valid
277 unsafe { #sse2_f64_name(f64_data, sign); }
278 return;
279 }
280 }
281
282 #[cfg(target_arch = "aarch64")]
283 {
284 // NEON is mandatory on aarch64
285 unsafe { #neon_f64_name(f64_data, sign); }
286 return;
287 }
288 }
289
290 // Fast path: f32 SIMD
291 if core::any::TypeId::of::<T>() == core::any::TypeId::of::<f32>() {
292 // Safety: Complex<T> is #[repr(C)] with (re, im) fields.
293 // When T == f32, &mut [Complex<f32>] has the same layout as
294 // &mut [f32] with twice the length: [re0, im0, re1, im1, ...].
295 let len = data.len() * 2;
296 let ptr = data.as_mut_ptr().cast::<f32>();
297 let f32_data = unsafe { core::slice::from_raw_parts_mut(ptr, len) };
298
299 #[cfg(target_arch = "x86_64")]
300 {
301 if is_x86_feature_detected!("avx512f") {
302 // Safety: AVX-512F detected
303 unsafe { #avx512_f32_name(f32_data, sign); }
304 return;
305 }
306 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
307 unsafe { #avx2_f32_name(f32_data, sign); }
308 return;
309 }
310 // f32 has no dedicated AVX (non-AVX2) path; fall through to SSE2
311 if is_x86_feature_detected!("sse2") {
312 unsafe { #sse2_f32_name(f32_data, sign); }
313 return;
314 }
315 }
316
317 #[cfg(target_arch = "aarch64")]
318 {
319 // NEON is mandatory on aarch64
320 unsafe { #neon_f32_name(f32_data, sign); }
321 return;
322 }
323 }
324
325 // Scalar fallback for other float types or unsupported architectures
326 #scalar_name(data, sign);
327 }
328 }
329}
330
331/// Generate the dispatcher for size-16 (f32 only via AVX-512F; scalar fallback).
332///
333/// Size-16 is only available as f32 via AVX-512F. All other paths fall through
334/// to the scalar implementation, including f64 (no size-16 f64 SIMD emitter).
335fn gen_dispatcher_16() -> proc_macro2::TokenStream {
336 let avx512_f32_name = format_ident!("codelet_simd_16_avx512_f32");
337 let scalar_name = format_ident!("codelet_simd_16_scalar");
338
339 quote! {
340 /// Size-16 SIMD-optimized FFT codelet.
341 ///
342 /// Selects AVX-512F f32 path when available; otherwise falls back to scalar.
343 /// No f64 SIMD path at size 16 (scalar is used instead).
344 ///
345 /// - x86_64 + avx512f: `__m512` 16-lane f32 butterfly with FMA twiddles
346 /// - all other: scalar fallback
347 #[inline]
348 pub fn codelet_simd_16<T: crate::kernel::Float>(
349 data: &mut [crate::kernel::Complex<T>],
350 sign: i32,
351 ) {
352 debug_assert!(
353 data.len() >= 16_usize,
354 "codelet_simd_16: need >= 16 elements, got {}",
355 data.len(),
356 );
357
358 // AVX-512F f32 path only
359 if core::any::TypeId::of::<T>() == core::any::TypeId::of::<f32>() {
360 let len = data.len() * 2;
361 let ptr = data.as_mut_ptr().cast::<f32>();
362 let f32_data = unsafe { core::slice::from_raw_parts_mut(ptr, len) };
363
364 #[cfg(target_arch = "x86_64")]
365 {
366 if is_x86_feature_detected!("avx512f") {
367 // Safety: AVX-512F detected, pointer valid for len f32s
368 unsafe { #avx512_f32_name(f32_data, sign); }
369 return;
370 }
371 }
372 }
373
374 // Scalar fallback for f64, other types, or no AVX-512F
375 #scalar_name(data, sign);
376 }
377 }
378}