Skip to main content

oxifft_codegen_impl/
gen_twiddle.rs

1//! Twiddle-factor codelet generation.
2//!
3//! Generates codelets that apply twiddle factors during multi-radix FFT computation.
4
5use proc_macro2::TokenStream;
6use quote::quote;
7use syn::LitInt;
8
9/// Generate a twiddle codelet for the given radix.
10///
11/// # Errors
12/// Returns a `syn::Error` when the input does not parse as a valid radix literal,
13/// or when the radix is not in the supported set {2, 4, 8, 16}.
14pub fn generate(input: TokenStream) -> Result<TokenStream, syn::Error> {
15    let radix: LitInt = syn::parse2(input)?;
16    let r: usize = radix.base10_parse().map_err(|_| {
17        syn::Error::new(
18            radix.span(),
19            "gen_twiddle_codelet: expected an integer radix literal",
20        )
21    })?;
22
23    match r {
24        2 => Ok(gen_twiddle_2()),
25        4 => Ok(gen_twiddle_4()),
26        8 => Ok(gen_twiddle_8()),
27        16 => Ok(gen_twiddle_16()),
28        _ => Err(syn::Error::new(
29            radix.span(),
30            format!("gen_twiddle_codelet: unsupported radix {r} (expected one of 2, 4, 8, 16)"),
31        )),
32    }
33}
34
35fn gen_twiddle_2() -> TokenStream {
36    quote! {
37        /// Radix-2 twiddle codelet.
38        ///
39        /// Applies a single twiddle factor and computes a 2-point butterfly.
40        #[inline(always)]
41        pub fn codelet_twiddle_2<T: crate::kernel::Float>(
42            x: &mut [crate::kernel::Complex<T>],
43            twiddle: crate::kernel::Complex<T>,
44        ) {
45            debug_assert!(x.len() >= 2);
46            let a = x[0];
47            let b = x[1] * twiddle;
48            x[0] = a + b;
49            x[1] = a - b;
50        }
51    }
52}
53
54fn gen_twiddle_4() -> TokenStream {
55    quote! {
56        /// Radix-4 twiddle codelet.
57        ///
58        /// Applies twiddle factors w1, w2, w3 to inputs x[1], x[2], x[3]
59        /// and computes a 4-point FFT.
60        #[inline(always)]
61        pub fn codelet_twiddle_4<T: crate::kernel::Float>(
62            x: &mut [crate::kernel::Complex<T>],
63            tw1: crate::kernel::Complex<T>,
64            tw2: crate::kernel::Complex<T>,
65            tw3: crate::kernel::Complex<T>,
66            sign: i32,
67        ) {
68            debug_assert!(x.len() >= 4);
69
70            let x0 = x[0];
71            let x1 = x[1] * tw1;
72            let x2 = x[2] * tw2;
73            let x3 = x[3] * tw3;
74
75            let t0 = x0 + x2;
76            let t1 = x0 - x2;
77            let t2 = x1 + x3;
78            let t3 = x1 - x3;
79
80            let t3_rot = if sign < 0 {
81                crate::kernel::Complex::new(t3.im, -t3.re)
82            } else {
83                crate::kernel::Complex::new(-t3.im, t3.re)
84            };
85
86            x[0] = t0 + t2;
87            x[1] = t1 + t3_rot;
88            x[2] = t0 - t2;
89            x[3] = t1 - t3_rot;
90        }
91    }
92}
93
94#[allow(clippy::too_many_lines)]
95fn gen_twiddle_16() -> TokenStream {
96    quote! {
97        /// Radix-16 twiddle codelet.
98        ///
99        /// Applies 15 twiddle factors to inputs x[1]..x[15] and computes a 16-point FFT
100        /// using a radix-2 DIT (decimation-in-time) butterfly structure.
101        ///
102        /// # Arguments
103        /// * `x`        - Input/output slice of at least 16 complex values
104        /// * `twiddles` - Array of 15 precomputed twiddle factors for positions 1..=15
105        /// * `sign`     - Transform direction: -1 for forward, +1 for inverse
106        #[inline(always)]
107        pub fn codelet_twiddle_16<T: crate::kernel::Float>(
108            x: &mut [crate::kernel::Complex<T>],
109            twiddles: &[crate::kernel::Complex<T>; 15],
110            sign: i32,
111        ) {
112            debug_assert!(x.len() >= 16);
113
114            // Step 1: Apply external twiddle factors to positions 1..=15
115            let x0  = x[0];
116            let x1  = x[1]  * twiddles[0];
117            let x2  = x[2]  * twiddles[1];
118            let x3  = x[3]  * twiddles[2];
119            let x4  = x[4]  * twiddles[3];
120            let x5  = x[5]  * twiddles[4];
121            let x6  = x[6]  * twiddles[5];
122            let x7  = x[7]  * twiddles[6];
123            let x8  = x[8]  * twiddles[7];
124            let x9  = x[9]  * twiddles[8];
125            let x10 = x[10] * twiddles[9];
126            let x11 = x[11] * twiddles[10];
127            let x12 = x[12] * twiddles[11];
128            let x13 = x[13] * twiddles[12];
129            let x14 = x[14] * twiddles[13];
130            let x15 = x[15] * twiddles[14];
131
132            // Step 2: Compute 16-point DFT using radix-2 DIT.
133            // Place twiddle-applied values in bit-reversed order, then apply 4 DIT stages.
134            //
135            // Bit-reversal permutation for 16 (4-bit reversal):
136            //   0->0, 1->8, 2->4, 3->12, 4->2, 5->10, 6->6, 7->14,
137            //   8->1, 9->9, 10->5, 11->13, 12->3, 13->11, 14->7, 15->15
138            let mut a = [crate::kernel::Complex::<T>::zero(); 16];
139            a[0]  = x0;
140            a[1]  = x8;
141            a[2]  = x4;
142            a[3]  = x12;
143            a[4]  = x2;
144            a[5]  = x10;
145            a[6]  = x6;
146            a[7]  = x14;
147            a[8]  = x1;
148            a[9]  = x9;
149            a[10] = x5;
150            a[11] = x13;
151            a[12] = x3;
152            a[13] = x11;
153            a[14] = x7;
154            a[15] = x15;
155
156            // DIT Stage 1: 8 butterflies, span 1 (W2^0 = 1, no twiddle)
157            for i in (0..16usize).step_by(2) {
158                let t = a[i + 1];
159                a[i + 1] = a[i] - t;
160                a[i]     = a[i] + t;
161            }
162
163            // DIT Stage 2: 4 groups of 2 butterflies, span 2
164            // W4^0 = 1,  W4^1 = -i (forward) or +i (inverse)
165            for group in (0..16usize).step_by(4) {
166                // k=0: W4^0 = 1
167                let t = a[group + 2];
168                a[group + 2] = a[group] - t;
169                a[group]     = a[group] + t;
170
171                // k=1: W4^1
172                let t = a[group + 3];
173                let t_tw = if sign < 0 {
174                    crate::kernel::Complex::new(t.im, -t.re)
175                } else {
176                    crate::kernel::Complex::new(-t.im, t.re)
177                };
178                a[group + 3] = a[group + 1] - t_tw;
179                a[group + 1] = a[group + 1] + t_tw;
180            }
181
182            // DIT Stage 3: 2 groups of 4 butterflies, span 4
183            // W8^k for k in 0..4
184            // c2 = 1/sqrt(2) ≈ 0.7071067811865476
185            let c2 = T::from_f64(0.707_106_781_186_547_6_f64);
186            for group in (0..16usize).step_by(8) {
187                // k=0: W8^0 = 1
188                let t = a[group + 4];
189                a[group + 4] = a[group] - t;
190                a[group]     = a[group] + t;
191
192                // k=1: W8^1 = (1-i)/sqrt(2) forward, (1+i)/sqrt(2) inverse
193                let t = a[group + 5];
194                let t_tw = if sign < 0 {
195                    crate::kernel::Complex::new((t.re + t.im) * c2, (t.im - t.re) * c2)
196                } else {
197                    crate::kernel::Complex::new((t.re - t.im) * c2, (t.im + t.re) * c2)
198                };
199                a[group + 5] = a[group + 1] - t_tw;
200                a[group + 1] = a[group + 1] + t_tw;
201
202                // k=2: W8^2 = -i forward, +i inverse
203                let t = a[group + 6];
204                let t_tw = if sign < 0 {
205                    crate::kernel::Complex::new(t.im, -t.re)
206                } else {
207                    crate::kernel::Complex::new(-t.im, t.re)
208                };
209                a[group + 6] = a[group + 2] - t_tw;
210                a[group + 2] = a[group + 2] + t_tw;
211
212                // k=3: W8^3 = (-1-i)/sqrt(2) forward, (-1+i)/sqrt(2) inverse
213                let t = a[group + 7];
214                let t_tw = if sign < 0 {
215                    crate::kernel::Complex::new((-t.re + t.im) * c2, (-t.im - t.re) * c2)
216                } else {
217                    crate::kernel::Complex::new((-t.re - t.im) * c2, (-t.im + t.re) * c2)
218                };
219                a[group + 7] = a[group + 3] - t_tw;
220                a[group + 3] = a[group + 3] + t_tw;
221            }
222
223            // DIT Stage 4: 1 group of 8 butterflies, span 8
224            // W16^k for k in 0..8
225            // Constants: cos(π/8), sin(π/8), 1/sqrt(2), cos(3π/8)=sin(π/8), sin(3π/8)=cos(π/8)
226            let c1 = T::from_f64(0.923_879_532_511_286_7_f64); // cos(π/8)
227            let s1 = T::from_f64(0.382_683_432_365_089_8_f64); // sin(π/8)
228
229            // k=0: W16^0 = 1
230            let t = a[8];
231            a[8] = a[0] - t;
232            a[0] = a[0] + t;
233
234            // k=1: W16^1 = cos(π/8) - i*sin(π/8) forward, cos(π/8) + i*sin(π/8) inverse
235            let t = a[9];
236            let t_tw = if sign < 0 {
237                crate::kernel::Complex::new(t.re * c1 + t.im * s1, t.im * c1 - t.re * s1)
238            } else {
239                crate::kernel::Complex::new(t.re * c1 - t.im * s1, t.im * c1 + t.re * s1)
240            };
241            a[9] = a[1] - t_tw;
242            a[1] = a[1] + t_tw;
243
244            // k=2: W16^2 = (1-i)/sqrt(2) forward, (1+i)/sqrt(2) inverse
245            let t = a[10];
246            let t_tw = if sign < 0 {
247                crate::kernel::Complex::new((t.re + t.im) * c2, (t.im - t.re) * c2)
248            } else {
249                crate::kernel::Complex::new((t.re - t.im) * c2, (t.im + t.re) * c2)
250            };
251            a[10] = a[2] - t_tw;
252            a[2] = a[2] + t_tw;
253
254            // k=3: W16^3 = cos(3π/8) - i*sin(3π/8) forward
255            //             = sin(π/8) - i*cos(π/8) forward  (since cos(3π/8)=sin(π/8))
256            let c3 = s1; // cos(3π/8) = sin(π/8)
257            let s3 = c1; // sin(3π/8) = cos(π/8)
258            let t = a[11];
259            let t_tw = if sign < 0 {
260                crate::kernel::Complex::new(t.re * c3 + t.im * s3, t.im * c3 - t.re * s3)
261            } else {
262                crate::kernel::Complex::new(t.re * c3 - t.im * s3, t.im * c3 + t.re * s3)
263            };
264            a[11] = a[3] - t_tw;
265            a[3] = a[3] + t_tw;
266
267            // k=4: W16^4 = -i forward, +i inverse
268            let t = a[12];
269            let t_tw = if sign < 0 {
270                crate::kernel::Complex::new(t.im, -t.re)
271            } else {
272                crate::kernel::Complex::new(-t.im, t.re)
273            };
274            a[12] = a[4] - t_tw;
275            a[4] = a[4] + t_tw;
276
277            // k=5: W16^5 = cos(5π/8) - i*sin(5π/8) = -sin(π/8) - i*cos(π/8) forward
278            let t = a[13];
279            let t_tw = if sign < 0 {
280                crate::kernel::Complex::new(-t.re * s1 + t.im * c1, -t.im * s1 - t.re * c1)
281            } else {
282                crate::kernel::Complex::new(-t.re * s1 - t.im * c1, -t.im * s1 + t.re * c1)
283            };
284            a[13] = a[5] - t_tw;
285            a[5] = a[5] + t_tw;
286
287            // k=6: W16^6 = (-1-i)/sqrt(2) forward, (-1+i)/sqrt(2) inverse
288            let t = a[14];
289            let t_tw = if sign < 0 {
290                crate::kernel::Complex::new((-t.re + t.im) * c2, (-t.im - t.re) * c2)
291            } else {
292                crate::kernel::Complex::new((-t.re - t.im) * c2, (-t.im + t.re) * c2)
293            };
294            a[14] = a[6] - t_tw;
295            a[6] = a[6] + t_tw;
296
297            // k=7: W16^7 = cos(7π/8) - i*sin(7π/8) = -cos(π/8) - i*sin(π/8) forward
298            let t = a[15];
299            let t_tw = if sign < 0 {
300                crate::kernel::Complex::new(-t.re * c1 + t.im * s1, -t.im * c1 - t.re * s1)
301            } else {
302                crate::kernel::Complex::new(-t.re * c1 - t.im * s1, -t.im * c1 + t.re * s1)
303            };
304            a[15] = a[7] - t_tw;
305            a[7] = a[7] + t_tw;
306
307            // Write back in natural order (DIT produces natural order after bit-reversal input)
308            for i in 0..16usize {
309                x[i] = a[i];
310            }
311        }
312    }
313}
314
315#[allow(clippy::too_many_lines)]
316fn gen_twiddle_8() -> TokenStream {
317    quote! {
318        /// Radix-8 twiddle codelet.
319        ///
320        /// Applies 7 external twiddle factors to inputs x[1]..x[7], then computes
321        /// an 8-point FFT using a radix-2 DIT butterfly structure.
322        ///
323        /// # Arguments
324        /// * `x`        - Input/output slice of at least 8 complex values
325        /// * `twiddles` - Array of 7 precomputed twiddle factors for positions 1..=7
326        /// * `sign`     - Transform direction: -1 for forward, +1 for inverse
327        #[inline(always)]
328        pub fn codelet_twiddle_8<T: crate::kernel::Float>(
329            x: &mut [crate::kernel::Complex<T>],
330            twiddles: &[crate::kernel::Complex<T>; 7],
331            sign: i32,
332        ) {
333            debug_assert!(x.len() >= 8);
334
335            // Step 1: Apply external twiddle factors to positions 1..=7
336            let x0 = x[0];
337            let x1 = x[1] * twiddles[0];
338            let x2 = x[2] * twiddles[1];
339            let x3 = x[3] * twiddles[2];
340            let x4 = x[4] * twiddles[3];
341            let x5 = x[5] * twiddles[4];
342            let x6 = x[6] * twiddles[5];
343            let x7 = x[7] * twiddles[6];
344
345            // Step 2: Compute 8-point DFT using radix-2 DIT.
346            // Place twiddle-applied values in bit-reversed order, then apply 3 DIT stages.
347            // Bit-reversal for 8 (3-bit): 0→0, 1→4, 2→2, 3→6, 4→1, 5→5, 6→3, 7→7
348            let mut a = [crate::kernel::Complex::<T>::zero(); 8];
349            a[0] = x0; a[1] = x4;
350            a[2] = x2; a[3] = x6;
351            a[4] = x1; a[5] = x5;
352            a[6] = x3; a[7] = x7;
353
354            // DIT Stage 1: 4 butterflies, span 1 (W2^0 = 1)
355            for i in (0..8usize).step_by(2) {
356                let t = a[i + 1];
357                a[i + 1] = a[i] - t;
358                a[i]     = a[i] + t;
359            }
360
361            // DIT Stage 2: 2 groups of 2 butterflies, span 2
362            // W4^0 = 1, W4^1 = -i (forward) or +i (inverse)
363            for group in (0..8usize).step_by(4) {
364                // k=0: W4^0 = 1
365                let t = a[group + 2];
366                a[group + 2] = a[group] - t;
367                a[group]     = a[group] + t;
368
369                // k=1: W4^1
370                let t = a[group + 3];
371                let t_tw = if sign < 0 {
372                    crate::kernel::Complex::new(t.im, -t.re)
373                } else {
374                    crate::kernel::Complex::new(-t.im, t.re)
375                };
376                a[group + 3] = a[group + 1] - t_tw;
377                a[group + 1] = a[group + 1] + t_tw;
378            }
379
380            // DIT Stage 3: 1 group of 4 butterflies, span 4
381            // W8^k for k in 0..4. c2 = 1/sqrt(2) ≈ 0.7071067811865476
382            let c2 = T::from_f64(0.707_106_781_186_547_6_f64);
383
384            // k=0: W8^0 = 1
385            let t = a[4];
386            a[4] = a[0] - t;
387            a[0] = a[0] + t;
388
389            // k=1: W8^1 = (1-i)/sqrt(2) forward, (1+i)/sqrt(2) inverse
390            let t = a[5];
391            let t_tw = if sign < 0 {
392                crate::kernel::Complex::new((t.re + t.im) * c2, (t.im - t.re) * c2)
393            } else {
394                crate::kernel::Complex::new((t.re - t.im) * c2, (t.im + t.re) * c2)
395            };
396            a[5] = a[1] - t_tw;
397            a[1] = a[1] + t_tw;
398
399            // k=2: W8^2 = -i (forward) or +i (inverse)
400            let t = a[6];
401            let t_tw = if sign < 0 {
402                crate::kernel::Complex::new(t.im, -t.re)
403            } else {
404                crate::kernel::Complex::new(-t.im, t.re)
405            };
406            a[6] = a[2] - t_tw;
407            a[2] = a[2] + t_tw;
408
409            // k=3: W8^3 = (-1-i)/sqrt(2) forward, (-1+i)/sqrt(2) inverse
410            let t = a[7];
411            let t_tw = if sign < 0 {
412                crate::kernel::Complex::new((-t.re + t.im) * c2, (-t.im - t.re) * c2)
413            } else {
414                crate::kernel::Complex::new((-t.re - t.im) * c2, (-t.im + t.re) * c2)
415            };
416            a[7] = a[3] - t_tw;
417            a[3] = a[3] + t_tw;
418
419            // Write back in natural order
420            for i in 0..8usize {
421                x[i] = a[i];
422            }
423        }
424
425        /// Radix-8 twiddle codelet with inline twiddle computation.
426        ///
427        /// This version computes twiddles from angle step, useful when twiddles
428        /// are not precomputed.
429        #[inline(always)]
430        pub fn codelet_twiddle_8_inline<T: crate::kernel::Float>(
431            x: &mut [crate::kernel::Complex<T>],
432            angle_step: T,
433            sign: i32,
434        ) {
435            debug_assert!(x.len() >= 8);
436
437            // Compute twiddles inline via fully-qualified Float trait methods to avoid ambiguity
438            let tw1 = crate::kernel::Complex::new(
439                <T as crate::kernel::Float>::cos(angle_step),
440                <T as crate::kernel::Float>::sin(angle_step),
441            );
442            let tw2 = crate::kernel::Complex::new(
443                <T as crate::kernel::Float>::cos(angle_step * T::from_usize(2)),
444                <T as crate::kernel::Float>::sin(angle_step * T::from_usize(2)),
445            );
446            let tw3 = crate::kernel::Complex::new(
447                <T as crate::kernel::Float>::cos(angle_step * T::from_usize(3)),
448                <T as crate::kernel::Float>::sin(angle_step * T::from_usize(3)),
449            );
450            let tw4 = crate::kernel::Complex::new(
451                <T as crate::kernel::Float>::cos(angle_step * T::from_usize(4)),
452                <T as crate::kernel::Float>::sin(angle_step * T::from_usize(4)),
453            );
454            let tw5 = crate::kernel::Complex::new(
455                <T as crate::kernel::Float>::cos(angle_step * T::from_usize(5)),
456                <T as crate::kernel::Float>::sin(angle_step * T::from_usize(5)),
457            );
458            let tw6 = crate::kernel::Complex::new(
459                <T as crate::kernel::Float>::cos(angle_step * T::from_usize(6)),
460                <T as crate::kernel::Float>::sin(angle_step * T::from_usize(6)),
461            );
462            let tw7 = crate::kernel::Complex::new(
463                <T as crate::kernel::Float>::cos(angle_step * T::from_usize(7)),
464                <T as crate::kernel::Float>::sin(angle_step * T::from_usize(7)),
465            );
466
467            let twiddles = [tw1, tw2, tw3, tw4, tw5, tw6, tw7];
468            codelet_twiddle_8(x, &twiddles, sign);
469        }
470    }
471}
472
473/// Generate a split-radix twiddle codelet for the given size.
474///
475/// If no size is specified (empty input), generates the generic runtime-parameterized
476/// split-radix twiddle codelet. If a size is given (8 or 16), generates a specialized
477/// unrolled version for that size.
478///
479/// # Errors
480/// Returns a `syn::Error` when the input does not parse as a valid size literal,
481/// or when the size is not in the supported set {8, 16} (or empty for the generic variant).
482pub fn generate_split_radix(input: TokenStream) -> Result<TokenStream, syn::Error> {
483    if input.is_empty() {
484        return Ok(gen_split_radix_twiddle());
485    }
486    let size: LitInt = syn::parse2(input)?;
487    let n: usize = size.base10_parse().map_err(|_| {
488        syn::Error::new(
489            size.span(),
490            "gen_split_radix_twiddle_codelet: expected an integer size literal",
491        )
492    })?;
493    match n {
494        8 => Ok(gen_split_radix_twiddle_8()),
495        16 => Ok(gen_split_radix_twiddle_16()),
496        _ => Err(syn::Error::new(
497            size.span(),
498            format!("gen_split_radix_twiddle_codelet: unsupported size {n} (use 8 or 16, or empty for generic)"),
499        )),
500    }
501}
502
503/// Generate the generic split-radix twiddle codelet (L-shaped butterfly).
504///
505/// The split-radix FFT decomposes an N-point DFT into:
506/// - One N/2-point DFT of even-indexed elements
507/// - Two N/4-point DFTs of odd-indexed elements (with twiddle factors `W_N^k` and `W_N^{3k`})
508///
509/// This codelet performs the combining step (L-shaped butterfly):
510///   For k = 0..N/4-1:
511///     t1 = `W_N^k` * O1[k],  t2 = `W_N^{3k`} * O3[k]
512///     p = t1 + t2,  m = t1 - t2
513///     X[k]       = E[k]     + p
514///     X[k+N/4]   = E[k+N/4] - j*(t1-t2)  (forward)
515///     X[k+N/2]   = E[k]     - p
516///     X[k+3N/4]  = E[k+N/4] + j*(t1-t2)  (forward)
517fn gen_split_radix_twiddle() -> TokenStream {
518    let expanded = quote! {
519        /// Split-radix twiddle codelet (L-shaped butterfly).
520        ///
521        /// Combines the results of an N/2-point even DFT with two N/4-point odd DFTs
522        /// using split-radix decomposition. This reduces the total number of real
523        /// multiplications compared to standard radix-2: approximately 4N log2(N) - 6N + 8
524        /// vs. 5N log2(N) for radix-2.
525        ///
526        /// # Data Layout
527        ///
528        /// On input, `data[0..n]` is organized as:
529        /// - `data[0..n/2]`: N/2-point DFT of even-indexed elements (E)
530        /// - `data[n/2..3n/4]`: N/4-point DFT of odd-1 elements (O1, indices 1,5,9,...)
531        /// - `data[3n/4..n]`: N/4-point DFT of odd-3 elements (O3, indices 3,7,11,...)
532        ///
533        /// On output, `data[0..n]` contains the combined N-point DFT result.
534        ///
535        /// # Arguments
536        /// * `data`     - Input/output slice of at least `n` complex values
537        /// * `n`        - Transform size (must be divisible by 4 and >= 4)
538        /// * `twiddles` - Twiddle factors W_N^k for k = 0..n/4-1
539        /// * `twiddles3`- Twiddle factors W_N^{3k} for k = 0..n/4-1
540        /// * `sign`     - Transform direction: -1 for forward, +1 for inverse
541        #[inline]
542        pub fn codelet_split_radix_twiddle<T: crate::kernel::Float>(
543            data: &mut [crate::kernel::Complex<T>],
544            n: usize,
545            twiddles: &[crate::kernel::Complex<T>],
546            twiddles3: &[crate::kernel::Complex<T>],
547            sign: i32,
548        ) {
549            debug_assert!(n >= 4 && n % 4 == 0, "n must be >= 4 and divisible by 4");
550            debug_assert!(data.len() >= n, "data slice too short for split-radix");
551
552            let n2 = n >> 1;   // N/2
553            let n4 = n >> 2;   // N/4
554
555            debug_assert!(twiddles.len() >= n4, "twiddles slice too short");
556            debug_assert!(twiddles3.len() >= n4, "twiddles3 slice too short");
557
558            for k in 0..n4 {
559                // Read the three sub-DFT results
560                let e_k    = data[k];          // E[k] from even DFT
561                let e_k_q  = data[k + n4];     // E[k + N/4] from even DFT
562                let o1_k   = data[n2 + k];     // O1[k] from first odd DFT
563                let o3_k   = data[n2 + n4 + k]; // O3[k] from second odd DFT
564
565                // Apply twiddle factors to odd sub-DFT results
566                let t1 = o1_k * twiddles[k];    // W_N^k * O1[k]
567                let t2 = o3_k * twiddles3[k];   // W_N^{3k} * O3[k]
568
569                // Split-radix butterfly computation
570                let p = t1 + t2;    // Sum of twiddled odd results
571                let m = t1 - t2;    // Difference of twiddled odd results
572
573                // Rotate difference by -j (forward) or +j (inverse)
574                // -j * (a + bi) = (b, -a);   +j * (a + bi) = (-b, a)
575                let m_rot = if sign < 0 {
576                    crate::kernel::Complex::new(m.im, -m.re)
577                } else {
578                    crate::kernel::Complex::new(-m.im, m.re)
579                };
580
581                // Write combined results to all four quarters
582                data[k]          = e_k   + p;       // X[k]
583                data[k + n4]     = e_k_q + m_rot;   // X[k + N/4]
584                data[k + n2]     = e_k   - p;       // X[k + N/2]
585                data[k + n2 + n4] = e_k_q - m_rot;  // X[k + 3N/4]
586            }
587        }
588
589        /// Split-radix twiddle codelet with inline twiddle computation.
590        ///
591        /// Computes W_N^k and W_N^{3k} twiddle factors on the fly from the
592        /// base angle step, useful when twiddle tables are not precomputed.
593        ///
594        /// # Arguments
595        /// * `data` - Input/output slice of at least `n` complex values
596        /// * `n`    - Transform size (must be divisible by 4 and >= 4)
597        /// * `sign` - Transform direction: -1 for forward, +1 for inverse
598        #[inline]
599        pub fn codelet_split_radix_twiddle_inline<T: crate::kernel::Float>(
600            data: &mut [crate::kernel::Complex<T>],
601            n: usize,
602            sign: i32,
603        ) {
604            debug_assert!(n >= 4 && n % 4 == 0, "n must be >= 4 and divisible by 4");
605            debug_assert!(data.len() >= n, "data slice too short for split-radix");
606
607            let n2 = n >> 1;
608            let n4 = n >> 2;
609
610            // Base angle: -2π/N (forward) or +2π/N (inverse)
611            let base_angle = if sign < 0 {
612                -2.0_f64 * core::f64::consts::PI / (n as f64)
613            } else {
614                2.0_f64 * core::f64::consts::PI / (n as f64)
615            };
616
617            for k in 0..n4 {
618                let angle_k = base_angle * (k as f64);
619                let angle_3k = base_angle * (3 * k) as f64;
620
621                let tw = crate::kernel::Complex::new(
622                    T::from_f64(angle_k.cos()),
623                    T::from_f64(angle_k.sin()),
624                );
625                let tw3 = crate::kernel::Complex::new(
626                    T::from_f64(angle_3k.cos()),
627                    T::from_f64(angle_3k.sin()),
628                );
629
630                let e_k    = data[k];
631                let e_k_q  = data[k + n4];
632                let o1_k   = data[n2 + k];
633                let o3_k   = data[n2 + n4 + k];
634
635                let t1 = o1_k * tw;
636                let t2 = o3_k * tw3;
637
638                let p = t1 + t2;
639                let m = t1 - t2;
640
641                let m_rot = if sign < 0 {
642                    crate::kernel::Complex::new(m.im, -m.re)
643                } else {
644                    crate::kernel::Complex::new(-m.im, m.re)
645                };
646
647                data[k]          = e_k   + p;
648                data[k + n4]     = e_k_q + m_rot;
649                data[k + n2]     = e_k   - p;
650                data[k + n2 + n4] = e_k_q - m_rot;
651            }
652        }
653    };
654    expanded
655}
656
657/// Generate a specialized 8-point split-radix twiddle codelet (fully unrolled).
658///
659/// N=8: N/2=4 even, N/4=2 odd-1, N/4=2 odd-3.
660/// Unrolls the L-shaped butterfly for k=0,1.
661fn gen_split_radix_twiddle_8() -> TokenStream {
662    let expanded = quote! {
663        /// Split-radix twiddle codelet for N=8 (fully unrolled).
664        ///
665        /// Combines a 4-point even DFT with two 2-point odd DFTs using
666        /// the L-shaped butterfly. All 2 iterations fully unrolled.
667        ///
668        /// # Data Layout
669        /// - `data[0..4]`: 4-point even DFT result (E)
670        /// - `data[4..6]`: 2-point odd-1 DFT result (O1)
671        /// - `data[6..8]`: 2-point odd-3 DFT result (O3)
672        ///
673        /// # Arguments
674        /// * `data`     - Input/output slice of at least 8 complex values
675        /// * `twiddles` - `[W_8^0, W_8^1]` twiddle factors
676        /// * `twiddles3`- `[W_8^0, W_8^3]` twiddle factors
677        /// * `sign`     - Transform direction: -1 for forward, +1 for inverse
678        #[inline(always)]
679        pub fn codelet_split_radix_twiddle_8<T: crate::kernel::Float>(
680            data: &mut [crate::kernel::Complex<T>],
681            twiddles: &[crate::kernel::Complex<T>; 2],
682            twiddles3: &[crate::kernel::Complex<T>; 2],
683            sign: i32,
684        ) {
685            debug_assert!(data.len() >= 8);
686
687            // k=0: E[0], E[2], O1[0], O3[0]
688            let e0   = data[0];
689            let e0_q = data[2];   // E[0 + N/4] = E[2]
690            let o1_0 = data[4];
691            let o3_0 = data[6];
692
693            let t1_0 = o1_0 * twiddles[0];    // W_8^0 * O1[0]
694            let t2_0 = o3_0 * twiddles3[0];   // W_8^0 * O3[0]
695
696            let p0 = t1_0 + t2_0;
697            let m0 = t1_0 - t2_0;
698            let m0_rot = if sign < 0 {
699                crate::kernel::Complex::new(m0.im, -m0.re)
700            } else {
701                crate::kernel::Complex::new(-m0.im, m0.re)
702            };
703
704            // k=1: E[1], E[3], O1[1], O3[1]
705            let e1   = data[1];
706            let e1_q = data[3];   // E[1 + N/4] = E[3]
707            let o1_1 = data[5];
708            let o3_1 = data[7];
709
710            let t1_1 = o1_1 * twiddles[1];    // W_8^1 * O1[1]
711            let t2_1 = o3_1 * twiddles3[1];   // W_8^3 * O3[1]
712
713            let p1 = t1_1 + t2_1;
714            let m1 = t1_1 - t2_1;
715            let m1_rot = if sign < 0 {
716                crate::kernel::Complex::new(m1.im, -m1.re)
717            } else {
718                crate::kernel::Complex::new(-m1.im, m1.re)
719            };
720
721            // Write all 8 outputs (4 pairs from 2 butterfly iterations)
722            data[0] = e0   + p0;       // X[0]
723            data[2] = e0_q + m0_rot;   // X[0 + N/4] = X[2]
724            data[4] = e0   - p0;       // X[0 + N/2] = X[4]
725            data[6] = e0_q - m0_rot;   // X[0 + 3N/4] = X[6]
726
727            data[1] = e1   + p1;       // X[1]
728            data[3] = e1_q + m1_rot;   // X[1 + N/4] = X[3]
729            data[5] = e1   - p1;       // X[1 + N/2] = X[5]
730            data[7] = e1_q - m1_rot;   // X[1 + 3N/4] = X[7]
731        }
732    };
733    expanded
734}
735
736/// Generate a specialized 16-point split-radix twiddle codelet (fully unrolled).
737///
738/// N=16: N/2=8 even, N/4=4 odd-1, N/4=4 odd-3.
739/// Unrolls the L-shaped butterfly for k=0,1,2,3.
740#[allow(clippy::too_many_lines)]
741fn gen_split_radix_twiddle_16() -> TokenStream {
742    let expanded = quote! {
743        /// Split-radix twiddle codelet for N=16 (fully unrolled).
744        ///
745        /// Combines an 8-point even DFT with two 4-point odd DFTs using
746        /// the L-shaped butterfly. All 4 iterations fully unrolled.
747        ///
748        /// # Data Layout
749        /// - `data[0..8]`:   8-point even DFT result (E)
750        /// - `data[8..12]`:  4-point odd-1 DFT result (O1)
751        /// - `data[12..16]`: 4-point odd-3 DFT result (O3)
752        ///
753        /// # Arguments
754        /// * `data`     - Input/output slice of at least 16 complex values
755        /// * `twiddles` - `[W_16^0, W_16^1, W_16^2, W_16^3]` twiddle factors
756        /// * `twiddles3`- `[W_16^0, W_16^3, W_16^6, W_16^9]` twiddle factors
757        /// * `sign`     - Transform direction: -1 for forward, +1 for inverse
758        #[inline(always)]
759        pub fn codelet_split_radix_twiddle_16<T: crate::kernel::Float>(
760            data: &mut [crate::kernel::Complex<T>],
761            twiddles: &[crate::kernel::Complex<T>; 4],
762            twiddles3: &[crate::kernel::Complex<T>; 4],
763            sign: i32,
764        ) {
765            debug_assert!(data.len() >= 16);
766
767            // k=0: E[0], E[4], O1[0], O3[0]
768            let e0   = data[0];
769            let e0_q = data[4];
770            let o1_0 = data[8];
771            let o3_0 = data[12];
772
773            let t1_0 = o1_0 * twiddles[0];
774            let t2_0 = o3_0 * twiddles3[0];
775            let p0 = t1_0 + t2_0;
776            let m0 = t1_0 - t2_0;
777            let m0_rot = if sign < 0 {
778                crate::kernel::Complex::new(m0.im, -m0.re)
779            } else {
780                crate::kernel::Complex::new(-m0.im, m0.re)
781            };
782
783            // k=1: E[1], E[5], O1[1], O3[1]
784            let e1   = data[1];
785            let e1_q = data[5];
786            let o1_1 = data[9];
787            let o3_1 = data[13];
788
789            let t1_1 = o1_1 * twiddles[1];
790            let t2_1 = o3_1 * twiddles3[1];
791            let p1 = t1_1 + t2_1;
792            let m1 = t1_1 - t2_1;
793            let m1_rot = if sign < 0 {
794                crate::kernel::Complex::new(m1.im, -m1.re)
795            } else {
796                crate::kernel::Complex::new(-m1.im, m1.re)
797            };
798
799            // k=2: E[2], E[6], O1[2], O3[2]
800            let e2   = data[2];
801            let e2_q = data[6];
802            let o1_2 = data[10];
803            let o3_2 = data[14];
804
805            let t1_2 = o1_2 * twiddles[2];
806            let t2_2 = o3_2 * twiddles3[2];
807            let p2 = t1_2 + t2_2;
808            let m2 = t1_2 - t2_2;
809            let m2_rot = if sign < 0 {
810                crate::kernel::Complex::new(m2.im, -m2.re)
811            } else {
812                crate::kernel::Complex::new(-m2.im, m2.re)
813            };
814
815            // k=3: E[3], E[7], O1[3], O3[3]
816            let e3   = data[3];
817            let e3_q = data[7];
818            let o1_3 = data[11];
819            let o3_3 = data[15];
820
821            let t1_3 = o1_3 * twiddles[3];
822            let t2_3 = o3_3 * twiddles3[3];
823            let p3 = t1_3 + t2_3;
824            let m3 = t1_3 - t2_3;
825            let m3_rot = if sign < 0 {
826                crate::kernel::Complex::new(m3.im, -m3.re)
827            } else {
828                crate::kernel::Complex::new(-m3.im, m3.re)
829            };
830
831            // Write all 16 outputs
832            // Quarter 0: X[k]
833            data[0]  = e0 + p0;
834            data[1]  = e1 + p1;
835            data[2]  = e2 + p2;
836            data[3]  = e3 + p3;
837
838            // Quarter 1: X[k + N/4]
839            data[4]  = e0_q + m0_rot;
840            data[5]  = e1_q + m1_rot;
841            data[6]  = e2_q + m2_rot;
842            data[7]  = e3_q + m3_rot;
843
844            // Quarter 2: X[k + N/2]
845            data[8]  = e0 - p0;
846            data[9]  = e1 - p1;
847            data[10] = e2 - p2;
848            data[11] = e3 - p3;
849
850            // Quarter 3: X[k + 3N/4]
851            data[12] = e0_q - m0_rot;
852            data[13] = e1_q - m1_rot;
853            data[14] = e2_q - m2_rot;
854            data[15] = e3_q - m3_rot;
855        }
856    };
857    expanded
858}