oxifft_codegen_impl/gen_simd/mod.rs
1//! SIMD codelet generation.
2//!
3//! Generates architecture-aware SIMD FFT codelets with multi-architecture dispatch.
4//! At compile time, generates:
5//! - AVX-512F variant (512-bit, `8×f64` / `16×f32`) for `x86_64`
6//! - AVX2+FMA variant (256-bit, `4×f64`) for `x86_64`
7//! - Pure-AVX variant (256-bit, `4×f64`, no FMA, no AVX2) for `x86_64`
8//! - SSE2 variant (128-bit, `2×f64`) for `x86_64`
9//! - NEON variant (128-bit, `2×f64`) for `aarch64`
10//! - Scalar fallback for all architectures
11//!
12//! The dispatcher function selects the best SIMD path at runtime using
13//! `is_x86_feature_detected!` (`x86_64`) or compile-time cfg (`aarch64`).
14//!
15//! Probe order for `x86_64`: AVX-512F > AVX2+FMA > AVX > SSE2 > scalar.
16//! AVX-512F is probed first (when the host supports it) to enable
17//! `_mm512_fmadd_pd`/`_mm512_fmsub_pd` based butterfly arithmetic.
18
19use proc_macro2::TokenStream;
20use quote::{format_ident, quote};
21use syn::LitInt;
22
23mod avx;
24mod avx2;
25mod avx512;
26mod neon;
27mod scalar;
28mod sse2;
29
30/// Generate a SIMD-optimized codelet for the given FFT size.
31///
32/// Supports sizes 2, 4, 8 for all ISAs and size 16 (f32 only, AVX-512F or scalar).
33/// The macro generates:
34/// - A public dispatcher that picks the best SIMD path at runtime
35/// - Architecture-specific inner functions with `#[target_feature]`
36/// - A generic scalar fallback
37///
38/// AVX-512F is probed before AVX2+FMA for sizes that have AVX-512 emitters.
39///
40/// # Errors
41/// Returns a `syn::Error` when the input does not parse as a valid size literal,
42/// or when the size is not in the supported set {2, 4, 8, 16}.
43pub fn generate(input: TokenStream) -> Result<TokenStream, syn::Error> {
44 let size: LitInt = syn::parse2(input)?;
45 let n: usize = size.base10_parse().map_err(|_| {
46 syn::Error::new(
47 size.span(),
48 "gen_simd_codelet: expected an integer size literal",
49 )
50 })?;
51
52 match n {
53 2 => Ok(gen_simd_size_2()),
54 4 => Ok(gen_simd_size_4()),
55 8 => Ok(gen_simd_size_8()),
56 16 => Ok(gen_simd_size_16()),
57 _ => Err(syn::Error::new(
58 size.span(),
59 format!("gen_simd_codelet: unsupported size {n} (expected one of 2, 4, 8, 16)"),
60 )),
61 }
62}
63
64/// Generate size-2 SIMD butterfly codelet.
65///
66/// Size-2 butterfly: out[0] = a + b, out[1] = a - b
67///
68/// SIMD strategy:
69/// - AVX-512F (`x86_64`): 256-bit YMM under avx512f feature umbrella, f64 only
70/// - SSE2/AVX2 (`x86_64`): `__m128d` / `__m128` for f64/f32, vector add/sub
71/// - NEON (`aarch64`): `float64x2_t` / `float32x2_t` for f64/f32, vector add/sub
72fn gen_simd_size_2() -> TokenStream {
73 let dispatcher = gen_dispatcher(2);
74 let scalar = scalar::gen_scalar_size_2();
75 let sse2_f64 = sse2::gen_sse2_size_2();
76 let sse2_f32 = sse2::gen_sse2_size_2_f32();
77 let plain_avx_f64 = avx::gen_avx_size_2_f64();
78 let avx2_f64 = avx2::gen_avx2_size_2();
79 let avx2_f32 = avx2::gen_avx2_size_2_f32();
80 let avx512_f64 = avx512::gen_avx512_size_2_f64();
81 let avx512_f32 = avx512::gen_avx512_size_2_f32();
82 let neon_f64 = neon::gen_neon_size_2();
83 let neon_f32 = neon::gen_neon_size_2_f32();
84
85 quote! {
86 #dispatcher
87 #scalar
88 #sse2_f64
89 #sse2_f32
90 #plain_avx_f64
91 #avx2_f64
92 #avx2_f32
93 #avx512_f64
94 #avx512_f32
95 #neon_f64
96 #neon_f32
97 }
98}
99
100/// Generate size-4 SIMD radix-4 codelet.
101///
102/// Size-4 FFT: radix-4 butterfly with sign-dependent ±i rotation.
103///
104/// SIMD strategy:
105/// - AVX-512F (`x86_64`): 256-bit f64 + f32 butterfly under avx512f feature
106/// - SSE2/AVX2 (`x86_64`): `__m128d` / `__m128` for f64/f32, shuffle-based rotation
107/// - NEON (`aarch64`): `float64x2_t` / `float32x2_t` for f64/f32, ext-based rotation
108fn gen_simd_size_4() -> TokenStream {
109 let dispatcher = gen_dispatcher(4);
110 let scalar = scalar::gen_scalar_size_4();
111 let sse2_f64 = sse2::gen_sse2_size_4();
112 let sse2_f32 = sse2::gen_sse2_size_4_f32();
113 let plain_avx_f64 = avx::gen_avx_size_4_f64();
114 let avx2_f64 = avx2::gen_avx2_size_4();
115 let avx2_f32 = avx2::gen_avx2_size_4_f32();
116 let avx512_f64 = avx512::gen_avx512_size_4_f64();
117 let avx512_f32 = avx512::gen_avx512_size_4_f32();
118 let neon_f64 = neon::gen_neon_size_4();
119 let neon_f32 = neon::gen_neon_size_4_f32();
120
121 quote! {
122 #dispatcher
123 #scalar
124 #sse2_f64
125 #sse2_f32
126 #plain_avx_f64
127 #avx2_f64
128 #avx2_f32
129 #avx512_f64
130 #avx512_f32
131 #neon_f64
132 #neon_f32
133 }
134}
135
136/// Generate size-8 SIMD radix-8 codelet.
137///
138/// Size-8 FFT: radix-2 DIT with 3 butterfly stages.
139///
140/// SIMD strategy:
141/// - AVX-512F (`x86_64`): 256-bit + FMA via ZMM promotion for twiddles
142/// - SSE2/AVX2 (`x86_64`): `__m128d` / `__m128` for f64/f32, FMA twiddles
143/// - NEON (`aarch64`): `float64x2_t` / `float32x2_t` for f64/f32, FMA twiddles
144fn gen_simd_size_8() -> TokenStream {
145 let dispatcher = gen_dispatcher(8);
146 let scalar = scalar::gen_scalar_size_8();
147 let sse2_f64 = sse2::gen_sse2_size_8();
148 let sse2_f32 = sse2::gen_sse2_size_8_f32();
149 let plain_avx_f64 = avx::gen_avx_size_8_f64();
150 let avx2_f64 = avx2::gen_avx2_size_8();
151 let avx2_f32 = avx2::gen_avx2_size_8_f32();
152 let avx512_f64 = avx512::gen_avx512_size_8_f64();
153 let avx512_f32 = avx512::gen_avx512_size_8_f32();
154 let neon_f64 = neon::gen_neon_size_8();
155 let neon_f32 = neon::gen_neon_size_8_f32();
156
157 quote! {
158 #dispatcher
159 #scalar
160 #sse2_f64
161 #sse2_f32
162 #plain_avx_f64
163 #avx2_f64
164 #avx2_f32
165 #avx512_f64
166 #avx512_f32
167 #neon_f64
168 #neon_f32
169 }
170}
171
172/// Generate size-16 SIMD radix-2 DIT codelet (f32 only via AVX-512F).
173///
174/// Size-16 FFT: radix-2 DIT with 4 butterfly stages.
175///
176/// SIMD strategy:
177/// - AVX-512F (`x86_64`): full `__m512` 16-lane f32 butterfly with FMA W16 twiddles
178/// - Scalar fallback for all other architectures (AVX2/SSE2/NEON lack this size)
179fn gen_simd_size_16() -> TokenStream {
180 let dispatcher = gen_dispatcher_16();
181 let scalar = scalar::gen_scalar_size_16();
182 let avx512_f32 = avx512::gen_avx512_size_16_f32();
183
184 quote! {
185 #dispatcher
186 #scalar
187 #avx512_f32
188 }
189}
190
191// ---------------------------------------------------------------------------
192// Dispatcher generation
193// ---------------------------------------------------------------------------
194
195/// Generate the public dispatcher function for sizes 2, 4, 8 (all ISAs).
196///
197/// Priority on `x86_64` (f64): AVX-512F > AVX2+FMA > AVX > SSE2 > scalar.
198/// Priority on `x86_64` (f32): AVX-512F > AVX2+FMA > SSE2 > scalar (no pure-AVX f32 path).
199/// Priority on `aarch64`: NEON > scalar.
200///
201/// The dispatcher:
202/// 1. Checks `T` via `core::any::TypeId` (since `Float: 'static`):
203/// - `f64`: uses f64 SIMD path
204/// - `f32`: uses f32 SIMD path
205/// - other: scalar fallback
206/// 2. On `x86_64`: probes AVX-512F first, then AVX2+FMA, then SSE2
207/// 3. On `aarch64`: uses NEON unconditionally
208/// 4. Falls back to the generic scalar implementation
209fn gen_dispatcher(n: usize) -> proc_macro2::TokenStream {
210 let fn_name = format_ident!("codelet_simd_{}", n);
211 let scalar_name = format_ident!("codelet_simd_{}_scalar", n);
212 let avx512_f64_name = format_ident!("codelet_simd_{}_avx512_f64", n);
213 let avx2_f64_name = format_ident!("codelet_simd_{}_avx2_f64", n);
214 let plain_avx_fn_name = format_ident!("codelet_simd_{}_avx_f64", n);
215 let sse2_f64_name = format_ident!("codelet_simd_{}_sse2_f64", n);
216 let neon_f64_name = format_ident!("codelet_simd_{}_neon_f64", n);
217 let avx512_f32_name = format_ident!("codelet_simd_{}_avx512_f32", n);
218 let avx2_f32_name = format_ident!("codelet_simd_{}_avx2_f32", n);
219 let sse2_f32_name = format_ident!("codelet_simd_{}_sse2_f32", n);
220 let neon_f32_name = format_ident!("codelet_simd_{}_neon_f32", n);
221 let n_lit = n;
222 let doc = format!(
223 "Size-{n} SIMD-optimized FFT codelet with architecture dispatch.\n\n\
224 Automatically selects the best SIMD path at runtime:\n\
225 - x86_64: AVX-512F > AVX2+FMA > AVX > SSE2 > scalar (f64; f32 uses AVX2/SSE2/512)\n\
226 - aarch64: NEON > scalar (both f64 and f32)\n\
227 - other: scalar fallback"
228 );
229
230 quote! {
231 #[doc = #doc]
232 #[inline]
233 pub fn #fn_name<T: crate::kernel::Float>(
234 data: &mut [crate::kernel::Complex<T>],
235 sign: i32,
236 ) {
237 debug_assert!(
238 data.len() >= #n_lit,
239 "codelet_simd_{}: need >= {} elements, got {}",
240 #n_lit,
241 #n_lit,
242 data.len(),
243 );
244
245 // Fast path: f64 SIMD
246 if core::any::TypeId::of::<T>() == core::any::TypeId::of::<f64>() {
247 // Safety: Complex<T> is #[repr(C)] with (re, im) fields.
248 // When T == f64, &mut [Complex<f64>] has the same layout as
249 // &mut [f64] with twice the length: [re0, im0, re1, im1, ...].
250 let len = data.len() * 2;
251 let ptr = data.as_mut_ptr().cast::<f64>();
252 let f64_data = unsafe { core::slice::from_raw_parts_mut(ptr, len) };
253
254 #[cfg(target_arch = "x86_64")]
255 {
256 if is_x86_feature_detected!("avx512f") {
257 // Safety: AVX-512F detected, pointer valid for len f64s
258 unsafe { #avx512_f64_name(f64_data, sign); }
259 return;
260 }
261 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
262 // Safety: AVX2+FMA detected, pointer valid for len f64s
263 unsafe { #avx2_f64_name(f64_data, sign); }
264 return;
265 }
266 // Pure AVX (no FMA, no AVX2) — probe after AVX2+FMA, before SSE2
267 if is_x86_feature_detected!("avx") {
268 // Safety: AVX detected (superset of SSE2), pointer valid
269 unsafe { #plain_avx_fn_name(f64_data, sign); }
270 return;
271 }
272 if is_x86_feature_detected!("sse2") {
273 // Safety: SSE2 detected (guaranteed on x86_64), pointer valid
274 unsafe { #sse2_f64_name(f64_data, sign); }
275 return;
276 }
277 }
278
279 #[cfg(target_arch = "aarch64")]
280 {
281 // NEON is mandatory on aarch64
282 unsafe { #neon_f64_name(f64_data, sign); }
283 return;
284 }
285 }
286
287 // Fast path: f32 SIMD
288 if core::any::TypeId::of::<T>() == core::any::TypeId::of::<f32>() {
289 // Safety: Complex<T> is #[repr(C)] with (re, im) fields.
290 // When T == f32, &mut [Complex<f32>] has the same layout as
291 // &mut [f32] with twice the length: [re0, im0, re1, im1, ...].
292 let len = data.len() * 2;
293 let ptr = data.as_mut_ptr().cast::<f32>();
294 let f32_data = unsafe { core::slice::from_raw_parts_mut(ptr, len) };
295
296 #[cfg(target_arch = "x86_64")]
297 {
298 if is_x86_feature_detected!("avx512f") {
299 // Safety: AVX-512F detected
300 unsafe { #avx512_f32_name(f32_data, sign); }
301 return;
302 }
303 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
304 unsafe { #avx2_f32_name(f32_data, sign); }
305 return;
306 }
307 // f32 has no dedicated AVX (non-AVX2) path; fall through to SSE2
308 if is_x86_feature_detected!("sse2") {
309 unsafe { #sse2_f32_name(f32_data, sign); }
310 return;
311 }
312 }
313
314 #[cfg(target_arch = "aarch64")]
315 {
316 // NEON is mandatory on aarch64
317 unsafe { #neon_f32_name(f32_data, sign); }
318 return;
319 }
320 }
321
322 // Scalar fallback for other float types or unsupported architectures
323 #scalar_name(data, sign);
324 }
325 }
326}
327
328/// Generate the dispatcher for size-16 (f32 only via AVX-512F; scalar fallback).
329///
330/// Size-16 is only available as f32 via AVX-512F. All other paths fall through
331/// to the scalar implementation, including f64 (no size-16 f64 SIMD emitter).
332fn gen_dispatcher_16() -> proc_macro2::TokenStream {
333 let avx512_f32_name = format_ident!("codelet_simd_16_avx512_f32");
334 let scalar_name = format_ident!("codelet_simd_16_scalar");
335
336 quote! {
337 /// Size-16 SIMD-optimized FFT codelet.
338 ///
339 /// Selects AVX-512F f32 path when available; otherwise falls back to scalar.
340 /// No f64 SIMD path at size 16 (scalar is used instead).
341 ///
342 /// - x86_64 + avx512f: `__m512` 16-lane f32 butterfly with FMA twiddles
343 /// - all other: scalar fallback
344 #[inline]
345 pub fn codelet_simd_16<T: crate::kernel::Float>(
346 data: &mut [crate::kernel::Complex<T>],
347 sign: i32,
348 ) {
349 debug_assert!(
350 data.len() >= 16_usize,
351 "codelet_simd_16: need >= 16 elements, got {}",
352 data.len(),
353 );
354
355 // AVX-512F f32 path only
356 if core::any::TypeId::of::<T>() == core::any::TypeId::of::<f32>() {
357 let len = data.len() * 2;
358 let ptr = data.as_mut_ptr().cast::<f32>();
359 let f32_data = unsafe { core::slice::from_raw_parts_mut(ptr, len) };
360
361 #[cfg(target_arch = "x86_64")]
362 {
363 if is_x86_feature_detected!("avx512f") {
364 // Safety: AVX-512F detected, pointer valid for len f32s
365 unsafe { #avx512_f32_name(f32_data, sign); }
366 return;
367 }
368 }
369 }
370
371 // Scalar fallback for f64, other types, or no AVX-512F
372 #scalar_name(data, sign);
373 }
374 }
375}