Skip to main content

oxifft_codegen_impl/
gen_any.rs

1//! Dispatch-layer codelet generation for arbitrary user-specified sizes.
2//!
3//! Routes a user-specified FFT size `N` to the most appropriate code emitter:
4//! - **Direct codelet set** {2, 4, 8, 16, 32, 64}:
5//!   delegates to the existing hand-optimised emitters (`gen_notw`).
6//! - **Winograd odd** {3, 5, 7}: delegates to `gen_odd`.
7//! - **Rader hardcoded** {11, 13}: delegates to `gen_rader`.
8//! - **Smooth-7 composites** (all prime factors in {2, 3, 5, 7}):
9//!   emits a thin runtime wrapper that delegates to `Plan::dft_1d`.
10//! - **Primes p <= 1021**: runtime wrapper (runtime Rader/Generic path).
11//! - **Everything else** (large primes, non-smooth composites):
12//!   emits a Bluestein runtime wrapper via `Plan::dft_1d`.
13//!
14//! # Codelet convention
15//!
16//! All generated functions follow the existing `OxiFFT` codelet convention:
17//! ```ignore
18//! pub fn codelet_any_{N}<T: crate::kernel::Float>(
19//!     x: &mut [crate::kernel::Complex<T>],
20//!     sign: i32,
21//! )
22//! ```
23//! where `sign < 0` means forward (W = e^{-2*pi*i/N}) and `sign > 0` means inverse.
24//!
25//! See `Plan::dft_1d` in the `oxifft` crate for the primary entry point.
26
27use proc_macro2::TokenStream;
28use quote::{format_ident, quote};
29use syn::LitInt;
30
31// ============================================================================
32// Error type
33// ============================================================================
34
35/// Error variants produced during any-size codelet generation.
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum CodegenError {
38    /// Size zero is not a valid FFT size.
39    InvalidSize(usize),
40    /// Size is valid but cannot be code-generated by any registered strategy.
41    UnsupportedSize(usize),
42    /// Code emission failed with the given message.
43    EmitError(String),
44}
45
46impl core::fmt::Display for CodegenError {
47    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
48        match self {
49            Self::InvalidSize(n) => write!(f, "invalid codelet size: {n}"),
50            Self::UnsupportedSize(n) => write!(f, "unsupported codelet size: {n}"),
51            Self::EmitError(s) => write!(f, "codegen emit error: {s}"),
52        }
53    }
54}
55
56// ============================================================================
57// Size classification
58// ============================================================================
59
60/// Classification of an FFT size for codelet routing.
61#[derive(Debug, Clone, PartialEq, Eq)]
62pub enum SizeClass {
63    /// Handled by the non-twiddle emitter: {2, 4, 8, 16, 32, 64}.
64    /// Also used for N=1 (identity, emitted directly).
65    Notw(usize),
66    /// Handled by the Winograd odd emitter: {3, 5, 7}.
67    Odd(usize),
68    /// Handled by the hardcoded Rader emitter (straight-line twiddles): {11, 13}.
69    RaderHardcoded(usize),
70    /// Smooth-7 composite (all prime factors in {2, 3, 5, 7}):
71    /// routed to the runtime `Plan::dft_1d` mixed-radix path.
72    MixedRadix(Vec<u16>),
73    /// Small prime <= 1021 not in the hardcoded set: routed to `Plan::dft_1d`
74    /// (which selects Winograd, Direct, or Generic internally).
75    RaderPrime(usize),
76    /// All other sizes: routed to `Plan::dft_1d` (Bluestein or Generic path).
77    Bluestein(usize),
78}
79
80/// Classify `n` for codelet generation.
81///
82/// # Errors
83///
84/// Returns `CodegenError::InvalidSize(0)` when `n == 0`.
85pub fn classify(n: usize) -> Result<SizeClass, CodegenError> {
86    if n == 0 {
87        return Err(CodegenError::InvalidSize(0));
88    }
89    if n == 1 {
90        // N=1 is the identity transform; route to Notw so we emit the trivial codelet.
91        return Ok(SizeClass::Notw(1));
92    }
93
94    // Hardcoded non-twiddle direct codelets {2, 4, 8, 16, 32, 64}
95    if matches!(n, 2 | 4 | 8 | 16 | 32 | 64) {
96        return Ok(SizeClass::Notw(n));
97    }
98    // Winograd odd codelets {3, 5, 7}
99    if matches!(n, 3 | 5 | 7) {
100        return Ok(SizeClass::Odd(n));
101    }
102    // Hardcoded Rader codelets (hand-optimised straight-line code) {11, 13}
103    if matches!(n, 11 | 13) {
104        return Ok(SizeClass::RaderHardcoded(n));
105    }
106
107    // Smooth-7: try to factor using only {2, 3, 5, 7}, greedy from largest supported radix
108    if let Some(factors) = try_factor_smooth7(n) {
109        return Ok(SizeClass::MixedRadix(factors));
110    }
111
112    // Prime p <= 1021 — known to have a primitive root; routed to runtime
113    if is_prime(n) && n <= 1021 {
114        return Ok(SizeClass::RaderPrime(n));
115    }
116
117    // Fallback: Bluestein (or Generic, as the runtime decides)
118    Ok(SizeClass::Bluestein(n))
119}
120
121// ============================================================================
122// Factoring helpers
123// ============================================================================
124
125fn try_factor_smooth7(mut n: usize) -> Option<Vec<u16>> {
126    // Greedy peel: largest radix first so that the factor vec is roughly sorted.
127    // All radices fit in u16, so the cast is safe.
128    const RADICES: &[usize] = &[16, 8, 7, 5, 4, 3, 2];
129    let mut factors = Vec::new();
130    for &r in RADICES {
131        while n % r == 0 {
132            // SAFETY: every radix in RADICES is <= 16, which fits in u16.
133            #[allow(clippy::cast_possible_truncation)]
134            factors.push(r as u16);
135            n /= r;
136        }
137    }
138    if n == 1 && !factors.is_empty() {
139        Some(factors)
140    } else {
141        None
142    }
143}
144
145const fn is_prime(n: usize) -> bool {
146    if n < 2 {
147        return false;
148    }
149    if n == 2 {
150        return true;
151    }
152    if n % 2 == 0 {
153        return false;
154    }
155    let mut i = 3usize;
156    // Use checked arithmetic to avoid overflow in i * i on 32-bit targets.
157    while let Some(sq) = i.checked_mul(i) {
158        if sq > n {
159            break;
160        }
161        if n % i == 0 {
162            return false;
163        }
164        i += 2;
165    }
166    true
167}
168
169// ============================================================================
170// Code generation
171// ============================================================================
172
173/// Generate a codelet `TokenStream` for size `n`.
174///
175/// The emitted function is:
176/// ```ignore
177/// pub fn codelet_any_{n}<T: crate::kernel::Float>(
178///     x: &mut [crate::kernel::Complex<T>],
179///     sign: i32,
180/// )
181/// ```
182///
183/// # Errors
184///
185/// Returns `CodegenError` if `n == 0` or code emission fails.
186pub fn generate(n: usize) -> Result<TokenStream, CodegenError> {
187    match classify(n)? {
188        SizeClass::Notw(sz) => generate_notw_any(sz),
189        SizeClass::Odd(sz) => generate_odd_any(sz),
190        SizeClass::RaderHardcoded(sz) => generate_rader_hardcoded(sz),
191        SizeClass::MixedRadix(_) | SizeClass::RaderPrime(_) | SizeClass::Bluestein(_) => {
192            Ok(generate_runtime_wrapper(n))
193        }
194    }
195}
196
197// ---- delegates to existing emitters ----------------------------------------
198
199fn generate_notw_any(sz: usize) -> Result<TokenStream, CodegenError> {
200    if sz == 1 {
201        return Ok(generate_identity_codelet());
202    }
203    // Call gen_notw::generate with a synthetic literal token stream.
204    let literal = proc_macro2::Literal::usize_unsuffixed(sz);
205    let ts = quote! { #literal };
206    crate::gen_notw::generate(ts).map_err(|e| CodegenError::EmitError(e.to_string()))
207}
208
209fn generate_odd_any(sz: usize) -> Result<TokenStream, CodegenError> {
210    let literal = proc_macro2::Literal::usize_unsuffixed(sz);
211    let ts = quote! { #literal };
212    crate::gen_odd::generate_from_macro(ts).map_err(|e| CodegenError::EmitError(e.to_string()))
213}
214
215fn generate_rader_hardcoded(sz: usize) -> Result<TokenStream, CodegenError> {
216    // gen_rader::generate_rader is #[must_use] and panics on unknown primes;
217    // only call it for 11 and 13 (the hardcoded set).
218    if matches!(sz, 11 | 13) {
219        Ok(crate::gen_rader::generate_rader(sz))
220    } else {
221        Err(CodegenError::EmitError(format!(
222            "generate_rader_hardcoded: size {sz} is not in the hardcoded set {{11, 13}}"
223        )))
224    }
225}
226
227// ---- identity (N=1) ---------------------------------------------------------
228
229fn generate_identity_codelet() -> TokenStream {
230    quote! {
231        /// Size-1 DFT codelet (identity transform — single element is unchanged).
232        #[inline(always)]
233        #[allow(clippy::trivially_copy_pass_by_ref, unused_variables)]
234        pub fn codelet_any_1<T: crate::kernel::Float>(
235            x: &mut [crate::kernel::Complex<T>],
236            sign: i32,
237        ) {
238            debug_assert!(x.len() >= 1, "codelet_any_1: input must have at least 1 element");
239            // DFT-1 is the identity; nothing to do.
240        }
241    }
242}
243
244// ---- runtime-delegating wrapper --------------------------------------------
245
246/// Emit a function that delegates to `Plan::dft_1d` at runtime.
247///
248/// Used for smooth-7 composites, runtime Rader primes, and Bluestein sizes —
249/// i.e., any size handled by the `OxiFFT` runtime but not by a hardcoded codelet.
250fn generate_runtime_wrapper(n: usize) -> TokenStream {
251    let fn_name = format_ident!("codelet_any_{n}");
252    let n_lit = proc_macro2::Literal::usize_unsuffixed(n);
253
254    quote! {
255        /// Runtime-delegating codelet generated for this size.
256        ///
257        /// Constructs an OxiFFT plan on each call and executes it.
258        /// `sign < 0` selects the forward transform; `sign > 0` selects the inverse.
259        pub fn #fn_name<T: crate::kernel::Float>(
260            x: &mut [crate::kernel::Complex<T>],
261            sign: i32,
262        ) {
263            use ::oxifft::api::{Direction, Flags, Plan};
264
265            debug_assert_eq!(x.len(), #n_lit, "codelet input length mismatch");
266
267            let direction = if sign < 0 {
268                Direction::Forward
269            } else {
270                Direction::Backward
271            };
272
273            let plan = Plan::<T>::dft_1d(#n_lit, direction, Flags::ESTIMATE)
274                .unwrap_or_else(|| {
275                    panic!(
276                        "OxiFFT: Plan::dft_1d failed for compile-time-verified size {}",
277                        #n_lit
278                    )
279                });
280
281            // Plan::execute is out-of-place; copy input to scratch, then execute in-place.
282            let input_snapshot: ::std::vec::Vec<crate::kernel::Complex<T>> = x.to_vec();
283            plan.execute(&input_snapshot, x);
284        }
285    }
286}
287
288// ============================================================================
289// Proc-macro entry point
290// ============================================================================
291
292/// Parse `gen_any_codelet!(N)` input and dispatch.
293///
294/// # Syntax
295/// ```ignore
296/// gen_any_codelet!(8);    // size-8, generates codelet_any_8
297/// gen_any_codelet!(15);   // size-15, runtime-delegating wrapper
298/// gen_any_codelet!(2003); // size-2003, Bluestein runtime wrapper
299/// ```
300///
301/// Returns a `compile_error!` token stream on parse or codegen failure.
302#[must_use]
303pub fn generate_from_macro(input: TokenStream) -> TokenStream {
304    match parse_and_generate(input) {
305        Ok(ts) => ts,
306        Err(e) => {
307            let msg = e.to_string();
308            quote! { compile_error!(#msg); }
309        }
310    }
311}
312
313fn parse_and_generate(input: TokenStream) -> Result<TokenStream, CodegenError> {
314    let size: LitInt = syn::parse2(input).map_err(|e| CodegenError::EmitError(e.to_string()))?;
315    let n: usize = size
316        .base10_parse()
317        .map_err(|e| CodegenError::EmitError(e.to_string()))?;
318    generate(n)
319}
320
321// ============================================================================
322// Builder (programmatic interface)
323// ============================================================================
324
325/// Builder for programmatic codelet generation without proc-macros.
326///
327/// # Example
328/// ```no_run
329/// use oxifft_codegen_impl::CodeletBuilder;
330///
331/// let ts = CodeletBuilder::new(15).build().unwrap();
332/// println!("{ts}");
333/// ```
334pub struct CodeletBuilder {
335    n: usize,
336    /// Reserved for future use: override the emitted function name.
337    /// Not wired through `build()` yet; the generated name is always `codelet_any_{N}`.
338    #[allow(dead_code)]
339    name_override: Option<String>,
340}
341
342impl CodeletBuilder {
343    /// Create a builder for the given FFT size.
344    #[must_use]
345    pub const fn new(n: usize) -> Self {
346        Self {
347            n,
348            name_override: None,
349        }
350    }
351
352    /// Override the generated function name (reserved for future use; currently has no effect).
353    #[must_use]
354    pub fn name(mut self, name: impl Into<String>) -> Self {
355        self.name_override = Some(name.into());
356        self
357    }
358
359    /// Generate the codelet `TokenStream`.
360    ///
361    /// # Errors
362    ///
363    /// Returns `CodegenError` if `n == 0` or code emission fails.
364    pub fn build(self) -> Result<TokenStream, CodegenError> {
365        generate(self.n)
366    }
367}
368
369// ============================================================================
370// Inline tests
371// ============================================================================
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn classify_all_notw() {
379        for &n in &[2usize, 4, 8, 16, 32, 64] {
380            assert!(
381                matches!(classify(n).unwrap(), SizeClass::Notw(_)),
382                "n={n} should be Notw"
383            );
384        }
385    }
386
387    #[test]
388    fn classify_all_odd() {
389        for &n in &[3usize, 5, 7] {
390            assert!(
391                matches!(classify(n).unwrap(), SizeClass::Odd(_)),
392                "n={n} should be Odd"
393            );
394        }
395    }
396
397    #[test]
398    fn classify_rader_hardcoded() {
399        assert!(matches!(
400            classify(11).unwrap(),
401            SizeClass::RaderHardcoded(11)
402        ));
403        assert!(matches!(
404            classify(13).unwrap(),
405            SizeClass::RaderHardcoded(13)
406        ));
407    }
408
409    #[test]
410    fn classify_rader_prime_runtime() {
411        // Primes >13 and <=1021 not in the hardcoded set
412        for &p in &[
413            17usize, 19, 23, 29, 31, 37, 41, 43, 47, 53, 97, 101, 1013, 1019, 1021,
414        ] {
415            assert!(
416                matches!(classify(p).unwrap(), SizeClass::RaderPrime(_)),
417                "n={p} should be RaderPrime"
418            );
419        }
420    }
421
422    #[test]
423    fn classify_bluestein_large_prime() {
424        // 2003 is prime and > 1021
425        assert!(matches!(
426            classify(2003).unwrap(),
427            SizeClass::Bluestein(2003)
428        ));
429    }
430
431    #[test]
432    fn classify_invalid_zero() {
433        assert_eq!(classify(0).unwrap_err(), CodegenError::InvalidSize(0));
434    }
435
436    #[test]
437    fn smooth7_factoring() {
438        // These should all be MixedRadix
439        for &n in &[
440            6usize, 10, 12, 14, 15, 21, 24, 28, 30, 35, 40, 42, 48, 56, 60, 80, 84, 96, 112, 120,
441            168, 240,
442        ] {
443            assert!(
444                matches!(classify(n).unwrap(), SizeClass::MixedRadix(_)),
445                "n={n} expected MixedRadix"
446            );
447        }
448    }
449
450    #[test]
451    fn smooth7_factors_correct_for_15() {
452        match classify(15).unwrap() {
453            SizeClass::MixedRadix(factors) => {
454                assert!(factors.contains(&5), "15 factors must include 5");
455                assert!(factors.contains(&3), "15 factors must include 3");
456            }
457            other => panic!("expected MixedRadix, got {other:?}"),
458        }
459    }
460
461    #[test]
462    fn is_prime_helper() {
463        assert!(is_prime(2));
464        assert!(is_prime(3));
465        assert!(is_prime(5));
466        assert!(is_prime(7));
467        assert!(is_prime(11));
468        assert!(is_prime(97));
469        assert!(!is_prime(0));
470        assert!(!is_prime(1));
471        assert!(!is_prime(4));
472        assert!(!is_prime(100));
473    }
474
475    #[test]
476    fn generate_emits_nonempty_for_direct_size() {
477        let ts = generate(8).unwrap();
478        assert!(!ts.to_string().is_empty());
479    }
480
481    #[test]
482    fn generate_emits_nonempty_for_odd_size() {
483        let ts = generate(3).unwrap();
484        assert!(!ts.to_string().is_empty());
485    }
486
487    #[test]
488    fn generate_emits_nonempty_for_rader_hardcoded() {
489        let ts = generate(11).unwrap();
490        assert!(!ts.to_string().is_empty());
491    }
492
493    #[test]
494    fn generate_emits_nonempty_for_mixed_radix() {
495        let ts = generate(15).unwrap();
496        assert!(!ts.to_string().is_empty());
497    }
498
499    #[test]
500    fn generate_emits_nonempty_for_bluestein() {
501        let ts = generate(2003).unwrap();
502        assert!(!ts.to_string().is_empty());
503    }
504
505    #[test]
506    fn generate_emits_nonempty_for_identity() {
507        let ts = generate(1).unwrap();
508        assert!(!ts.to_string().is_empty());
509    }
510
511    #[test]
512    fn generate_zero_returns_err() {
513        assert!(generate(0).is_err());
514    }
515
516    #[test]
517    fn codelet_builder_zero_returns_err() {
518        assert!(CodeletBuilder::new(0).build().is_err());
519    }
520
521    #[test]
522    fn codelet_builder_happy_path() {
523        let ts = CodeletBuilder::new(8).build().unwrap();
524        assert!(!ts.to_string().is_empty());
525    }
526}