#![allow(clippy::cast_precision_loss)]
#[cfg(test)]
use std::collections::HashMap;
use std::collections::HashSet;
use std::fmt;
#[derive(Clone, Debug, PartialEq)]
pub enum Expr {
Input { index: usize, is_real: bool },
Const(f64),
Add(Box<Self>, Box<Self>),
Sub(Box<Self>, Box<Self>),
Mul(Box<Self>, Box<Self>),
Neg(Box<Self>),
Temp(String),
}
impl Expr {
#[must_use]
pub const fn input_re(index: usize) -> Self {
Self::Input {
index,
is_real: true,
}
}
#[must_use]
pub const fn input_im(index: usize) -> Self {
Self::Input {
index,
is_real: false,
}
}
#[must_use]
pub const fn constant(value: f64) -> Self {
Self::Const(value)
}
#[must_use]
#[allow(clippy::should_implement_trait)]
pub fn add(self, other: Self) -> Self {
Self::Add(Box::new(self), Box::new(other))
}
#[must_use]
#[allow(clippy::should_implement_trait)]
pub fn sub(self, other: Self) -> Self {
Self::Sub(Box::new(self), Box::new(other))
}
#[must_use]
#[allow(clippy::should_implement_trait)]
pub fn mul(self, other: Self) -> Self {
Self::Mul(Box::new(self), Box::new(other))
}
#[must_use]
pub const fn const_value(&self) -> Option<f64> {
match self {
Self::Const(v) => Some(*v),
_ => None,
}
}
#[must_use]
pub fn structural_hash(&self) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::Hasher;
let mut hasher = DefaultHasher::new();
self.hash_recursive(&mut hasher);
hasher.finish()
}
fn hash_recursive<H: std::hash::Hasher>(&self, hasher: &mut H) {
use std::hash::Hash;
match self {
Self::Input { index, is_real } => {
0u8.hash(hasher);
index.hash(hasher);
is_real.hash(hasher);
}
Self::Const(v) => {
1u8.hash(hasher);
v.to_bits().hash(hasher);
}
Self::Add(a, b) => {
2u8.hash(hasher);
a.hash_recursive(hasher);
b.hash_recursive(hasher);
}
Self::Sub(a, b) => {
3u8.hash(hasher);
a.hash_recursive(hasher);
b.hash_recursive(hasher);
}
Self::Mul(a, b) => {
4u8.hash(hasher);
a.hash_recursive(hasher);
b.hash_recursive(hasher);
}
Self::Neg(a) => {
5u8.hash(hasher);
a.hash_recursive(hasher);
}
Self::Temp(name) => {
6u8.hash(hasher);
name.hash(hasher);
}
}
}
pub fn collect_temp_refs(&self, refs: &mut HashSet<String>) {
match self {
Self::Temp(name) => {
refs.insert(name.clone());
}
Self::Add(a, b) | Self::Sub(a, b) | Self::Mul(a, b) => {
a.collect_temp_refs(refs);
b.collect_temp_refs(refs);
}
Self::Neg(a) => a.collect_temp_refs(refs),
Self::Input { .. } | Self::Const(_) => {}
}
}
#[must_use]
pub fn op_count(&self) -> usize {
match self {
Self::Input { .. } | Self::Const(_) | Self::Temp(_) => 0,
Self::Add(a, b) | Self::Sub(a, b) | Self::Mul(a, b) => 1 + a.op_count() + b.op_count(),
Self::Neg(a) => 1 + a.op_count(),
}
}
}
#[cfg(test)]
impl Expr {
#[must_use]
#[allow(clippy::should_implement_trait)]
pub fn neg(self) -> Self {
Self::Neg(Box::new(self))
}
}
impl fmt::Display for Expr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Input { index, is_real } => {
write!(f, "x[{}].{}", index, if *is_real { "re" } else { "im" })
}
Self::Const(v) => write!(f, "{v}"),
Self::Add(a, b) => write!(f, "({a} + {b})"),
Self::Sub(a, b) => write!(f, "({a} - {b})"),
Self::Mul(a, b) => write!(f, "({a} * {b})"),
Self::Neg(a) => write!(f, "(-{a})"),
Self::Temp(name) => write!(f, "{name}"),
}
}
}
#[derive(Clone, Debug)]
pub struct ComplexExpr {
pub re: Expr,
pub im: Expr,
}
impl ComplexExpr {
#[must_use]
pub const fn input(index: usize) -> Self {
Self {
re: Expr::input_re(index),
im: Expr::input_im(index),
}
}
#[must_use]
pub const fn constant(re: f64, im: f64) -> Self {
Self {
re: Expr::constant(re),
im: Expr::constant(im),
}
}
#[must_use]
#[allow(clippy::should_implement_trait)]
pub fn add(&self, other: &Self) -> Self {
Self {
re: self.re.clone().add(other.re.clone()),
im: self.im.clone().add(other.im.clone()),
}
}
#[must_use]
#[allow(clippy::should_implement_trait)]
pub fn sub(&self, other: &Self) -> Self {
Self {
re: self.re.clone().sub(other.re.clone()),
im: self.im.clone().sub(other.im.clone()),
}
}
#[must_use]
#[allow(clippy::should_implement_trait)]
pub fn mul(&self, other: &Self) -> Self {
Self {
re: self
.re
.clone()
.mul(other.re.clone())
.sub(self.im.clone().mul(other.im.clone())),
im: self
.re
.clone()
.mul(other.im.clone())
.add(self.im.clone().mul(other.re.clone())),
}
}
}
#[cfg(test)]
impl ComplexExpr {
#[must_use]
pub fn mul_j(&self) -> Self {
Self {
re: self.im.clone().neg(),
im: self.re.clone(),
}
}
#[must_use]
pub fn mul_neg_j(&self) -> Self {
Self {
re: self.im.clone(),
im: self.re.clone().neg(),
}
}
#[must_use]
pub fn neg(&self) -> Self {
Self {
re: self.re.clone().neg(),
im: self.im.clone().neg(),
}
}
}
#[cfg(test)]
pub struct CseOptimizer {
expr_cache: HashMap<u64, (Expr, String, usize)>,
temp_counter: usize,
min_uses: usize,
}
#[cfg(test)]
impl CseOptimizer {
#[must_use]
pub fn new() -> Self {
Self {
expr_cache: HashMap::new(),
temp_counter: 0,
min_uses: 2,
}
}
#[must_use]
pub const fn with_min_uses(mut self, min_uses: usize) -> Self {
self.min_uses = min_uses;
self
}
#[must_use]
pub fn register(&mut self, expr: &Expr) -> Expr {
if matches!(expr, Expr::Input { .. } | Expr::Const(_) | Expr::Temp(_)) {
return expr.clone();
}
let hash = expr.structural_hash();
if let Some((_, name, count)) = self.expr_cache.get_mut(&hash) {
*count += 1;
return Expr::Temp(name.clone());
}
let name = format!("t{}", self.temp_counter);
self.temp_counter += 1;
self.expr_cache
.insert(hash, (expr.clone(), name.clone(), 1));
Expr::Temp(name)
}
#[must_use]
pub fn get_temporaries(&self) -> Vec<(String, Expr)> {
let mut temps: Vec<_> = self
.expr_cache
.values()
.filter(|(_, _, count)| *count >= self.min_uses)
.map(|(expr, name, _)| (name.clone(), expr.clone()))
.collect();
temps.sort_by(|a, b| a.0.cmp(&b.0));
temps
}
}
#[cfg(test)]
impl Default for CseOptimizer {
fn default() -> Self {
Self::new()
}
}
pub struct StrengthReducer;
impl StrengthReducer {
#[must_use]
pub fn reduce(expr: &Expr) -> Expr {
match expr {
Expr::Mul(a, b) => {
let ra = Self::reduce(a);
let rb = Self::reduce(b);
if ra.const_value() == Some(0.0) || rb.const_value() == Some(0.0) {
return Expr::Const(0.0);
}
if ra.const_value() == Some(1.0) {
return rb;
}
if rb.const_value() == Some(1.0) {
return ra;
}
if ra.const_value() == Some(-1.0) {
return Expr::Neg(Box::new(rb));
}
if rb.const_value() == Some(-1.0) {
return Expr::Neg(Box::new(ra));
}
if let (Some(va), Some(vb)) = (ra.const_value(), rb.const_value()) {
return Expr::Const(va * vb);
}
Expr::Mul(Box::new(ra), Box::new(rb))
}
Expr::Add(a, b) => {
let ra = Self::reduce(a);
let rb = Self::reduce(b);
if ra.const_value() == Some(0.0) {
return rb;
}
if rb.const_value() == Some(0.0) {
return ra;
}
if let (Some(va), Some(vb)) = (ra.const_value(), rb.const_value()) {
return Expr::Const(va + vb);
}
Expr::Add(Box::new(ra), Box::new(rb))
}
Expr::Sub(a, b) => {
let ra = Self::reduce(a);
let rb = Self::reduce(b);
if ra == rb {
return Expr::Const(0.0);
}
if rb.const_value() == Some(0.0) {
return ra;
}
if ra.const_value() == Some(0.0) {
return Expr::Neg(Box::new(rb));
}
if let (Some(va), Some(vb)) = (ra.const_value(), rb.const_value()) {
return Expr::Const(va - vb);
}
Expr::Sub(Box::new(ra), Box::new(rb))
}
Expr::Neg(a) => {
let ra = Self::reduce(a);
if let Expr::Neg(inner) = &ra {
return *inner.clone();
}
if let Some(v) = ra.const_value() {
return Expr::Const(-v);
}
Expr::Neg(Box::new(ra))
}
Expr::Input { .. } | Expr::Const(_) | Expr::Temp(_) => expr.clone(),
}
}
}
pub struct ConstantFolder;
impl ConstantFolder {
#[must_use]
pub fn fold(expr: &Expr) -> Expr {
let mut current = expr.clone();
loop {
let folded = StrengthReducer::reduce(¤t);
if folded == current {
return current;
}
current = folded;
}
}
}
#[cfg(test)]
impl ConstantFolder {
pub fn fold_program(program: &mut Program) {
for (_name, expr) in &mut program.assignments {
*expr = Self::fold(expr);
}
for expr in &mut program.outputs {
*expr = Self::fold(expr);
}
}
}
#[cfg(test)]
pub struct DeadCodeEliminator;
#[cfg(test)]
impl DeadCodeEliminator {
pub fn eliminate(program: &mut Program) {
let mut live: HashSet<String> = HashSet::new();
for expr in &program.outputs {
expr.collect_temp_refs(&mut live);
}
let assign_map: HashMap<String, &Expr> = program
.assignments
.iter()
.map(|(name, expr)| (name.clone(), expr))
.collect();
let mut worklist: Vec<String> = live.iter().cloned().collect();
while let Some(name) = worklist.pop() {
if let Some(expr) = assign_map.get(&name) {
let mut new_refs = HashSet::new();
expr.collect_temp_refs(&mut new_refs);
for r in new_refs {
if live.insert(r.clone()) {
worklist.push(r);
}
}
}
}
program.assignments.retain(|(name, _)| live.contains(name));
}
}
#[cfg(test)]
#[derive(Clone, Debug)]
pub struct Program {
pub assignments: Vec<(String, Expr)>,
pub outputs: Vec<Expr>,
}
#[cfg(test)]
impl Program {
#[must_use]
pub const fn new() -> Self {
Self {
assignments: Vec::new(),
outputs: Vec::new(),
}
}
#[must_use]
pub fn from_cse(cse: &CseOptimizer, outputs: Vec<Expr>) -> Self {
Self {
assignments: cse.get_temporaries(),
outputs,
}
}
#[must_use]
pub fn op_count(&self) -> usize {
let assign_ops: usize = self.assignments.iter().map(|(_, e)| e.op_count()).sum();
let output_ops: usize = self.outputs.iter().map(Expr::op_count).sum();
assign_ops + output_ops
}
}
#[cfg(test)]
impl Default for Program {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[must_use]
pub fn optimize(mut program: Program) -> Program {
ConstantFolder::fold_program(&mut program);
let mut cse = CseOptimizer::new();
let new_outputs: Vec<Expr> = program
.outputs
.iter()
.map(|expr| cse.register(expr))
.collect();
let new_assignments: Vec<(String, Expr)> = program
.assignments
.iter()
.map(|(name, expr)| (name.clone(), cse.register(expr)))
.collect();
let mut all_assignments = cse.get_temporaries();
for (name, expr) in new_assignments {
if !all_assignments.iter().any(|(n, _)| n == &name) {
all_assignments.push((name, expr));
}
}
program.assignments = all_assignments;
program.outputs = new_outputs;
DeadCodeEliminator::eliminate(&mut program);
program
}
#[cfg(test)]
#[must_use]
pub fn optimize_fold_and_dce(mut program: Program) -> Program {
ConstantFolder::fold_program(&mut program);
DeadCodeEliminator::eliminate(&mut program);
program
}
pub struct SymbolicFFT {
pub outputs: Vec<ComplexExpr>,
}
impl SymbolicFFT {
#[must_use]
pub fn radix2_dit(n: usize, forward: bool) -> Self {
assert!(n.is_power_of_two(), "n must be power of 2");
let sign = if forward { -1.0 } else { 1.0 };
let mut data: Vec<ComplexExpr> = (0..n).map(ComplexExpr::input).collect();
let mut j = 0;
for i in 0..n {
if i < j {
data.swap(i, j);
}
let mut m = n >> 1;
while m >= 1 && j >= m {
j -= m;
m >>= 1;
}
j += m;
}
let mut len = 2;
while len <= n {
let half = len / 2;
let angle_step = sign * 2.0 * std::f64::consts::PI / len as f64;
for start in (0..n).step_by(len) {
for k in 0..half {
let angle = angle_step * k as f64;
let twiddle = ComplexExpr::constant(angle.cos(), angle.sin());
let u = data[start + k].clone();
let t = data[start + k + half].mul(&twiddle);
data[start + k] = u.add(&t);
data[start + k + half] = u.sub(&t);
}
}
len *= 2;
}
let outputs: Vec<ComplexExpr> = data
.into_iter()
.map(|c| ComplexExpr {
re: StrengthReducer::reduce(&c.re),
im: StrengthReducer::reduce(&c.im),
})
.collect();
Self { outputs }
}
#[must_use]
pub fn op_count(&self) -> usize {
self.outputs
.iter()
.map(|c| c.re.op_count() + c.im.op_count())
.sum()
}
}
#[cfg(test)]
impl SymbolicFFT {
#[must_use]
pub fn n(&self) -> usize {
self.outputs.len()
}
#[must_use]
pub fn dft(n: usize, forward: bool) -> Self {
let sign = if forward { -1.0 } else { 1.0 };
let mut outputs = Vec::with_capacity(n);
for k in 0..n {
let mut re = Expr::Const(0.0);
let mut im = Expr::Const(0.0);
for j in 0..n {
let angle = sign * 2.0 * std::f64::consts::PI * (k * j) as f64 / n as f64;
let tw_re = angle.cos();
let tw_im = angle.sin();
let input = ComplexExpr::input(j);
let twiddle = ComplexExpr::constant(tw_re, tw_im);
let product = input.mul(&twiddle);
re = re.add(product.re);
im = im.add(product.im);
}
outputs.push(ComplexExpr {
re: StrengthReducer::reduce(&re),
im: StrengthReducer::reduce(&im),
});
}
Self { outputs }
}
}
#[path = "symbolic_emit.rs"]
mod symbolic_emit;
pub use symbolic_emit::{emit_body_from_symbolic, schedule_instructions};
#[cfg(test)]
#[path = "symbolic_tests.rs"]
mod tests;