Skip to main content

oxifft_codegen_impl/
symbolic_emit.rs

1//! Code emission: symbolic FFT expressions → `proc_macro2::TokenStream`.
2//!
3//! This module contains the recursive CSE optimizer used for code generation
4//! and the emission functions that convert symbolic FFT expressions into Rust
5//! token streams.
6
7use std::collections::{BinaryHeap, HashMap, HashSet};
8
9use proc_macro2::TokenStream;
10use quote::{format_ident, quote};
11
12use super::{ConstantFolder, Expr, SymbolicFFT};
13
14// ============================================================================
15// Recursive CSE for code generation
16// ============================================================================
17
18/// Recursive CSE optimizer for code emission.
19///
20/// Unlike `CseOptimizer::register()` which only hashes the top-level expression,
21/// this walker traverses the entire expression tree, extracts shared
22/// subexpressions into named temporaries, and replaces their occurrences with
23/// `Expr::Temp` references.  Expressions used only once are left inline.
24pub(super) struct RecursiveCse {
25    /// Map from structural hash → (original expr, temp name, use count).
26    cache: HashMap<u64, (Expr, String, usize)>,
27    counter: usize,
28}
29
30impl RecursiveCse {
31    pub(super) fn new() -> Self {
32        Self {
33            cache: HashMap::new(),
34            counter: 0,
35        }
36    }
37
38    /// Count usages of each subexpression across all outputs (bottom-up).
39    pub(super) fn count_recursive(&mut self, expr: &Expr) {
40        match expr {
41            Expr::Input { .. } | Expr::Const(_) | Expr::Temp(_) => {}
42            Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b) => {
43                self.count_recursive(a);
44                self.count_recursive(b);
45                let hash = expr.structural_hash();
46                let entry = self.cache.entry(hash).or_insert_with(|| {
47                    let name = format!("t{}", self.counter);
48                    self.counter += 1;
49                    (expr.clone(), name, 0)
50                });
51                entry.2 += 1;
52            }
53            Expr::Neg(a) => {
54                self.count_recursive(a);
55                let hash = expr.structural_hash();
56                let entry = self.cache.entry(hash).or_insert_with(|| {
57                    let name = format!("t{}", self.counter);
58                    self.counter += 1;
59                    (expr.clone(), name, 0)
60                });
61                entry.2 += 1;
62            }
63        }
64    }
65
66    /// Rewrite an expression replacing shared subexpressions with `Temp` refs.
67    /// Only extracts subexpressions used >= 2 times.
68    ///
69    /// Set `top_level_name` to `Some(name)` when rewriting the RHS of an
70    /// assignment — this prevents the expression from replacing itself with
71    /// its own temp name (self-reference).
72    fn rewrite_inner(&self, expr: &Expr, exclude_hash: Option<u64>) -> Expr {
73        match expr {
74            Expr::Input { .. } | Expr::Const(_) | Expr::Temp(_) => expr.clone(),
75            Expr::Add(a, b) => {
76                let hash = expr.structural_hash();
77                if exclude_hash != Some(hash) {
78                    if let Some((_, name, count)) = self.cache.get(&hash) {
79                        if *count >= 2 {
80                            return Expr::Temp(name.clone());
81                        }
82                    }
83                }
84                Expr::Add(
85                    Box::new(self.rewrite_inner(a, None)),
86                    Box::new(self.rewrite_inner(b, None)),
87                )
88            }
89            Expr::Sub(a, b) => {
90                let hash = expr.structural_hash();
91                if exclude_hash != Some(hash) {
92                    if let Some((_, name, count)) = self.cache.get(&hash) {
93                        if *count >= 2 {
94                            return Expr::Temp(name.clone());
95                        }
96                    }
97                }
98                Expr::Sub(
99                    Box::new(self.rewrite_inner(a, None)),
100                    Box::new(self.rewrite_inner(b, None)),
101                )
102            }
103            Expr::Mul(a, b) => {
104                let hash = expr.structural_hash();
105                if exclude_hash != Some(hash) {
106                    if let Some((_, name, count)) = self.cache.get(&hash) {
107                        if *count >= 2 {
108                            return Expr::Temp(name.clone());
109                        }
110                    }
111                }
112                Expr::Mul(
113                    Box::new(self.rewrite_inner(a, None)),
114                    Box::new(self.rewrite_inner(b, None)),
115                )
116            }
117            Expr::Neg(a) => {
118                let hash = expr.structural_hash();
119                if exclude_hash != Some(hash) {
120                    if let Some((_, name, count)) = self.cache.get(&hash) {
121                        if *count >= 2 {
122                            return Expr::Temp(name.clone());
123                        }
124                    }
125                }
126                Expr::Neg(Box::new(self.rewrite_inner(a, None)))
127            }
128        }
129    }
130
131    /// Rewrite an output expression (replacing shared subexpressions with Temp refs).
132    pub(super) fn rewrite(&self, expr: &Expr) -> Expr {
133        self.rewrite_inner(expr, None)
134    }
135
136    /// Rewrite the RHS of an assignment, excluding the assignment itself from
137    /// self-reference replacement.
138    pub(super) fn rewrite_assignment_rhs(&self, name: &str, expr: &Expr) -> Expr {
139        // Find the hash for this assignment's expression
140        let hash = self
141            .cache
142            .iter()
143            .find(|(_, (_, n, _))| n == name)
144            .map(|(h, _)| *h);
145        self.rewrite_inner(expr, hash)
146    }
147
148    /// Return sorted assignments for temps used >= 2 times.
149    pub(super) fn get_assignments(&self) -> Vec<(String, Expr)> {
150        let mut result: Vec<(String, Expr)> = self
151            .cache
152            .values()
153            .filter(|(_, _, count)| *count >= 2)
154            .map(|(expr, name, _)| (name.clone(), expr.clone()))
155            .collect();
156        // Sort by the numeric suffix for deterministic output.
157        // Names are "t0", "t1", ..., "t99", "t100", etc.
158        result.sort_by(|a, b| {
159            let na: usize = a.0[1..].parse().unwrap_or(0);
160            let nb: usize = b.0[1..].parse().unwrap_or(0);
161            na.cmp(&nb)
162        });
163        result
164    }
165}
166
167// ============================================================================
168// Code emission: symbolic FFT → proc_macro2::TokenStream
169// ============================================================================
170
171/// Build the body `TokenStream` for one direction of an n-point FFT from
172/// symbolic computation and optimization passes.
173///
174/// This function:
175/// 1. Builds the symbolic DAG via `SymbolicFFT::radix2_dit(n, forward)`.
176/// 2. Applies constant folding to each output expression.
177/// 3. Performs recursive CSE to extract shared subexpressions into named temps.
178/// 4. Emits: input extractions, CSE temporaries, output assignments.
179#[must_use]
180pub fn emit_body_from_symbolic(n: usize, forward: bool) -> TokenStream {
181    let fft = SymbolicFFT::radix2_dit(n, forward);
182
183    // Step 1: constant-fold each output
184    let folded_outputs: Vec<(Expr, Expr)> = fft
185        .outputs
186        .iter()
187        .map(|c| (ConstantFolder::fold(&c.re), ConstantFolder::fold(&c.im)))
188        .collect();
189
190    let ops_before = fft.op_count();
191
192    // Step 2: recursive CSE across all folded expressions
193    let mut cse = RecursiveCse::new();
194    for (re, im) in &folded_outputs {
195        cse.count_recursive(re);
196        cse.count_recursive(im);
197    }
198
199    // Step 3: rewrite outputs to replace shared subexpressions with Temp refs
200    let rewritten_outputs: Vec<(Expr, Expr)> = folded_outputs
201        .iter()
202        .map(|(re, im)| (cse.rewrite(re), cse.rewrite(im)))
203        .collect();
204
205    // Also rewrite the CSE assignment RHS (their children may be shared too).
206    // Use rewrite_assignment_rhs to avoid self-reference replacement.
207    let mut assignments: Vec<(String, Expr)> = cse
208        .get_assignments()
209        .into_iter()
210        .map(|(name, expr)| {
211            let rewritten = cse.rewrite_assignment_rhs(&name, &expr);
212            (name, rewritten)
213        })
214        .collect();
215
216    // Topological sort: assignments that reference other assignments must come later
217    assignments = topological_sort_assignments(assignments);
218
219    if std::env::var("OXIFFT_CODEGEN_DEBUG").is_ok() {
220        let ops_after: usize = assignments.iter().map(|(_, e)| e.op_count()).sum::<usize>()
221            + rewritten_outputs
222                .iter()
223                .map(|(re, im)| re.op_count() + im.op_count())
224                .sum::<usize>();
225        let pct = if ops_before > 0 {
226            (ops_after as f64 - ops_before as f64) / ops_before as f64 * 100.0
227        } else {
228            0.0
229        };
230        eprintln!(
231            "[oxifft-codegen] n={n} forward={forward}: {ops_before} ops → {ops_after} ops ({pct:+.1}%)",
232        );
233    }
234
235    // Step 4: instruction-scheduling optimizer pass
236    // Re-orders assignments by critical-path priority (Sethi-Ullman heuristic):
237    // leaves (depth=0) are emitted first; among equal depths, prefer statements
238    // whose results are consumed by the longest remaining critical path.
239    schedule_instructions(&mut assignments);
240
241    emit_folded_body(n, &assignments, &rewritten_outputs)
242}
243
244// ============================================================================
245// Instruction-scheduling optimizer pass
246// ============================================================================
247
248/// Schedule assignment statements to maximise instruction-level parallelism (ILP).
249///
250/// Algorithm (Sethi-Ullman critical-path heuristic):
251/// 1. Build a def-use dependency graph: for each statement `(name, expr)`, record
252///    all prior statements whose results `expr` references (via `Expr::Temp`).
253/// 2. Compute critical-path depth per statement via longest-path from leaves:
254///    - Statements with no temp-ref dependencies → depth 0.
255///    - Each dependent statement → `1 + max(deps' depths)`.
256/// 3. Topological re-ordering: maintain a ready-queue of statements whose all
257///    dependencies have already been emitted.  Among ready candidates, prefer
258///    those with the **largest** critical-path depth (i.e., the ones that unblock
259///    the longest remaining work) — this is the "greedy critical-path first" rule.
260/// 4. Guaranteed correctness: no statement is emitted before all its deps are done.
261///
262/// The pass operates in-place on the assignment vector. It will not reorder
263/// statements that were placed in a topologically invalid order beforehand —
264/// call `topological_sort_assignments` first if needed.  In practice,
265/// `emit_body_from_symbolic` calls both in sequence.
266pub fn schedule_instructions(stmts: &mut Vec<(String, Expr)>) {
267    let n = stmts.len();
268    if n <= 1 {
269        return;
270    }
271
272    // Build name → index map for O(1) predecessor lookup.
273    let index_of: std::collections::HashMap<String, usize> = stmts
274        .iter()
275        .enumerate()
276        .map(|(i, (name, _))| (name.clone(), i))
277        .collect();
278
279    // For each statement, collect its direct predecessor indices (statements it depends on).
280    let predecessors: Vec<Vec<usize>> = stmts
281        .iter()
282        .map(|(_, expr)| {
283            let mut refs = HashSet::new();
284            expr.collect_temp_refs(&mut refs);
285            refs.iter()
286                .filter_map(|r| index_of.get(r).copied())
287                .collect()
288        })
289        .collect();
290
291    // Compute critical-path depth per statement (longest path from a leaf).
292    // Leaves (no deps) have depth 0.  We process in topological order (guaranteed
293    // by the caller's prior topological sort).
294    let mut depth = vec![0usize; n];
295    for (i, preds) in predecessors.iter().enumerate() {
296        for &pred in preds {
297            let candidate = depth[pred] + 1;
298            if candidate > depth[i] {
299                depth[i] = candidate;
300            }
301        }
302    }
303
304    // Build successor sets: for each statement i, which statements directly use it?
305    let mut successors: Vec<Vec<usize>> = vec![Vec::new(); n];
306    for (i, preds) in predecessors.iter().enumerate() {
307        for &pred in preds {
308            successors[pred].push(i);
309        }
310    }
311
312    // Greedy critical-path scheduler.
313    // ready_queue: statements all of whose predecessors have been emitted, stored
314    // as (depth, original_index) — highest depth first (max-heap via BinaryHeap).
315    let mut in_degree: Vec<usize> = predecessors.iter().map(Vec::len).collect();
316    let mut emitted = vec![false; n];
317    let mut order: Vec<usize> = Vec::with_capacity(n);
318
319    // Seed ready queue with all depth-0 (no-predecessor) statements.
320    let mut ready: BinaryHeap<(usize, usize)> = BinaryHeap::new();
321    for (i, &deg) in in_degree.iter().enumerate() {
322        if deg == 0 {
323            ready.push((depth[i], i));
324        }
325    }
326
327    while let Some((_, idx)) = ready.pop() {
328        if emitted[idx] {
329            continue; // guard against duplicate insertions
330        }
331        emitted[idx] = true;
332        order.push(idx);
333        // Decrement in-degree for each successor; push newly ready ones.
334        for &succ in &successors[idx] {
335            if in_degree[succ] > 0 {
336                in_degree[succ] -= 1;
337            }
338            if in_degree[succ] == 0 && !emitted[succ] {
339                ready.push((depth[succ], succ));
340            }
341        }
342    }
343
344    // If scheduling was incomplete (cycle or bug), preserve original order for stragglers.
345    if order.len() < n {
346        for (i, &already_emitted) in emitted.iter().enumerate() {
347            if !already_emitted {
348                order.push(i);
349            }
350        }
351    }
352
353    // Reorder stmts according to the computed schedule.
354    // We need to physically rearrange the Vec without cloning Expr trees.
355    // Build a temporary vec of (old_index, new_position) pairs, then permute.
356    let mut positioned: Vec<Option<(String, Expr)>> = stmts.drain(..).map(Some).collect();
357    let reordered: Vec<(String, Expr)> = order
358        .into_iter()
359        .filter_map(|i| positioned[i].take())
360        .collect();
361    *stmts = reordered;
362}
363
364/// Topologically sort assignments so that each temp is defined before use.
365fn topological_sort_assignments(assignments: Vec<(String, Expr)>) -> Vec<(String, Expr)> {
366    let mut defined: HashSet<String> = HashSet::new();
367    let mut result: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
368    let mut remaining = assignments;
369
370    // Iterative pass: on each iteration, move all assignments whose dependencies
371    // are fully satisfied into `result`.  Repeat until stable.
372    loop {
373        let before_len = result.len();
374        let mut next_remaining = Vec::new();
375        for (name, expr) in remaining {
376            let mut refs: HashSet<String> = HashSet::new();
377            expr.collect_temp_refs(&mut refs);
378            if refs.iter().all(|r| defined.contains(r)) {
379                defined.insert(name.clone());
380                result.push((name, expr));
381            } else {
382                next_remaining.push((name, expr));
383            }
384        }
385        remaining = next_remaining;
386        if remaining.is_empty() || result.len() == before_len {
387            // Either done, or there's a cycle (shouldn't happen in acyclic DAG)
388            result.extend(remaining);
389            break;
390        }
391    }
392    result
393}
394
395/// Emit the inner body statements from constant-folded and CSE-optimized outputs.
396///
397/// Emits:
398/// - input extraction: `let x{i}_re = x[{i}].re; let x{i}_im = x[{i}].im;`
399/// - CSE temporaries: `let {name} = {expr};`
400/// - output assignments: `x[{k}] = crate::kernel::Complex::new({re}, {im});`
401fn emit_folded_body(
402    n: usize,
403    assignments: &[(String, Expr)],
404    outputs: &[(Expr, Expr)],
405) -> TokenStream {
406    assert_eq!(
407        outputs.len(),
408        n,
409        "expected n outputs for n-point complex FFT, got {}",
410        outputs.len()
411    );
412
413    let mut body = TokenStream::new();
414
415    // Extract inputs
416    for i in 0..n {
417        let re_name = format_ident!("x{i}_re");
418        let im_name = format_ident!("x{i}_im");
419        body.extend(quote! {
420            let #re_name = x[#i].re;
421            let #im_name = x[#i].im;
422        });
423    }
424
425    // Emit CSE temporaries
426    for (name, expr) in assignments {
427        let id = format_ident!("{name}");
428        let tok = emit_scalar_expr(expr);
429        body.extend(quote! { let #id = #tok; });
430    }
431
432    // Emit outputs
433    for (k, (re_expr, im_expr)) in outputs.iter().enumerate() {
434        let re_tok = emit_scalar_expr(re_expr);
435        let im_tok = emit_scalar_expr(im_expr);
436        body.extend(quote! {
437            x[#k] = crate::kernel::Complex::new(#re_tok, #im_tok);
438        });
439    }
440
441    body
442}
443
444/// Emit a single scalar `Expr` as a `TokenStream`.
445fn emit_scalar_expr(expr: &Expr) -> TokenStream {
446    match expr {
447        Expr::Input { index, is_real } => {
448            let name = if *is_real {
449                format_ident!("x{index}_re")
450            } else {
451                format_ident!("x{index}_im")
452            };
453            quote! { #name }
454        }
455        Expr::Const(v) => {
456            if (*v - 0.0_f64).abs() < f64::EPSILON {
457                quote! { T::ZERO }
458            } else if (*v - 1.0_f64).abs() < f64::EPSILON {
459                quote! { T::ONE }
460            } else if (*v - (-1.0_f64)).abs() < f64::EPSILON {
461                quote! { (-T::ONE) }
462            } else {
463                let v = *v;
464                quote! { T::from_f64(#v) }
465            }
466        }
467        Expr::Add(a, b) => {
468            let a = emit_scalar_expr(a);
469            let b = emit_scalar_expr(b);
470            quote! { (#a + #b) }
471        }
472        Expr::Sub(a, b) => {
473            let a = emit_scalar_expr(a);
474            let b = emit_scalar_expr(b);
475            quote! { (#a - #b) }
476        }
477        Expr::Mul(a, b) => {
478            let a = emit_scalar_expr(a);
479            let b = emit_scalar_expr(b);
480            quote! { (#a * #b) }
481        }
482        Expr::Neg(a) => {
483            let a = emit_scalar_expr(a);
484            quote! { (-#a) }
485        }
486        Expr::Temp(name) => {
487            let id = format_ident!("{name}");
488            quote! { #id }
489        }
490    }
491}