1#![allow(clippy::cast_precision_loss)] #[cfg(test)]
13use std::collections::HashMap;
14use std::collections::HashSet;
15use std::fmt;
16
17#[derive(Clone, Debug, PartialEq)]
19pub enum Expr {
20 Input { index: usize, is_real: bool },
22 Const(f64),
24 Add(Box<Self>, Box<Self>),
26 Sub(Box<Self>, Box<Self>),
28 Mul(Box<Self>, Box<Self>),
30 Neg(Box<Self>),
32 Temp(String),
34}
35
36impl Expr {
37 #[must_use]
39 pub const fn input_re(index: usize) -> Self {
40 Self::Input {
41 index,
42 is_real: true,
43 }
44 }
45
46 #[must_use]
48 pub const fn input_im(index: usize) -> Self {
49 Self::Input {
50 index,
51 is_real: false,
52 }
53 }
54
55 #[must_use]
57 pub const fn constant(value: f64) -> Self {
58 Self::Const(value)
59 }
60
61 #[must_use]
63 #[allow(clippy::should_implement_trait)]
64 pub fn add(self, other: Self) -> Self {
65 Self::Add(Box::new(self), Box::new(other))
66 }
67
68 #[must_use]
70 #[allow(clippy::should_implement_trait)]
71 pub fn sub(self, other: Self) -> Self {
72 Self::Sub(Box::new(self), Box::new(other))
73 }
74
75 #[must_use]
77 #[allow(clippy::should_implement_trait)]
78 pub fn mul(self, other: Self) -> Self {
79 Self::Mul(Box::new(self), Box::new(other))
80 }
81
82 #[must_use]
84 pub const fn const_value(&self) -> Option<f64> {
85 match self {
86 Self::Const(v) => Some(*v),
87 _ => None,
88 }
89 }
90
91 #[must_use]
93 pub fn structural_hash(&self) -> u64 {
94 use std::collections::hash_map::DefaultHasher;
95 use std::hash::Hasher;
96
97 let mut hasher = DefaultHasher::new();
98 self.hash_recursive(&mut hasher);
99 hasher.finish()
100 }
101
102 fn hash_recursive<H: std::hash::Hasher>(&self, hasher: &mut H) {
103 use std::hash::Hash;
104 match self {
105 Self::Input { index, is_real } => {
106 0u8.hash(hasher);
107 index.hash(hasher);
108 is_real.hash(hasher);
109 }
110 Self::Const(v) => {
111 1u8.hash(hasher);
112 v.to_bits().hash(hasher);
113 }
114 Self::Add(a, b) => {
115 2u8.hash(hasher);
116 a.hash_recursive(hasher);
117 b.hash_recursive(hasher);
118 }
119 Self::Sub(a, b) => {
120 3u8.hash(hasher);
121 a.hash_recursive(hasher);
122 b.hash_recursive(hasher);
123 }
124 Self::Mul(a, b) => {
125 4u8.hash(hasher);
126 a.hash_recursive(hasher);
127 b.hash_recursive(hasher);
128 }
129 Self::Neg(a) => {
130 5u8.hash(hasher);
131 a.hash_recursive(hasher);
132 }
133 Self::Temp(name) => {
134 6u8.hash(hasher);
135 name.hash(hasher);
136 }
137 }
138 }
139
140 pub fn collect_temp_refs(&self, refs: &mut HashSet<String>) {
142 match self {
143 Self::Temp(name) => {
144 refs.insert(name.clone());
145 }
146 Self::Add(a, b) | Self::Sub(a, b) | Self::Mul(a, b) => {
147 a.collect_temp_refs(refs);
148 b.collect_temp_refs(refs);
149 }
150 Self::Neg(a) => a.collect_temp_refs(refs),
151 Self::Input { .. } | Self::Const(_) => {}
152 }
153 }
154
155 #[must_use]
157 pub fn op_count(&self) -> usize {
158 match self {
159 Self::Input { .. } | Self::Const(_) | Self::Temp(_) => 0,
160 Self::Add(a, b) | Self::Sub(a, b) | Self::Mul(a, b) => 1 + a.op_count() + b.op_count(),
161 Self::Neg(a) => 1 + a.op_count(),
162 }
163 }
164}
165
166#[cfg(test)]
167impl Expr {
168 #[must_use]
170 #[allow(clippy::should_implement_trait)]
171 pub fn neg(self) -> Self {
172 Self::Neg(Box::new(self))
173 }
174}
175
176impl fmt::Display for Expr {
177 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178 match self {
179 Self::Input { index, is_real } => {
180 write!(f, "x[{}].{}", index, if *is_real { "re" } else { "im" })
181 }
182 Self::Const(v) => write!(f, "{v}"),
183 Self::Add(a, b) => write!(f, "({a} + {b})"),
184 Self::Sub(a, b) => write!(f, "({a} - {b})"),
185 Self::Mul(a, b) => write!(f, "({a} * {b})"),
186 Self::Neg(a) => write!(f, "(-{a})"),
187 Self::Temp(name) => write!(f, "{name}"),
188 }
189 }
190}
191
192#[derive(Clone, Debug)]
194pub struct ComplexExpr {
195 pub re: Expr,
196 pub im: Expr,
197}
198
199impl ComplexExpr {
200 #[must_use]
202 pub const fn input(index: usize) -> Self {
203 Self {
204 re: Expr::input_re(index),
205 im: Expr::input_im(index),
206 }
207 }
208
209 #[must_use]
211 pub const fn constant(re: f64, im: f64) -> Self {
212 Self {
213 re: Expr::constant(re),
214 im: Expr::constant(im),
215 }
216 }
217
218 #[must_use]
220 #[allow(clippy::should_implement_trait)]
221 pub fn add(&self, other: &Self) -> Self {
222 Self {
223 re: self.re.clone().add(other.re.clone()),
224 im: self.im.clone().add(other.im.clone()),
225 }
226 }
227
228 #[must_use]
230 #[allow(clippy::should_implement_trait)]
231 pub fn sub(&self, other: &Self) -> Self {
232 Self {
233 re: self.re.clone().sub(other.re.clone()),
234 im: self.im.clone().sub(other.im.clone()),
235 }
236 }
237
238 #[must_use]
240 #[allow(clippy::should_implement_trait)]
241 pub fn mul(&self, other: &Self) -> Self {
242 Self {
244 re: self
245 .re
246 .clone()
247 .mul(other.re.clone())
248 .sub(self.im.clone().mul(other.im.clone())),
249 im: self
250 .re
251 .clone()
252 .mul(other.im.clone())
253 .add(self.im.clone().mul(other.re.clone())),
254 }
255 }
256}
257
258#[cfg(test)]
259impl ComplexExpr {
260 #[must_use]
262 pub fn mul_j(&self) -> Self {
263 Self {
265 re: self.im.clone().neg(),
266 im: self.re.clone(),
267 }
268 }
269
270 #[must_use]
272 pub fn mul_neg_j(&self) -> Self {
273 Self {
275 re: self.im.clone(),
276 im: self.re.clone().neg(),
277 }
278 }
279
280 #[must_use]
282 pub fn neg(&self) -> Self {
283 Self {
284 re: self.re.clone().neg(),
285 im: self.im.clone().neg(),
286 }
287 }
288}
289
290#[cfg(test)]
292pub struct CseOptimizer {
293 expr_cache: HashMap<u64, (Expr, String, usize)>,
295 temp_counter: usize,
297 min_uses: usize,
299}
300
301#[cfg(test)]
302impl CseOptimizer {
303 #[must_use]
305 pub fn new() -> Self {
306 Self {
307 expr_cache: HashMap::new(),
308 temp_counter: 0,
309 min_uses: 2,
310 }
311 }
312
313 #[must_use]
315 pub const fn with_min_uses(mut self, min_uses: usize) -> Self {
316 self.min_uses = min_uses;
317 self
318 }
319
320 #[must_use]
322 pub fn register(&mut self, expr: &Expr) -> Expr {
323 if matches!(expr, Expr::Input { .. } | Expr::Const(_) | Expr::Temp(_)) {
325 return expr.clone();
326 }
327
328 let hash = expr.structural_hash();
329
330 if let Some((_, name, count)) = self.expr_cache.get_mut(&hash) {
331 *count += 1;
332 return Expr::Temp(name.clone());
333 }
334
335 let name = format!("t{}", self.temp_counter);
336 self.temp_counter += 1;
337 self.expr_cache
338 .insert(hash, (expr.clone(), name.clone(), 1));
339 Expr::Temp(name)
340 }
341
342 #[must_use]
344 pub fn get_temporaries(&self) -> Vec<(String, Expr)> {
345 let mut temps: Vec<_> = self
346 .expr_cache
347 .values()
348 .filter(|(_, _, count)| *count >= self.min_uses)
349 .map(|(expr, name, _)| (name.clone(), expr.clone()))
350 .collect();
351 temps.sort_by(|a, b| a.0.cmp(&b.0));
352 temps
353 }
354}
355
356#[cfg(test)]
357impl Default for CseOptimizer {
358 fn default() -> Self {
359 Self::new()
360 }
361}
362
363pub struct StrengthReducer;
365
366impl StrengthReducer {
367 #[must_use]
370 pub fn reduce(expr: &Expr) -> Expr {
371 match expr {
372 Expr::Mul(a, b) => {
374 let ra = Self::reduce(a);
375 let rb = Self::reduce(b);
376
377 if ra.const_value() == Some(0.0) || rb.const_value() == Some(0.0) {
379 return Expr::Const(0.0);
380 }
381 if ra.const_value() == Some(1.0) {
383 return rb;
384 }
385 if rb.const_value() == Some(1.0) {
386 return ra;
387 }
388 if ra.const_value() == Some(-1.0) {
390 return Expr::Neg(Box::new(rb));
391 }
392 if rb.const_value() == Some(-1.0) {
393 return Expr::Neg(Box::new(ra));
394 }
395 if let (Some(va), Some(vb)) = (ra.const_value(), rb.const_value()) {
397 return Expr::Const(va * vb);
398 }
399 Expr::Mul(Box::new(ra), Box::new(rb))
400 }
401
402 Expr::Add(a, b) => {
404 let ra = Self::reduce(a);
405 let rb = Self::reduce(b);
406
407 if ra.const_value() == Some(0.0) {
409 return rb;
410 }
411 if rb.const_value() == Some(0.0) {
412 return ra;
413 }
414 if let (Some(va), Some(vb)) = (ra.const_value(), rb.const_value()) {
416 return Expr::Const(va + vb);
417 }
418 Expr::Add(Box::new(ra), Box::new(rb))
419 }
420
421 Expr::Sub(a, b) => {
423 let ra = Self::reduce(a);
424 let rb = Self::reduce(b);
425
426 if ra == rb {
428 return Expr::Const(0.0);
429 }
430 if rb.const_value() == Some(0.0) {
432 return ra;
433 }
434 if ra.const_value() == Some(0.0) {
435 return Expr::Neg(Box::new(rb));
436 }
437 if let (Some(va), Some(vb)) = (ra.const_value(), rb.const_value()) {
439 return Expr::Const(va - vb);
440 }
441 Expr::Sub(Box::new(ra), Box::new(rb))
442 }
443
444 Expr::Neg(a) => {
446 let ra = Self::reduce(a);
447
448 if let Expr::Neg(inner) = &ra {
450 return *inner.clone();
451 }
452 if let Some(v) = ra.const_value() {
454 return Expr::Const(-v);
455 }
456 Expr::Neg(Box::new(ra))
457 }
458
459 Expr::Input { .. } | Expr::Const(_) | Expr::Temp(_) => expr.clone(),
461 }
462 }
463}
464
465pub struct ConstantFolder;
470
471impl ConstantFolder {
472 #[must_use]
477 pub fn fold(expr: &Expr) -> Expr {
478 let mut current = expr.clone();
479 loop {
480 let folded = StrengthReducer::reduce(¤t);
481 if folded == current {
482 return current;
483 }
484 current = folded;
485 }
486 }
487}
488
489#[cfg(test)]
490impl ConstantFolder {
491 pub fn fold_program(program: &mut Program) {
493 for (_name, expr) in &mut program.assignments {
494 *expr = Self::fold(expr);
495 }
496 for expr in &mut program.outputs {
497 *expr = Self::fold(expr);
498 }
499 }
500}
501
502#[cfg(test)]
504pub struct DeadCodeEliminator;
505
506#[cfg(test)]
507impl DeadCodeEliminator {
508 pub fn eliminate(program: &mut Program) {
514 let mut live: HashSet<String> = HashSet::new();
516 for expr in &program.outputs {
517 expr.collect_temp_refs(&mut live);
518 }
519
520 let assign_map: HashMap<String, &Expr> = program
522 .assignments
523 .iter()
524 .map(|(name, expr)| (name.clone(), expr))
525 .collect();
526
527 let mut worklist: Vec<String> = live.iter().cloned().collect();
529 while let Some(name) = worklist.pop() {
530 if let Some(expr) = assign_map.get(&name) {
531 let mut new_refs = HashSet::new();
532 expr.collect_temp_refs(&mut new_refs);
533 for r in new_refs {
534 if live.insert(r.clone()) {
535 worklist.push(r);
536 }
537 }
538 }
539 }
540
541 program.assignments.retain(|(name, _)| live.contains(name));
543 }
544}
545
546#[cfg(test)]
551#[derive(Clone, Debug)]
552pub struct Program {
553 pub assignments: Vec<(String, Expr)>,
555 pub outputs: Vec<Expr>,
557}
558
559#[cfg(test)]
560impl Program {
561 #[must_use]
563 pub const fn new() -> Self {
564 Self {
565 assignments: Vec::new(),
566 outputs: Vec::new(),
567 }
568 }
569
570 #[must_use]
572 pub fn from_cse(cse: &CseOptimizer, outputs: Vec<Expr>) -> Self {
573 Self {
574 assignments: cse.get_temporaries(),
575 outputs,
576 }
577 }
578
579 #[must_use]
581 pub fn op_count(&self) -> usize {
582 let assign_ops: usize = self.assignments.iter().map(|(_, e)| e.op_count()).sum();
583 let output_ops: usize = self.outputs.iter().map(Expr::op_count).sum();
584 assign_ops + output_ops
585 }
586}
587
588#[cfg(test)]
589impl Default for Program {
590 fn default() -> Self {
591 Self::new()
592 }
593}
594
595#[cfg(test)]
604#[must_use]
605pub fn optimize(mut program: Program) -> Program {
606 ConstantFolder::fold_program(&mut program);
608
609 let mut cse = CseOptimizer::new();
611 let new_outputs: Vec<Expr> = program
612 .outputs
613 .iter()
614 .map(|expr| cse.register(expr))
615 .collect();
616
617 let new_assignments: Vec<(String, Expr)> = program
619 .assignments
620 .iter()
621 .map(|(name, expr)| (name.clone(), cse.register(expr)))
622 .collect();
623
624 let mut all_assignments = cse.get_temporaries();
626 for (name, expr) in new_assignments {
627 if !all_assignments.iter().any(|(n, _)| n == &name) {
628 all_assignments.push((name, expr));
629 }
630 }
631
632 program.assignments = all_assignments;
633 program.outputs = new_outputs;
634
635 DeadCodeEliminator::eliminate(&mut program);
637
638 program
639}
640
641#[cfg(test)]
643#[must_use]
644pub fn optimize_fold_and_dce(mut program: Program) -> Program {
645 ConstantFolder::fold_program(&mut program);
646 DeadCodeEliminator::eliminate(&mut program);
647 program
648}
649
650pub struct SymbolicFFT {
652 pub outputs: Vec<ComplexExpr>,
654}
655
656impl SymbolicFFT {
657 #[must_use]
662 pub fn radix2_dit(n: usize, forward: bool) -> Self {
663 assert!(n.is_power_of_two(), "n must be power of 2");
664
665 let sign = if forward { -1.0 } else { 1.0 };
666
667 let mut data: Vec<ComplexExpr> = (0..n).map(ComplexExpr::input).collect();
669
670 let mut j = 0;
672 for i in 0..n {
673 if i < j {
674 data.swap(i, j);
675 }
676 let mut m = n >> 1;
677 while m >= 1 && j >= m {
678 j -= m;
679 m >>= 1;
680 }
681 j += m;
682 }
683
684 let mut len = 2;
686 while len <= n {
687 let half = len / 2;
688 let angle_step = sign * 2.0 * std::f64::consts::PI / len as f64;
689
690 for start in (0..n).step_by(len) {
691 for k in 0..half {
692 let angle = angle_step * k as f64;
693 let twiddle = ComplexExpr::constant(angle.cos(), angle.sin());
694
695 let u = data[start + k].clone();
696 let t = data[start + k + half].mul(&twiddle);
697
698 data[start + k] = u.add(&t);
699 data[start + k + half] = u.sub(&t);
700 }
701 }
702
703 len *= 2;
704 }
705
706 let outputs: Vec<ComplexExpr> = data
708 .into_iter()
709 .map(|c| ComplexExpr {
710 re: StrengthReducer::reduce(&c.re),
711 im: StrengthReducer::reduce(&c.im),
712 })
713 .collect();
714
715 Self { outputs }
716 }
717
718 #[must_use]
720 pub fn op_count(&self) -> usize {
721 self.outputs
722 .iter()
723 .map(|c| c.re.op_count() + c.im.op_count())
724 .sum()
725 }
726}
727
728#[cfg(test)]
729impl SymbolicFFT {
730 #[must_use]
732 pub fn n(&self) -> usize {
733 self.outputs.len()
734 }
735
736 #[must_use]
738 pub fn dft(n: usize, forward: bool) -> Self {
739 let sign = if forward { -1.0 } else { 1.0 };
740 let mut outputs = Vec::with_capacity(n);
741
742 for k in 0..n {
743 let mut re = Expr::Const(0.0);
744 let mut im = Expr::Const(0.0);
745
746 for j in 0..n {
747 let angle = sign * 2.0 * std::f64::consts::PI * (k * j) as f64 / n as f64;
748 let tw_re = angle.cos();
749 let tw_im = angle.sin();
750
751 let input = ComplexExpr::input(j);
752 let twiddle = ComplexExpr::constant(tw_re, tw_im);
753 let product = input.mul(&twiddle);
754
755 re = re.add(product.re);
756 im = im.add(product.im);
757 }
758
759 outputs.push(ComplexExpr {
760 re: StrengthReducer::reduce(&re),
761 im: StrengthReducer::reduce(&im),
762 });
763 }
764
765 Self { outputs }
766 }
767}
768
769#[path = "symbolic_emit.rs"]
775mod symbolic_emit;
776pub use symbolic_emit::{emit_body_from_symbolic, schedule_instructions};
777
778#[cfg(test)]
784#[path = "symbolic_tests.rs"]
785mod tests;