Skip to main content

oxifft_codegen/
lib.rs

1//! `OxiFFT` Codelet Generator
2//!
3//! This proc-macro crate generates optimized FFT codelets at compile time.
4//! It replaces FFTW's OCaml-based genfft with Rust procedural macros.
5//!
6//! # Overview
7//!
8//! Codelets are highly optimized kernels for small FFT sizes (2-64).
9//! They are generated at compile time with:
10//! - Common subexpression elimination
11//! - Strength reduction
12//! - Optimal instruction ordering
13//! - SIMD-aware code patterns
14//!
15//! # Usage
16//!
17//! ```ignore
18//! use oxifft_codegen::gen_dft_codelet;
19//!
20//! // Generate size-8 DFT codelet
21//! gen_dft_codelet!(8);
22//! ```
23
24extern crate proc_macro;
25
26use proc_macro::TokenStream;
27
28/// Generate a non-twiddle (base case) DFT codelet.
29///
30/// # Arguments
31/// * `size` - The FFT size (must be 2, 4, 8, 16, 32, or 64)
32///
33/// # Example
34/// ```ignore
35/// gen_notw_codelet!(8);
36/// ```
37#[proc_macro]
38pub fn gen_notw_codelet(input: TokenStream) -> TokenStream {
39    let input2: proc_macro2::TokenStream = input.into();
40    oxifft_codegen_impl::gen_notw::generate(input2)
41        .unwrap_or_else(|e| e.to_compile_error())
42        .into()
43}
44
45/// Generate a twiddle-factor DFT codelet.
46#[proc_macro]
47pub fn gen_twiddle_codelet(input: TokenStream) -> TokenStream {
48    let input2: proc_macro2::TokenStream = input.into();
49    oxifft_codegen_impl::gen_twiddle::generate(input2)
50        .unwrap_or_else(|e| e.to_compile_error())
51        .into()
52}
53
54/// Generate a split-radix twiddle codelet.
55///
56/// The split-radix FFT decomposes an N-point DFT into one N/2-point DFT
57/// (even-indexed elements) and two N/4-point DFTs (odd-indexed elements)
58/// with twiddle factors `W_N^k` and `W_N^{3k`}, reducing the total multiply count.
59///
60/// # Usage
61/// ```ignore
62/// // Generate generic runtime-parameterized split-radix twiddle codelet
63/// gen_split_radix_twiddle_codelet!();
64///
65/// // Generate specialized unrolled version for N=8
66/// gen_split_radix_twiddle_codelet!(8);
67///
68/// // Generate specialized unrolled version for N=16
69/// gen_split_radix_twiddle_codelet!(16);
70/// ```
71#[proc_macro]
72pub fn gen_split_radix_twiddle_codelet(input: TokenStream) -> TokenStream {
73    let input2: proc_macro2::TokenStream = input.into();
74    oxifft_codegen_impl::gen_twiddle::generate_split_radix(input2)
75        .unwrap_or_else(|e| e.to_compile_error())
76        .into()
77}
78
79/// Generate a SIMD-optimized codelet.
80#[proc_macro]
81pub fn gen_simd_codelet(input: TokenStream) -> TokenStream {
82    let input2: proc_macro2::TokenStream = input.into();
83    oxifft_codegen_impl::gen_simd::generate(input2)
84        .unwrap_or_else(|e| e.to_compile_error())
85        .into()
86}
87
88/// Convenience macro to generate all codelets for a size.
89#[proc_macro]
90pub fn gen_dft_codelet(input: TokenStream) -> TokenStream {
91    let input2: proc_macro2::TokenStream = input.into();
92    oxifft_codegen_impl::gen_notw::generate(input2)
93        .unwrap_or_else(|e| e.to_compile_error())
94        .into()
95}
96
97/// Generate an odd-size (3, 5, 7) DFT codelet using Winograd minimum-multiply factorization.
98///
99/// The generated function is an in-place `&mut [Complex<T>]` codelet with `sign: i32`
100/// for runtime forward/inverse dispatch (matching `gen_notw_codelet!` conventions).
101///
102/// # Arguments
103/// * The size literal — must be 3, 5, or 7.
104///
105/// # Example
106/// ```ignore
107/// gen_odd_codelet!(3);  // emits `codelet_notw_3`
108/// gen_odd_codelet!(5);  // emits `codelet_notw_5`
109/// gen_odd_codelet!(7);  // emits `codelet_notw_7`
110/// ```
111#[proc_macro]
112pub fn gen_odd_codelet(input: TokenStream) -> TokenStream {
113    let input2: proc_macro2::TokenStream = input.into();
114    oxifft_codegen_impl::gen_odd::generate_from_macro(input2)
115        .unwrap_or_else(|e| e.to_compile_error())
116        .into()
117}
118
119/// Generate a Rader prime DFT codelet for primes 11 and 13.
120///
121/// The generated function uses the Rader algorithm to reduce the prime-size DFT
122/// to a cyclic convolution, computed as straight-line code with hardcoded twiddle
123/// factors.  Generator g = 2 for both supported primes.
124///
125/// # Arguments
126/// * The prime literal — must be 11 or 13.
127///
128/// # Example
129/// ```ignore
130/// gen_rader_codelet!(11);  // emits `codelet_notw_11`
131/// gen_rader_codelet!(13);  // emits `codelet_notw_13`
132/// ```
133#[proc_macro]
134pub fn gen_rader_codelet(input: TokenStream) -> TokenStream {
135    let input2: proc_macro2::TokenStream = input.into();
136    oxifft_codegen_impl::gen_rader::generate_from_macro(input2)
137        .unwrap_or_else(|e| e.to_compile_error())
138        .into()
139}
140
141/// Generate a vectorized multi-transform codelet.
142///
143/// Emits a function that processes V DFT transforms of size N simultaneously,
144/// where V is the SIMD lane count for the chosen ISA and precision.
145///
146/// The generated function name follows `notw_{size}_v{v}_{isa}_{ty}`.
147///
148/// # Arguments
149/// * `size` — DFT size: 2, 4, or 8
150/// * `v`    — number of simultaneous transforms (lane count)
151/// * `isa`  — target ISA: `sse2`, `avx2`, or `scalar`
152/// * `ty`   — float type: `f32` or `f64`
153///
154/// # Example
155/// ```ignore
156/// gen_multi_transform_codelet!(size = 4, v = 8, isa = avx2, ty = f32);
157/// // emits: pub unsafe fn notw_4_v8_avx2_f32(...)
158/// ```
159#[proc_macro]
160pub fn gen_multi_transform_codelet(input: TokenStream) -> TokenStream {
161    let input2: proc_macro2::TokenStream = input.into();
162    oxifft_codegen_impl::gen_simd::multi_transform::generate_from_macro(input2)
163        .unwrap_or_else(|e| e.to_compile_error())
164        .into()
165}
166
167/// Generate a cached ISA runtime dispatcher for a SIMD codelet.
168///
169/// Emits a function `codelet_simd_{size}_cached_{ty}(data, sign)` that caches
170/// the best ISA level in an `AtomicU8` static, avoiding repeated
171/// `is_x86_feature_detected!` / `is_aarch64_feature_detected!` calls on hot
172/// paths.  The cached dispatcher delegates to the same arch-specific inner
173/// functions as the uncached `codelet_simd_{size}<T>` dispatcher.
174///
175/// # Arguments
176/// * `size` — DFT size: 2, 4, 8, or 16
177/// * `ty`   — float type: `f32` or `f64`
178///
179/// # Priority order (high → low)
180/// - `x86_64`: AVX-512F > AVX2+FMA > AVX > SSE2 > scalar
181/// - `aarch64`: NEON > scalar
182/// - other: scalar
183///
184/// # Example
185/// ```ignore
186/// gen_dispatcher_codelet!(size = 4, ty = f32);
187/// // emits: pub fn codelet_simd_4_cached_f32(data: &mut [Complex<f32>], sign: i32)
188/// ```
189#[proc_macro]
190pub fn gen_dispatcher_codelet(input: TokenStream) -> TokenStream {
191    let input2: proc_macro2::TokenStream = input.into();
192    oxifft_codegen_impl::gen_simd::runtime_dispatch::generate_from_macro(input2)
193        .unwrap_or_else(|e| e.to_compile_error())
194        .into()
195}
196
197/// Generate a complete FFT codelet for any user-specified size N.
198///
199/// Routes to the most appropriate emitter based on the size:
200/// - **Direct set** {2, 4, 8, 16, 32, 64}: optimised non-twiddle codelet.
201/// - **Winograd odd** {3, 5, 7}: Winograd minimum-multiply codelet.
202/// - **Rader hardcoded** {11, 13}: straight-line Rader cyclic-convolution codelet.
203/// - **Smooth-7 composites** (all prime factors in {2, 3, 5, 7}):
204///   runtime-delegating wrapper using `Plan::dft_1d` (mixed-radix path).
205/// - **Primes p ≤ 1021**: runtime-delegating wrapper (runtime Rader/Generic).
206/// - **All other sizes**: runtime-delegating Bluestein wrapper via `Plan::dft_1d`.
207///
208/// # Syntax
209/// ```ignore
210/// gen_any_codelet!(8);     // emits codelet_any_8  (direct notw codelet)
211/// gen_any_codelet!(15);    // emits codelet_any_15 (runtime mixed-radix wrapper)
212/// gen_any_codelet!(2003);  // emits codelet_any_2003 (Bluestein wrapper)
213/// ```
214///
215/// The emitted function signature is:
216/// ```ignore
217/// pub fn codelet_any_{N}<T: crate::kernel::Float>(
218///     x: &mut [crate::kernel::Complex<T>],
219///     sign: i32,
220/// )
221/// ```
222#[proc_macro]
223pub fn gen_any_codelet(input: TokenStream) -> TokenStream {
224    let input2: proc_macro2::TokenStream = input.into();
225    oxifft_codegen_impl::gen_any::generate_from_macro(input2).into()
226}
227
228/// Generate a real-to-half-complex (R2HC) or half-complex-to-real (HC2R) codelet.
229///
230/// The generated function has the same signature and produces numerically equivalent
231/// results to the hand-written codelets in `oxifft/src/rdft/codelets/mod.rs`.
232///
233/// # Usage
234/// ```ignore
235/// use oxifft_codegen::gen_rdft_codelet;
236///
237/// // Generates `pub fn r2hc_4_gen<T: crate::kernel::Float>(x: &[T], y: &mut [Complex<T>])`
238/// gen_rdft_codelet!(size = 4, kind = R2hc);
239///
240/// // Generates `pub fn hc2r_4_gen<T: crate::kernel::Float>(y: &[Complex<T>], x: &mut [T])`
241/// gen_rdft_codelet!(size = 4, kind = Hc2r);
242/// ```
243///
244/// Supported sizes: 2, 4, 8.
245#[proc_macro]
246pub fn gen_rdft_codelet(input: TokenStream) -> TokenStream {
247    let input2: proc_macro2::TokenStream = input.into();
248    oxifft_codegen_impl::gen_rdft::generate(input2)
249        .unwrap_or_else(|e| e.to_compile_error())
250        .into()
251}