Skip to main content

oxifft_codegen_impl/
gen_notw.rs

1//! Non-twiddle codelet generation.
2//!
3//! Generates optimized base-case FFT kernels using symbolic computation,
4//! common subexpression elimination, and strength reduction.
5
6use crate::symbolic::emit_body_from_symbolic;
7use proc_macro2::TokenStream;
8use quote::quote;
9use syn::LitInt;
10
11/// Generate a non-twiddle codelet for the given size.
12///
13/// # Errors
14/// Returns a `syn::Error` when the input token stream does not parse as a valid
15/// size literal, or when the size is not in the supported set {2, 4, 8, 16, 32, 64}.
16pub fn generate(input: TokenStream) -> Result<TokenStream, syn::Error> {
17    let size: LitInt = syn::parse2(input)?;
18    let n: usize = size.base10_parse().map_err(|_| {
19        syn::Error::new(
20            size.span(),
21            "gen_notw_codelet: expected an integer size literal",
22        )
23    })?;
24
25    match n {
26        2 => Ok(gen_size_2()),
27        4 => Ok(gen_size_4()),
28        8 => Ok(gen_size_8()),
29        16 => Ok(gen_size_16()),
30        32 => Ok(gen_size_32()),
31        64 => Ok(gen_size_64()),
32        _ => Err(syn::Error::new(
33            size.span(),
34            format!("gen_notw_codelet: unsupported size {n} (expected one of 2, 4, 8, 16, 32, 64)"),
35        )),
36    }
37}
38
39fn gen_size_2() -> TokenStream {
40    quote! {
41        /// Size-2 DFT codelet (butterfly).
42        #[inline(always)]
43        pub fn codelet_notw_2<T: crate::kernel::Float>(
44            x: &mut [crate::kernel::Complex<T>],
45            _sign: i32,
46        ) {
47            debug_assert!(x.len() >= 2);
48            let a = x[0];
49            let b = x[1];
50            x[0] = a + b;
51            x[1] = a - b;
52        }
53    }
54}
55
56fn gen_size_4() -> TokenStream {
57    quote! {
58        /// Size-4 DFT codelet.
59        #[inline(always)]
60        pub fn codelet_notw_4<T: crate::kernel::Float>(
61            x: &mut [crate::kernel::Complex<T>],
62            sign: i32,
63        ) {
64            debug_assert!(x.len() >= 4);
65
66            let x0 = x[0];
67            let x1 = x[1];
68            let x2 = x[2];
69            let x3 = x[3];
70
71            // Stage 1
72            let t0 = x0 + x2;
73            let t1 = x0 - x2;
74            let t2 = x1 + x3;
75            let t3 = x1 - x3;
76
77            // Apply rotation
78            let t3_rot = if sign < 0 {
79                crate::kernel::Complex::new(t3.im, -t3.re)
80            } else {
81                crate::kernel::Complex::new(-t3.im, t3.re)
82            };
83
84            // Stage 2
85            x[0] = t0 + t2;
86            x[1] = t1 + t3_rot;
87            x[2] = t0 - t2;
88            x[3] = t1 - t3_rot;
89        }
90    }
91}
92
93fn gen_size_8() -> TokenStream {
94    // Size-8 DFT using radix-2 DIT with explicit butterfly stages.
95    // All constants pre-computed via T::from_f64() to avoid trait method ambiguity.
96    quote! {
97        /// Size-8 DFT codelet using radix-2 DIT decomposition.
98        ///
99        /// Inputs are taken in natural order, output is in natural order.
100        #[inline(always)]
101        pub fn codelet_notw_8<T: crate::kernel::Float>(
102            x: &mut [crate::kernel::Complex<T>],
103            sign: i32,
104        ) {
105            debug_assert!(x.len() >= 8);
106
107            // 1/sqrt(2) ≈ 0.7071067811865476
108            let c2 = T::from_f64(0.707_106_781_186_547_6_f64);
109
110            // Bit-reversal permutation for DIT (3-bit reversal)
111            // Natural:   0 1 2 3 4 5 6 7
112            // Bit-rev:   0 4 2 6 1 5 3 7
113            let mut a = [crate::kernel::Complex::<T>::zero(); 8];
114            a[0] = x[0]; a[1] = x[4];
115            a[2] = x[2]; a[3] = x[6];
116            a[4] = x[1]; a[5] = x[5];
117            a[6] = x[3]; a[7] = x[7];
118
119            // DIT Stage 1: 4 butterflies, span 1 (W2^0 = 1)
120            for i in (0..8usize).step_by(2) {
121                let t = a[i + 1];
122                a[i + 1] = a[i] - t;
123                a[i]     = a[i] + t;
124            }
125
126            // DIT Stage 2: 2 groups of 2 butterflies, span 2
127            // W4^0 = 1, W4^1 = -i (forward) or +i (inverse)
128            for group in (0..8usize).step_by(4) {
129                // k=0: W4^0 = 1
130                let t = a[group + 2];
131                a[group + 2] = a[group] - t;
132                a[group]     = a[group] + t;
133
134                // k=1: W4^1
135                let t = a[group + 3];
136                let t_tw = if sign < 0 {
137                    crate::kernel::Complex::new(t.im, -t.re)
138                } else {
139                    crate::kernel::Complex::new(-t.im, t.re)
140                };
141                a[group + 3] = a[group + 1] - t_tw;
142                a[group + 1] = a[group + 1] + t_tw;
143            }
144
145            // DIT Stage 3: 1 group of 4 butterflies, span 4
146            // W8^k for k in 0..4
147            // k=0: W8^0 = 1
148            let t = a[4];
149            a[4] = a[0] - t;
150            a[0] = a[0] + t;
151
152            // k=1: W8^1 = (1-i)/sqrt(2) forward, (1+i)/sqrt(2) inverse
153            let t = a[5];
154            let t_tw = if sign < 0 {
155                crate::kernel::Complex::new((t.re + t.im) * c2, (t.im - t.re) * c2)
156            } else {
157                crate::kernel::Complex::new((t.re - t.im) * c2, (t.im + t.re) * c2)
158            };
159            a[5] = a[1] - t_tw;
160            a[1] = a[1] + t_tw;
161
162            // k=2: W8^2 = -i (forward) or +i (inverse)
163            let t = a[6];
164            let t_tw = if sign < 0 {
165                crate::kernel::Complex::new(t.im, -t.re)
166            } else {
167                crate::kernel::Complex::new(-t.im, t.re)
168            };
169            a[6] = a[2] - t_tw;
170            a[2] = a[2] + t_tw;
171
172            // k=3: W8^3 = (-1-i)/sqrt(2) forward, (-1+i)/sqrt(2) inverse
173            let t = a[7];
174            let t_tw = if sign < 0 {
175                crate::kernel::Complex::new((-t.re + t.im) * c2, (-t.im - t.re) * c2)
176            } else {
177                crate::kernel::Complex::new((-t.re - t.im) * c2, (-t.im + t.re) * c2)
178            };
179            a[7] = a[3] - t_tw;
180            a[3] = a[3] + t_tw;
181
182            // Write back in natural order
183            for i in 0..8usize {
184                x[i] = a[i];
185            }
186        }
187    }
188}
189
190fn gen_size_16() -> TokenStream {
191    // Size-16 DFT codelet generated via the symbolic optimization pipeline.
192    // Forward and inverse bodies are emitted from the optimized symbolic DAG,
193    // then dispatched at runtime by `sign`.
194    let fwd = emit_body_from_symbolic(16, true);
195    let inv = emit_body_from_symbolic(16, false);
196    quote! {
197        /// Size-16 DFT codelet — generated via symbolic CSE/constant-folding pipeline.
198        ///
199        /// `sign < 0` → forward transform; `sign > 0` → inverse (un-normalized).
200        #[inline(always)]
201        #[allow(clippy::too_many_lines, clippy::approx_constant, clippy::suboptimal_flops)]
202        pub fn codelet_notw_16<T: crate::kernel::Float>(
203            x: &mut [crate::kernel::Complex<T>],
204            sign: i32,
205        ) {
206            debug_assert!(x.len() >= 16);
207            if sign < 0 {
208                #fwd
209            } else {
210                #inv
211            }
212        }
213    }
214}
215
216fn gen_size_32() -> TokenStream {
217    // Size-32 DFT codelet generated via the symbolic optimization pipeline.
218    let fwd = emit_body_from_symbolic(32, true);
219    let inv = emit_body_from_symbolic(32, false);
220    quote! {
221        /// Size-32 DFT codelet — generated via symbolic CSE/constant-folding pipeline.
222        ///
223        /// `sign < 0` → forward transform; `sign > 0` → inverse (un-normalized).
224        #[inline(always)]
225        #[allow(clippy::too_many_lines, clippy::approx_constant, clippy::suboptimal_flops)]
226        pub fn codelet_notw_32<T: crate::kernel::Float>(
227            x: &mut [crate::kernel::Complex<T>],
228            sign: i32,
229        ) {
230            debug_assert!(x.len() >= 32);
231            if sign < 0 {
232                #fwd
233            } else {
234                #inv
235            }
236        }
237    }
238}
239
240fn gen_size_64() -> TokenStream {
241    // Size-64 DFT codelet generated via the symbolic optimization pipeline.
242    let fwd = emit_body_from_symbolic(64, true);
243    let inv = emit_body_from_symbolic(64, false);
244    quote! {
245        /// Size-64 DFT codelet — generated via symbolic CSE/constant-folding pipeline.
246        ///
247        /// `sign < 0` → forward transform; `sign > 0` → inverse (un-normalized).
248        #[inline(always)]
249        #[allow(clippy::too_many_lines, clippy::approx_constant, clippy::suboptimal_flops)]
250        pub fn codelet_notw_64<T: crate::kernel::Float>(
251            x: &mut [crate::kernel::Complex<T>],
252            sign: i32,
253        ) {
254            debug_assert!(x.len() >= 64);
255            if sign < 0 {
256                #fwd
257            } else {
258                #inv
259            }
260        }
261    }
262}