use std::collections::{BinaryHeap, HashMap, HashSet};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use super::{ConstantFolder, Expr, SymbolicFFT};
pub(super) struct RecursiveCse {
cache: HashMap<u64, (Expr, String, usize)>,
counter: usize,
}
impl RecursiveCse {
pub(super) fn new() -> Self {
Self {
cache: HashMap::new(),
counter: 0,
}
}
pub(super) fn count_recursive(&mut self, expr: &Expr) {
match expr {
Expr::Input { .. } | Expr::Const(_) | Expr::Temp(_) => {}
Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b) => {
self.count_recursive(a);
self.count_recursive(b);
let hash = expr.structural_hash();
let entry = self.cache.entry(hash).or_insert_with(|| {
let name = format!("t{}", self.counter);
self.counter += 1;
(expr.clone(), name, 0)
});
entry.2 += 1;
}
Expr::Neg(a) => {
self.count_recursive(a);
let hash = expr.structural_hash();
let entry = self.cache.entry(hash).or_insert_with(|| {
let name = format!("t{}", self.counter);
self.counter += 1;
(expr.clone(), name, 0)
});
entry.2 += 1;
}
}
}
fn rewrite_inner(&self, expr: &Expr, exclude_hash: Option<u64>) -> Expr {
match expr {
Expr::Input { .. } | Expr::Const(_) | Expr::Temp(_) => expr.clone(),
Expr::Add(a, b) => {
let hash = expr.structural_hash();
if exclude_hash != Some(hash) {
if let Some((_, name, count)) = self.cache.get(&hash) {
if *count >= 2 {
return Expr::Temp(name.clone());
}
}
}
Expr::Add(
Box::new(self.rewrite_inner(a, None)),
Box::new(self.rewrite_inner(b, None)),
)
}
Expr::Sub(a, b) => {
let hash = expr.structural_hash();
if exclude_hash != Some(hash) {
if let Some((_, name, count)) = self.cache.get(&hash) {
if *count >= 2 {
return Expr::Temp(name.clone());
}
}
}
Expr::Sub(
Box::new(self.rewrite_inner(a, None)),
Box::new(self.rewrite_inner(b, None)),
)
}
Expr::Mul(a, b) => {
let hash = expr.structural_hash();
if exclude_hash != Some(hash) {
if let Some((_, name, count)) = self.cache.get(&hash) {
if *count >= 2 {
return Expr::Temp(name.clone());
}
}
}
Expr::Mul(
Box::new(self.rewrite_inner(a, None)),
Box::new(self.rewrite_inner(b, None)),
)
}
Expr::Neg(a) => {
let hash = expr.structural_hash();
if exclude_hash != Some(hash) {
if let Some((_, name, count)) = self.cache.get(&hash) {
if *count >= 2 {
return Expr::Temp(name.clone());
}
}
}
Expr::Neg(Box::new(self.rewrite_inner(a, None)))
}
}
}
pub(super) fn rewrite(&self, expr: &Expr) -> Expr {
self.rewrite_inner(expr, None)
}
pub(super) fn rewrite_assignment_rhs(&self, name: &str, expr: &Expr) -> Expr {
let hash = self
.cache
.iter()
.find(|(_, (_, n, _))| n == name)
.map(|(h, _)| *h);
self.rewrite_inner(expr, hash)
}
pub(super) fn get_assignments(&self) -> Vec<(String, Expr)> {
let mut result: Vec<(String, Expr)> = self
.cache
.values()
.filter(|(_, _, count)| *count >= 2)
.map(|(expr, name, _)| (name.clone(), expr.clone()))
.collect();
result.sort_by(|a, b| {
let na: usize = a.0[1..].parse().unwrap_or(0);
let nb: usize = b.0[1..].parse().unwrap_or(0);
na.cmp(&nb)
});
result
}
}
#[must_use]
pub fn emit_body_from_symbolic(n: usize, forward: bool) -> TokenStream {
let fft = SymbolicFFT::radix2_dit(n, forward);
let folded_outputs: Vec<(Expr, Expr)> = fft
.outputs
.iter()
.map(|c| (ConstantFolder::fold(&c.re), ConstantFolder::fold(&c.im)))
.collect();
let ops_before = fft.op_count();
let mut cse = RecursiveCse::new();
for (re, im) in &folded_outputs {
cse.count_recursive(re);
cse.count_recursive(im);
}
let rewritten_outputs: Vec<(Expr, Expr)> = folded_outputs
.iter()
.map(|(re, im)| (cse.rewrite(re), cse.rewrite(im)))
.collect();
let mut assignments: Vec<(String, Expr)> = cse
.get_assignments()
.into_iter()
.map(|(name, expr)| {
let rewritten = cse.rewrite_assignment_rhs(&name, &expr);
(name, rewritten)
})
.collect();
assignments = topological_sort_assignments(assignments);
if std::env::var("OXIFFT_CODEGEN_DEBUG").is_ok() {
let ops_after: usize = assignments.iter().map(|(_, e)| e.op_count()).sum::<usize>()
+ rewritten_outputs
.iter()
.map(|(re, im)| re.op_count() + im.op_count())
.sum::<usize>();
let pct = if ops_before > 0 {
(ops_after as f64 - ops_before as f64) / ops_before as f64 * 100.0
} else {
0.0
};
eprintln!(
"[oxifft-codegen] n={n} forward={forward}: {ops_before} ops → {ops_after} ops ({pct:+.1}%)",
);
}
schedule_instructions(&mut assignments);
emit_folded_body(n, &assignments, &rewritten_outputs)
}
pub fn schedule_instructions(stmts: &mut Vec<(String, Expr)>) {
let n = stmts.len();
if n <= 1 {
return;
}
let index_of: std::collections::HashMap<String, usize> = stmts
.iter()
.enumerate()
.map(|(i, (name, _))| (name.clone(), i))
.collect();
let predecessors: Vec<Vec<usize>> = stmts
.iter()
.map(|(_, expr)| {
let mut refs = HashSet::new();
expr.collect_temp_refs(&mut refs);
refs.iter()
.filter_map(|r| index_of.get(r).copied())
.collect()
})
.collect();
let mut depth = vec![0usize; n];
for (i, preds) in predecessors.iter().enumerate() {
for &pred in preds {
let candidate = depth[pred] + 1;
if candidate > depth[i] {
depth[i] = candidate;
}
}
}
let mut successors: Vec<Vec<usize>> = vec![Vec::new(); n];
for (i, preds) in predecessors.iter().enumerate() {
for &pred in preds {
successors[pred].push(i);
}
}
let mut in_degree: Vec<usize> = predecessors.iter().map(Vec::len).collect();
let mut emitted = vec![false; n];
let mut order: Vec<usize> = Vec::with_capacity(n);
let mut ready: BinaryHeap<(usize, usize)> = BinaryHeap::new();
for (i, °) in in_degree.iter().enumerate() {
if deg == 0 {
ready.push((depth[i], i));
}
}
while let Some((_, idx)) = ready.pop() {
if emitted[idx] {
continue; }
emitted[idx] = true;
order.push(idx);
for &succ in &successors[idx] {
if in_degree[succ] > 0 {
in_degree[succ] -= 1;
}
if in_degree[succ] == 0 && !emitted[succ] {
ready.push((depth[succ], succ));
}
}
}
if order.len() < n {
for (i, &already_emitted) in emitted.iter().enumerate() {
if !already_emitted {
order.push(i);
}
}
}
let mut positioned: Vec<Option<(String, Expr)>> = stmts.drain(..).map(Some).collect();
let reordered: Vec<(String, Expr)> = order
.into_iter()
.filter_map(|i| positioned[i].take())
.collect();
*stmts = reordered;
}
fn topological_sort_assignments(assignments: Vec<(String, Expr)>) -> Vec<(String, Expr)> {
let mut defined: HashSet<String> = HashSet::new();
let mut result: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
let mut remaining = assignments;
loop {
let before_len = result.len();
let mut next_remaining = Vec::new();
for (name, expr) in remaining {
let mut refs: HashSet<String> = HashSet::new();
expr.collect_temp_refs(&mut refs);
if refs.iter().all(|r| defined.contains(r)) {
defined.insert(name.clone());
result.push((name, expr));
} else {
next_remaining.push((name, expr));
}
}
remaining = next_remaining;
if remaining.is_empty() || result.len() == before_len {
result.extend(remaining);
break;
}
}
result
}
fn emit_folded_body(
n: usize,
assignments: &[(String, Expr)],
outputs: &[(Expr, Expr)],
) -> TokenStream {
assert_eq!(
outputs.len(),
n,
"expected n outputs for n-point complex FFT, got {}",
outputs.len()
);
let mut body = TokenStream::new();
for i in 0..n {
let re_name = format_ident!("x{i}_re");
let im_name = format_ident!("x{i}_im");
body.extend(quote! {
let #re_name = x[#i].re;
let #im_name = x[#i].im;
});
}
for (name, expr) in assignments {
let id = format_ident!("{name}");
let tok = emit_scalar_expr(expr);
body.extend(quote! { let #id = #tok; });
}
for (k, (re_expr, im_expr)) in outputs.iter().enumerate() {
let re_tok = emit_scalar_expr(re_expr);
let im_tok = emit_scalar_expr(im_expr);
body.extend(quote! {
x[#k] = crate::kernel::Complex::new(#re_tok, #im_tok);
});
}
body
}
fn emit_scalar_expr(expr: &Expr) -> TokenStream {
match expr {
Expr::Input { index, is_real } => {
let name = if *is_real {
format_ident!("x{index}_re")
} else {
format_ident!("x{index}_im")
};
quote! { #name }
}
Expr::Const(v) => {
if (*v - 0.0_f64).abs() < f64::EPSILON {
quote! { T::ZERO }
} else if (*v - 1.0_f64).abs() < f64::EPSILON {
quote! { T::ONE }
} else if (*v - (-1.0_f64)).abs() < f64::EPSILON {
quote! { (-T::ONE) }
} else {
let v = *v;
quote! { T::from_f64(#v) }
}
}
Expr::Add(a, b) => {
let a = emit_scalar_expr(a);
let b = emit_scalar_expr(b);
quote! { (#a + #b) }
}
Expr::Sub(a, b) => {
let a = emit_scalar_expr(a);
let b = emit_scalar_expr(b);
quote! { (#a - #b) }
}
Expr::Mul(a, b) => {
let a = emit_scalar_expr(a);
let b = emit_scalar_expr(b);
quote! { (#a * #b) }
}
Expr::Neg(a) => {
let a = emit_scalar_expr(a);
quote! { (-#a) }
}
Expr::Temp(name) => {
let id = format_ident!("{name}");
quote! { #id }
}
}
}