oxifft_codegen_impl/gen_simd/multi_transform/mod.rs
1//! Build-time codegen for SIMD vrank multi-transform codelets.
2//!
3//! A multi-transform codelet processes `V` DFTs of size `N` simultaneously.
4//!
5//! # Implementations
6//!
7//! - **SSE2 f32 (V=4)**: true SIMD for sizes 2 and 4 via `notw_{size}_v4_sse2_f32_soa`.
8//! - **AVX2 f32 (V=8)**: true SIMD for sizes 2, 4, and 8 via `notw_{size}_v8_avx2_f32_soa`.
9//! - **All other combos**: sequential scalar fallback over `AoS` layout.
10//!
11//! # Data layouts
12//!
13//! ## `AoS` (Array-of-Structs) — outer function signature
14//!
15//! For `V` transforms of size `N`:
16//! ```text
17//! data[element_idx * v * 2 + transform_idx * 2 + 0] = re of x[element_idx] for transform transform_idx
18//! data[element_idx * v * 2 + transform_idx * 2 + 1] = im of x[element_idx] for transform transform_idx
19//! ```
20//!
21//! ## `SoA` (Struct-of-Arrays) — inner SIMD function signature
22//!
23//! For `V` transforms of size `N` (only used internally by SIMD paths):
24//! ```text
25//! re_in[element_idx * v + transform_idx] = real part of x[element_idx] for transform transform_idx
26//! im_in[element_idx * v + transform_idx] = imag part of x[element_idx] for transform transform_idx
27//! ```
28//!
29//! The SIMD functions operate natively in `SoA`. The outer `AoS` function optionally
30//! calls the inner `SoA` function (when `ISA` + precision match a SIMD path), otherwise
31//! falls back to the sequential scalar loop.
32//!
33//! # Generated function signatures
34//!
35//! Outer (`AoS`, called by users):
36//! ```rust,ignore
37//! pub unsafe fn notw_4_v8_avx2_f32(
38//! input: *const f32, output: *mut f32,
39//! istride: usize, ostride: usize, count: usize,
40//! )
41//! ```
42//!
43//! Inner `SoA` SIMD helpers (emitted alongside, for direct use or testing):
44//! ```rust,ignore
45//! pub unsafe fn notw_4_v8_avx2_f32_soa(
46//! re_in: *const f32, im_in: *const f32,
47//! re_out: *mut f32, im_out: *mut f32,
48//! )
49//! ```
50
51use proc_macro2::TokenStream;
52use quote::{format_ident, quote};
53use syn::{
54 parse::{Parse, ParseStream},
55 LitInt, Token,
56};
57
58mod scalar;
59mod simd_avx2_f32;
60mod simd_sse2_f32;
61
62#[cfg(test)]
63mod tests;
64
65// ============================================================================
66// Public types
67// ============================================================================
68
69/// Target ISA for a multi-transform codelet.
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum SimdIsa {
72 /// SSE2 (128-bit, 4 f32 or 2 f64 lanes).
73 Sse2,
74 /// AVX2+FMA (256-bit, 8 f32 or 4 f64 lanes).
75 Avx2,
76 /// Scalar fallback (no SIMD).
77 Scalar,
78}
79
80impl SimdIsa {
81 /// Number of scalar lanes for `f32`.
82 #[must_use]
83 pub const fn lanes_f32(self) -> usize {
84 match self {
85 Self::Sse2 => 4,
86 Self::Avx2 => 8,
87 Self::Scalar => 1,
88 }
89 }
90
91 /// Number of scalar lanes for `f64`.
92 #[must_use]
93 pub const fn lanes_f64(self) -> usize {
94 match self {
95 Self::Sse2 => 2,
96 Self::Avx2 => 4,
97 Self::Scalar => 1,
98 }
99 }
100
101 /// Lowercase name used in generated identifiers.
102 #[must_use]
103 pub const fn ident_str(self) -> &'static str {
104 match self {
105 Self::Sse2 => "sse2",
106 Self::Avx2 => "avx2",
107 Self::Scalar => "scalar",
108 }
109 }
110}
111
112/// Floating-point precision for a multi-transform codelet.
113#[derive(Debug, Clone, Copy, PartialEq, Eq)]
114pub enum Precision {
115 /// 32-bit single precision.
116 F32,
117 /// 64-bit double precision.
118 F64,
119}
120
121impl Precision {
122 /// Lowercase type name used in generated identifiers and code.
123 #[must_use]
124 pub const fn type_str(self) -> &'static str {
125 match self {
126 Self::F32 => "f32",
127 Self::F64 => "f64",
128 }
129 }
130}
131
132/// Configuration for a vectorized multi-transform codelet.
133///
134/// Describes a (DFT size, `ISA`, V, precision) tuple used to emit a
135/// batch-of-V-transforms function at build time.
136#[derive(Debug, Clone)]
137pub struct MultiTransformConfig {
138 /// DFT size — must be 2, 4, or 8.
139 pub size: usize,
140 /// Number of simultaneous transforms (lane count: 4 for SSE2 f32, 8 for AVX2 f32, etc.).
141 pub v: usize,
142 /// Target ISA.
143 pub isa: SimdIsa,
144 /// `f32` or `f64`.
145 pub precision: Precision,
146}
147
148// ============================================================================
149// SIMD dispatch logic
150// ============================================================================
151
152/// Returns `true` when the (`ISA`, precision, size) combination has a true SIMD
153/// multi-transform implementation (`SoA` inner function).
154///
155/// - SSE2 f32: sizes 2 and 4
156/// - AVX2 f32: sizes 2, 4, and 8
157/// - All f64 combos: scalar fallback only
158const fn has_simd_impl(isa: SimdIsa, precision: Precision, size: usize) -> bool {
159 matches!(
160 (isa, precision, size),
161 (SimdIsa::Sse2, Precision::F32, 2 | 4) | (SimdIsa::Avx2, Precision::F32, 2 | 4 | 8)
162 )
163}
164
165/// Emit the inner `SoA` SIMD function `TokenStream` for the given config.
166///
167/// Returns `None` if the config has no SIMD implementation.
168fn gen_simd_inner(config: &MultiTransformConfig) -> Option<TokenStream> {
169 match (config.isa, config.precision, config.size) {
170 (SimdIsa::Sse2, Precision::F32, 2) => Some(simd_sse2_f32::gen_sse2_f32_v4_size2_soa()),
171 (SimdIsa::Sse2, Precision::F32, 4) => Some(simd_sse2_f32::gen_sse2_f32_v4_size4_soa()),
172 (SimdIsa::Avx2, Precision::F32, 2) => Some(simd_avx2_f32::gen_avx2_f32_v8_size2_soa()),
173 (SimdIsa::Avx2, Precision::F32, 4) => Some(simd_avx2_f32::gen_avx2_f32_v8_size4_soa()),
174 (SimdIsa::Avx2, Precision::F32, 8) => Some(simd_avx2_f32::gen_avx2_f32_v8_size8_soa()),
175 _ => None,
176 }
177}
178
179// ============================================================================
180// Code generation
181// ============================================================================
182
183/// Build the outer `AoS` function body for any config (scalar loop over all transforms).
184///
185/// The outer function always processes transforms sequentially (scalar `AoS` loop),
186/// regardless of whether a companion `SoA` SIMD function is also emitted.
187/// Callers that want true SIMD throughput should use the `_soa` companion directly.
188///
189/// # Panics
190///
191/// Panics only if internal constant string literals fail to parse — impossible
192/// in practice.
193fn gen_outer_body(config: &MultiTransformConfig, size: usize, v: usize) -> TokenStream {
194 let butterfly_body = scalar::gen_scalar_butterfly(size, config.precision);
195 let v_lit = v;
196 let size_lit = size;
197 quote! {
198 let batches = count / #v_lit;
199 let remainder = count % #v_lit;
200
201 for b in 0..batches {
202 for t in 0..#v_lit {
203 let base_in = (b * #v_lit + t) * 2;
204 let base_out = (b * #v_lit + t) * 2;
205 #butterfly_body
206 }
207 }
208 for t in 0..remainder {
209 let base_in = (batches * #v_lit + t) * 2;
210 let base_out = (batches * #v_lit + t) * 2;
211 #butterfly_body
212 }
213 let _ = #size_lit;
214 }
215}
216
217/// Generate a multi-transform codelet `TokenStream`.
218///
219/// # Output
220///
221/// Always emits a public outer function `notw_{size}_v{v}_{isa}_{ty}` with
222/// `AoS` signature `(input, output, istride, ostride, count)`.
223///
224/// For supported (`ISA`, precision, size) combinations (SSE2 f32 sizes 2/4,
225/// AVX2 f32 sizes 2/4/8), also emits a companion inner function
226/// `notw_{size}_v{v}_{isa}_{ty}_soa` with `SoA` signature
227/// `(re_in, im_in, re_out, im_out)` that is the **true SIMD implementation**.
228///
229/// # Errors
230///
231/// Returns a [`syn::Error`] when:
232/// - `config.size` is not one of 2, 4, or 8.
233/// - `config.v` is 0.
234///
235/// # Panics
236///
237/// Panics only if internal constant string literals that are guaranteed to be
238/// valid fail to parse as token streams — this cannot occur in practice.
239pub fn generate_multi_transform(config: &MultiTransformConfig) -> Result<TokenStream, syn::Error> {
240 if !matches!(config.size, 2 | 4 | 8) {
241 return Err(syn::Error::new(
242 proc_macro2::Span::call_site(),
243 format!(
244 "multi_transform: unsupported size {} (expected 2, 4, or 8)",
245 config.size
246 ),
247 ));
248 }
249 if config.v == 0 {
250 return Err(syn::Error::new(
251 proc_macro2::Span::call_site(),
252 "multi_transform: v must be >= 1",
253 ));
254 }
255
256 let fn_name = format_ident!(
257 "notw_{}_v{}_{}_{}",
258 config.size,
259 config.v,
260 config.isa.ident_str(),
261 config.precision.type_str()
262 );
263 let size = config.size;
264 let v = config.v;
265 let ty_str = config.precision.type_str();
266 let ty_tokens: TokenStream = ty_str.parse().expect("valid type token");
267
268 let use_simd = has_simd_impl(config.isa, config.precision, size);
269 let simd_inner = gen_simd_inner(config);
270 let outer_body = gen_outer_body(config, size, v);
271
272 let stride = v * 2;
273 let simd_note = if use_simd {
274 format!(
275 "True SIMD available via `notw_{size}_v{v}_{isa}_{ty}_soa` (`SoA` layout).",
276 isa = config.isa.ident_str(),
277 ty = ty_str,
278 )
279 } else {
280 "Sequential scalar fallback (no SIMD for this `ISA`+precision+size combination).".into()
281 };
282
283 let fn_doc = format!(
284 "Process `count` transforms of size {size} in batches of {v} (v={v}) using {isa} ISA.\n\n\
285 # Data layout (`AoS`)\n\
286 Interleaved with stride {v}: `data[element * {stride} + transform * 2 + c]`\n\
287 where `c` is 0 for real, 1 for imaginary.\n\n\
288 # SIMD acceleration\n\
289 {simd_note}\n\n\
290 # Safety\n\
291 - `input` must be valid for `count * {size} * 2 * {v}` reads of `{ty_str}`.\n\
292 - `output` must be valid for `count * {size} * 2 * {v}` writes of `{ty_str}`.\n\
293 - `istride` / `ostride` must be `2 * {v}` for the canonical `AoS` layout.\n\
294 - No alignment requirement; uses unaligned loads.",
295 size = size,
296 v = v,
297 isa = config.isa.ident_str(),
298 stride = stride,
299 ty_str = ty_str,
300 simd_note = simd_note,
301 );
302
303 let outer_fn = quote! {
304 #[doc = #fn_doc]
305 pub unsafe fn #fn_name(
306 input: *const #ty_tokens,
307 output: *mut #ty_tokens,
308 istride: usize,
309 ostride: usize,
310 count: usize,
311 ) {
312 #outer_body
313 }
314 };
315
316 Ok(if let Some(inner) = simd_inner {
317 quote! {
318 #inner
319 #outer_fn
320 }
321 } else {
322 outer_fn
323 })
324}
325
326// ============================================================================
327// Proc-macro entry point
328// ============================================================================
329
330/// Parsed arguments from `gen_multi_transform_codelet!(size=4, v=8, isa=avx2, ty=f32)`.
331struct MacroArgs {
332 size: usize,
333 v: usize,
334 isa: SimdIsa,
335 precision: Precision,
336}
337
338impl Parse for MacroArgs {
339 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
340 let mut size: Option<usize> = None;
341 let mut v: Option<usize> = None;
342 let mut isa: Option<SimdIsa> = None;
343 let mut precision: Option<Precision> = None;
344
345 while !input.is_empty() {
346 let key: syn::Ident = input.parse()?;
347 let _eq: Token![=] = input.parse()?;
348 match key.to_string().as_str() {
349 "size" => {
350 let lit: LitInt = input.parse()?;
351 size = Some(lit.base10_parse::<usize>().map_err(|_| {
352 syn::Error::new(lit.span(), "expected an integer literal for `size`")
353 })?);
354 }
355 "v" => {
356 let lit: LitInt = input.parse()?;
357 v = Some(lit.base10_parse::<usize>().map_err(|_| {
358 syn::Error::new(lit.span(), "expected an integer literal for `v`")
359 })?);
360 }
361 "isa" => {
362 let ident: syn::Ident = input.parse()?;
363 isa = Some(match ident.to_string().as_str() {
364 "sse2" => SimdIsa::Sse2,
365 "avx2" => SimdIsa::Avx2,
366 "scalar" => SimdIsa::Scalar,
367 other => {
368 return Err(syn::Error::new(
369 ident.span(),
370 format!(
371 "unknown isa `{other}`, expected one of: sse2, avx2, scalar"
372 ),
373 ));
374 }
375 });
376 }
377 "ty" => {
378 let ident: syn::Ident = input.parse()?;
379 precision = Some(match ident.to_string().as_str() {
380 "f32" => Precision::F32,
381 "f64" => Precision::F64,
382 other => {
383 return Err(syn::Error::new(
384 ident.span(),
385 format!("unknown ty `{other}`, expected f32 or f64"),
386 ));
387 }
388 });
389 }
390 other => {
391 return Err(syn::Error::new(
392 key.span(),
393 format!("unknown key `{other}`, expected one of: size, v, isa, ty"),
394 ));
395 }
396 }
397 if input.peek(Token![,]) {
398 let _: Token![,] = input.parse()?;
399 }
400 }
401
402 let size = size.ok_or_else(|| {
403 syn::Error::new(proc_macro2::Span::call_site(), "missing `size` argument")
404 })?;
405 let v = v.ok_or_else(|| {
406 syn::Error::new(proc_macro2::Span::call_site(), "missing `v` argument")
407 })?;
408 let isa = isa.ok_or_else(|| {
409 syn::Error::new(proc_macro2::Span::call_site(), "missing `isa` argument")
410 })?;
411 let precision = precision.ok_or_else(|| {
412 syn::Error::new(proc_macro2::Span::call_site(), "missing `ty` argument")
413 })?;
414
415 Ok(Self {
416 size,
417 v,
418 isa,
419 precision,
420 })
421 }
422}
423
424/// Entry point for the `gen_multi_transform_codelet!` proc-macro.
425///
426/// Parses `size=N, v=V, isa=ISA, ty=TY` from the token stream and calls
427/// [`generate_multi_transform`].
428///
429/// # Example
430/// ```ignore
431/// gen_multi_transform_codelet!(size = 4, v = 8, isa = avx2, ty = f32);
432/// ```
433///
434/// # Errors
435///
436/// Returns a [`syn::Error`] when the input does not parse as valid key-value
437/// pairs, a required key is missing, or `size` / `isa` / `ty` have unsupported
438/// values.
439pub fn generate_from_macro(input: TokenStream) -> Result<TokenStream, syn::Error> {
440 let args: MacroArgs = syn::parse2(input)?;
441 let config = MultiTransformConfig {
442 size: args.size,
443 v: args.v,
444 isa: args.isa,
445 precision: args.precision,
446 };
447 generate_multi_transform(&config)
448}