Skip to main content

oxifft_codegen_impl/
symbolic.rs

1//! Symbolic FFT operation representation.
2//!
3//! This module provides a symbolic representation of FFT operations that can be
4//! optimized through common subexpression elimination (CSE) and strength reduction
5//! before code generation.
6//!
7//! Note: These types are infrastructure for the codegen proc-macros and are tested
8//! but not directly exported (proc-macro crates can only export proc-macro functions).
9
10#![allow(clippy::cast_precision_loss)] // FFT sizes fit comfortably in f64 mantissa
11
12#[cfg(test)]
13use std::collections::HashMap;
14use std::collections::HashSet;
15use std::fmt;
16
17/// A symbolic expression representing FFT operations.
18#[derive(Clone, Debug, PartialEq)]
19pub enum Expr {
20    /// Input variable: `x[index].re` or `x[index].im`
21    Input { index: usize, is_real: bool },
22    /// Constant value
23    Const(f64),
24    /// Addition
25    Add(Box<Self>, Box<Self>),
26    /// Subtraction
27    Sub(Box<Self>, Box<Self>),
28    /// Multiplication
29    Mul(Box<Self>, Box<Self>),
30    /// Negation
31    Neg(Box<Self>),
32    /// Named temporary (result of CSE)
33    Temp(String),
34}
35
36impl Expr {
37    /// Create a real input reference.
38    #[must_use]
39    pub const fn input_re(index: usize) -> Self {
40        Self::Input {
41            index,
42            is_real: true,
43        }
44    }
45
46    /// Create an imaginary input reference.
47    #[must_use]
48    pub const fn input_im(index: usize) -> Self {
49        Self::Input {
50            index,
51            is_real: false,
52        }
53    }
54
55    /// Create a constant.
56    #[must_use]
57    pub const fn constant(value: f64) -> Self {
58        Self::Const(value)
59    }
60
61    /// Create addition expression.
62    #[must_use]
63    #[allow(clippy::should_implement_trait)]
64    pub fn add(self, other: Self) -> Self {
65        Self::Add(Box::new(self), Box::new(other))
66    }
67
68    /// Create subtraction expression.
69    #[must_use]
70    #[allow(clippy::should_implement_trait)]
71    pub fn sub(self, other: Self) -> Self {
72        Self::Sub(Box::new(self), Box::new(other))
73    }
74
75    /// Create multiplication expression.
76    #[must_use]
77    #[allow(clippy::should_implement_trait)]
78    pub fn mul(self, other: Self) -> Self {
79        Self::Mul(Box::new(self), Box::new(other))
80    }
81
82    /// Get constant value if this is a constant.
83    #[must_use]
84    pub const fn const_value(&self) -> Option<f64> {
85        match self {
86            Self::Const(v) => Some(*v),
87            _ => None,
88        }
89    }
90
91    /// Hash the expression for CSE.
92    #[must_use]
93    pub fn structural_hash(&self) -> u64 {
94        use std::collections::hash_map::DefaultHasher;
95        use std::hash::Hasher;
96
97        let mut hasher = DefaultHasher::new();
98        self.hash_recursive(&mut hasher);
99        hasher.finish()
100    }
101
102    fn hash_recursive<H: std::hash::Hasher>(&self, hasher: &mut H) {
103        use std::hash::Hash;
104        match self {
105            Self::Input { index, is_real } => {
106                0u8.hash(hasher);
107                index.hash(hasher);
108                is_real.hash(hasher);
109            }
110            Self::Const(v) => {
111                1u8.hash(hasher);
112                v.to_bits().hash(hasher);
113            }
114            Self::Add(a, b) => {
115                2u8.hash(hasher);
116                a.hash_recursive(hasher);
117                b.hash_recursive(hasher);
118            }
119            Self::Sub(a, b) => {
120                3u8.hash(hasher);
121                a.hash_recursive(hasher);
122                b.hash_recursive(hasher);
123            }
124            Self::Mul(a, b) => {
125                4u8.hash(hasher);
126                a.hash_recursive(hasher);
127                b.hash_recursive(hasher);
128            }
129            Self::Neg(a) => {
130                5u8.hash(hasher);
131                a.hash_recursive(hasher);
132            }
133            Self::Temp(name) => {
134                6u8.hash(hasher);
135                name.hash(hasher);
136            }
137        }
138    }
139
140    /// Collect all `Temp` variable names referenced in this expression.
141    pub fn collect_temp_refs(&self, refs: &mut HashSet<String>) {
142        match self {
143            Self::Temp(name) => {
144                refs.insert(name.clone());
145            }
146            Self::Add(a, b) | Self::Sub(a, b) | Self::Mul(a, b) => {
147                a.collect_temp_refs(refs);
148                b.collect_temp_refs(refs);
149            }
150            Self::Neg(a) => a.collect_temp_refs(refs),
151            Self::Input { .. } | Self::Const(_) => {}
152        }
153    }
154
155    /// Count operations in this expression.
156    #[must_use]
157    pub fn op_count(&self) -> usize {
158        match self {
159            Self::Input { .. } | Self::Const(_) | Self::Temp(_) => 0,
160            Self::Add(a, b) | Self::Sub(a, b) | Self::Mul(a, b) => 1 + a.op_count() + b.op_count(),
161            Self::Neg(a) => 1 + a.op_count(),
162        }
163    }
164}
165
166#[cfg(test)]
167impl Expr {
168    /// Create negation. (test helper)
169    #[must_use]
170    #[allow(clippy::should_implement_trait)]
171    pub fn neg(self) -> Self {
172        Self::Neg(Box::new(self))
173    }
174}
175
176impl fmt::Display for Expr {
177    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178        match self {
179            Self::Input { index, is_real } => {
180                write!(f, "x[{}].{}", index, if *is_real { "re" } else { "im" })
181            }
182            Self::Const(v) => write!(f, "{v}"),
183            Self::Add(a, b) => write!(f, "({a} + {b})"),
184            Self::Sub(a, b) => write!(f, "({a} - {b})"),
185            Self::Mul(a, b) => write!(f, "({a} * {b})"),
186            Self::Neg(a) => write!(f, "(-{a})"),
187            Self::Temp(name) => write!(f, "{name}"),
188        }
189    }
190}
191
192/// A complex symbolic expression (real, imaginary pair).
193#[derive(Clone, Debug)]
194pub struct ComplexExpr {
195    pub re: Expr,
196    pub im: Expr,
197}
198
199impl ComplexExpr {
200    /// Create from input index.
201    #[must_use]
202    pub const fn input(index: usize) -> Self {
203        Self {
204            re: Expr::input_re(index),
205            im: Expr::input_im(index),
206        }
207    }
208
209    /// Create from constant.
210    #[must_use]
211    pub const fn constant(re: f64, im: f64) -> Self {
212        Self {
213            re: Expr::constant(re),
214            im: Expr::constant(im),
215        }
216    }
217
218    /// Complex addition.
219    #[must_use]
220    #[allow(clippy::should_implement_trait)]
221    pub fn add(&self, other: &Self) -> Self {
222        Self {
223            re: self.re.clone().add(other.re.clone()),
224            im: self.im.clone().add(other.im.clone()),
225        }
226    }
227
228    /// Complex subtraction.
229    #[must_use]
230    #[allow(clippy::should_implement_trait)]
231    pub fn sub(&self, other: &Self) -> Self {
232        Self {
233            re: self.re.clone().sub(other.re.clone()),
234            im: self.im.clone().sub(other.im.clone()),
235        }
236    }
237
238    /// Complex multiplication.
239    #[must_use]
240    #[allow(clippy::should_implement_trait)]
241    pub fn mul(&self, other: &Self) -> Self {
242        // (a + bi)(c + di) = (ac - bd) + (ad + bc)i
243        Self {
244            re: self
245                .re
246                .clone()
247                .mul(other.re.clone())
248                .sub(self.im.clone().mul(other.im.clone())),
249            im: self
250                .re
251                .clone()
252                .mul(other.im.clone())
253                .add(self.im.clone().mul(other.re.clone())),
254        }
255    }
256}
257
258#[cfg(test)]
259impl ComplexExpr {
260    /// Multiply by j = sqrt(-1). (test helper)
261    #[must_use]
262    pub fn mul_j(&self) -> Self {
263        // (a + bi) * i = -b + ai
264        Self {
265            re: self.im.clone().neg(),
266            im: self.re.clone(),
267        }
268    }
269
270    /// Multiply by -j = -sqrt(-1). (test helper)
271    #[must_use]
272    pub fn mul_neg_j(&self) -> Self {
273        // (a + bi) * (-i) = b - ai
274        Self {
275            re: self.im.clone(),
276            im: self.re.clone().neg(),
277        }
278    }
279
280    /// Negation. (test helper)
281    #[must_use]
282    pub fn neg(&self) -> Self {
283        Self {
284            re: self.re.clone().neg(),
285            im: self.im.clone().neg(),
286        }
287    }
288}
289
290/// Common Subexpression Elimination optimizer. (used only in tests)
291#[cfg(test)]
292pub struct CseOptimizer {
293    /// Map from expression hash to (expression, temp name, use count).
294    expr_cache: HashMap<u64, (Expr, String, usize)>,
295    /// Counter for generating temp names.
296    temp_counter: usize,
297    /// Threshold for CSE (min uses to create temp).
298    min_uses: usize,
299}
300
301#[cfg(test)]
302impl CseOptimizer {
303    /// Create a new CSE optimizer.
304    #[must_use]
305    pub fn new() -> Self {
306        Self {
307            expr_cache: HashMap::new(),
308            temp_counter: 0,
309            min_uses: 2,
310        }
311    }
312
313    /// Set minimum uses threshold for CSE.
314    #[must_use]
315    pub const fn with_min_uses(mut self, min_uses: usize) -> Self {
316        self.min_uses = min_uses;
317        self
318    }
319
320    /// Register an expression and return the optimized version.
321    #[must_use]
322    pub fn register(&mut self, expr: &Expr) -> Expr {
323        // Don't CSE simple expressions
324        if matches!(expr, Expr::Input { .. } | Expr::Const(_) | Expr::Temp(_)) {
325            return expr.clone();
326        }
327
328        let hash = expr.structural_hash();
329
330        if let Some((_, name, count)) = self.expr_cache.get_mut(&hash) {
331            *count += 1;
332            return Expr::Temp(name.clone());
333        }
334
335        let name = format!("t{}", self.temp_counter);
336        self.temp_counter += 1;
337        self.expr_cache
338            .insert(hash, (expr.clone(), name.clone(), 1));
339        Expr::Temp(name)
340    }
341
342    /// Get all temporaries that should be generated.
343    #[must_use]
344    pub fn get_temporaries(&self) -> Vec<(String, Expr)> {
345        let mut temps: Vec<_> = self
346            .expr_cache
347            .values()
348            .filter(|(_, _, count)| *count >= self.min_uses)
349            .map(|(expr, name, _)| (name.clone(), expr.clone()))
350            .collect();
351        temps.sort_by(|a, b| a.0.cmp(&b.0));
352        temps
353    }
354}
355
356#[cfg(test)]
357impl Default for CseOptimizer {
358    fn default() -> Self {
359        Self::new()
360    }
361}
362
363/// Strength reduction optimizer.
364pub struct StrengthReducer;
365
366impl StrengthReducer {
367    /// Apply strength reduction to an expression.
368    /// Reduces recursively from bottom up.
369    #[must_use]
370    pub fn reduce(expr: &Expr) -> Expr {
371        match expr {
372            // Mul: reduce children first, then simplify
373            Expr::Mul(a, b) => {
374                let ra = Self::reduce(a);
375                let rb = Self::reduce(b);
376
377                // Mul by 0 -> 0
378                if ra.const_value() == Some(0.0) || rb.const_value() == Some(0.0) {
379                    return Expr::Const(0.0);
380                }
381                // Mul by 1 -> identity
382                if ra.const_value() == Some(1.0) {
383                    return rb;
384                }
385                if rb.const_value() == Some(1.0) {
386                    return ra;
387                }
388                // Mul by -1 -> negation
389                if ra.const_value() == Some(-1.0) {
390                    return Expr::Neg(Box::new(rb));
391                }
392                if rb.const_value() == Some(-1.0) {
393                    return Expr::Neg(Box::new(ra));
394                }
395                // Const * Const -> Const
396                if let (Some(va), Some(vb)) = (ra.const_value(), rb.const_value()) {
397                    return Expr::Const(va * vb);
398                }
399                Expr::Mul(Box::new(ra), Box::new(rb))
400            }
401
402            // Add: reduce children first, then simplify
403            Expr::Add(a, b) => {
404                let ra = Self::reduce(a);
405                let rb = Self::reduce(b);
406
407                // Add with 0 -> identity
408                if ra.const_value() == Some(0.0) {
409                    return rb;
410                }
411                if rb.const_value() == Some(0.0) {
412                    return ra;
413                }
414                // Const + Const -> Const
415                if let (Some(va), Some(vb)) = (ra.const_value(), rb.const_value()) {
416                    return Expr::Const(va + vb);
417                }
418                Expr::Add(Box::new(ra), Box::new(rb))
419            }
420
421            // Sub: reduce children first, then simplify
422            Expr::Sub(a, b) => {
423                let ra = Self::reduce(a);
424                let rb = Self::reduce(b);
425
426                // x - x -> 0 (structural equality)
427                if ra == rb {
428                    return Expr::Const(0.0);
429                }
430                // Sub with 0 -> identity/negation
431                if rb.const_value() == Some(0.0) {
432                    return ra;
433                }
434                if ra.const_value() == Some(0.0) {
435                    return Expr::Neg(Box::new(rb));
436                }
437                // Const - Const -> Const
438                if let (Some(va), Some(vb)) = (ra.const_value(), rb.const_value()) {
439                    return Expr::Const(va - vb);
440                }
441                Expr::Sub(Box::new(ra), Box::new(rb))
442            }
443
444            // Neg: reduce child first, then simplify
445            Expr::Neg(a) => {
446                let ra = Self::reduce(a);
447
448                // Neg of Neg -> identity
449                if let Expr::Neg(inner) = &ra {
450                    return *inner.clone();
451                }
452                // Neg of Const -> Const
453                if let Some(v) = ra.const_value() {
454                    return Expr::Const(-v);
455                }
456                Expr::Neg(Box::new(ra))
457            }
458
459            // Terminals
460            Expr::Input { .. } | Expr::Const(_) | Expr::Temp(_) => expr.clone(),
461        }
462    }
463}
464
465/// Constant folder that applies algebraic simplifications to fixpoint.
466///
467/// This wraps [`StrengthReducer`] and applies it repeatedly until the expression
468/// no longer changes, ensuring all nested constant folding opportunities are caught.
469pub struct ConstantFolder;
470
471impl ConstantFolder {
472    /// Apply constant folding to an expression until fixpoint.
473    ///
474    /// This applies strength reduction (which includes constant folding rules)
475    /// repeatedly until the expression stabilizes.
476    #[must_use]
477    pub fn fold(expr: &Expr) -> Expr {
478        let mut current = expr.clone();
479        loop {
480            let folded = StrengthReducer::reduce(&current);
481            if folded == current {
482                return current;
483            }
484            current = folded;
485        }
486    }
487}
488
489#[cfg(test)]
490impl ConstantFolder {
491    /// Apply constant folding to all expressions in a program. (test helper)
492    pub fn fold_program(program: &mut Program) {
493        for (_name, expr) in &mut program.assignments {
494            *expr = Self::fold(expr);
495        }
496        for expr in &mut program.outputs {
497            *expr = Self::fold(expr);
498        }
499    }
500}
501
502/// Dead code eliminator for symbolic programs. (used only in tests)
503#[cfg(test)]
504pub struct DeadCodeEliminator;
505
506#[cfg(test)]
507impl DeadCodeEliminator {
508    /// Eliminate dead temporary assignments from a program.
509    ///
510    /// Performs a reachability analysis starting from output expressions,
511    /// transitively marking all referenced temporaries as live, then
512    /// removes any assignments not in the live set.
513    pub fn eliminate(program: &mut Program) {
514        // Collect all temp refs from output expressions
515        let mut live: HashSet<String> = HashSet::new();
516        for expr in &program.outputs {
517            expr.collect_temp_refs(&mut live);
518        }
519
520        // Build a map from temp name to its expression for transitive lookup
521        let assign_map: HashMap<String, &Expr> = program
522            .assignments
523            .iter()
524            .map(|(name, expr)| (name.clone(), expr))
525            .collect();
526
527        // Transitive closure: keep discovering new live temps
528        let mut worklist: Vec<String> = live.iter().cloned().collect();
529        while let Some(name) = worklist.pop() {
530            if let Some(expr) = assign_map.get(&name) {
531                let mut new_refs = HashSet::new();
532                expr.collect_temp_refs(&mut new_refs);
533                for r in new_refs {
534                    if live.insert(r.clone()) {
535                        worklist.push(r);
536                    }
537                }
538            }
539        }
540
541        // Retain only live assignments
542        program.assignments.retain(|(name, _)| live.contains(name));
543    }
544}
545
546/// A symbolic program: a sequence of temporary assignments plus output expressions.
547///
548/// This type is used in tests for the optimization pipeline infrastructure.
549/// For code generation, `emit_body_from_symbolic` uses `RecursiveCse` directly.
550#[cfg(test)]
551#[derive(Clone, Debug)]
552pub struct Program {
553    /// Temporary variable assignments in order: `(name, expression)`.
554    pub assignments: Vec<(String, Expr)>,
555    /// Output expressions (may reference temps from assignments).
556    pub outputs: Vec<Expr>,
557}
558
559#[cfg(test)]
560impl Program {
561    /// Create a new empty program.
562    #[must_use]
563    pub const fn new() -> Self {
564        Self {
565            assignments: Vec::new(),
566            outputs: Vec::new(),
567        }
568    }
569
570    /// Create a program from CSE optimizer results and output expressions.
571    #[must_use]
572    pub fn from_cse(cse: &CseOptimizer, outputs: Vec<Expr>) -> Self {
573        Self {
574            assignments: cse.get_temporaries(),
575            outputs,
576        }
577    }
578
579    /// Total operation count across all assignments and outputs.
580    #[must_use]
581    pub fn op_count(&self) -> usize {
582        let assign_ops: usize = self.assignments.iter().map(|(_, e)| e.op_count()).sum();
583        let output_ops: usize = self.outputs.iter().map(Expr::op_count).sum();
584        assign_ops + output_ops
585    }
586}
587
588#[cfg(test)]
589impl Default for Program {
590    fn default() -> Self {
591        Self::new()
592    }
593}
594
595/// Apply all optimization passes to a program.
596///
597/// The optimization pipeline is:
598/// 1. **Constant folding** — simplify constant expressions and algebraic identities
599/// 2. **CSE** — extract common subexpressions into temporaries
600/// 3. **Dead code elimination** — remove unused temporaries
601///
602/// Returns the optimized program.
603#[cfg(test)]
604#[must_use]
605pub fn optimize(mut program: Program) -> Program {
606    // Pass 1: Constant folding
607    ConstantFolder::fold_program(&mut program);
608
609    // Pass 2: CSE on the folded expressions
610    let mut cse = CseOptimizer::new();
611    let new_outputs: Vec<Expr> = program
612        .outputs
613        .iter()
614        .map(|expr| cse.register(expr))
615        .collect();
616
617    // Also register assignment RHS through CSE
618    let new_assignments: Vec<(String, Expr)> = program
619        .assignments
620        .iter()
621        .map(|(name, expr)| (name.clone(), cse.register(expr)))
622        .collect();
623
624    // Merge CSE-generated temporaries with existing ones
625    let mut all_assignments = cse.get_temporaries();
626    for (name, expr) in new_assignments {
627        if !all_assignments.iter().any(|(n, _)| n == &name) {
628            all_assignments.push((name, expr));
629        }
630    }
631
632    program.assignments = all_assignments;
633    program.outputs = new_outputs;
634
635    // Pass 3: Dead code elimination
636    DeadCodeEliminator::eliminate(&mut program);
637
638    program
639}
640
641/// Apply constant folding and DCE without CSE (for cases where CSE is handled separately).
642#[cfg(test)]
643#[must_use]
644pub fn optimize_fold_and_dce(mut program: Program) -> Program {
645    ConstantFolder::fold_program(&mut program);
646    DeadCodeEliminator::eliminate(&mut program);
647    program
648}
649
650/// FFT symbolic computation.
651pub struct SymbolicFFT {
652    /// Output expressions (real, imag pairs). Length equals the transform size.
653    pub outputs: Vec<ComplexExpr>,
654}
655
656impl SymbolicFFT {
657    /// Generate radix-2 Cooley-Tukey FFT symbolically.
658    ///
659    /// # Panics
660    /// Panics if `n` is not a power of two.
661    #[must_use]
662    pub fn radix2_dit(n: usize, forward: bool) -> Self {
663        assert!(n.is_power_of_two(), "n must be power of 2");
664
665        let sign = if forward { -1.0 } else { 1.0 };
666
667        // Start with inputs
668        let mut data: Vec<ComplexExpr> = (0..n).map(ComplexExpr::input).collect();
669
670        // Bit-reversal permutation
671        let mut j = 0;
672        for i in 0..n {
673            if i < j {
674                data.swap(i, j);
675            }
676            let mut m = n >> 1;
677            while m >= 1 && j >= m {
678                j -= m;
679                m >>= 1;
680            }
681            j += m;
682        }
683
684        // Cooley-Tukey stages
685        let mut len = 2;
686        while len <= n {
687            let half = len / 2;
688            let angle_step = sign * 2.0 * std::f64::consts::PI / len as f64;
689
690            for start in (0..n).step_by(len) {
691                for k in 0..half {
692                    let angle = angle_step * k as f64;
693                    let twiddle = ComplexExpr::constant(angle.cos(), angle.sin());
694
695                    let u = data[start + k].clone();
696                    let t = data[start + k + half].mul(&twiddle);
697
698                    data[start + k] = u.add(&t);
699                    data[start + k + half] = u.sub(&t);
700                }
701            }
702
703            len *= 2;
704        }
705
706        // Apply strength reduction to all outputs
707        let outputs: Vec<ComplexExpr> = data
708            .into_iter()
709            .map(|c| ComplexExpr {
710                re: StrengthReducer::reduce(&c.re),
711                im: StrengthReducer::reduce(&c.im),
712            })
713            .collect();
714
715        Self { outputs }
716    }
717
718    /// Total operation count.
719    #[must_use]
720    pub fn op_count(&self) -> usize {
721        self.outputs
722            .iter()
723            .map(|c| c.re.op_count() + c.im.op_count())
724            .sum()
725    }
726}
727
728#[cfg(test)]
729impl SymbolicFFT {
730    /// Size of the FFT (derived from output count).
731    #[must_use]
732    pub fn n(&self) -> usize {
733        self.outputs.len()
734    }
735
736    /// Generate naive O(n²) DFT symbolically. (test helper)
737    #[must_use]
738    pub fn dft(n: usize, forward: bool) -> Self {
739        let sign = if forward { -1.0 } else { 1.0 };
740        let mut outputs = Vec::with_capacity(n);
741
742        for k in 0..n {
743            let mut re = Expr::Const(0.0);
744            let mut im = Expr::Const(0.0);
745
746            for j in 0..n {
747                let angle = sign * 2.0 * std::f64::consts::PI * (k * j) as f64 / n as f64;
748                let tw_re = angle.cos();
749                let tw_im = angle.sin();
750
751                let input = ComplexExpr::input(j);
752                let twiddle = ComplexExpr::constant(tw_re, tw_im);
753                let product = input.mul(&twiddle);
754
755                re = re.add(product.re);
756                im = im.add(product.im);
757            }
758
759            outputs.push(ComplexExpr {
760                re: StrengthReducer::reduce(&re),
761                im: StrengthReducer::reduce(&im),
762            });
763        }
764
765        Self { outputs }
766    }
767}
768
769// ============================================================================
770// Code emission: symbolic FFT → proc_macro2::TokenStream
771// (implementation lives in symbolic_emit.rs to keep this file under 2000 lines)
772// ============================================================================
773
774#[path = "symbolic_emit.rs"]
775mod symbolic_emit;
776pub use symbolic_emit::{emit_body_from_symbolic, schedule_instructions};
777
778// ============================================================================
779// Tests
780// (implementation lives in symbolic_tests.rs to keep this file under 2000 lines)
781// ============================================================================
782
783#[cfg(test)]
784#[path = "symbolic_tests.rs"]
785mod tests;