Skip to main content

oxifft_codegen_impl/
gen_rdft.rs

1//! RDFT codelet generation — R2HC and HC2R.
2//!
3//! Generates optimized codelets matching the hand-written signatures in
4//! `oxifft/src/rdft/codelets/mod.rs`:
5//!
6//! - `r2hc_N<T: Float>(x: &[T], y: &mut [Complex<T>])` — real to half-complex
7//! - `hc2r_N<T: Float>(y: &[Complex<T>], x: &mut [T])` — half-complex to real (unnormalized)
8//!
9//! R2HC output stores N/2+1 complex bins: Y\[0\]…Y\[N/2\].
10//! Y\[0\].im and Y\[N/2\].im are always zero for real inputs.
11//! HC2R is the exact inverse butterfly; caller divides by N for true inverse.
12
13#![allow(clippy::cast_precision_loss)] // small FFT sizes (≤8) fit safely in f64 mantissa
14
15use std::collections::HashMap;
16
17use proc_macro2::{Span, TokenStream};
18use quote::{format_ident, quote};
19use syn::{parse::ParseStream, Ident, LitInt, Token};
20
21use crate::symbolic::{ConstantFolder, Expr, StrengthReducer};
22
23// ============================================================================
24// Input parsing
25// ============================================================================
26
27/// Parsed arguments for `gen_rdft_codelet!(size = N, kind = R2hc | Hc2r)`.
28pub struct RdftInput {
29    pub size: usize,
30    pub kind: RdftKind,
31}
32
33/// Which codelet direction to generate.
34#[derive(Copy, Clone, PartialEq, Eq, Debug)]
35pub enum RdftKind {
36    R2hc,
37    Hc2r,
38}
39
40impl syn::parse::Parse for RdftInput {
41    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
42        // Parse: size = <int>, kind = <Ident>
43        let kw_size: Ident = input.parse()?;
44        if kw_size != "size" {
45            return Err(syn::Error::new(
46                kw_size.span(),
47                "expected `size = N, kind = R2hc | Hc2r`",
48            ));
49        }
50        let _eq: Token![=] = input.parse()?;
51        let size_lit: LitInt = input.parse()?;
52        let size: usize = size_lit
53            .base10_parse()
54            .map_err(|_| syn::Error::new(size_lit.span(), "expected integer size literal"))?;
55
56        let _comma: Token![,] = input.parse()?;
57
58        let kw_kind: Ident = input.parse()?;
59        if kw_kind != "kind" {
60            return Err(syn::Error::new(
61                kw_kind.span(),
62                "expected `kind = R2hc | Hc2r`",
63            ));
64        }
65        let _eq2: Token![=] = input.parse()?;
66        let kind_ident: Ident = input.parse()?;
67
68        let kind = match kind_ident.to_string().as_str() {
69            "R2hc" => RdftKind::R2hc,
70            "Hc2r" => RdftKind::Hc2r,
71            other => {
72                return Err(syn::Error::new(
73                    kind_ident.span(),
74                    format!("unknown RDFT kind `{other}`, expected `R2hc` or `Hc2r`"),
75                ))
76            }
77        };
78
79        Ok(Self { size, kind })
80    }
81}
82
83// ============================================================================
84// Public entry point
85// ============================================================================
86
87/// Generate a `gen_rdft_codelet!(size = N, kind = R2hc|Hc2r)` codelet.
88///
89/// # Errors
90/// Returns `syn::Error` if the input fails to parse or the size is unsupported.
91pub fn generate(input: TokenStream) -> Result<TokenStream, syn::Error> {
92    let parsed: RdftInput = syn::parse2(input)?;
93    match parsed.kind {
94        RdftKind::R2hc => gen_r2hc(parsed.size),
95        RdftKind::Hc2r => gen_hc2r(parsed.size),
96    }
97}
98
99// ============================================================================
100// R2HC generation
101// ============================================================================
102
103fn gen_r2hc(n: usize) -> Result<TokenStream, syn::Error> {
104    match n {
105        2 | 4 | 8 => Ok(emit_r2hc_codelet(n)),
106        _ => Err(syn::Error::new(
107            Span::call_site(),
108            format!("gen_rdft_codelet: unsupported size {n} for R2hc (expected 2, 4, or 8)"),
109        )),
110    }
111}
112
113/// Build symbolic R2HC expressions for all k in 0..=n/2.
114///
115/// Y[k].re = Σ_{j=0..n-1} x[j] · cos(-2π·j·k/N)  = Σ x[j]·cos(2π·j·k/N)
116/// Y[k].im = Σ_{j=0..n-1} x[j] · sin(-2π·j·k/N)  = -Σ x[j]·sin(2π·j·k/N)
117///
118/// (Uses the DFT kernel e^{-2πi·j·k/N} = cos - i·sin, so im component accumulates
119///  sin of the negative angle.)
120fn symbolic_r2hc(n: usize) -> Vec<(Expr, Expr)> {
121    let half = n / 2;
122    let mut outputs = Vec::with_capacity(half + 1);
123
124    for k in 0..=half {
125        let mut re_acc = Expr::Const(0.0);
126        let mut im_acc = Expr::Const(0.0);
127        for j in 0..n {
128            // angle for e^{-2πi·j·k/N}
129            let angle = -2.0 * std::f64::consts::PI * (j * k) as f64 / n as f64;
130            let cos_val = angle.cos();
131            let sin_val = angle.sin(); // = -sin(2π·j·k/N)
132            let xj = Expr::input_re(j);
133            re_acc = re_acc.add(xj.clone().mul(Expr::Const(cos_val)));
134            im_acc = im_acc.add(xj.mul(Expr::Const(sin_val)));
135        }
136        let re_red = ConstantFolder::fold(&StrengthReducer::reduce(&re_acc));
137        let im_red = ConstantFolder::fold(&StrengthReducer::reduce(&im_acc));
138        outputs.push((re_red, im_red));
139    }
140    outputs
141}
142
143/// Emit the full R2HC codelet function as a `TokenStream`.
144fn emit_r2hc_codelet(n: usize) -> TokenStream {
145    let outputs = symbolic_r2hc(n); // len = n/2 + 1
146    let half = n / 2;
147    let min_out = half + 1;
148    let fn_name = format_ident!("r2hc_{n}_gen");
149    let body = emit_r2hc_body(n, &outputs);
150
151    quote! {
152        /// Generated R2HC (real to half-complex) codelet.
153        ///
154        /// Input: `x` — N real samples.
155        /// Output: `y` — N/2+1 complex bins Y[0]..Y[N/2].
156        #[inline(always)]
157        #[allow(clippy::too_many_lines, clippy::approx_constant, clippy::suboptimal_flops)]
158        pub fn #fn_name<T: crate::kernel::Float>(x: &[T], y: &mut [crate::kernel::Complex<T>]) {
159            debug_assert_eq!(x.len(), #n);
160            debug_assert!(y.len() >= #min_out);
161            #body
162        }
163    }
164}
165
166/// Build the body for R2HC.
167fn emit_r2hc_body(n: usize, outputs: &[(Expr, Expr)]) -> TokenStream {
168    // Collect all expressions for CSE
169    let all_exprs: Vec<&Expr> = outputs.iter().flat_map(|(re, im)| [re, im]).collect();
170
171    let mut cse = LocalCse::new();
172    for expr in &all_exprs {
173        cse.count_recursive(expr);
174    }
175
176    let mut body = TokenStream::new();
177
178    // Input extraction: `let x0 = x[0];` ...
179    for i in 0..n {
180        let var = format_ident!("x{i}");
181        body.extend(quote! { let #var = x[#i]; });
182    }
183
184    // CSE temporaries: emit original expr (no sub-CSE of assignment bodies).
185    // Assignments are topologically sorted by Temp-ref dependencies.
186    // Since the original exprs contain no Temp nodes (they are raw symbolic trees),
187    // topological_sort leaves them in name order (t0, t1, ...) which is correct
188    // because count_recursive assigns names in traversal order.
189    let assignments = cse.get_assignments();
190    for (name, expr) in &assignments {
191        let id = format_ident!("{name}");
192        // Emit the original expr without further CSE rewriting to avoid
193        // self-referential or forward Temp dependencies.
194        let tok = emit_real_scalar(expr);
195        body.extend(quote! { let #id = #tok; });
196    }
197
198    // Output assignments: `y[k] = Complex::new(re, im);`
199    for (k, (re_expr, im_expr)) in outputs.iter().enumerate() {
200        let re_tok = emit_real_scalar(&cse.rewrite(re_expr));
201        let im_tok = emit_real_scalar(&cse.rewrite(im_expr));
202        body.extend(quote! {
203            y[#k] = crate::kernel::Complex::new(#re_tok, #im_tok);
204        });
205    }
206
207    body
208}
209
210// ============================================================================
211// HC2R generation
212// ============================================================================
213
214fn gen_hc2r(n: usize) -> Result<TokenStream, syn::Error> {
215    match n {
216        2 | 4 | 8 => Ok(emit_hc2r_codelet(n)),
217        _ => Err(syn::Error::new(
218            Span::call_site(),
219            format!("gen_rdft_codelet: unsupported size {n} for Hc2r (expected 2, 4, or 8)"),
220        )),
221    }
222}
223
224/// Build symbolic HC2R expressions for all j in 0..n.
225///
226/// Unnormalized inverse (no 1/N factor):
227///   x[j] = Y[0].re
228///         + Y[N/2].re · cos(π·j)                              (Nyquist)
229///         + 2 · Σ_{k=1..N/2-1} (Y[k].re·cos(2π·j·k/N) - Y[k].im·sin(2π·j·k/N))
230///
231/// Input indices: `Expr::input_re(k)` = Y[k].re, `Expr::input_im(k)` = Y[k].im
232/// (only k in 0..=N/2 are provided as inputs).
233fn symbolic_hc2r(n: usize) -> Vec<Expr> {
234    let half = n / 2;
235    let mut outputs = Vec::with_capacity(n);
236
237    for j in 0..n {
238        // DC component: Y[0].re (always 1× contribution)
239        let mut acc = Expr::input_re(0);
240
241        // Interior bins k=1..N/2-1: factor-of-2 from conjugate symmetry
242        for k in 1..half {
243            let angle = 2.0 * std::f64::consts::PI * (j * k) as f64 / n as f64;
244            let cos_val = angle.cos();
245            let sin_val = angle.sin();
246
247            let yk_re = Expr::input_re(k);
248            let yk_im = Expr::input_im(k);
249
250            let term_re = yk_re.mul(Expr::Const(cos_val));
251            let term_im = yk_im.mul(Expr::Const(sin_val));
252            let term = term_re.sub(term_im);
253            // Multiply by 2 for the conjugate symmetric pair
254            let term2 = term.mul(Expr::Const(2.0));
255            acc = acc.add(term2);
256        }
257
258        // Nyquist bin: Y[N/2].re · cos(π·j) = Y[N/2].re · (-1)^j
259        let nyquist_angle = std::f64::consts::PI * j as f64;
260        let nyquist_cos = nyquist_angle.cos(); // exactly +1.0 or -1.0
261        let nyquist_term = Expr::input_re(half).mul(Expr::Const(nyquist_cos));
262        acc = acc.add(nyquist_term);
263
264        let reduced = ConstantFolder::fold(&StrengthReducer::reduce(&acc));
265        outputs.push(reduced);
266    }
267    outputs
268}
269
270/// Emit the full HC2R codelet function as a `TokenStream`.
271fn emit_hc2r_codelet(n: usize) -> TokenStream {
272    let outputs = symbolic_hc2r(n);
273    let half = n / 2;
274    let min_in = half + 1;
275    let fn_name = format_ident!("hc2r_{n}_gen");
276    let body = emit_hc2r_body(n, &outputs, half);
277
278    quote! {
279        /// Generated HC2R (half-complex to real) codelet.
280        ///
281        /// Input: `y` — N/2+1 complex bins Y[0]..Y[N/2].
282        /// Output: `x` — N real samples (unnormalized; caller divides by N).
283        #[inline(always)]
284        #[allow(clippy::too_many_lines, clippy::approx_constant, clippy::suboptimal_flops)]
285        pub fn #fn_name<T: crate::kernel::Float>(y: &[crate::kernel::Complex<T>], x: &mut [T]) {
286            debug_assert!(y.len() >= #min_in);
287            debug_assert_eq!(x.len(), #n);
288            #body
289        }
290    }
291}
292
293/// Build the body for HC2R.
294fn emit_hc2r_body(_n: usize, outputs: &[Expr], half: usize) -> TokenStream {
295    let mut cse = LocalCse::new();
296    for expr in outputs {
297        cse.count_recursive(expr);
298    }
299
300    let mut body = TokenStream::new();
301
302    // Input extraction: `let y0_re = y[0].re; let y0_im = y[0].im;` ...
303    for k in 0..=half {
304        let re_var = format_ident!("y{k}_re");
305        let im_var = format_ident!("y{k}_im");
306        body.extend(quote! {
307            let #re_var = y[#k].re;
308            let #im_var = y[#k].im;
309        });
310    }
311
312    // CSE temporaries: emit original exprs directly (no sub-CSE of bodies).
313    let assignments = cse.get_assignments();
314    for (name, expr) in &assignments {
315        let id = format_ident!("{name}");
316        let tok = emit_hc2r_scalar(expr);
317        body.extend(quote! { let #id = #tok; });
318    }
319
320    // Output assignments: `x[j] = <expr>;`
321    for (j, expr) in outputs.iter().enumerate() {
322        let val_tok = emit_hc2r_scalar(&cse.rewrite(expr));
323        body.extend(quote! { x[#j] = #val_tok; });
324    }
325
326    body
327}
328
329// ============================================================================
330// Local CSE (recursive, mirrors symbolic_emit::RecursiveCse)
331// ============================================================================
332
333/// Local recursive CSE for RDFT expressions.
334///
335/// Counts subexpression usages across all outputs, then rewrites
336/// shared subexpressions (used ≥ 2 times) as `Temp` references.
337struct LocalCse {
338    /// `structural_hash` → (original expr, temp name, use count)
339    cache: HashMap<u64, (Expr, String, usize)>,
340    counter: usize,
341}
342
343impl LocalCse {
344    fn new() -> Self {
345        Self {
346            cache: HashMap::new(),
347            counter: 0,
348        }
349    }
350
351    /// Count subexpression usages (bottom-up traversal).
352    fn count_recursive(&mut self, expr: &Expr) {
353        match expr {
354            Expr::Input { .. } | Expr::Const(_) | Expr::Temp(_) => {}
355            Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b) => {
356                self.count_recursive(a);
357                self.count_recursive(b);
358                let hash = expr.structural_hash();
359                let entry = self.cache.entry(hash).or_insert_with(|| {
360                    let name = format!("t{}", self.counter);
361                    self.counter += 1;
362                    (expr.clone(), name, 0)
363                });
364                entry.2 += 1;
365            }
366            Expr::Neg(a) => {
367                self.count_recursive(a);
368                let hash = expr.structural_hash();
369                let entry = self.cache.entry(hash).or_insert_with(|| {
370                    let name = format!("t{}", self.counter);
371                    self.counter += 1;
372                    (expr.clone(), name, 0)
373                });
374                entry.2 += 1;
375            }
376        }
377    }
378
379    /// Rewrite an expression, replacing shared subexpressions with `Temp` refs.
380    fn rewrite(&self, expr: &Expr) -> Expr {
381        self.rewrite_inner(expr, None)
382    }
383
384    fn rewrite_inner(&self, expr: &Expr, exclude_hash: Option<u64>) -> Expr {
385        match expr {
386            Expr::Input { .. } | Expr::Const(_) | Expr::Temp(_) => expr.clone(),
387            Expr::Add(a, b) => {
388                let hash = expr.structural_hash();
389                if exclude_hash != Some(hash) {
390                    if let Some((_, name, count)) = self.cache.get(&hash) {
391                        if *count >= 2 {
392                            return Expr::Temp(name.clone());
393                        }
394                    }
395                }
396                Expr::Add(
397                    Box::new(self.rewrite_inner(a, None)),
398                    Box::new(self.rewrite_inner(b, None)),
399                )
400            }
401            Expr::Sub(a, b) => {
402                let hash = expr.structural_hash();
403                if exclude_hash != Some(hash) {
404                    if let Some((_, name, count)) = self.cache.get(&hash) {
405                        if *count >= 2 {
406                            return Expr::Temp(name.clone());
407                        }
408                    }
409                }
410                Expr::Sub(
411                    Box::new(self.rewrite_inner(a, None)),
412                    Box::new(self.rewrite_inner(b, None)),
413                )
414            }
415            Expr::Mul(a, b) => {
416                let hash = expr.structural_hash();
417                if exclude_hash != Some(hash) {
418                    if let Some((_, name, count)) = self.cache.get(&hash) {
419                        if *count >= 2 {
420                            return Expr::Temp(name.clone());
421                        }
422                    }
423                }
424                Expr::Mul(
425                    Box::new(self.rewrite_inner(a, None)),
426                    Box::new(self.rewrite_inner(b, None)),
427                )
428            }
429            Expr::Neg(a) => {
430                let hash = expr.structural_hash();
431                if exclude_hash != Some(hash) {
432                    if let Some((_, name, count)) = self.cache.get(&hash) {
433                        if *count >= 2 {
434                            return Expr::Temp(name.clone());
435                        }
436                    }
437                }
438                Expr::Neg(Box::new(self.rewrite_inner(a, None)))
439            }
440        }
441    }
442
443    /// Return sorted assignments for temps used ≥ 2 times.
444    fn get_assignments(&self) -> Vec<(String, Expr)> {
445        let mut result: Vec<(String, Expr)> = self
446            .cache
447            .values()
448            .filter(|(_, _, count)| *count >= 2)
449            .map(|(expr, name, _)| (name.clone(), expr.clone()))
450            .collect();
451        result.sort_by(|a, b| {
452            let na: usize = a.0[1..].parse().unwrap_or(0);
453            let nb: usize = b.0[1..].parse().unwrap_or(0);
454            na.cmp(&nb)
455        });
456        result
457    }
458}
459
460// ============================================================================
461// Scalar emission
462// ============================================================================
463
464/// Emit a scalar `Expr` for R2HC.
465///
466/// `Input { index, is_real: true }` → `x{index}` (real input)
467/// `Input { index, is_real: false }` → should not occur in R2HC; emits
468///   `y{index}_im` as a fallback (occurs in shared emitter paths only).
469fn emit_real_scalar(expr: &Expr) -> TokenStream {
470    match expr {
471        Expr::Input { index, is_real } => {
472            if *is_real {
473                let name = format_ident!("x{index}");
474                quote! { #name }
475            } else {
476                // Not expected in R2HC but safe fallback
477                let name = format_ident!("y{index}_im");
478                quote! { #name }
479            }
480        }
481        Expr::Const(v) => emit_const(*v),
482        Expr::Add(a, b) => {
483            let a = emit_real_scalar(a);
484            let b = emit_real_scalar(b);
485            quote! { (#a + #b) }
486        }
487        Expr::Sub(a, b) => {
488            let a = emit_real_scalar(a);
489            let b = emit_real_scalar(b);
490            quote! { (#a - #b) }
491        }
492        Expr::Mul(a, b) => {
493            let a = emit_real_scalar(a);
494            let b = emit_real_scalar(b);
495            quote! { (#a * #b) }
496        }
497        Expr::Neg(a) => {
498            let a = emit_real_scalar(a);
499            quote! { (-#a) }
500        }
501        Expr::Temp(name) => {
502            let id = format_ident!("{name}");
503            quote! { #id }
504        }
505    }
506}
507
508/// Emit a scalar `Expr` for HC2R (complex inputs `y{k}_re` / `y{k}_im`).
509fn emit_hc2r_scalar(expr: &Expr) -> TokenStream {
510    match expr {
511        Expr::Input { index, is_real } => {
512            let name = if *is_real {
513                format_ident!("y{index}_re")
514            } else {
515                format_ident!("y{index}_im")
516            };
517            quote! { #name }
518        }
519        Expr::Const(v) => emit_const(*v),
520        Expr::Add(a, b) => {
521            let a = emit_hc2r_scalar(a);
522            let b = emit_hc2r_scalar(b);
523            quote! { (#a + #b) }
524        }
525        Expr::Sub(a, b) => {
526            let a = emit_hc2r_scalar(a);
527            let b = emit_hc2r_scalar(b);
528            quote! { (#a - #b) }
529        }
530        Expr::Mul(a, b) => {
531            let a = emit_hc2r_scalar(a);
532            let b = emit_hc2r_scalar(b);
533            quote! { (#a * #b) }
534        }
535        Expr::Neg(a) => {
536            let a = emit_hc2r_scalar(a);
537            quote! { (-#a) }
538        }
539        Expr::Temp(name) => {
540            let id = format_ident!("{name}");
541            quote! { #id }
542        }
543    }
544}
545
546/// Emit a constant value as `T::ZERO`, `T::ONE`, `(-T::ONE)`, or `T::from_f64(v)`.
547fn emit_const(v: f64) -> TokenStream {
548    if (v - 0.0_f64).abs() < f64::EPSILON {
549        quote! { T::ZERO }
550    } else if (v - 1.0_f64).abs() < f64::EPSILON {
551        quote! { T::ONE }
552    } else if (v - (-1.0_f64)).abs() < f64::EPSILON {
553        quote! { (-T::ONE) }
554    } else {
555        quote! { T::from_f64(#v) }
556    }
557}