oxifft_codegen_impl/gen_twiddle.rs
1//! Twiddle-factor codelet generation.
2//!
3//! Generates codelets that apply twiddle factors during multi-radix FFT computation.
4
5use proc_macro2::TokenStream;
6use quote::quote;
7use syn::LitInt;
8
9/// Generate a twiddle codelet for the given radix.
10///
11/// # Errors
12/// Returns a `syn::Error` when the input does not parse as a valid radix literal,
13/// or when the radix is not in the supported set {2, 4, 8, 16}.
14pub fn generate(input: TokenStream) -> Result<TokenStream, syn::Error> {
15 let radix: LitInt = syn::parse2(input)?;
16 let r: usize = radix.base10_parse().map_err(|_| {
17 syn::Error::new(
18 radix.span(),
19 "gen_twiddle_codelet: expected an integer radix literal",
20 )
21 })?;
22
23 match r {
24 2 => Ok(gen_twiddle_2()),
25 4 => Ok(gen_twiddle_4()),
26 8 => Ok(gen_twiddle_8()),
27 16 => Ok(gen_twiddle_16()),
28 _ => Err(syn::Error::new(
29 radix.span(),
30 format!("gen_twiddle_codelet: unsupported radix {r} (expected one of 2, 4, 8, 16)"),
31 )),
32 }
33}
34
35fn gen_twiddle_2() -> TokenStream {
36 quote! {
37 /// Radix-2 twiddle codelet.
38 ///
39 /// Applies a single twiddle factor and computes a 2-point butterfly.
40 #[inline(always)]
41 pub fn codelet_twiddle_2<T: crate::kernel::Float>(
42 x: &mut [crate::kernel::Complex<T>],
43 twiddle: crate::kernel::Complex<T>,
44 ) {
45 debug_assert!(x.len() >= 2);
46 let a = x[0];
47 let b = x[1] * twiddle;
48 x[0] = a + b;
49 x[1] = a - b;
50 }
51 }
52}
53
54fn gen_twiddle_4() -> TokenStream {
55 quote! {
56 /// Radix-4 twiddle codelet.
57 ///
58 /// Applies twiddle factors w1, w2, w3 to inputs x[1], x[2], x[3]
59 /// and computes a 4-point FFT.
60 #[inline(always)]
61 pub fn codelet_twiddle_4<T: crate::kernel::Float>(
62 x: &mut [crate::kernel::Complex<T>],
63 tw1: crate::kernel::Complex<T>,
64 tw2: crate::kernel::Complex<T>,
65 tw3: crate::kernel::Complex<T>,
66 sign: i32,
67 ) {
68 debug_assert!(x.len() >= 4);
69
70 let x0 = x[0];
71 let x1 = x[1] * tw1;
72 let x2 = x[2] * tw2;
73 let x3 = x[3] * tw3;
74
75 let t0 = x0 + x2;
76 let t1 = x0 - x2;
77 let t2 = x1 + x3;
78 let t3 = x1 - x3;
79
80 let t3_rot = if sign < 0 {
81 crate::kernel::Complex::new(t3.im, -t3.re)
82 } else {
83 crate::kernel::Complex::new(-t3.im, t3.re)
84 };
85
86 x[0] = t0 + t2;
87 x[1] = t1 + t3_rot;
88 x[2] = t0 - t2;
89 x[3] = t1 - t3_rot;
90 }
91 }
92}
93
94#[allow(clippy::too_many_lines)]
95fn gen_twiddle_16() -> TokenStream {
96 quote! {
97 /// Radix-16 twiddle codelet.
98 ///
99 /// Applies 15 twiddle factors to inputs x[1]..x[15] and computes a 16-point FFT
100 /// using a radix-2 DIT (decimation-in-time) butterfly structure.
101 ///
102 /// # Arguments
103 /// * `x` - Input/output slice of at least 16 complex values
104 /// * `twiddles` - Array of 15 precomputed twiddle factors for positions 1..=15
105 /// * `sign` - Transform direction: -1 for forward, +1 for inverse
106 #[inline(always)]
107 pub fn codelet_twiddle_16<T: crate::kernel::Float>(
108 x: &mut [crate::kernel::Complex<T>],
109 twiddles: &[crate::kernel::Complex<T>; 15],
110 sign: i32,
111 ) {
112 debug_assert!(x.len() >= 16);
113
114 // Step 1: Apply external twiddle factors to positions 1..=15
115 let x0 = x[0];
116 let x1 = x[1] * twiddles[0];
117 let x2 = x[2] * twiddles[1];
118 let x3 = x[3] * twiddles[2];
119 let x4 = x[4] * twiddles[3];
120 let x5 = x[5] * twiddles[4];
121 let x6 = x[6] * twiddles[5];
122 let x7 = x[7] * twiddles[6];
123 let x8 = x[8] * twiddles[7];
124 let x9 = x[9] * twiddles[8];
125 let x10 = x[10] * twiddles[9];
126 let x11 = x[11] * twiddles[10];
127 let x12 = x[12] * twiddles[11];
128 let x13 = x[13] * twiddles[12];
129 let x14 = x[14] * twiddles[13];
130 let x15 = x[15] * twiddles[14];
131
132 // Step 2: Compute 16-point DFT using radix-2 DIT.
133 // Place twiddle-applied values in bit-reversed order, then apply 4 DIT stages.
134 //
135 // Bit-reversal permutation for 16 (4-bit reversal):
136 // 0->0, 1->8, 2->4, 3->12, 4->2, 5->10, 6->6, 7->14,
137 // 8->1, 9->9, 10->5, 11->13, 12->3, 13->11, 14->7, 15->15
138 let mut a = [crate::kernel::Complex::<T>::zero(); 16];
139 a[0] = x0;
140 a[1] = x8;
141 a[2] = x4;
142 a[3] = x12;
143 a[4] = x2;
144 a[5] = x10;
145 a[6] = x6;
146 a[7] = x14;
147 a[8] = x1;
148 a[9] = x9;
149 a[10] = x5;
150 a[11] = x13;
151 a[12] = x3;
152 a[13] = x11;
153 a[14] = x7;
154 a[15] = x15;
155
156 // DIT Stage 1: 8 butterflies, span 1 (W2^0 = 1, no twiddle)
157 for i in (0..16usize).step_by(2) {
158 let t = a[i + 1];
159 a[i + 1] = a[i] - t;
160 a[i] = a[i] + t;
161 }
162
163 // DIT Stage 2: 4 groups of 2 butterflies, span 2
164 // W4^0 = 1, W4^1 = -i (forward) or +i (inverse)
165 for group in (0..16usize).step_by(4) {
166 // k=0: W4^0 = 1
167 let t = a[group + 2];
168 a[group + 2] = a[group] - t;
169 a[group] = a[group] + t;
170
171 // k=1: W4^1
172 let t = a[group + 3];
173 let t_tw = if sign < 0 {
174 crate::kernel::Complex::new(t.im, -t.re)
175 } else {
176 crate::kernel::Complex::new(-t.im, t.re)
177 };
178 a[group + 3] = a[group + 1] - t_tw;
179 a[group + 1] = a[group + 1] + t_tw;
180 }
181
182 // DIT Stage 3: 2 groups of 4 butterflies, span 4
183 // W8^k for k in 0..4
184 // c2 = 1/sqrt(2) ≈ 0.7071067811865476
185 let c2 = T::from_f64(0.707_106_781_186_547_6_f64);
186 for group in (0..16usize).step_by(8) {
187 // k=0: W8^0 = 1
188 let t = a[group + 4];
189 a[group + 4] = a[group] - t;
190 a[group] = a[group] + t;
191
192 // k=1: W8^1 = (1-i)/sqrt(2) forward, (1+i)/sqrt(2) inverse
193 let t = a[group + 5];
194 let t_tw = if sign < 0 {
195 crate::kernel::Complex::new((t.re + t.im) * c2, (t.im - t.re) * c2)
196 } else {
197 crate::kernel::Complex::new((t.re - t.im) * c2, (t.im + t.re) * c2)
198 };
199 a[group + 5] = a[group + 1] - t_tw;
200 a[group + 1] = a[group + 1] + t_tw;
201
202 // k=2: W8^2 = -i forward, +i inverse
203 let t = a[group + 6];
204 let t_tw = if sign < 0 {
205 crate::kernel::Complex::new(t.im, -t.re)
206 } else {
207 crate::kernel::Complex::new(-t.im, t.re)
208 };
209 a[group + 6] = a[group + 2] - t_tw;
210 a[group + 2] = a[group + 2] + t_tw;
211
212 // k=3: W8^3 = (-1-i)/sqrt(2) forward, (-1+i)/sqrt(2) inverse
213 let t = a[group + 7];
214 let t_tw = if sign < 0 {
215 crate::kernel::Complex::new((-t.re + t.im) * c2, (-t.im - t.re) * c2)
216 } else {
217 crate::kernel::Complex::new((-t.re - t.im) * c2, (-t.im + t.re) * c2)
218 };
219 a[group + 7] = a[group + 3] - t_tw;
220 a[group + 3] = a[group + 3] + t_tw;
221 }
222
223 // DIT Stage 4: 1 group of 8 butterflies, span 8
224 // W16^k for k in 0..8
225 // Constants: cos(π/8), sin(π/8), 1/sqrt(2), cos(3π/8)=sin(π/8), sin(3π/8)=cos(π/8)
226 let c1 = T::from_f64(0.923_879_532_511_286_7_f64); // cos(π/8)
227 let s1 = T::from_f64(0.382_683_432_365_089_8_f64); // sin(π/8)
228
229 // k=0: W16^0 = 1
230 let t = a[8];
231 a[8] = a[0] - t;
232 a[0] = a[0] + t;
233
234 // k=1: W16^1 = cos(π/8) - i*sin(π/8) forward, cos(π/8) + i*sin(π/8) inverse
235 let t = a[9];
236 let t_tw = if sign < 0 {
237 crate::kernel::Complex::new(t.re * c1 + t.im * s1, t.im * c1 - t.re * s1)
238 } else {
239 crate::kernel::Complex::new(t.re * c1 - t.im * s1, t.im * c1 + t.re * s1)
240 };
241 a[9] = a[1] - t_tw;
242 a[1] = a[1] + t_tw;
243
244 // k=2: W16^2 = (1-i)/sqrt(2) forward, (1+i)/sqrt(2) inverse
245 let t = a[10];
246 let t_tw = if sign < 0 {
247 crate::kernel::Complex::new((t.re + t.im) * c2, (t.im - t.re) * c2)
248 } else {
249 crate::kernel::Complex::new((t.re - t.im) * c2, (t.im + t.re) * c2)
250 };
251 a[10] = a[2] - t_tw;
252 a[2] = a[2] + t_tw;
253
254 // k=3: W16^3 = cos(3π/8) - i*sin(3π/8) forward
255 // = sin(π/8) - i*cos(π/8) forward (since cos(3π/8)=sin(π/8))
256 let c3 = s1; // cos(3π/8) = sin(π/8)
257 let s3 = c1; // sin(3π/8) = cos(π/8)
258 let t = a[11];
259 let t_tw = if sign < 0 {
260 crate::kernel::Complex::new(t.re * c3 + t.im * s3, t.im * c3 - t.re * s3)
261 } else {
262 crate::kernel::Complex::new(t.re * c3 - t.im * s3, t.im * c3 + t.re * s3)
263 };
264 a[11] = a[3] - t_tw;
265 a[3] = a[3] + t_tw;
266
267 // k=4: W16^4 = -i forward, +i inverse
268 let t = a[12];
269 let t_tw = if sign < 0 {
270 crate::kernel::Complex::new(t.im, -t.re)
271 } else {
272 crate::kernel::Complex::new(-t.im, t.re)
273 };
274 a[12] = a[4] - t_tw;
275 a[4] = a[4] + t_tw;
276
277 // k=5: W16^5 = cos(5π/8) - i*sin(5π/8) = -sin(π/8) - i*cos(π/8) forward
278 let t = a[13];
279 let t_tw = if sign < 0 {
280 crate::kernel::Complex::new(-t.re * s1 + t.im * c1, -t.im * s1 - t.re * c1)
281 } else {
282 crate::kernel::Complex::new(-t.re * s1 - t.im * c1, -t.im * s1 + t.re * c1)
283 };
284 a[13] = a[5] - t_tw;
285 a[5] = a[5] + t_tw;
286
287 // k=6: W16^6 = (-1-i)/sqrt(2) forward, (-1+i)/sqrt(2) inverse
288 let t = a[14];
289 let t_tw = if sign < 0 {
290 crate::kernel::Complex::new((-t.re + t.im) * c2, (-t.im - t.re) * c2)
291 } else {
292 crate::kernel::Complex::new((-t.re - t.im) * c2, (-t.im + t.re) * c2)
293 };
294 a[14] = a[6] - t_tw;
295 a[6] = a[6] + t_tw;
296
297 // k=7: W16^7 = cos(7π/8) - i*sin(7π/8) = -cos(π/8) - i*sin(π/8) forward
298 let t = a[15];
299 let t_tw = if sign < 0 {
300 crate::kernel::Complex::new(-t.re * c1 + t.im * s1, -t.im * c1 - t.re * s1)
301 } else {
302 crate::kernel::Complex::new(-t.re * c1 - t.im * s1, -t.im * c1 + t.re * s1)
303 };
304 a[15] = a[7] - t_tw;
305 a[7] = a[7] + t_tw;
306
307 // Write back in natural order (DIT produces natural order after bit-reversal input)
308 for i in 0..16usize {
309 x[i] = a[i];
310 }
311 }
312 }
313}
314
315#[allow(clippy::too_many_lines)]
316fn gen_twiddle_8() -> TokenStream {
317 quote! {
318 /// Radix-8 twiddle codelet.
319 ///
320 /// Applies 7 external twiddle factors to inputs x[1]..x[7], then computes
321 /// an 8-point FFT using a radix-2 DIT butterfly structure.
322 ///
323 /// # Arguments
324 /// * `x` - Input/output slice of at least 8 complex values
325 /// * `twiddles` - Array of 7 precomputed twiddle factors for positions 1..=7
326 /// * `sign` - Transform direction: -1 for forward, +1 for inverse
327 #[inline(always)]
328 pub fn codelet_twiddle_8<T: crate::kernel::Float>(
329 x: &mut [crate::kernel::Complex<T>],
330 twiddles: &[crate::kernel::Complex<T>; 7],
331 sign: i32,
332 ) {
333 debug_assert!(x.len() >= 8);
334
335 // Step 1: Apply external twiddle factors to positions 1..=7
336 let x0 = x[0];
337 let x1 = x[1] * twiddles[0];
338 let x2 = x[2] * twiddles[1];
339 let x3 = x[3] * twiddles[2];
340 let x4 = x[4] * twiddles[3];
341 let x5 = x[5] * twiddles[4];
342 let x6 = x[6] * twiddles[5];
343 let x7 = x[7] * twiddles[6];
344
345 // Step 2: Compute 8-point DFT using radix-2 DIT.
346 // Place twiddle-applied values in bit-reversed order, then apply 3 DIT stages.
347 // Bit-reversal for 8 (3-bit): 0→0, 1→4, 2→2, 3→6, 4→1, 5→5, 6→3, 7→7
348 let mut a = [crate::kernel::Complex::<T>::zero(); 8];
349 a[0] = x0; a[1] = x4;
350 a[2] = x2; a[3] = x6;
351 a[4] = x1; a[5] = x5;
352 a[6] = x3; a[7] = x7;
353
354 // DIT Stage 1: 4 butterflies, span 1 (W2^0 = 1)
355 for i in (0..8usize).step_by(2) {
356 let t = a[i + 1];
357 a[i + 1] = a[i] - t;
358 a[i] = a[i] + t;
359 }
360
361 // DIT Stage 2: 2 groups of 2 butterflies, span 2
362 // W4^0 = 1, W4^1 = -i (forward) or +i (inverse)
363 for group in (0..8usize).step_by(4) {
364 // k=0: W4^0 = 1
365 let t = a[group + 2];
366 a[group + 2] = a[group] - t;
367 a[group] = a[group] + t;
368
369 // k=1: W4^1
370 let t = a[group + 3];
371 let t_tw = if sign < 0 {
372 crate::kernel::Complex::new(t.im, -t.re)
373 } else {
374 crate::kernel::Complex::new(-t.im, t.re)
375 };
376 a[group + 3] = a[group + 1] - t_tw;
377 a[group + 1] = a[group + 1] + t_tw;
378 }
379
380 // DIT Stage 3: 1 group of 4 butterflies, span 4
381 // W8^k for k in 0..4. c2 = 1/sqrt(2) ≈ 0.7071067811865476
382 let c2 = T::from_f64(0.707_106_781_186_547_6_f64);
383
384 // k=0: W8^0 = 1
385 let t = a[4];
386 a[4] = a[0] - t;
387 a[0] = a[0] + t;
388
389 // k=1: W8^1 = (1-i)/sqrt(2) forward, (1+i)/sqrt(2) inverse
390 let t = a[5];
391 let t_tw = if sign < 0 {
392 crate::kernel::Complex::new((t.re + t.im) * c2, (t.im - t.re) * c2)
393 } else {
394 crate::kernel::Complex::new((t.re - t.im) * c2, (t.im + t.re) * c2)
395 };
396 a[5] = a[1] - t_tw;
397 a[1] = a[1] + t_tw;
398
399 // k=2: W8^2 = -i (forward) or +i (inverse)
400 let t = a[6];
401 let t_tw = if sign < 0 {
402 crate::kernel::Complex::new(t.im, -t.re)
403 } else {
404 crate::kernel::Complex::new(-t.im, t.re)
405 };
406 a[6] = a[2] - t_tw;
407 a[2] = a[2] + t_tw;
408
409 // k=3: W8^3 = (-1-i)/sqrt(2) forward, (-1+i)/sqrt(2) inverse
410 let t = a[7];
411 let t_tw = if sign < 0 {
412 crate::kernel::Complex::new((-t.re + t.im) * c2, (-t.im - t.re) * c2)
413 } else {
414 crate::kernel::Complex::new((-t.re - t.im) * c2, (-t.im + t.re) * c2)
415 };
416 a[7] = a[3] - t_tw;
417 a[3] = a[3] + t_tw;
418
419 // Write back in natural order
420 for i in 0..8usize {
421 x[i] = a[i];
422 }
423 }
424
425 /// Radix-8 twiddle codelet with inline twiddle computation.
426 ///
427 /// This version computes twiddles from angle step, useful when twiddles
428 /// are not precomputed.
429 #[inline(always)]
430 pub fn codelet_twiddle_8_inline<T: crate::kernel::Float>(
431 x: &mut [crate::kernel::Complex<T>],
432 angle_step: T,
433 sign: i32,
434 ) {
435 debug_assert!(x.len() >= 8);
436
437 // Compute twiddles inline via fully-qualified Float trait methods to avoid ambiguity
438 let tw1 = crate::kernel::Complex::new(
439 <T as crate::kernel::Float>::cos(angle_step),
440 <T as crate::kernel::Float>::sin(angle_step),
441 );
442 let tw2 = crate::kernel::Complex::new(
443 <T as crate::kernel::Float>::cos(angle_step * T::from_usize(2)),
444 <T as crate::kernel::Float>::sin(angle_step * T::from_usize(2)),
445 );
446 let tw3 = crate::kernel::Complex::new(
447 <T as crate::kernel::Float>::cos(angle_step * T::from_usize(3)),
448 <T as crate::kernel::Float>::sin(angle_step * T::from_usize(3)),
449 );
450 let tw4 = crate::kernel::Complex::new(
451 <T as crate::kernel::Float>::cos(angle_step * T::from_usize(4)),
452 <T as crate::kernel::Float>::sin(angle_step * T::from_usize(4)),
453 );
454 let tw5 = crate::kernel::Complex::new(
455 <T as crate::kernel::Float>::cos(angle_step * T::from_usize(5)),
456 <T as crate::kernel::Float>::sin(angle_step * T::from_usize(5)),
457 );
458 let tw6 = crate::kernel::Complex::new(
459 <T as crate::kernel::Float>::cos(angle_step * T::from_usize(6)),
460 <T as crate::kernel::Float>::sin(angle_step * T::from_usize(6)),
461 );
462 let tw7 = crate::kernel::Complex::new(
463 <T as crate::kernel::Float>::cos(angle_step * T::from_usize(7)),
464 <T as crate::kernel::Float>::sin(angle_step * T::from_usize(7)),
465 );
466
467 let twiddles = [tw1, tw2, tw3, tw4, tw5, tw6, tw7];
468 codelet_twiddle_8(x, &twiddles, sign);
469 }
470 }
471}
472
473/// Generate a split-radix twiddle codelet for the given size.
474///
475/// If no size is specified (empty input), generates the generic runtime-parameterized
476/// split-radix twiddle codelet. If a size is given (8 or 16), generates a specialized
477/// unrolled version for that size.
478///
479/// # Errors
480/// Returns a `syn::Error` when the input does not parse as a valid size literal,
481/// or when the size is not in the supported set {8, 16} (or empty for the generic variant).
482pub fn generate_split_radix(input: TokenStream) -> Result<TokenStream, syn::Error> {
483 if input.is_empty() {
484 return Ok(gen_split_radix_twiddle());
485 }
486 let size: LitInt = syn::parse2(input)?;
487 let n: usize = size.base10_parse().map_err(|_| {
488 syn::Error::new(
489 size.span(),
490 "gen_split_radix_twiddle_codelet: expected an integer size literal",
491 )
492 })?;
493 match n {
494 8 => Ok(gen_split_radix_twiddle_8()),
495 16 => Ok(gen_split_radix_twiddle_16()),
496 _ => Err(syn::Error::new(
497 size.span(),
498 format!("gen_split_radix_twiddle_codelet: unsupported size {n} (use 8 or 16, or empty for generic)"),
499 )),
500 }
501}
502
503/// Generate the generic split-radix twiddle codelet (L-shaped butterfly).
504///
505/// The split-radix FFT decomposes an N-point DFT into:
506/// - One N/2-point DFT of even-indexed elements
507/// - Two N/4-point DFTs of odd-indexed elements (with twiddle factors `W_N^k` and `W_N^{3k`})
508///
509/// This codelet performs the combining step (L-shaped butterfly):
510/// For k = 0..N/4-1:
511/// t1 = `W_N^k` * O1[k], t2 = `W_N^{3k`} * O3[k]
512/// p = t1 + t2, m = t1 - t2
513/// X[k] = E[k] + p
514/// X[k+N/4] = E[k+N/4] - j*(t1-t2) (forward)
515/// X[k+N/2] = E[k] - p
516/// X[k+3N/4] = E[k+N/4] + j*(t1-t2) (forward)
517fn gen_split_radix_twiddle() -> TokenStream {
518 let expanded = quote! {
519 /// Split-radix twiddle codelet (L-shaped butterfly).
520 ///
521 /// Combines the results of an N/2-point even DFT with two N/4-point odd DFTs
522 /// using split-radix decomposition. This reduces the total number of real
523 /// multiplications compared to standard radix-2: approximately 4N log2(N) - 6N + 8
524 /// vs. 5N log2(N) for radix-2.
525 ///
526 /// # Data Layout
527 ///
528 /// On input, `data[0..n]` is organized as:
529 /// - `data[0..n/2]`: N/2-point DFT of even-indexed elements (E)
530 /// - `data[n/2..3n/4]`: N/4-point DFT of odd-1 elements (O1, indices 1,5,9,...)
531 /// - `data[3n/4..n]`: N/4-point DFT of odd-3 elements (O3, indices 3,7,11,...)
532 ///
533 /// On output, `data[0..n]` contains the combined N-point DFT result.
534 ///
535 /// # Arguments
536 /// * `data` - Input/output slice of at least `n` complex values
537 /// * `n` - Transform size (must be divisible by 4 and >= 4)
538 /// * `twiddles` - Twiddle factors W_N^k for k = 0..n/4-1
539 /// * `twiddles3`- Twiddle factors W_N^{3k} for k = 0..n/4-1
540 /// * `sign` - Transform direction: -1 for forward, +1 for inverse
541 #[inline]
542 pub fn codelet_split_radix_twiddle<T: crate::kernel::Float>(
543 data: &mut [crate::kernel::Complex<T>],
544 n: usize,
545 twiddles: &[crate::kernel::Complex<T>],
546 twiddles3: &[crate::kernel::Complex<T>],
547 sign: i32,
548 ) {
549 debug_assert!(n >= 4 && n % 4 == 0, "n must be >= 4 and divisible by 4");
550 debug_assert!(data.len() >= n, "data slice too short for split-radix");
551
552 let n2 = n >> 1; // N/2
553 let n4 = n >> 2; // N/4
554
555 debug_assert!(twiddles.len() >= n4, "twiddles slice too short");
556 debug_assert!(twiddles3.len() >= n4, "twiddles3 slice too short");
557
558 for k in 0..n4 {
559 // Read the three sub-DFT results
560 let e_k = data[k]; // E[k] from even DFT
561 let e_k_q = data[k + n4]; // E[k + N/4] from even DFT
562 let o1_k = data[n2 + k]; // O1[k] from first odd DFT
563 let o3_k = data[n2 + n4 + k]; // O3[k] from second odd DFT
564
565 // Apply twiddle factors to odd sub-DFT results
566 let t1 = o1_k * twiddles[k]; // W_N^k * O1[k]
567 let t2 = o3_k * twiddles3[k]; // W_N^{3k} * O3[k]
568
569 // Split-radix butterfly computation
570 let p = t1 + t2; // Sum of twiddled odd results
571 let m = t1 - t2; // Difference of twiddled odd results
572
573 // Rotate difference by -j (forward) or +j (inverse)
574 // -j * (a + bi) = (b, -a); +j * (a + bi) = (-b, a)
575 let m_rot = if sign < 0 {
576 crate::kernel::Complex::new(m.im, -m.re)
577 } else {
578 crate::kernel::Complex::new(-m.im, m.re)
579 };
580
581 // Write combined results to all four quarters
582 data[k] = e_k + p; // X[k]
583 data[k + n4] = e_k_q + m_rot; // X[k + N/4]
584 data[k + n2] = e_k - p; // X[k + N/2]
585 data[k + n2 + n4] = e_k_q - m_rot; // X[k + 3N/4]
586 }
587 }
588
589 /// Split-radix twiddle codelet with inline twiddle computation.
590 ///
591 /// Computes W_N^k and W_N^{3k} twiddle factors on the fly from the
592 /// base angle step, useful when twiddle tables are not precomputed.
593 ///
594 /// # Arguments
595 /// * `data` - Input/output slice of at least `n` complex values
596 /// * `n` - Transform size (must be divisible by 4 and >= 4)
597 /// * `sign` - Transform direction: -1 for forward, +1 for inverse
598 #[inline]
599 pub fn codelet_split_radix_twiddle_inline<T: crate::kernel::Float>(
600 data: &mut [crate::kernel::Complex<T>],
601 n: usize,
602 sign: i32,
603 ) {
604 debug_assert!(n >= 4 && n % 4 == 0, "n must be >= 4 and divisible by 4");
605 debug_assert!(data.len() >= n, "data slice too short for split-radix");
606
607 let n2 = n >> 1;
608 let n4 = n >> 2;
609
610 // Base angle: -2π/N (forward) or +2π/N (inverse)
611 let base_angle = if sign < 0 {
612 -2.0_f64 * core::f64::consts::PI / (n as f64)
613 } else {
614 2.0_f64 * core::f64::consts::PI / (n as f64)
615 };
616
617 for k in 0..n4 {
618 let angle_k = base_angle * (k as f64);
619 let angle_3k = base_angle * (3 * k) as f64;
620
621 let tw = crate::kernel::Complex::new(
622 T::from_f64(angle_k.cos()),
623 T::from_f64(angle_k.sin()),
624 );
625 let tw3 = crate::kernel::Complex::new(
626 T::from_f64(angle_3k.cos()),
627 T::from_f64(angle_3k.sin()),
628 );
629
630 let e_k = data[k];
631 let e_k_q = data[k + n4];
632 let o1_k = data[n2 + k];
633 let o3_k = data[n2 + n4 + k];
634
635 let t1 = o1_k * tw;
636 let t2 = o3_k * tw3;
637
638 let p = t1 + t2;
639 let m = t1 - t2;
640
641 let m_rot = if sign < 0 {
642 crate::kernel::Complex::new(m.im, -m.re)
643 } else {
644 crate::kernel::Complex::new(-m.im, m.re)
645 };
646
647 data[k] = e_k + p;
648 data[k + n4] = e_k_q + m_rot;
649 data[k + n2] = e_k - p;
650 data[k + n2 + n4] = e_k_q - m_rot;
651 }
652 }
653 };
654 expanded
655}
656
657/// Generate a specialized 8-point split-radix twiddle codelet (fully unrolled).
658///
659/// N=8: N/2=4 even, N/4=2 odd-1, N/4=2 odd-3.
660/// Unrolls the L-shaped butterfly for k=0,1.
661fn gen_split_radix_twiddle_8() -> TokenStream {
662 let expanded = quote! {
663 /// Split-radix twiddle codelet for N=8 (fully unrolled).
664 ///
665 /// Combines a 4-point even DFT with two 2-point odd DFTs using
666 /// the L-shaped butterfly. All 2 iterations fully unrolled.
667 ///
668 /// # Data Layout
669 /// - `data[0..4]`: 4-point even DFT result (E)
670 /// - `data[4..6]`: 2-point odd-1 DFT result (O1)
671 /// - `data[6..8]`: 2-point odd-3 DFT result (O3)
672 ///
673 /// # Arguments
674 /// * `data` - Input/output slice of at least 8 complex values
675 /// * `twiddles` - `[W_8^0, W_8^1]` twiddle factors
676 /// * `twiddles3`- `[W_8^0, W_8^3]` twiddle factors
677 /// * `sign` - Transform direction: -1 for forward, +1 for inverse
678 #[inline(always)]
679 pub fn codelet_split_radix_twiddle_8<T: crate::kernel::Float>(
680 data: &mut [crate::kernel::Complex<T>],
681 twiddles: &[crate::kernel::Complex<T>; 2],
682 twiddles3: &[crate::kernel::Complex<T>; 2],
683 sign: i32,
684 ) {
685 debug_assert!(data.len() >= 8);
686
687 // k=0: E[0], E[2], O1[0], O3[0]
688 let e0 = data[0];
689 let e0_q = data[2]; // E[0 + N/4] = E[2]
690 let o1_0 = data[4];
691 let o3_0 = data[6];
692
693 let t1_0 = o1_0 * twiddles[0]; // W_8^0 * O1[0]
694 let t2_0 = o3_0 * twiddles3[0]; // W_8^0 * O3[0]
695
696 let p0 = t1_0 + t2_0;
697 let m0 = t1_0 - t2_0;
698 let m0_rot = if sign < 0 {
699 crate::kernel::Complex::new(m0.im, -m0.re)
700 } else {
701 crate::kernel::Complex::new(-m0.im, m0.re)
702 };
703
704 // k=1: E[1], E[3], O1[1], O3[1]
705 let e1 = data[1];
706 let e1_q = data[3]; // E[1 + N/4] = E[3]
707 let o1_1 = data[5];
708 let o3_1 = data[7];
709
710 let t1_1 = o1_1 * twiddles[1]; // W_8^1 * O1[1]
711 let t2_1 = o3_1 * twiddles3[1]; // W_8^3 * O3[1]
712
713 let p1 = t1_1 + t2_1;
714 let m1 = t1_1 - t2_1;
715 let m1_rot = if sign < 0 {
716 crate::kernel::Complex::new(m1.im, -m1.re)
717 } else {
718 crate::kernel::Complex::new(-m1.im, m1.re)
719 };
720
721 // Write all 8 outputs (4 pairs from 2 butterfly iterations)
722 data[0] = e0 + p0; // X[0]
723 data[2] = e0_q + m0_rot; // X[0 + N/4] = X[2]
724 data[4] = e0 - p0; // X[0 + N/2] = X[4]
725 data[6] = e0_q - m0_rot; // X[0 + 3N/4] = X[6]
726
727 data[1] = e1 + p1; // X[1]
728 data[3] = e1_q + m1_rot; // X[1 + N/4] = X[3]
729 data[5] = e1 - p1; // X[1 + N/2] = X[5]
730 data[7] = e1_q - m1_rot; // X[1 + 3N/4] = X[7]
731 }
732 };
733 expanded
734}
735
736/// Generate a specialized 16-point split-radix twiddle codelet (fully unrolled).
737///
738/// N=16: N/2=8 even, N/4=4 odd-1, N/4=4 odd-3.
739/// Unrolls the L-shaped butterfly for k=0,1,2,3.
740#[allow(clippy::too_many_lines)]
741fn gen_split_radix_twiddle_16() -> TokenStream {
742 let expanded = quote! {
743 /// Split-radix twiddle codelet for N=16 (fully unrolled).
744 ///
745 /// Combines an 8-point even DFT with two 4-point odd DFTs using
746 /// the L-shaped butterfly. All 4 iterations fully unrolled.
747 ///
748 /// # Data Layout
749 /// - `data[0..8]`: 8-point even DFT result (E)
750 /// - `data[8..12]`: 4-point odd-1 DFT result (O1)
751 /// - `data[12..16]`: 4-point odd-3 DFT result (O3)
752 ///
753 /// # Arguments
754 /// * `data` - Input/output slice of at least 16 complex values
755 /// * `twiddles` - `[W_16^0, W_16^1, W_16^2, W_16^3]` twiddle factors
756 /// * `twiddles3`- `[W_16^0, W_16^3, W_16^6, W_16^9]` twiddle factors
757 /// * `sign` - Transform direction: -1 for forward, +1 for inverse
758 #[inline(always)]
759 pub fn codelet_split_radix_twiddle_16<T: crate::kernel::Float>(
760 data: &mut [crate::kernel::Complex<T>],
761 twiddles: &[crate::kernel::Complex<T>; 4],
762 twiddles3: &[crate::kernel::Complex<T>; 4],
763 sign: i32,
764 ) {
765 debug_assert!(data.len() >= 16);
766
767 // k=0: E[0], E[4], O1[0], O3[0]
768 let e0 = data[0];
769 let e0_q = data[4];
770 let o1_0 = data[8];
771 let o3_0 = data[12];
772
773 let t1_0 = o1_0 * twiddles[0];
774 let t2_0 = o3_0 * twiddles3[0];
775 let p0 = t1_0 + t2_0;
776 let m0 = t1_0 - t2_0;
777 let m0_rot = if sign < 0 {
778 crate::kernel::Complex::new(m0.im, -m0.re)
779 } else {
780 crate::kernel::Complex::new(-m0.im, m0.re)
781 };
782
783 // k=1: E[1], E[5], O1[1], O3[1]
784 let e1 = data[1];
785 let e1_q = data[5];
786 let o1_1 = data[9];
787 let o3_1 = data[13];
788
789 let t1_1 = o1_1 * twiddles[1];
790 let t2_1 = o3_1 * twiddles3[1];
791 let p1 = t1_1 + t2_1;
792 let m1 = t1_1 - t2_1;
793 let m1_rot = if sign < 0 {
794 crate::kernel::Complex::new(m1.im, -m1.re)
795 } else {
796 crate::kernel::Complex::new(-m1.im, m1.re)
797 };
798
799 // k=2: E[2], E[6], O1[2], O3[2]
800 let e2 = data[2];
801 let e2_q = data[6];
802 let o1_2 = data[10];
803 let o3_2 = data[14];
804
805 let t1_2 = o1_2 * twiddles[2];
806 let t2_2 = o3_2 * twiddles3[2];
807 let p2 = t1_2 + t2_2;
808 let m2 = t1_2 - t2_2;
809 let m2_rot = if sign < 0 {
810 crate::kernel::Complex::new(m2.im, -m2.re)
811 } else {
812 crate::kernel::Complex::new(-m2.im, m2.re)
813 };
814
815 // k=3: E[3], E[7], O1[3], O3[3]
816 let e3 = data[3];
817 let e3_q = data[7];
818 let o1_3 = data[11];
819 let o3_3 = data[15];
820
821 let t1_3 = o1_3 * twiddles[3];
822 let t2_3 = o3_3 * twiddles3[3];
823 let p3 = t1_3 + t2_3;
824 let m3 = t1_3 - t2_3;
825 let m3_rot = if sign < 0 {
826 crate::kernel::Complex::new(m3.im, -m3.re)
827 } else {
828 crate::kernel::Complex::new(-m3.im, m3.re)
829 };
830
831 // Write all 16 outputs
832 // Quarter 0: X[k]
833 data[0] = e0 + p0;
834 data[1] = e1 + p1;
835 data[2] = e2 + p2;
836 data[3] = e3 + p3;
837
838 // Quarter 1: X[k + N/4]
839 data[4] = e0_q + m0_rot;
840 data[5] = e1_q + m1_rot;
841 data[6] = e2_q + m2_rot;
842 data[7] = e3_q + m3_rot;
843
844 // Quarter 2: X[k + N/2]
845 data[8] = e0 - p0;
846 data[9] = e1 - p1;
847 data[10] = e2 - p2;
848 data[11] = e3 - p3;
849
850 // Quarter 3: X[k + 3N/4]
851 data[12] = e0_q - m0_rot;
852 data[13] = e1_q - m1_rot;
853 data[14] = e2_q - m2_rot;
854 data[15] = e3_q - m3_rot;
855 }
856 };
857 expanded
858}