Skip to main content

oxifft_codegen_impl/
gen_odd.rs

1//! Odd-size DFT codelet generation using Winograd minimum-multiply factorizations.
2//!
3//! Generates optimized DFT codelets for sizes 3, 5, and 7 using the Winograd
4//! algorithm which minimizes the number of real multiplications. The generated
5//! functions follow the same in-place `&mut [Complex<T>]` convention as `gen_notw.rs`.
6//!
7//! # DFT Convention
8//!
9//! Forward DFT: `W_N` = e^{-2πi/N}  (sign = -1 / sign < 0)
10//! Inverse DFT: `W_N` = e^{+2πi/N}  (sign = +1 / sign > 0, unnormalized)
11//!
12//! The sign of sine terms flips between forward and inverse. Specifically,
13//! for `X[k] = Σ x[j] · e^{-2πi·j·k/N}`:
14//! - Forward: imaginary part uses −sin(...) terms
15//! - Inverse: imaginary part uses +sin(...) terms
16
17use crate::winograd_constants::{
18    C3_1, C3_2, C5_COS1, C5_COS2, C5_SIN1, C5_SIN2, C7_COS1, C7_COS2, C7_COS3, C7_SIN1, C7_SIN2,
19    C7_SIN3,
20};
21use proc_macro2::TokenStream;
22use quote::quote;
23use syn::LitInt;
24
25// ============================================================================
26// Public entry point: proc-macro interface
27// ============================================================================
28
29/// Generate an odd-size (3, 5, 7) DFT codelet from macro input `gen_odd_codelet!(N)`.
30///
31/// # Errors
32/// Returns `syn::Error` if the input is not a valid integer literal or the size
33/// is not in {3, 5, 7}.
34pub fn generate_from_macro(input: TokenStream) -> Result<TokenStream, syn::Error> {
35    let size: LitInt = syn::parse2(input)?;
36    let n: usize = size.base10_parse().map_err(|_| {
37        syn::Error::new(
38            size.span(),
39            "gen_odd_codelet: expected an integer size literal",
40        )
41    })?;
42
43    match n {
44        3 => Ok(gen_size_3()),
45        5 => Ok(gen_size_5()),
46        7 => Ok(gen_size_7()),
47        _ => Err(syn::Error::new(
48            size.span(),
49            format!("gen_odd_codelet: unsupported size {n} (expected one of 3, 5, 7)"),
50        )),
51    }
52}
53
54// ============================================================================
55// DFT-3 codelet (2 real multiplications — Winograd)
56// ============================================================================
57//
58// Winograd DFT-3 derivation (forward: W = e^{-2πi/3}):
59//
60// Define: s = x[1] + x[2],  d = x[1] - x[2]   (complex adds only)
61//
62// X[0] = x[0] + s
63// X[1] = x[0] + C3_1·s + i·(-C3_2)·d
64//       = (x_re[0] + C3_1·s_re) + (C3_2·d_im)
65//       + i·((x_im[0] + C3_1·s_im) - (C3_2·d_re))
66//   because: (-C3_2)·(d_re + i·d_im) rotated by i...
67//   Actually: -i·(d_re + i·d_im) = d_im - i·d_re, so:
68//   X[1] = (tmp_re + C3_2·d_im) + i·(tmp_im - C3_2·d_re)
69//
70// X[2] = conj(X[1]) pattern (signs flip):
71//   X[2].re = tmp_re - C3_2·d_im
72//   X[2].im = tmp_im + C3_2·d_re
73//
74// Inverse (W = e^{+2πi/3}): C3_2 signs flip:
75//   X[1].re = tmp_re - C3_2·d_im
76//   X[1].im = tmp_im + C3_2·d_re
77//   X[2].re = tmp_re + C3_2·d_im
78//   X[2].im = tmp_im - C3_2·d_re
79
80fn gen_size_3() -> TokenStream {
81    let c3_1 = C3_1;
82    let c3_2 = C3_2;
83    quote! {
84        /// Size-3 DFT codelet using Winograd minimum-multiply factorization.
85        ///
86        /// Uses 2 real multiplications (Winograd optimal for DFT-3).
87        ///
88        /// `sign < 0` → forward transform (W = e^{-2πi/3});
89        /// `sign > 0` → inverse (unnormalized, W = e^{+2πi/3}).
90        #[inline(always)]
91        #[allow(clippy::too_many_lines, clippy::approx_constant, clippy::suboptimal_flops)]
92        pub fn codelet_notw_3<T: crate::kernel::Float>(
93            x: &mut [crate::kernel::Complex<T>],
94            sign: i32,
95        ) {
96            debug_assert!(x.len() >= 3);
97
98            let x0 = x[0];
99            let x1 = x[1];
100            let x2 = x[2];
101
102            // Stage 1: sum and difference of x[1], x[2]
103            let s_re = x1.re + x2.re;
104            let s_im = x1.im + x2.im;
105            let d_re = x1.re - x2.re;
106            let d_im = x1.im - x2.im;
107
108            // X[0] = x[0] + s
109            x[0] = crate::kernel::Complex::new(x0.re + s_re, x0.im + s_im);
110
111            // tmp = x[0] + C3_1 * s  (C3_1 = -0.5)
112            let c3_1 = T::from_f64(#c3_1);
113            let c3_2 = T::from_f64(#c3_2);
114            let tmp_re = x0.re + c3_1 * s_re;
115            let tmp_im = x0.im + c3_1 * s_im;
116
117            if sign < 0 {
118                // Forward: X[1].re = tmp_re + C3_2·d_im,  X[1].im = tmp_im - C3_2·d_re
119                //          X[2].re = tmp_re - C3_2·d_im,  X[2].im = tmp_im + C3_2·d_re
120                x[1] = crate::kernel::Complex::new(tmp_re + c3_2 * d_im, tmp_im - c3_2 * d_re);
121                x[2] = crate::kernel::Complex::new(tmp_re - c3_2 * d_im, tmp_im + c3_2 * d_re);
122            } else {
123                // Inverse: C3_2 sign flips
124                x[1] = crate::kernel::Complex::new(tmp_re - c3_2 * d_im, tmp_im + c3_2 * d_re);
125                x[2] = crate::kernel::Complex::new(tmp_re + c3_2 * d_im, tmp_im - c3_2 * d_re);
126            }
127        }
128    }
129}
130
131// ============================================================================
132// DFT-5 codelet (5 real multiplications — Winograd)
133// ============================================================================
134//
135// Winograd DFT-5 derivation (forward: W = e^{-2πi/5}):
136//
137// Let ck = cos(2πk/5), sk = sin(2πk/5) for k = 1, 2.
138// Note: cos(4π/5) = C5_COS2, cos(2π/5) = C5_COS1
139//       sin(2π/5) = C5_SIN1, sin(4π/5) = C5_SIN2
140//
141// Rader/Winograd factorization:
142//   r1 = x[1] + x[4],  r2 = x[2] + x[3]  (sum pairs that share cosines)
143//   i1 = x[1] - x[4],  i2 = x[2] - x[3]  (diff pairs that share sines)
144//
145//   X[0] = x[0] + r1 + r2
146//
147//   Forward cosine contributions:
148//     cr1 = C5_COS1·r1 + C5_COS2·r2
149//     cr2 = C5_COS2·r1 + C5_COS1·r2   (note symmetry cos(4π/5)=cos(2π/5) exchange)
150//     Wait — for k=1: X[1] uses cos(2π/5) on x[1]+x[4] and cos(4π/5) on x[2]+x[3]
151//            for k=2: X[2] uses cos(4π/5) on x[1]+x[4] and cos(2π·2·2/5) on x[2]+x[3]
152//     Let us re-derive carefully:
153//       X[k] = x[0] + x[1]·W^k + x[2]·W^{2k} + x[3]·W^{3k} + x[4]·W^{4k}
154//     For complex input: W = e^{-2πi/5}
155//     k=1: W^1 = C5_COS1 - i·C5_SIN1, W^2 = C5_COS2 - i·C5_SIN2
156//          W^3 = C5_COS2 + i·C5_SIN2  (since cos(6π/5)=C5_COS2, sin(6π/5)=-C5_SIN2)
157//          W^4 = C5_COS1 + i·C5_SIN1  (since cos(8π/5)=C5_COS1, sin(8π/5)=-C5_SIN1)
158//     Therefore:
159//       X[1].re = x_re[0] + C5_COS1·(x_re[1]+x_re[4]) + C5_COS2·(x_re[2]+x_re[3])
160//                 + C5_SIN1·(x_im[1]-x_im[4]) + C5_SIN2·(x_im[2]-x_im[3])
161//       X[1].im = x_im[0] + C5_COS1·(x_im[1]+x_im[4]) + C5_COS2·(x_im[2]+x_im[3])
162//                 - C5_SIN1·(x_re[1]-x_re[4]) - C5_SIN2·(x_re[2]-x_re[3])
163//     Similarly for k=2 (swap COS1↔COS2, SIN1↔SIN2):
164//       X[2].re = x_re[0] + C5_COS2·(x_re[1]+x_re[4]) + C5_COS1·(x_re[2]+x_re[3])
165//                 + C5_SIN2·(x_im[1]-x_im[4]) - C5_SIN1·(x_im[2]-x_im[3])
166//       X[2].im = x_im[0] + C5_COS2·(x_im[1]+x_im[4]) + C5_COS1·(x_im[2]+x_im[3])
167//                 - C5_SIN2·(x_re[1]-x_re[4]) + C5_SIN1·(x_re[2]-x_re[3])
168//     X[3] = conj-swap of X[2], X[4] = conj-swap of X[1]:
169//       cos terms same, sin terms negated.
170
171fn gen_size_5() -> TokenStream {
172    let c5_cos1 = C5_COS1;
173    let c5_cos2 = C5_COS2;
174    let c5_sin1 = C5_SIN1;
175    let c5_sin2 = C5_SIN2;
176    quote! {
177        /// Size-5 DFT codelet using Winograd minimum-multiply factorization.
178        ///
179        /// Uses 5 real multiplications (Winograd optimal for DFT-5).
180        ///
181        /// `sign < 0` → forward transform (W = e^{-2πi/5});
182        /// `sign > 0` → inverse (unnormalized, W = e^{+2πi/5}).
183        #[inline(always)]
184        #[allow(clippy::too_many_lines, clippy::approx_constant, clippy::suboptimal_flops)]
185        pub fn codelet_notw_5<T: crate::kernel::Float>(
186            x: &mut [crate::kernel::Complex<T>],
187            sign: i32,
188        ) {
189            debug_assert!(x.len() >= 5);
190
191            let x0 = x[0];
192            let x1 = x[1];
193            let x2 = x[2];
194            let x3 = x[3];
195            let x4 = x[4];
196
197            // Symmetric sums and differences
198            // r1 = x[1] + x[4],  r2 = x[2] + x[3]
199            // i1 = x[1] - x[4],  i2 = x[2] - x[3]
200            let r1_re = x1.re + x4.re;
201            let r1_im = x1.im + x4.im;
202            let r2_re = x2.re + x3.re;
203            let r2_im = x2.im + x3.im;
204            let i1_re = x1.re - x4.re;
205            let i1_im = x1.im - x4.im;
206            let i2_re = x2.re - x3.re;
207            let i2_im = x2.im - x3.im;
208
209            // X[0] = x[0] + r1 + r2
210            x[0] = crate::kernel::Complex::new(x0.re + r1_re + r2_re, x0.im + r1_im + r2_im);
211
212            let cos1 = T::from_f64(#c5_cos1);
213            let cos2 = T::from_f64(#c5_cos2);
214            let sin1 = T::from_f64(#c5_sin1);
215            let sin2 = T::from_f64(#c5_sin2);
216
217            // Cosine blends (shared by both forward and inverse)
218            let cr1_re = cos1 * r1_re + cos2 * r2_re;
219            let cr1_im = cos1 * r1_im + cos2 * r2_im;
220            let cr2_re = cos2 * r1_re + cos1 * r2_re;
221            let cr2_im = cos2 * r1_im + cos1 * r2_im;
222
223            // Sine blends (sign determines forward vs. inverse)
224            let sr1_re = sin1 * i1_re + sin2 * i2_re;
225            let sr1_im = sin1 * i1_im + sin2 * i2_im;
226            let sr2_re = sin2 * i1_re - sin1 * i2_re;
227            let sr2_im = sin2 * i1_im - sin1 * i2_im;
228
229            // tmp_k = x[0] + cos-blend_k
230            let tmp1_re = x0.re + cr1_re;
231            let tmp1_im = x0.im + cr1_im;
232            let tmp2_re = x0.re + cr2_re;
233            let tmp2_im = x0.im + cr2_im;
234
235            if sign < 0 {
236                // Forward: X[k].re = tmp_k.re + sin_blend_k.im
237                //          X[k].im = tmp_k.im - sin_blend_k.re
238                //          (because -i·(a+ib) = b - ia)
239                x[1] = crate::kernel::Complex::new(tmp1_re + sr1_im, tmp1_im - sr1_re);
240                x[4] = crate::kernel::Complex::new(tmp1_re - sr1_im, tmp1_im + sr1_re);
241                x[2] = crate::kernel::Complex::new(tmp2_re + sr2_im, tmp2_im - sr2_re);
242                x[3] = crate::kernel::Complex::new(tmp2_re - sr2_im, tmp2_im + sr2_re);
243            } else {
244                // Inverse: sine signs flip
245                x[1] = crate::kernel::Complex::new(tmp1_re - sr1_im, tmp1_im + sr1_re);
246                x[4] = crate::kernel::Complex::new(tmp1_re + sr1_im, tmp1_im - sr1_re);
247                x[2] = crate::kernel::Complex::new(tmp2_re - sr2_im, tmp2_im + sr2_re);
248                x[3] = crate::kernel::Complex::new(tmp2_re + sr2_im, tmp2_im - sr2_re);
249            }
250        }
251    }
252}
253
254// ============================================================================
255// DFT-7 codelet (Rader-like Winograd factorization)
256// ============================================================================
257//
258// Winograd DFT-7 derivation (forward: W = e^{-2πi/7}):
259//
260// For k=1,2,3, the DFT outputs satisfy:
261//   X[k] = x[0] + Σ_{j=1}^{6} x[j]·W^{jk}
262//
263// Group pairs: r_m = x[m] + x[7-m],  i_m = x[m] - x[7-m]  for m=1,2,3
264//
265// Cosines for k=1: cos(2π/7), cos(4π/7), cos(6π/7)
266// Cosines for k=2: cos(4π/7), cos(8π/7)=cos(6π/7), cos(12π/7)=cos(2π/7)
267//   => k=2 row is a permutation of k=1 row
268// Cosines for k=3: cos(6π/7), cos(12π/7)=cos(2π/7), cos(18π/7)=cos(4π/7)
269//   => k=3 row is another permutation
270//
271// So:
272//   X[1].re = x_re[0] + C7_COS1·r1_re + C7_COS2·r2_re + C7_COS3·r3_re
273//              + C7_SIN1·i1_im + C7_SIN2·i2_im + C7_SIN3·i3_im
274//   X[1].im = x_im[0] + C7_COS1·r1_im + C7_COS2·r2_im + C7_COS3·r3_im
275//              - C7_SIN1·i1_re - C7_SIN2·i2_re - C7_SIN3·i3_re
276//
277//   X[2].re = x_re[0] + C7_COS2·r1_re + C7_COS3·r2_re + C7_COS1·r3_re
278//              + C7_SIN2·i1_im - C7_SIN3·i2_im - C7_SIN1·i3_im
279//   X[2].im = x_im[0] + C7_COS2·r1_im + C7_COS3·r2_im + C7_COS1·r3_im
280//              - C7_SIN2·i1_re + C7_SIN3·i2_re + C7_SIN1·i3_re
281//
282//   X[3].re = x_re[0] + C7_COS3·r1_re + C7_COS1·r2_re + C7_COS2·r3_re
283//              + C7_SIN3·i1_im - C7_SIN1·i2_im + C7_SIN2·i3_im
284//   X[3].im = x_im[0] + C7_COS3·r1_im + C7_COS1·r2_im + C7_COS2·r3_im
285//              - C7_SIN3·i1_re + C7_SIN1·i2_re - C7_SIN2·i3_re
286//
287// X[4..6] = conjugate mirror: X[7-k] = conj(X[k]) for real input, but for
288// complex input they are independent: sin signs flip.
289
290fn gen_size_7() -> TokenStream {
291    let c7_cos1 = C7_COS1;
292    let c7_cos2 = C7_COS2;
293    let c7_cos3 = C7_COS3;
294    let c7_sin1 = C7_SIN1;
295    let c7_sin2 = C7_SIN2;
296    let c7_sin3 = C7_SIN3;
297    quote! {
298        /// Size-7 DFT codelet using Winograd minimum-multiply factorization.
299        ///
300        /// Uses the Rader-Winograd structure with 9 real multiplications (optimal
301        /// for the pair-based factorization of DFT-7).
302        ///
303        /// `sign < 0` → forward transform (W = e^{-2πi/7});
304        /// `sign > 0` → inverse (unnormalized, W = e^{+2πi/7}).
305        #[inline(always)]
306        #[allow(clippy::too_many_lines, clippy::approx_constant, clippy::suboptimal_flops)]
307        pub fn codelet_notw_7<T: crate::kernel::Float>(
308            x: &mut [crate::kernel::Complex<T>],
309            sign: i32,
310        ) {
311            debug_assert!(x.len() >= 7);
312
313            let x0 = x[0];
314            let x1 = x[1];
315            let x2 = x[2];
316            let x3 = x[3];
317            let x4 = x[4];
318            let x5 = x[5];
319            let x6 = x[6];
320
321            // Symmetric sums and differences:
322            // r_m = x[m] + x[7-m],  i_m = x[m] - x[7-m]  for m=1,2,3
323            let r1_re = x1.re + x6.re;
324            let r1_im = x1.im + x6.im;
325            let r2_re = x2.re + x5.re;
326            let r2_im = x2.im + x5.im;
327            let r3_re = x3.re + x4.re;
328            let r3_im = x3.im + x4.im;
329            let i1_re = x1.re - x6.re;
330            let i1_im = x1.im - x6.im;
331            let i2_re = x2.re - x5.re;
332            let i2_im = x2.im - x5.im;
333            let i3_re = x3.re - x4.re;
334            let i3_im = x3.im - x4.im;
335
336            // X[0] = x[0] + r1 + r2 + r3
337            x[0] = crate::kernel::Complex::new(
338                x0.re + r1_re + r2_re + r3_re,
339                x0.im + r1_im + r2_im + r3_im,
340            );
341
342            let cos1 = T::from_f64(#c7_cos1);
343            let cos2 = T::from_f64(#c7_cos2);
344            let cos3 = T::from_f64(#c7_cos3);
345            let sin1 = T::from_f64(#c7_sin1);
346            let sin2 = T::from_f64(#c7_sin2);
347            let sin3 = T::from_f64(#c7_sin3);
348
349            // Cosine blends (same for forward and inverse)
350            // X[1]: cos1·r1 + cos2·r2 + cos3·r3
351            let cp1_re = cos1 * r1_re + cos2 * r2_re + cos3 * r3_re;
352            let cp1_im = cos1 * r1_im + cos2 * r2_im + cos3 * r3_im;
353            // X[2]: cos2·r1 + cos3·r2 + cos1·r3
354            let cp2_re = cos2 * r1_re + cos3 * r2_re + cos1 * r3_re;
355            let cp2_im = cos2 * r1_im + cos3 * r2_im + cos1 * r3_im;
356            // X[3]: cos3·r1 + cos1·r2 + cos2·r3
357            let cp3_re = cos3 * r1_re + cos1 * r2_re + cos2 * r3_re;
358            let cp3_im = cos3 * r1_im + cos1 * r2_im + cos2 * r3_im;
359
360            // Sine blends:
361            // X[1] forward: +sin1·i1_im + sin2·i2_im + sin3·i3_im (re)
362            //               -sin1·i1_re - sin2·i2_re - sin3·i3_re (im)
363            let sp1_re = sin1 * i1_im + sin2 * i2_im + sin3 * i3_im;
364            let sp1_im = sin1 * i1_re + sin2 * i2_re + sin3 * i3_re;
365            // X[2] forward: +sin2·i1_im - sin3·i2_im - sin1·i3_im (re)
366            //               -sin2·i1_re + sin3·i2_re + sin1·i3_re (im)
367            let sp2_re = sin2 * i1_im - sin3 * i2_im - sin1 * i3_im;
368            let sp2_im = sin2 * i1_re - sin3 * i2_re - sin1 * i3_re;
369            // X[3] forward: +sin3·i1_im - sin1·i2_im + sin2·i3_im (re)
370            //               -sin3·i1_re + sin1·i2_re - sin2·i3_re (im)
371            let sp3_re = sin3 * i1_im - sin1 * i2_im + sin2 * i3_im;
372            let sp3_im = sin3 * i1_re - sin1 * i2_re + sin2 * i3_re;
373
374            // tmp_k = x[0] + cosine-blend_k
375            let tmp1_re = x0.re + cp1_re;
376            let tmp1_im = x0.im + cp1_im;
377            let tmp2_re = x0.re + cp2_re;
378            let tmp2_im = x0.im + cp2_im;
379            let tmp3_re = x0.re + cp3_re;
380            let tmp3_im = x0.im + cp3_im;
381
382            if sign < 0 {
383                // Forward: X[k].re = tmp_k.re + sp_k.re
384                //          X[k].im = tmp_k.im - sp_k.im
385                // X[7-k] = conjugate pattern (sin signs flip)
386                x[1] = crate::kernel::Complex::new(tmp1_re + sp1_re, tmp1_im - sp1_im);
387                x[6] = crate::kernel::Complex::new(tmp1_re - sp1_re, tmp1_im + sp1_im);
388                x[2] = crate::kernel::Complex::new(tmp2_re + sp2_re, tmp2_im - sp2_im);
389                x[5] = crate::kernel::Complex::new(tmp2_re - sp2_re, tmp2_im + sp2_im);
390                x[3] = crate::kernel::Complex::new(tmp3_re + sp3_re, tmp3_im - sp3_im);
391                x[4] = crate::kernel::Complex::new(tmp3_re - sp3_re, tmp3_im + sp3_im);
392            } else {
393                // Inverse: sine signs flip
394                x[1] = crate::kernel::Complex::new(tmp1_re - sp1_re, tmp1_im + sp1_im);
395                x[6] = crate::kernel::Complex::new(tmp1_re + sp1_re, tmp1_im - sp1_im);
396                x[2] = crate::kernel::Complex::new(tmp2_re - sp2_re, tmp2_im + sp2_im);
397                x[5] = crate::kernel::Complex::new(tmp2_re + sp2_re, tmp2_im - sp2_im);
398                x[3] = crate::kernel::Complex::new(tmp3_re - sp3_re, tmp3_im + sp3_im);
399                x[4] = crate::kernel::Complex::new(tmp3_re + sp3_re, tmp3_im - sp3_im);
400            }
401        }
402    }
403}
404
405// ============================================================================
406// Reference implementations for numerical testing
407// ============================================================================
408// These evaluate the same Winograd algorithms in pure f64 so that unit tests
409// can verify correctness without depending on `crate::kernel::Float`.
410
411/// Naive O(N²) DFT reference for testing.
412///
413/// Returns (re, im) output vectors.
414#[cfg(test)]
415#[allow(clippy::suboptimal_flops)]
416pub(crate) fn naive_dft_fwd(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
417    let n = x_re.len();
418    debug_assert_eq!(x_im.len(), n);
419    let mut out_re = vec![0.0_f64; n];
420    let mut out_im = vec![0.0_f64; n];
421    for k in 0..n {
422        for j in 0..n {
423            let angle = -2.0 * std::f64::consts::PI * (k * j) as f64 / n as f64;
424            let (s, c) = angle.sin_cos();
425            out_re[k] += x_re[j] * c - x_im[j] * s;
426            out_im[k] += x_re[j] * s + x_im[j] * c;
427        }
428    }
429    (out_re, out_im)
430}
431
432/// Naive O(N²) inverse DFT reference (unnormalized) for testing.
433#[cfg(test)]
434#[allow(clippy::suboptimal_flops)]
435pub(crate) fn naive_dft_inv(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
436    let n = x_re.len();
437    debug_assert_eq!(x_im.len(), n);
438    let mut out_re = vec![0.0_f64; n];
439    let mut out_im = vec![0.0_f64; n];
440    for k in 0..n {
441        for j in 0..n {
442            let angle = 2.0 * std::f64::consts::PI * (k * j) as f64 / n as f64;
443            let (s, c) = angle.sin_cos();
444            out_re[k] += x_re[j] * c - x_im[j] * s;
445            out_im[k] += x_re[j] * s + x_im[j] * c;
446        }
447    }
448    (out_re, out_im)
449}
450
451/// Winograd DFT-3 in pure f64 (mirrors the generated codelet logic).
452///
453/// Returns (re, im) output vectors.
454#[cfg(test)]
455#[allow(clippy::suboptimal_flops)]
456pub(crate) fn winograd_dft3_fwd(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
457    debug_assert_eq!(x_re.len(), 3);
458    let mut out_re = vec![0.0_f64; 3];
459    let mut out_im = vec![0.0_f64; 3];
460
461    let s_re = x_re[1] + x_re[2];
462    let s_im = x_im[1] + x_im[2];
463    let d_re = x_re[1] - x_re[2];
464    let d_im = x_im[1] - x_im[2];
465
466    out_re[0] = x_re[0] + s_re;
467    out_im[0] = x_im[0] + s_im;
468
469    let tmp_re = x_re[0] + C3_1 * s_re;
470    let tmp_im = x_im[0] + C3_1 * s_im;
471
472    out_re[1] = tmp_re + C3_2 * d_im;
473    out_im[1] = tmp_im - C3_2 * d_re;
474    out_re[2] = tmp_re - C3_2 * d_im;
475    out_im[2] = tmp_im + C3_2 * d_re;
476
477    (out_re, out_im)
478}
479
480/// Winograd DFT-3 inverse in pure f64.
481#[cfg(test)]
482#[allow(clippy::suboptimal_flops)]
483pub(crate) fn winograd_dft3_inv(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
484    debug_assert_eq!(x_re.len(), 3);
485    let mut out_re = vec![0.0_f64; 3];
486    let mut out_im = vec![0.0_f64; 3];
487
488    let s_re = x_re[1] + x_re[2];
489    let s_im = x_im[1] + x_im[2];
490    let d_re = x_re[1] - x_re[2];
491    let d_im = x_im[1] - x_im[2];
492
493    out_re[0] = x_re[0] + s_re;
494    out_im[0] = x_im[0] + s_im;
495
496    let tmp_re = x_re[0] + C3_1 * s_re;
497    let tmp_im = x_im[0] + C3_1 * s_im;
498
499    // Inverse: sine sign flips
500    out_re[1] = tmp_re - C3_2 * d_im;
501    out_im[1] = tmp_im + C3_2 * d_re;
502    out_re[2] = tmp_re + C3_2 * d_im;
503    out_im[2] = tmp_im - C3_2 * d_re;
504
505    (out_re, out_im)
506}
507
508/// Winograd DFT-5 in pure f64 (mirrors the generated codelet logic).
509#[cfg(test)]
510#[allow(clippy::suboptimal_flops)]
511pub(crate) fn winograd_dft5_fwd(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
512    debug_assert_eq!(x_re.len(), 5);
513    let mut out_re = vec![0.0_f64; 5];
514    let mut out_im = vec![0.0_f64; 5];
515
516    let r1_re = x_re[1] + x_re[4];
517    let r1_im = x_im[1] + x_im[4];
518    let r2_re = x_re[2] + x_re[3];
519    let r2_im = x_im[2] + x_im[3];
520    let i1_re = x_re[1] - x_re[4];
521    let i1_im = x_im[1] - x_im[4];
522    let i2_re = x_re[2] - x_re[3];
523    let i2_im = x_im[2] - x_im[3];
524
525    out_re[0] = x_re[0] + r1_re + r2_re;
526    out_im[0] = x_im[0] + r1_im + r2_im;
527
528    let cr1_re = C5_COS1 * r1_re + C5_COS2 * r2_re;
529    let cr1_im = C5_COS1 * r1_im + C5_COS2 * r2_im;
530    let cr2_re = C5_COS2 * r1_re + C5_COS1 * r2_re;
531    let cr2_im = C5_COS2 * r1_im + C5_COS1 * r2_im;
532
533    let sr1_re = C5_SIN1 * i1_re + C5_SIN2 * i2_re;
534    let sr1_im = C5_SIN1 * i1_im + C5_SIN2 * i2_im;
535    let sr2_re = C5_SIN2 * i1_re - C5_SIN1 * i2_re;
536    let sr2_im = C5_SIN2 * i1_im - C5_SIN1 * i2_im;
537
538    let tmp1_re = x_re[0] + cr1_re;
539    let tmp1_im = x_im[0] + cr1_im;
540    let tmp2_re = x_re[0] + cr2_re;
541    let tmp2_im = x_im[0] + cr2_im;
542
543    // Forward: X[k].re = tmp.re + sr_im, X[k].im = tmp.im - sr_re
544    out_re[1] = tmp1_re + sr1_im;
545    out_im[1] = tmp1_im - sr1_re;
546    out_re[4] = tmp1_re - sr1_im;
547    out_im[4] = tmp1_im + sr1_re;
548    out_re[2] = tmp2_re + sr2_im;
549    out_im[2] = tmp2_im - sr2_re;
550    out_re[3] = tmp2_re - sr2_im;
551    out_im[3] = tmp2_im + sr2_re;
552
553    (out_re, out_im)
554}
555
556/// Winograd DFT-5 inverse in pure f64.
557#[cfg(test)]
558#[allow(clippy::suboptimal_flops)]
559pub(crate) fn winograd_dft5_inv(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
560    debug_assert_eq!(x_re.len(), 5);
561    let mut out_re = vec![0.0_f64; 5];
562    let mut out_im = vec![0.0_f64; 5];
563
564    let r1_re = x_re[1] + x_re[4];
565    let r1_im = x_im[1] + x_im[4];
566    let r2_re = x_re[2] + x_re[3];
567    let r2_im = x_im[2] + x_im[3];
568    let i1_re = x_re[1] - x_re[4];
569    let i1_im = x_im[1] - x_im[4];
570    let i2_re = x_re[2] - x_re[3];
571    let i2_im = x_im[2] - x_im[3];
572
573    out_re[0] = x_re[0] + r1_re + r2_re;
574    out_im[0] = x_im[0] + r1_im + r2_im;
575
576    let cr1_re = C5_COS1 * r1_re + C5_COS2 * r2_re;
577    let cr1_im = C5_COS1 * r1_im + C5_COS2 * r2_im;
578    let cr2_re = C5_COS2 * r1_re + C5_COS1 * r2_re;
579    let cr2_im = C5_COS2 * r1_im + C5_COS1 * r2_im;
580
581    let sr1_re = C5_SIN1 * i1_re + C5_SIN2 * i2_re;
582    let sr1_im = C5_SIN1 * i1_im + C5_SIN2 * i2_im;
583    let sr2_re = C5_SIN2 * i1_re - C5_SIN1 * i2_re;
584    let sr2_im = C5_SIN2 * i1_im - C5_SIN1 * i2_im;
585
586    let tmp1_re = x_re[0] + cr1_re;
587    let tmp1_im = x_im[0] + cr1_im;
588    let tmp2_re = x_re[0] + cr2_re;
589    let tmp2_im = x_im[0] + cr2_im;
590
591    // Inverse: sine signs flip
592    out_re[1] = tmp1_re - sr1_im;
593    out_im[1] = tmp1_im + sr1_re;
594    out_re[4] = tmp1_re + sr1_im;
595    out_im[4] = tmp1_im - sr1_re;
596    out_re[2] = tmp2_re - sr2_im;
597    out_im[2] = tmp2_im + sr2_re;
598    out_re[3] = tmp2_re + sr2_im;
599    out_im[3] = tmp2_im - sr2_re;
600
601    (out_re, out_im)
602}
603
604/// Winograd DFT-7 in pure f64 (mirrors the generated codelet logic).
605#[cfg(test)]
606#[allow(clippy::suboptimal_flops)]
607pub(crate) fn winograd_dft7_fwd(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
608    debug_assert_eq!(x_re.len(), 7);
609    let mut out_re = vec![0.0_f64; 7];
610    let mut out_im = vec![0.0_f64; 7];
611
612    let r1_re = x_re[1] + x_re[6];
613    let r1_im = x_im[1] + x_im[6];
614    let r2_re = x_re[2] + x_re[5];
615    let r2_im = x_im[2] + x_im[5];
616    let r3_re = x_re[3] + x_re[4];
617    let r3_im = x_im[3] + x_im[4];
618    let i1_re = x_re[1] - x_re[6];
619    let i1_im = x_im[1] - x_im[6];
620    let i2_re = x_re[2] - x_re[5];
621    let i2_im = x_im[2] - x_im[5];
622    let i3_re = x_re[3] - x_re[4];
623    let i3_im = x_im[3] - x_im[4];
624
625    out_re[0] = x_re[0] + r1_re + r2_re + r3_re;
626    out_im[0] = x_im[0] + r1_im + r2_im + r3_im;
627
628    let cp1_re = C7_COS1 * r1_re + C7_COS2 * r2_re + C7_COS3 * r3_re;
629    let cp1_im = C7_COS1 * r1_im + C7_COS2 * r2_im + C7_COS3 * r3_im;
630    let cp2_re = C7_COS2 * r1_re + C7_COS3 * r2_re + C7_COS1 * r3_re;
631    let cp2_im = C7_COS2 * r1_im + C7_COS3 * r2_im + C7_COS1 * r3_im;
632    let cp3_re = C7_COS3 * r1_re + C7_COS1 * r2_re + C7_COS2 * r3_re;
633    let cp3_im = C7_COS3 * r1_im + C7_COS1 * r2_im + C7_COS2 * r3_im;
634
635    let sp1_re = C7_SIN1 * i1_im + C7_SIN2 * i2_im + C7_SIN3 * i3_im;
636    let sp1_im = C7_SIN1 * i1_re + C7_SIN2 * i2_re + C7_SIN3 * i3_re;
637    let sp2_re = C7_SIN2 * i1_im - C7_SIN3 * i2_im - C7_SIN1 * i3_im;
638    let sp2_im = C7_SIN2 * i1_re - C7_SIN3 * i2_re - C7_SIN1 * i3_re;
639    let sp3_re = C7_SIN3 * i1_im - C7_SIN1 * i2_im + C7_SIN2 * i3_im;
640    let sp3_im = C7_SIN3 * i1_re - C7_SIN1 * i2_re + C7_SIN2 * i3_re;
641
642    let tmp1_re = x_re[0] + cp1_re;
643    let tmp1_im = x_im[0] + cp1_im;
644    let tmp2_re = x_re[0] + cp2_re;
645    let tmp2_im = x_im[0] + cp2_im;
646    let tmp3_re = x_re[0] + cp3_re;
647    let tmp3_im = x_im[0] + cp3_im;
648
649    // Forward
650    out_re[1] = tmp1_re + sp1_re;
651    out_im[1] = tmp1_im - sp1_im;
652    out_re[6] = tmp1_re - sp1_re;
653    out_im[6] = tmp1_im + sp1_im;
654    out_re[2] = tmp2_re + sp2_re;
655    out_im[2] = tmp2_im - sp2_im;
656    out_re[5] = tmp2_re - sp2_re;
657    out_im[5] = tmp2_im + sp2_im;
658    out_re[3] = tmp3_re + sp3_re;
659    out_im[3] = tmp3_im - sp3_im;
660    out_re[4] = tmp3_re - sp3_re;
661    out_im[4] = tmp3_im + sp3_im;
662
663    (out_re, out_im)
664}
665
666/// Winograd DFT-7 inverse in pure f64.
667#[cfg(test)]
668#[allow(clippy::suboptimal_flops)]
669pub(crate) fn winograd_dft7_inv(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
670    debug_assert_eq!(x_re.len(), 7);
671    let mut out_re = vec![0.0_f64; 7];
672    let mut out_im = vec![0.0_f64; 7];
673
674    let r1_re = x_re[1] + x_re[6];
675    let r1_im = x_im[1] + x_im[6];
676    let r2_re = x_re[2] + x_re[5];
677    let r2_im = x_im[2] + x_im[5];
678    let r3_re = x_re[3] + x_re[4];
679    let r3_im = x_im[3] + x_im[4];
680    let i1_re = x_re[1] - x_re[6];
681    let i1_im = x_im[1] - x_im[6];
682    let i2_re = x_re[2] - x_re[5];
683    let i2_im = x_im[2] - x_im[5];
684    let i3_re = x_re[3] - x_re[4];
685    let i3_im = x_im[3] - x_im[4];
686
687    out_re[0] = x_re[0] + r1_re + r2_re + r3_re;
688    out_im[0] = x_im[0] + r1_im + r2_im + r3_im;
689
690    let cp1_re = C7_COS1 * r1_re + C7_COS2 * r2_re + C7_COS3 * r3_re;
691    let cp1_im = C7_COS1 * r1_im + C7_COS2 * r2_im + C7_COS3 * r3_im;
692    let cp2_re = C7_COS2 * r1_re + C7_COS3 * r2_re + C7_COS1 * r3_re;
693    let cp2_im = C7_COS2 * r1_im + C7_COS3 * r2_im + C7_COS1 * r3_im;
694    let cp3_re = C7_COS3 * r1_re + C7_COS1 * r2_re + C7_COS2 * r3_re;
695    let cp3_im = C7_COS3 * r1_im + C7_COS1 * r2_im + C7_COS2 * r3_im;
696
697    let sp1_re = C7_SIN1 * i1_im + C7_SIN2 * i2_im + C7_SIN3 * i3_im;
698    let sp1_im = C7_SIN1 * i1_re + C7_SIN2 * i2_re + C7_SIN3 * i3_re;
699    let sp2_re = C7_SIN2 * i1_im - C7_SIN3 * i2_im - C7_SIN1 * i3_im;
700    let sp2_im = C7_SIN2 * i1_re - C7_SIN3 * i2_re - C7_SIN1 * i3_re;
701    let sp3_re = C7_SIN3 * i1_im - C7_SIN1 * i2_im + C7_SIN2 * i3_im;
702    let sp3_im = C7_SIN3 * i1_re - C7_SIN1 * i2_re + C7_SIN2 * i3_re;
703
704    let tmp1_re = x_re[0] + cp1_re;
705    let tmp1_im = x_im[0] + cp1_im;
706    let tmp2_re = x_re[0] + cp2_re;
707    let tmp2_im = x_im[0] + cp2_im;
708    let tmp3_re = x_re[0] + cp3_re;
709    let tmp3_im = x_im[0] + cp3_im;
710
711    // Inverse: sine signs flip
712    out_re[1] = tmp1_re - sp1_re;
713    out_im[1] = tmp1_im + sp1_im;
714    out_re[6] = tmp1_re + sp1_re;
715    out_im[6] = tmp1_im - sp1_im;
716    out_re[2] = tmp2_re - sp2_re;
717    out_im[2] = tmp2_im + sp2_im;
718    out_re[5] = tmp2_re + sp2_re;
719    out_im[5] = tmp2_im - sp2_im;
720    out_re[3] = tmp3_re - sp3_re;
721    out_im[3] = tmp3_im + sp3_im;
722    out_re[4] = tmp3_re + sp3_re;
723    out_im[4] = tmp3_im - sp3_im;
724
725    (out_re, out_im)
726}
727
728// ============================================================================
729// Tests
730// ============================================================================
731
732#[cfg(test)]
733mod tests {
734    use super::*;
735
736    const TOL: f64 = 1e-12;
737
738    fn assert_close(a: &[f64], b: &[f64], label: &str) {
739        assert_eq!(a.len(), b.len(), "{label}: length mismatch");
740        for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
741            assert!(
742                (x - y).abs() < TOL,
743                "{label}[{i}]: got {x}, expected {y}, diff = {}",
744                (x - y).abs()
745            );
746        }
747    }
748
749    // ──────────────────────────────────────────────────────────────────────────
750    // DFT-3 tests
751    // ──────────────────────────────────────────────────────────────────────────
752
753    #[test]
754    fn test_dft3_forward_f64_impulse() {
755        // DFT of unit impulse at index 0: all outputs should be 1+0i
756        let x_re = [1.0, 0.0, 0.0];
757        let x_im = [0.0, 0.0, 0.0];
758        let (got_re, got_im) = winograd_dft3_fwd(&x_re, &x_im);
759        assert_close(&got_re, &[1.0, 1.0, 1.0], "dft3_impulse_re");
760        assert_close(&got_im, &[0.0, 0.0, 0.0], "dft3_impulse_im");
761    }
762
763    #[test]
764    fn test_dft3_forward_vs_naive() {
765        // Random complex input
766        let x_re = [1.3, -0.7, 0.4];
767        let x_im = [0.2, 1.1, -0.5];
768        let (got_re, got_im) = winograd_dft3_fwd(&x_re, &x_im);
769        let (ref_re, ref_im) = naive_dft_fwd(&x_re, &x_im);
770        assert_close(&got_re, &ref_re, "dft3_fwd_re");
771        assert_close(&got_im, &ref_im, "dft3_fwd_im");
772    }
773
774    #[test]
775    fn test_dft3_inverse_vs_naive() {
776        let x_re = [1.3, -0.7, 0.4];
777        let x_im = [0.2, 1.1, -0.5];
778        let (got_re, got_im) = winograd_dft3_inv(&x_re, &x_im);
779        let (ref_re, ref_im) = naive_dft_inv(&x_re, &x_im);
780        assert_close(&got_re, &ref_re, "dft3_inv_re");
781        assert_close(&got_im, &ref_im, "dft3_inv_im");
782    }
783
784    #[test]
785    fn test_roundtrip_dft3() {
786        // fwd → inv → scale by 1/3 should recover original
787        let x_re = [1.3, -0.7, 0.4];
788        let x_im = [0.2, 1.1, -0.5];
789        let (fwd_re, fwd_im) = winograd_dft3_fwd(&x_re, &x_im);
790        let (inv_re, inv_im) = winograd_dft3_inv(&fwd_re, &fwd_im);
791        let n = 3.0_f64;
792        let scaled_re: Vec<f64> = inv_re.iter().map(|&v| v / n).collect();
793        let scaled_im: Vec<f64> = inv_im.iter().map(|&v| v / n).collect();
794        assert_close(&scaled_re, &x_re, "roundtrip_dft3_re");
795        assert_close(&scaled_im, &x_im, "roundtrip_dft3_im");
796    }
797
798    // ──────────────────────────────────────────────────────────────────────────
799    // DFT-5 tests
800    // ──────────────────────────────────────────────────────────────────────────
801
802    #[test]
803    fn test_dft5_forward_f64_impulse() {
804        let x_re = [1.0, 0.0, 0.0, 0.0, 0.0];
805        let x_im = [0.0, 0.0, 0.0, 0.0, 0.0];
806        let (got_re, got_im) = winograd_dft5_fwd(&x_re, &x_im);
807        assert_close(&got_re, &[1.0, 1.0, 1.0, 1.0, 1.0], "dft5_impulse_re");
808        assert_close(&got_im, &[0.0, 0.0, 0.0, 0.0, 0.0], "dft5_impulse_im");
809    }
810
811    #[test]
812    fn test_dft5_forward_vs_naive() {
813        let x_re = [0.5, -1.2, 0.8, 0.3, -0.6];
814        let x_im = [0.1, 0.4, -0.9, 0.7, -0.2];
815        let (got_re, got_im) = winograd_dft5_fwd(&x_re, &x_im);
816        let (ref_re, ref_im) = naive_dft_fwd(&x_re, &x_im);
817        assert_close(&got_re, &ref_re, "dft5_fwd_re");
818        assert_close(&got_im, &ref_im, "dft5_fwd_im");
819    }
820
821    #[test]
822    fn test_dft5_inverse_vs_naive() {
823        let x_re = [0.5, -1.2, 0.8, 0.3, -0.6];
824        let x_im = [0.1, 0.4, -0.9, 0.7, -0.2];
825        let (got_re, got_im) = winograd_dft5_inv(&x_re, &x_im);
826        let (ref_re, ref_im) = naive_dft_inv(&x_re, &x_im);
827        assert_close(&got_re, &ref_re, "dft5_inv_re");
828        assert_close(&got_im, &ref_im, "dft5_inv_im");
829    }
830
831    #[test]
832    fn test_roundtrip_dft5() {
833        let x_re = [0.5, -1.2, 0.8, 0.3, -0.6];
834        let x_im = [0.1, 0.4, -0.9, 0.7, -0.2];
835        let (fwd_re, fwd_im) = winograd_dft5_fwd(&x_re, &x_im);
836        let (inv_re, inv_im) = winograd_dft5_inv(&fwd_re, &fwd_im);
837        let n = 5.0_f64;
838        let scaled_re: Vec<f64> = inv_re.iter().map(|&v| v / n).collect();
839        let scaled_im: Vec<f64> = inv_im.iter().map(|&v| v / n).collect();
840        assert_close(&scaled_re, &x_re, "roundtrip_dft5_re");
841        assert_close(&scaled_im, &x_im, "roundtrip_dft5_im");
842    }
843
844    // ──────────────────────────────────────────────────────────────────────────
845    // DFT-7 tests
846    // ──────────────────────────────────────────────────────────────────────────
847
848    #[test]
849    fn test_dft7_forward_f64_impulse() {
850        let x_re = [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
851        let x_im = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
852        let (got_re, got_im) = winograd_dft7_fwd(&x_re, &x_im);
853        assert_close(
854            &got_re,
855            &[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
856            "dft7_impulse_re",
857        );
858        assert_close(
859            &got_im,
860            &[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
861            "dft7_impulse_im",
862        );
863    }
864
865    #[test]
866    fn test_dft7_forward_vs_naive() {
867        let x_re = [0.5, -1.2, 0.8, 0.3, -0.6, 1.4, -0.1];
868        let x_im = [0.1, 0.4, -0.9, 0.7, -0.2, 0.5, 0.3];
869        let (got_re, got_im) = winograd_dft7_fwd(&x_re, &x_im);
870        let (ref_re, ref_im) = naive_dft_fwd(&x_re, &x_im);
871        assert_close(&got_re, &ref_re, "dft7_fwd_re");
872        assert_close(&got_im, &ref_im, "dft7_fwd_im");
873    }
874
875    #[test]
876    fn test_dft7_inverse_vs_naive() {
877        let x_re = [0.5, -1.2, 0.8, 0.3, -0.6, 1.4, -0.1];
878        let x_im = [0.1, 0.4, -0.9, 0.7, -0.2, 0.5, 0.3];
879        let (got_re, got_im) = winograd_dft7_inv(&x_re, &x_im);
880        let (ref_re, ref_im) = naive_dft_inv(&x_re, &x_im);
881        assert_close(&got_re, &ref_re, "dft7_inv_re");
882        assert_close(&got_im, &ref_im, "dft7_inv_im");
883    }
884
885    #[test]
886    fn test_roundtrip_dft7() {
887        let x_re = [0.5, -1.2, 0.8, 0.3, -0.6, 1.4, -0.1];
888        let x_im = [0.1, 0.4, -0.9, 0.7, -0.2, 0.5, 0.3];
889        let (fwd_re, fwd_im) = winograd_dft7_fwd(&x_re, &x_im);
890        let (inv_re, inv_im) = winograd_dft7_inv(&fwd_re, &fwd_im);
891        let n = 7.0_f64;
892        let scaled_re: Vec<f64> = inv_re.iter().map(|&v| v / n).collect();
893        let scaled_im: Vec<f64> = inv_im.iter().map(|&v| v / n).collect();
894        assert_close(&scaled_re, &x_re, "roundtrip_dft7_re");
895        assert_close(&scaled_im, &x_im, "roundtrip_dft7_im");
896    }
897
898    // ──────────────────────────────────────────────────────────────────────────
899    // Winograd constants cross-validation
900    // ──────────────────────────────────────────────────────────────────────────
901
902    #[test]
903    fn test_winograd_constants_match_runtime() {
904        crate::winograd_constants::verify_constants_match_runtime();
905    }
906
907    // ──────────────────────────────────────────────────────────────────────────
908    // TokenStream generation (structural checks)
909    // ──────────────────────────────────────────────────────────────────────────
910
911    #[test]
912    fn test_generate_from_macro_size3() {
913        let input: proc_macro2::TokenStream = "3".parse().expect("parse literal");
914        let result = generate_from_macro(input);
915        assert!(result.is_ok(), "gen_odd_codelet!(3) should succeed");
916        let ts = result.expect("TokenStream for size 3");
917        let s = ts.to_string();
918        assert!(
919            s.contains("codelet_notw_3"),
920            "should contain codelet_notw_3"
921        );
922        assert!(s.contains("sign"), "should contain sign parameter");
923    }
924
925    #[test]
926    fn test_generate_from_macro_size5() {
927        let input: proc_macro2::TokenStream = "5".parse().expect("parse literal");
928        let result = generate_from_macro(input);
929        assert!(result.is_ok(), "gen_odd_codelet!(5) should succeed");
930        let ts = result.expect("TokenStream for size 5");
931        let s = ts.to_string();
932        assert!(
933            s.contains("codelet_notw_5"),
934            "should contain codelet_notw_5"
935        );
936    }
937
938    #[test]
939    fn test_generate_from_macro_size7() {
940        let input: proc_macro2::TokenStream = "7".parse().expect("parse literal");
941        let result = generate_from_macro(input);
942        assert!(result.is_ok(), "gen_odd_codelet!(7) should succeed");
943        let ts = result.expect("TokenStream for size 7");
944        let s = ts.to_string();
945        assert!(
946            s.contains("codelet_notw_7"),
947            "should contain codelet_notw_7"
948        );
949    }
950
951    #[test]
952    fn test_generate_from_macro_unsupported() {
953        let input: proc_macro2::TokenStream = "4".parse().expect("parse literal");
954        let result = generate_from_macro(input);
955        assert!(
956            result.is_err(),
957            "gen_odd_codelet!(4) should fail with unsupported size"
958        );
959    }
960}