Skip to main content

oxifft_codegen_impl/
gen_rader.rs

1//! Rader prime codelet generation for primes 11 and 13.
2//!
3//! Rader's algorithm reduces a prime-length DFT to a cyclic convolution.
4//! For prime p with generator g of (ℤ/pℤ)*:
5//!
6//! ```text
7//! X[0]      = Σ_{n=0}^{p-1} x[n]
8//! X[g^b]    = x[0] + (A * B)[b]     for b = 0..p-2
9//! ```
10//!
11//! where:
12//! - `A[c] = x[g^{-c} mod p]`   (input permuted by inverse generator powers)
13//! - `B[m] = e^{sign·2πi·g^m/p}` (precomputed twiddles — hardcoded at codegen time)
14//! - `*` denotes cyclic convolution of length (p-1)
15//!
16//! The cyclic convolution is expanded as straight-line code (no sub-FFT calls),
17//! consistent with the `gen_odd.rs` Winograd codelet pattern.
18//!
19//! # DFT Convention
20//!
21//! Forward DFT: `sign < 0`, W = e^{-2πi/p}
22//! Inverse DFT: `sign > 0`, W = e^{+2πi/p} (unnormalized)
23
24use proc_macro2::TokenStream;
25use quote::quote;
26use syn::LitInt;
27
28// ============================================================================
29// Compile-time number-theory helpers
30// ============================================================================
31
32/// Modular exponentiation: base^exp mod m.
33const fn mod_pow(mut base: u64, mut exp: u64, m: u64) -> u64 {
34    let mut result = 1u64;
35    base %= m;
36    while exp > 0 {
37        if exp & 1 == 1 {
38            result = result * base % m;
39        }
40        base = base * base % m;
41        exp >>= 1;
42    }
43    result
44}
45
46/// Returns the set of distinct prime factors of n.
47fn prime_factors(mut n: u64) -> Vec<u64> {
48    let mut factors = Vec::new();
49    let mut d = 2u64;
50    while d * d <= n {
51        if n % d == 0 {
52            factors.push(d);
53            while n % d == 0 {
54                n /= d;
55            }
56        }
57        d += 1;
58    }
59    if n > 1 {
60        factors.push(n);
61    }
62    factors
63}
64
65/// Check whether g is a primitive root mod p (p must be prime).
66///
67/// g is a primitive root mod p iff g^((p-1)/q) ≢ 1 (mod p) for every prime q | (p-1).
68#[must_use]
69pub fn is_primitive_root(g: u64, p: u64) -> bool {
70    let pm1 = p - 1;
71    for q in prime_factors(pm1) {
72        if mod_pow(g, pm1 / q, p) == 1 {
73            return false;
74        }
75    }
76    true
77}
78
79/// Find the smallest primitive root of prime p.
80///
81/// # Panics
82///
83/// Panics if no primitive root is found below p (impossible for actual primes —
84/// every prime has a primitive root by the theory of cyclic groups).
85#[must_use]
86pub fn find_generator(p: u64) -> u64 {
87    for g in 2..p {
88        if is_primitive_root(g, p) {
89            return g;
90        }
91    }
92    panic!("find_generator: no primitive root found for prime {p}");
93}
94
95// ============================================================================
96// Precomputed constant tables (codegen time)
97// ============================================================================
98
99/// Generator powers for p=11: g^k mod 11 for k = 0..9 (g=2)
100const G11_POWERS: [usize; 10] = [1, 2, 4, 8, 5, 10, 9, 7, 3, 6];
101/// Inverse generator powers for p=11: g^{-k} mod 11 for k = 0..9
102const G11_INV_POWERS: [usize; 10] = [1, 6, 3, 7, 9, 10, 5, 8, 4, 2];
103
104/// Forward Rader twiddle re-parts for p=11: cos(-2π·g^m/11)
105#[allow(clippy::excessive_precision)]
106const B11_FWD_RE: [f64; 10] = [
107    0.841_253_532_831_181_2_f64,
108    0.415_415_013_001_886_4_f64,
109    -0.654_860_733_945_285_1_f64,
110    -0.142_314_838_273_285_23_f64,
111    -0.959_492_973_614_497_4_f64,
112    0.841_253_532_831_181_2_f64,
113    0.415_415_013_001_886_05_f64,
114    -0.654_860_733_945_285_2_f64,
115    -0.142_314_838_273_285_01_f64,
116    -0.959_492_973_614_497_5_f64,
117];
118/// Forward Rader twiddle im-parts for p=11: sin(-2π·g^m/11)
119const B11_FWD_IM: [f64; 10] = [
120    -0.540_640_817_455_597_6_f64,
121    -0.909_631_995_354_518_3_f64,
122    -0.755_749_574_354_258_3_f64,
123    0.989_821_441_880_932_7_f64,
124    -0.281_732_556_841_429_67_f64,
125    0.540_640_817_455_597_6_f64,
126    0.909_631_995_354_518_6_f64,
127    0.755_749_574_354_258_2_f64,
128    -0.989_821_441_880_932_8_f64,
129    0.281_732_556_841_429_4_f64,
130];
131/// Backward Rader twiddle re-parts for p=11: cos(+2π·g^m/11) — same as forward (cos is even)
132const B11_BWD_RE: [f64; 10] = B11_FWD_RE;
133/// Backward Rader twiddle im-parts for p=11: sin(+2π·g^m/11) — negated imaginary parts
134const B11_BWD_IM: [f64; 10] = [
135    0.540_640_817_455_597_6_f64,
136    0.909_631_995_354_518_3_f64,
137    0.755_749_574_354_258_3_f64,
138    -0.989_821_441_880_932_7_f64,
139    0.281_732_556_841_429_67_f64,
140    -0.540_640_817_455_597_6_f64,
141    -0.909_631_995_354_518_6_f64,
142    -0.755_749_574_354_258_2_f64,
143    0.989_821_441_880_932_8_f64,
144    -0.281_732_556_841_429_4_f64,
145];
146
147/// Generator powers for p=13: g^k mod 13 for k = 0..11 (g=2)
148const G13_POWERS: [usize; 12] = [1, 2, 4, 8, 3, 6, 12, 11, 9, 5, 10, 7];
149/// Inverse generator powers for p=13: g^{-k} mod 13 for k = 0..11
150const G13_INV_POWERS: [usize; 12] = [1, 7, 10, 5, 9, 11, 12, 6, 3, 8, 4, 2];
151
152/// Forward Rader twiddle re-parts for p=13: cos(-2π·g^m/13)
153#[allow(clippy::excessive_precision)]
154const B13_FWD_RE: [f64; 12] = [
155    0.885_456_025_653_209_9_f64,
156    0.568_064_746_731_155_8_f64,
157    -0.354_604_887_042_535_5_f64,
158    -0.748_510_748_171_101_3_f64,
159    0.120_536_680_255_323_01_f64,
160    -0.970_941_817_426_052_f64,
161    0.885_456_025_653_210_f64,
162    0.568_064_746_731_154_8_f64,
163    -0.354_604_887_042_535_9_f64,
164    -0.748_510_748_171_101_2_f64,
165    0.120_536_680_255_323_2_f64,
166    -0.970_941_817_426_052_1_f64,
167];
168/// Forward Rader twiddle im-parts for p=13: sin(-2π·g^m/13)
169#[allow(clippy::excessive_precision)]
170const B13_FWD_IM: [f64; 12] = [
171    -0.464_723_172_043_768_5_f64,
172    -0.822_983_865_893_656_4_f64,
173    -0.935_016_242_685_414_8_f64,
174    0.663_122_658_240_795_f64,
175    -0.992_708_874_098_054_f64,
176    -0.239_315_664_287_557_68_f64,
177    0.464_723_172_043_768_4_f64,
178    0.822_983_865_893_657_f64,
179    0.935_016_242_685_414_7_f64,
180    -0.663_122_658_240_795_2_f64,
181    0.992_708_874_098_054_f64,
182    0.239_315_664_287_557_43_f64,
183];
184/// Backward Rader twiddle re-parts for p=13: cos(+2π·g^m/13) — same as forward (cos is even)
185const B13_BWD_RE: [f64; 12] = B13_FWD_RE;
186/// Backward Rader twiddle im-parts for p=13: sin(+2π·g^m/13) — negated imaginary parts
187#[allow(clippy::excessive_precision)]
188const B13_BWD_IM: [f64; 12] = [
189    0.464_723_172_043_768_5_f64,
190    0.822_983_865_893_656_4_f64,
191    0.935_016_242_685_414_8_f64,
192    -0.663_122_658_240_795_f64,
193    0.992_708_874_098_054_f64,
194    0.239_315_664_287_557_68_f64,
195    -0.464_723_172_043_768_4_f64,
196    -0.822_983_865_893_657_f64,
197    -0.935_016_242_685_414_7_f64,
198    0.663_122_658_240_795_2_f64,
199    -0.992_708_874_098_054_f64,
200    -0.239_315_664_287_557_43_f64,
201];
202
203// ============================================================================
204// Public entry points
205// ============================================================================
206
207/// Parse `gen_rader_codelet!(N)` input and dispatch.
208///
209/// # Errors
210/// Returns `syn::Error` if the input is not a valid integer literal or the
211/// prime is not in {11, 13}.
212pub fn generate_from_macro(input: TokenStream) -> Result<TokenStream, syn::Error> {
213    let size: LitInt = syn::parse2(input)?;
214    let prime: usize = size.base10_parse().map_err(|_| {
215        syn::Error::new(
216            size.span(),
217            "gen_rader_codelet: expected an integer prime literal",
218        )
219    })?;
220
221    match prime {
222        11 => Ok(gen_size_11()),
223        13 => Ok(gen_size_13()),
224        _ => Err(syn::Error::new(
225            size.span(),
226            format!("gen_rader_codelet: unsupported prime {prime} (expected one of 11, 13)"),
227        )),
228    }
229}
230
231/// Generate a Rader-form codelet `TokenStream` for the given prime ∈ {11, 13}.
232///
233/// This is the non-proc-macro entry point used by benchmark/test harnesses.
234///
235/// # Panics
236///
237/// Panics if `prime` is not 11 or 13.
238#[must_use]
239pub fn generate_rader(prime: usize) -> TokenStream {
240    match prime {
241        11 => gen_size_11(),
242        13 => gen_size_13(),
243        _ => panic!("gen_rader: unsupported prime {prime} (expected 11 or 13)"),
244    }
245}
246
247// ============================================================================
248// DFT-11 codelet (Rader, straight-line cyclic convolution of length 10)
249// ============================================================================
250
251#[allow(clippy::similar_names)]
252fn gen_size_11() -> TokenStream {
253    // Emit g_powers and g_inv_powers as literal arrays so the quote! can use them.
254    let g_pows: Vec<proc_macro2::Literal> = G11_POWERS
255        .iter()
256        .map(|&v| proc_macro2::Literal::usize_suffixed(v))
257        .collect();
258    let g_inv_pows: Vec<proc_macro2::Literal> = G11_INV_POWERS
259        .iter()
260        .map(|&v| proc_macro2::Literal::usize_suffixed(v))
261        .collect();
262
263    // Build forward and backward twiddle literal arrays with distinct names.
264    let twd11_fwd_re: Vec<proc_macro2::Literal> = B11_FWD_RE
265        .iter()
266        .map(|&v| proc_macro2::Literal::f64_suffixed(v))
267        .collect();
268    let twd11_fwd_im: Vec<proc_macro2::Literal> = B11_FWD_IM
269        .iter()
270        .map(|&v| proc_macro2::Literal::f64_suffixed(v))
271        .collect();
272    let twd11_bwd_re: Vec<proc_macro2::Literal> = B11_BWD_RE
273        .iter()
274        .map(|&v| proc_macro2::Literal::f64_suffixed(v))
275        .collect();
276    let twd11_bwd_im: Vec<proc_macro2::Literal> = B11_BWD_IM
277        .iter()
278        .map(|&v| proc_macro2::Literal::f64_suffixed(v))
279        .collect();
280
281    quote! {
282        /// Size-11 DFT codelet using Rader's algorithm.
283        ///
284        /// Reduces the prime-11 DFT to a cyclic convolution of length 10,
285        /// computed as straight-line code.  Generator g = 2.
286        ///
287        /// `sign < 0` → forward transform (W = e^{-2πi/11});
288        /// `sign > 0` → inverse (unnormalized, W = e^{+2πi/11}).
289        #[inline(always)]
290        #[allow(
291            clippy::too_many_lines,
292            clippy::approx_constant,
293            clippy::suboptimal_flops,
294            clippy::unreadable_literal
295        )]
296        pub fn codelet_notw_11<T: crate::kernel::Float>(
297            x: &mut [crate::kernel::Complex<T>],
298            sign: i32,
299        ) {
300            debug_assert!(x.len() >= 11);
301
302            // ── Step 1: X[0] = sum of all inputs ──────────────────────────
303            let mut sum_re = T::zero();
304            let mut sum_im = T::zero();
305            for i in 0..11usize {
306                sum_re = sum_re + x[i].re;
307                sum_im = sum_im + x[i].im;
308            }
309
310            // ── Step 2: A[c] = x[g^{-c} mod 11] ──────────────────────────
311            // g_inv_powers[c] for c = 0..9
312            let g_inv_powers: [usize; 10] = [#(#g_inv_pows),*];
313            let mut a_re = [T::zero(); 10];
314            let mut a_im = [T::zero(); 10];
315            for c in 0..10usize {
316                let idx = g_inv_powers[c];
317                a_re[c] = x[idx].re;
318                a_im[c] = x[idx].im;
319            }
320
321            // ── Step 3: Select twiddle factors based on sign ───────────────
322            // Forward B[m] = e^{-2πi·g^m/11},  Inverse B[m] = e^{+2πi·g^m/11}
323            let tw_re: [T; 10];
324            let tw_im: [T; 10];
325            if sign < 0 {
326                tw_re = [#(T::from_f64(#twd11_fwd_re)),*];
327                tw_im = [#(T::from_f64(#twd11_fwd_im)),*];
328            } else {
329                tw_re = [#(T::from_f64(#twd11_bwd_re)),*];
330                tw_im = [#(T::from_f64(#twd11_bwd_im)),*];
331            }
332
333            // ── Step 4: Cyclic convolution conv[b] = Σ_c A[c]·B[(b-c)%10] ─
334            let mut conv_re = [T::zero(); 10];
335            let mut conv_im = [T::zero(); 10];
336            for b in 0..10usize {
337                let mut cr = T::zero();
338                let mut ci = T::zero();
339                for c in 0..10usize {
340                    let bc = (10 + b - c) % 10;
341                    // complex mul: A[c] * B[bc]
342                    cr = cr + a_re[c] * tw_re[bc] - a_im[c] * tw_im[bc];
343                    ci = ci + a_re[c] * tw_im[bc] + a_im[c] * tw_re[bc];
344                }
345                conv_re[b] = cr;
346                conv_im[b] = ci;
347            }
348
349            // ── Step 5: Assemble output ────────────────────────────────────
350            // X[0]     = sum
351            // X[g^b]   = x[0] + conv[b]  for b = 0..9
352            let x0_re = x[0].re;
353            let x0_im = x[0].im;
354            x[0] = crate::kernel::Complex::new(sum_re, sum_im);
355
356            let g_powers: [usize; 10] = [#(#g_pows),*];
357            for b in 0..10usize {
358                let idx = g_powers[b];
359                x[idx] = crate::kernel::Complex::new(x0_re + conv_re[b], x0_im + conv_im[b]);
360            }
361        }
362    }
363}
364
365// ============================================================================
366// DFT-13 codelet (Rader, straight-line cyclic convolution of length 12)
367// ============================================================================
368
369#[allow(clippy::similar_names)]
370fn gen_size_13() -> TokenStream {
371    let g_pows: Vec<proc_macro2::Literal> = G13_POWERS
372        .iter()
373        .map(|&v| proc_macro2::Literal::usize_suffixed(v))
374        .collect();
375    let g_inv_pows: Vec<proc_macro2::Literal> = G13_INV_POWERS
376        .iter()
377        .map(|&v| proc_macro2::Literal::usize_suffixed(v))
378        .collect();
379
380    let twd13_fwd_re: Vec<proc_macro2::Literal> = B13_FWD_RE
381        .iter()
382        .map(|&v| proc_macro2::Literal::f64_suffixed(v))
383        .collect();
384    let twd13_fwd_im: Vec<proc_macro2::Literal> = B13_FWD_IM
385        .iter()
386        .map(|&v| proc_macro2::Literal::f64_suffixed(v))
387        .collect();
388    let twd13_bwd_re: Vec<proc_macro2::Literal> = B13_BWD_RE
389        .iter()
390        .map(|&v| proc_macro2::Literal::f64_suffixed(v))
391        .collect();
392    let twd13_bwd_im: Vec<proc_macro2::Literal> = B13_BWD_IM
393        .iter()
394        .map(|&v| proc_macro2::Literal::f64_suffixed(v))
395        .collect();
396
397    quote! {
398        /// Size-13 DFT codelet using Rader's algorithm.
399        ///
400        /// Reduces the prime-13 DFT to a cyclic convolution of length 12,
401        /// computed as straight-line code.  Generator g = 2.
402        ///
403        /// `sign < 0` → forward transform (W = e^{-2πi/13});
404        /// `sign > 0` → inverse (unnormalized, W = e^{+2πi/13}).
405        #[inline(always)]
406        #[allow(
407            clippy::too_many_lines,
408            clippy::approx_constant,
409            clippy::suboptimal_flops,
410            clippy::unreadable_literal
411        )]
412        pub fn codelet_notw_13<T: crate::kernel::Float>(
413            x: &mut [crate::kernel::Complex<T>],
414            sign: i32,
415        ) {
416            debug_assert!(x.len() >= 13);
417
418            // ── Step 1: X[0] = sum of all inputs ──────────────────────────
419            let mut sum_re = T::zero();
420            let mut sum_im = T::zero();
421            for i in 0..13usize {
422                sum_re = sum_re + x[i].re;
423                sum_im = sum_im + x[i].im;
424            }
425
426            // ── Step 2: A[c] = x[g^{-c} mod 13] ──────────────────────────
427            let g_inv_powers: [usize; 12] = [#(#g_inv_pows),*];
428            let mut a_re = [T::zero(); 12];
429            let mut a_im = [T::zero(); 12];
430            for c in 0..12usize {
431                let idx = g_inv_powers[c];
432                a_re[c] = x[idx].re;
433                a_im[c] = x[idx].im;
434            }
435
436            // ── Step 3: Select twiddle factors based on sign ───────────────
437            let tw_re: [T; 12];
438            let tw_im: [T; 12];
439            if sign < 0 {
440                tw_re = [#(T::from_f64(#twd13_fwd_re)),*];
441                tw_im = [#(T::from_f64(#twd13_fwd_im)),*];
442            } else {
443                tw_re = [#(T::from_f64(#twd13_bwd_re)),*];
444                tw_im = [#(T::from_f64(#twd13_bwd_im)),*];
445            }
446
447            // ── Step 4: Cyclic convolution conv[b] = Σ_c A[c]·B[(b-c)%12] ─
448            let mut conv_re = [T::zero(); 12];
449            let mut conv_im = [T::zero(); 12];
450            for b in 0..12usize {
451                let mut cr = T::zero();
452                let mut ci = T::zero();
453                for c in 0..12usize {
454                    let bc = (12 + b - c) % 12;
455                    cr = cr + a_re[c] * tw_re[bc] - a_im[c] * tw_im[bc];
456                    ci = ci + a_re[c] * tw_im[bc] + a_im[c] * tw_re[bc];
457                }
458                conv_re[b] = cr;
459                conv_im[b] = ci;
460            }
461
462            // ── Step 5: Assemble output ────────────────────────────────────
463            let x0_re = x[0].re;
464            let x0_im = x[0].im;
465            x[0] = crate::kernel::Complex::new(sum_re, sum_im);
466
467            let g_powers: [usize; 12] = [#(#g_pows),*];
468            for b in 0..12usize {
469                let idx = g_powers[b];
470                x[idx] = crate::kernel::Complex::new(x0_re + conv_re[b], x0_im + conv_im[b]);
471            }
472        }
473    }
474}
475
476// ============================================================================
477// Pure-f64 reference implementations for #[cfg(test)]
478// ============================================================================
479
480/// Naive O(N²) DFT reference (forward, sign=-1).
481#[cfg(test)]
482#[allow(clippy::suboptimal_flops)]
483pub(crate) fn naive_dft_fwd(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
484    let n = x_re.len();
485    debug_assert_eq!(x_im.len(), n);
486    let mut out_re = vec![0.0_f64; n];
487    let mut out_im = vec![0.0_f64; n];
488    for k in 0..n {
489        for j in 0..n {
490            let angle = -2.0 * std::f64::consts::PI * (k * j) as f64 / n as f64;
491            let (s, c) = angle.sin_cos();
492            out_re[k] += x_re[j] * c - x_im[j] * s;
493            out_im[k] += x_re[j] * s + x_im[j] * c;
494        }
495    }
496    (out_re, out_im)
497}
498
499/// Naive O(N²) inverse DFT reference (sign=+1, unnormalized).
500#[cfg(test)]
501#[allow(clippy::suboptimal_flops)]
502pub(crate) fn naive_dft_inv(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
503    let n = x_re.len();
504    debug_assert_eq!(x_im.len(), n);
505    let mut out_re = vec![0.0_f64; n];
506    let mut out_im = vec![0.0_f64; n];
507    for k in 0..n {
508        for j in 0..n {
509            let angle = 2.0 * std::f64::consts::PI * (k * j) as f64 / n as f64;
510            let (s, c) = angle.sin_cos();
511            out_re[k] += x_re[j] * c - x_im[j] * s;
512            out_im[k] += x_re[j] * s + x_im[j] * c;
513        }
514    }
515    (out_re, out_im)
516}
517
518/// Rader DFT-11 (forward) in pure f64 — mirrors the generated codelet.
519#[cfg(test)]
520pub(crate) fn rader_dft11_fwd(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
521    debug_assert_eq!(x_re.len(), 11);
522    rader_dft_generic(
523        x_re,
524        x_im,
525        &G11_POWERS,
526        &G11_INV_POWERS,
527        &B11_FWD_RE,
528        &B11_FWD_IM,
529    )
530}
531
532/// Rader DFT-11 (inverse, unnormalized) in pure f64.
533#[cfg(test)]
534pub(crate) fn rader_dft11_inv(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
535    debug_assert_eq!(x_re.len(), 11);
536    rader_dft_generic(
537        x_re,
538        x_im,
539        &G11_POWERS,
540        &G11_INV_POWERS,
541        &B11_BWD_RE,
542        &B11_BWD_IM,
543    )
544}
545
546/// Rader DFT-13 (forward) in pure f64.
547#[cfg(test)]
548pub(crate) fn rader_dft13_fwd(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
549    debug_assert_eq!(x_re.len(), 13);
550    rader_dft_generic(
551        x_re,
552        x_im,
553        &G13_POWERS,
554        &G13_INV_POWERS,
555        &B13_FWD_RE,
556        &B13_FWD_IM,
557    )
558}
559
560/// Rader DFT-13 (inverse, unnormalized) in pure f64.
561#[cfg(test)]
562pub(crate) fn rader_dft13_inv(x_re: &[f64], x_im: &[f64]) -> (Vec<f64>, Vec<f64>) {
563    debug_assert_eq!(x_re.len(), 13);
564    rader_dft_generic(
565        x_re,
566        x_im,
567        &G13_POWERS,
568        &G13_INV_POWERS,
569        &B13_BWD_RE,
570        &B13_BWD_IM,
571    )
572}
573
574/// Generic Rader DFT in pure f64 for testing (not compiled in production).
575///
576/// Computes the Rader DFT via direct straight-line cyclic convolution.
577#[cfg(test)]
578#[allow(clippy::suboptimal_flops)]
579fn rader_dft_generic(
580    x_re: &[f64],
581    x_im: &[f64],
582    g_powers: &[usize],
583    g_inv_powers: &[usize],
584    twd_re: &[f64],
585    twd_im: &[f64],
586) -> (Vec<f64>, Vec<f64>) {
587    let p = x_re.len();
588    let n = p - 1;
589    debug_assert_eq!(g_powers.len(), n);
590    debug_assert_eq!(g_inv_powers.len(), n);
591    debug_assert_eq!(twd_re.len(), n);
592    debug_assert_eq!(twd_im.len(), n);
593
594    // Step 1: X[0] = sum of all inputs
595    let sum_re: f64 = x_re.iter().sum();
596    let sum_im: f64 = x_im.iter().sum();
597
598    // Step 2: A[c] = x[g^{-c}]
599    let a_re: Vec<f64> = (0..n).map(|c| x_re[g_inv_powers[c]]).collect();
600    let a_im: Vec<f64> = (0..n).map(|c| x_im[g_inv_powers[c]]).collect();
601
602    // Step 3: Cyclic convolution
603    let mut conv_re = vec![0.0_f64; n];
604    let mut conv_im = vec![0.0_f64; n];
605    for b in 0..n {
606        for c in 0..n {
607            let bc = (n + b - c) % n;
608            conv_re[b] += a_re[c] * twd_re[bc] - a_im[c] * twd_im[bc];
609            conv_im[b] += a_re[c] * twd_im[bc] + a_im[c] * twd_re[bc];
610        }
611    }
612
613    // Step 4: Assemble output
614    let mut out_re = vec![0.0_f64; p];
615    let mut out_im = vec![0.0_f64; p];
616    out_re[0] = sum_re;
617    out_im[0] = sum_im;
618    for b in 0..n {
619        let idx = g_powers[b];
620        out_re[idx] = x_re[0] + conv_re[b];
621        out_im[idx] = x_im[0] + conv_im[b];
622    }
623
624    (out_re, out_im)
625}
626
627// ============================================================================
628// Tests
629// ============================================================================
630
631#[cfg(test)]
632mod tests {
633    use super::*;
634
635    const TOL: f64 = 1e-12;
636
637    fn assert_close(a: &[f64], b: &[f64], label: &str) {
638        assert_eq!(a.len(), b.len(), "{label}: length mismatch");
639        for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
640            assert!(
641                (x - y).abs() < TOL,
642                "{label}[{i}]: got {x}, expected {y}, diff = {}",
643                (x - y).abs()
644            );
645        }
646    }
647
648    // ── Utility: generate deterministic test vectors ──────────────────────────
649
650    fn test_vec_11() -> ([f64; 11], [f64; 11]) {
651        let x_re = [
652            0.493_581, -1.234_567, 0.812_345, -0.456_789, 1.123_456, -0.234_567, 0.678_901,
653            -0.890_123, 0.345_678, -0.567_890, 0.901_234,
654        ];
655        let x_im = [
656            0.234_567, 0.678_901, -0.456_789, 0.890_123, -0.123_456, 0.567_890, -0.789_012,
657            0.234_567, -0.678_901, 0.456_789, -0.890_123,
658        ];
659        (x_re, x_im)
660    }
661
662    fn test_vec_13() -> ([f64; 13], [f64; 13]) {
663        let x_re = [
664            0.493_581, -1.234_567, 0.812_345, -0.456_789, 1.123_456, -0.234_567, 0.678_901,
665            -0.890_123, 0.345_678, -0.567_890, 0.901_234, -0.123_456, 0.789_012,
666        ];
667        let x_im = [
668            0.234_567, 0.678_901, -0.456_789, 0.890_123, -0.123_456, 0.567_890, -0.789_012,
669            0.234_567, -0.678_901, 0.456_789, -0.890_123, 0.123_456, -0.567_890,
670        ];
671        (x_re, x_im)
672    }
673
674    // ── Number-theory helpers ─────────────────────────────────────────────────
675
676    #[test]
677    fn test_generator_11() {
678        assert!(
679            is_primitive_root(2, 11),
680            "2 should be a primitive root mod 11"
681        );
682        assert!(
683            !is_primitive_root(10, 11),
684            "10 should NOT be a primitive root mod 11"
685        );
686        assert_eq!(find_generator(11), 2);
687    }
688
689    #[test]
690    fn test_generator_13() {
691        assert!(
692            is_primitive_root(2, 13),
693            "2 should be a primitive root mod 13"
694        );
695        assert_eq!(find_generator(13), 2);
696    }
697
698    #[test]
699    fn test_mod_pow_basic() {
700        // Fermat's little theorem: g^(p-1) ≡ 1 (mod p)
701        assert_eq!(mod_pow(2, 10, 11), 1); // 2^10 ≡ 1 (mod 11)
702        assert_eq!(mod_pow(2, 12, 13), 1); // 2^12 ≡ 1 (mod 13)
703    }
704
705    // ── Impulse tests (catches sign-convention bugs) ───────────────────────────
706
707    #[test]
708    fn test_dft11_forward_f64_impulse() {
709        // DFT of unit impulse at index 0: all outputs should be 1+0i
710        let x_re = [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
711        let x_im = [0.0; 11];
712        let (got_re, got_im) = rader_dft11_fwd(&x_re, &x_im);
713        assert_close(&got_re, &[1.0; 11], "dft11_impulse_re");
714        assert_close(&got_im, &[0.0; 11], "dft11_impulse_im");
715    }
716
717    #[test]
718    fn test_dft13_forward_f64_impulse() {
719        // DFT of unit impulse at index 0: all outputs should be 1+0i
720        let x_re = [
721            1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
722        ];
723        let x_im = [0.0; 13];
724        let (got_re, got_im) = rader_dft13_fwd(&x_re, &x_im);
725        assert_close(&got_re, &[1.0; 13], "dft13_impulse_re");
726        assert_close(&got_im, &[0.0; 13], "dft13_impulse_im");
727    }
728
729    // ── Forward vs. naive DFT ─────────────────────────────────────────────────
730
731    #[test]
732    fn test_rader11_forward_vs_naive() {
733        let (x_re, x_im) = test_vec_11();
734        let (got_re, got_im) = rader_dft11_fwd(&x_re, &x_im);
735        let (ref_re, ref_im) = naive_dft_fwd(&x_re, &x_im);
736        assert_close(&got_re, &ref_re, "rader11_fwd_re");
737        assert_close(&got_im, &ref_im, "rader11_fwd_im");
738    }
739
740    #[test]
741    fn test_rader13_forward_vs_naive() {
742        let (x_re, x_im) = test_vec_13();
743        let (got_re, got_im) = rader_dft13_fwd(&x_re, &x_im);
744        let (ref_re, ref_im) = naive_dft_fwd(&x_re, &x_im);
745        assert_close(&got_re, &ref_re, "rader13_fwd_re");
746        assert_close(&got_im, &ref_im, "rader13_fwd_im");
747    }
748
749    // ── Inverse vs. naive IDFT ────────────────────────────────────────────────
750
751    #[test]
752    fn test_rader11_inverse_vs_naive() {
753        let (x_re, x_im) = test_vec_11();
754        let (got_re, got_im) = rader_dft11_inv(&x_re, &x_im);
755        let (ref_re, ref_im) = naive_dft_inv(&x_re, &x_im);
756        assert_close(&got_re, &ref_re, "rader11_inv_re");
757        assert_close(&got_im, &ref_im, "rader11_inv_im");
758    }
759
760    #[test]
761    fn test_rader13_inverse_vs_naive() {
762        let (x_re, x_im) = test_vec_13();
763        let (got_re, got_im) = rader_dft13_inv(&x_re, &x_im);
764        let (ref_re, ref_im) = naive_dft_inv(&x_re, &x_im);
765        assert_close(&got_re, &ref_re, "rader13_inv_re");
766        assert_close(&got_im, &ref_im, "rader13_inv_im");
767    }
768
769    // ── Round-trip: fwd → inv → scale → original ─────────────────────────────
770
771    #[test]
772    fn test_roundtrip_rader11() {
773        let (x_re, x_im) = test_vec_11();
774        let (fwd_re, fwd_im) = rader_dft11_fwd(&x_re, &x_im);
775        let (inv_re, inv_im) = rader_dft11_inv(&fwd_re, &fwd_im);
776        let n = 11.0_f64;
777        let scaled_re: Vec<f64> = inv_re.iter().map(|&v| v / n).collect();
778        let scaled_im: Vec<f64> = inv_im.iter().map(|&v| v / n).collect();
779        assert_close(&scaled_re, &x_re, "roundtrip_rader11_re");
780        assert_close(&scaled_im, &x_im, "roundtrip_rader11_im");
781    }
782
783    #[test]
784    fn test_roundtrip_rader13() {
785        let (x_re, x_im) = test_vec_13();
786        let (fwd_re, fwd_im) = rader_dft13_fwd(&x_re, &x_im);
787        let (inv_re, inv_im) = rader_dft13_inv(&fwd_re, &fwd_im);
788        let n = 13.0_f64;
789        let scaled_re: Vec<f64> = inv_re.iter().map(|&v| v / n).collect();
790        let scaled_im: Vec<f64> = inv_im.iter().map(|&v| v / n).collect();
791        assert_close(&scaled_re, &x_re, "roundtrip_rader13_re");
792        assert_close(&scaled_im, &x_im, "roundtrip_rader13_im");
793    }
794
795    // ── TokenStream structural checks ─────────────────────────────────────────
796
797    #[test]
798    fn test_generate_from_macro_prime11() {
799        let input: proc_macro2::TokenStream = "11".parse().expect("parse literal");
800        let result = generate_from_macro(input);
801        assert!(result.is_ok(), "gen_rader_codelet!(11) should succeed");
802        let ts = result.expect("TokenStream for prime 11");
803        let s = ts.to_string();
804        assert!(
805            s.contains("codelet_notw_11"),
806            "should contain codelet_notw_11"
807        );
808        assert!(s.contains("sign"), "should contain sign parameter");
809    }
810
811    #[test]
812    fn test_generate_from_macro_prime13() {
813        let input: proc_macro2::TokenStream = "13".parse().expect("parse literal");
814        let result = generate_from_macro(input);
815        assert!(result.is_ok(), "gen_rader_codelet!(13) should succeed");
816        let ts = result.expect("TokenStream for prime 13");
817        let s = ts.to_string();
818        assert!(
819            s.contains("codelet_notw_13"),
820            "should contain codelet_notw_13"
821        );
822    }
823
824    #[test]
825    fn test_generate_from_macro_unsupported() {
826        let input: proc_macro2::TokenStream = "17".parse().expect("parse literal");
827        let result = generate_from_macro(input);
828        assert!(
829            result.is_err(),
830            "gen_rader_codelet!(17) should fail with unsupported prime"
831        );
832    }
833}