1use std::collections::{BinaryHeap, HashMap, HashSet};
8
9use proc_macro2::TokenStream;
10use quote::{format_ident, quote};
11
12use super::{ConstantFolder, Expr, SymbolicFFT};
13
14pub(super) struct RecursiveCse {
25 cache: HashMap<u64, (Expr, String, usize)>,
27 counter: usize,
28}
29
30impl RecursiveCse {
31 pub(super) fn new() -> Self {
32 Self {
33 cache: HashMap::new(),
34 counter: 0,
35 }
36 }
37
38 pub(super) fn count_recursive(&mut self, expr: &Expr) {
40 match expr {
41 Expr::Input { .. } | Expr::Const(_) | Expr::Temp(_) => {}
42 Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b) => {
43 self.count_recursive(a);
44 self.count_recursive(b);
45 let hash = expr.structural_hash();
46 let entry = self.cache.entry(hash).or_insert_with(|| {
47 let name = format!("t{}", self.counter);
48 self.counter += 1;
49 (expr.clone(), name, 0)
50 });
51 entry.2 += 1;
52 }
53 Expr::Neg(a) => {
54 self.count_recursive(a);
55 let hash = expr.structural_hash();
56 let entry = self.cache.entry(hash).or_insert_with(|| {
57 let name = format!("t{}", self.counter);
58 self.counter += 1;
59 (expr.clone(), name, 0)
60 });
61 entry.2 += 1;
62 }
63 }
64 }
65
66 fn rewrite_inner(&self, expr: &Expr, exclude_hash: Option<u64>) -> Expr {
73 match expr {
74 Expr::Input { .. } | Expr::Const(_) | Expr::Temp(_) => expr.clone(),
75 Expr::Add(a, b) => {
76 let hash = expr.structural_hash();
77 if exclude_hash != Some(hash) {
78 if let Some((_, name, count)) = self.cache.get(&hash) {
79 if *count >= 2 {
80 return Expr::Temp(name.clone());
81 }
82 }
83 }
84 Expr::Add(
85 Box::new(self.rewrite_inner(a, None)),
86 Box::new(self.rewrite_inner(b, None)),
87 )
88 }
89 Expr::Sub(a, b) => {
90 let hash = expr.structural_hash();
91 if exclude_hash != Some(hash) {
92 if let Some((_, name, count)) = self.cache.get(&hash) {
93 if *count >= 2 {
94 return Expr::Temp(name.clone());
95 }
96 }
97 }
98 Expr::Sub(
99 Box::new(self.rewrite_inner(a, None)),
100 Box::new(self.rewrite_inner(b, None)),
101 )
102 }
103 Expr::Mul(a, b) => {
104 let hash = expr.structural_hash();
105 if exclude_hash != Some(hash) {
106 if let Some((_, name, count)) = self.cache.get(&hash) {
107 if *count >= 2 {
108 return Expr::Temp(name.clone());
109 }
110 }
111 }
112 Expr::Mul(
113 Box::new(self.rewrite_inner(a, None)),
114 Box::new(self.rewrite_inner(b, None)),
115 )
116 }
117 Expr::Neg(a) => {
118 let hash = expr.structural_hash();
119 if exclude_hash != Some(hash) {
120 if let Some((_, name, count)) = self.cache.get(&hash) {
121 if *count >= 2 {
122 return Expr::Temp(name.clone());
123 }
124 }
125 }
126 Expr::Neg(Box::new(self.rewrite_inner(a, None)))
127 }
128 }
129 }
130
131 pub(super) fn rewrite(&self, expr: &Expr) -> Expr {
133 self.rewrite_inner(expr, None)
134 }
135
136 pub(super) fn rewrite_assignment_rhs(&self, name: &str, expr: &Expr) -> Expr {
139 let hash = self
141 .cache
142 .iter()
143 .find(|(_, (_, n, _))| n == name)
144 .map(|(h, _)| *h);
145 self.rewrite_inner(expr, hash)
146 }
147
148 pub(super) fn get_assignments(&self) -> Vec<(String, Expr)> {
150 let mut result: Vec<(String, Expr)> = self
151 .cache
152 .values()
153 .filter(|(_, _, count)| *count >= 2)
154 .map(|(expr, name, _)| (name.clone(), expr.clone()))
155 .collect();
156 result.sort_by(|a, b| {
159 let na: usize = a.0[1..].parse().unwrap_or(0);
160 let nb: usize = b.0[1..].parse().unwrap_or(0);
161 na.cmp(&nb)
162 });
163 result
164 }
165}
166
167#[must_use]
180pub fn emit_body_from_symbolic(n: usize, forward: bool) -> TokenStream {
181 let fft = SymbolicFFT::radix2_dit(n, forward);
182
183 let folded_outputs: Vec<(Expr, Expr)> = fft
185 .outputs
186 .iter()
187 .map(|c| (ConstantFolder::fold(&c.re), ConstantFolder::fold(&c.im)))
188 .collect();
189
190 let ops_before = fft.op_count();
191
192 let mut cse = RecursiveCse::new();
194 for (re, im) in &folded_outputs {
195 cse.count_recursive(re);
196 cse.count_recursive(im);
197 }
198
199 let rewritten_outputs: Vec<(Expr, Expr)> = folded_outputs
201 .iter()
202 .map(|(re, im)| (cse.rewrite(re), cse.rewrite(im)))
203 .collect();
204
205 let mut assignments: Vec<(String, Expr)> = cse
208 .get_assignments()
209 .into_iter()
210 .map(|(name, expr)| {
211 let rewritten = cse.rewrite_assignment_rhs(&name, &expr);
212 (name, rewritten)
213 })
214 .collect();
215
216 assignments = topological_sort_assignments(assignments);
218
219 if std::env::var("OXIFFT_CODEGEN_DEBUG").is_ok() {
220 let ops_after: usize = assignments.iter().map(|(_, e)| e.op_count()).sum::<usize>()
221 + rewritten_outputs
222 .iter()
223 .map(|(re, im)| re.op_count() + im.op_count())
224 .sum::<usize>();
225 let pct = if ops_before > 0 {
226 (ops_after as f64 - ops_before as f64) / ops_before as f64 * 100.0
227 } else {
228 0.0
229 };
230 eprintln!(
231 "[oxifft-codegen] n={n} forward={forward}: {ops_before} ops → {ops_after} ops ({pct:+.1}%)",
232 );
233 }
234
235 schedule_instructions(&mut assignments);
240
241 emit_folded_body(n, &assignments, &rewritten_outputs)
242}
243
244pub fn schedule_instructions(stmts: &mut Vec<(String, Expr)>) {
267 let n = stmts.len();
268 if n <= 1 {
269 return;
270 }
271
272 let index_of: std::collections::HashMap<String, usize> = stmts
274 .iter()
275 .enumerate()
276 .map(|(i, (name, _))| (name.clone(), i))
277 .collect();
278
279 let predecessors: Vec<Vec<usize>> = stmts
281 .iter()
282 .map(|(_, expr)| {
283 let mut refs = HashSet::new();
284 expr.collect_temp_refs(&mut refs);
285 refs.iter()
286 .filter_map(|r| index_of.get(r).copied())
287 .collect()
288 })
289 .collect();
290
291 let mut depth = vec![0usize; n];
295 for (i, preds) in predecessors.iter().enumerate() {
296 for &pred in preds {
297 let candidate = depth[pred] + 1;
298 if candidate > depth[i] {
299 depth[i] = candidate;
300 }
301 }
302 }
303
304 let mut successors: Vec<Vec<usize>> = vec![Vec::new(); n];
306 for (i, preds) in predecessors.iter().enumerate() {
307 for &pred in preds {
308 successors[pred].push(i);
309 }
310 }
311
312 let mut in_degree: Vec<usize> = predecessors.iter().map(Vec::len).collect();
316 let mut emitted = vec![false; n];
317 let mut order: Vec<usize> = Vec::with_capacity(n);
318
319 let mut ready: BinaryHeap<(usize, usize)> = BinaryHeap::new();
321 for (i, °) in in_degree.iter().enumerate() {
322 if deg == 0 {
323 ready.push((depth[i], i));
324 }
325 }
326
327 while let Some((_, idx)) = ready.pop() {
328 if emitted[idx] {
329 continue; }
331 emitted[idx] = true;
332 order.push(idx);
333 for &succ in &successors[idx] {
335 if in_degree[succ] > 0 {
336 in_degree[succ] -= 1;
337 }
338 if in_degree[succ] == 0 && !emitted[succ] {
339 ready.push((depth[succ], succ));
340 }
341 }
342 }
343
344 if order.len() < n {
346 for (i, &already_emitted) in emitted.iter().enumerate() {
347 if !already_emitted {
348 order.push(i);
349 }
350 }
351 }
352
353 let mut positioned: Vec<Option<(String, Expr)>> = stmts.drain(..).map(Some).collect();
357 let reordered: Vec<(String, Expr)> = order
358 .into_iter()
359 .filter_map(|i| positioned[i].take())
360 .collect();
361 *stmts = reordered;
362}
363
364fn topological_sort_assignments(assignments: Vec<(String, Expr)>) -> Vec<(String, Expr)> {
366 let mut defined: HashSet<String> = HashSet::new();
367 let mut result: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
368 let mut remaining = assignments;
369
370 loop {
373 let before_len = result.len();
374 let mut next_remaining = Vec::new();
375 for (name, expr) in remaining {
376 let mut refs: HashSet<String> = HashSet::new();
377 expr.collect_temp_refs(&mut refs);
378 if refs.iter().all(|r| defined.contains(r)) {
379 defined.insert(name.clone());
380 result.push((name, expr));
381 } else {
382 next_remaining.push((name, expr));
383 }
384 }
385 remaining = next_remaining;
386 if remaining.is_empty() || result.len() == before_len {
387 result.extend(remaining);
389 break;
390 }
391 }
392 result
393}
394
395fn emit_folded_body(
402 n: usize,
403 assignments: &[(String, Expr)],
404 outputs: &[(Expr, Expr)],
405) -> TokenStream {
406 assert_eq!(
407 outputs.len(),
408 n,
409 "expected n outputs for n-point complex FFT, got {}",
410 outputs.len()
411 );
412
413 let mut body = TokenStream::new();
414
415 for i in 0..n {
417 let re_name = format_ident!("x{i}_re");
418 let im_name = format_ident!("x{i}_im");
419 body.extend(quote! {
420 let #re_name = x[#i].re;
421 let #im_name = x[#i].im;
422 });
423 }
424
425 for (name, expr) in assignments {
427 let id = format_ident!("{name}");
428 let tok = emit_scalar_expr(expr);
429 body.extend(quote! { let #id = #tok; });
430 }
431
432 for (k, (re_expr, im_expr)) in outputs.iter().enumerate() {
434 let re_tok = emit_scalar_expr(re_expr);
435 let im_tok = emit_scalar_expr(im_expr);
436 body.extend(quote! {
437 x[#k] = crate::kernel::Complex::new(#re_tok, #im_tok);
438 });
439 }
440
441 body
442}
443
444fn emit_scalar_expr(expr: &Expr) -> TokenStream {
446 match expr {
447 Expr::Input { index, is_real } => {
448 let name = if *is_real {
449 format_ident!("x{index}_re")
450 } else {
451 format_ident!("x{index}_im")
452 };
453 quote! { #name }
454 }
455 Expr::Const(v) => {
456 if (*v - 0.0_f64).abs() < f64::EPSILON {
457 quote! { T::ZERO }
458 } else if (*v - 1.0_f64).abs() < f64::EPSILON {
459 quote! { T::ONE }
460 } else if (*v - (-1.0_f64)).abs() < f64::EPSILON {
461 quote! { (-T::ONE) }
462 } else {
463 let v = *v;
464 quote! { T::from_f64(#v) }
465 }
466 }
467 Expr::Add(a, b) => {
468 let a = emit_scalar_expr(a);
469 let b = emit_scalar_expr(b);
470 quote! { (#a + #b) }
471 }
472 Expr::Sub(a, b) => {
473 let a = emit_scalar_expr(a);
474 let b = emit_scalar_expr(b);
475 quote! { (#a - #b) }
476 }
477 Expr::Mul(a, b) => {
478 let a = emit_scalar_expr(a);
479 let b = emit_scalar_expr(b);
480 quote! { (#a * #b) }
481 }
482 Expr::Neg(a) => {
483 let a = emit_scalar_expr(a);
484 quote! { (-#a) }
485 }
486 Expr::Temp(name) => {
487 let id = format_ident!("{name}");
488 quote! { #id }
489 }
490 }
491}