use std::collections::HashMap;
use crate::{E, Expr, symbol};
fn expr_cost(e: &E) -> usize {
match e.as_ref() {
Expr::Sym(_) | Expr::Const(_) => 0,
Expr::Neg(a) | Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a)
| Expr::Asin(a) | Expr::Acos(a) | Expr::Atan(a)
| Expr::Sinh(a) | Expr::Cosh(a) | Expr::Tanh(a)
| Expr::Exp(a) | Expr::Ln(a) | Expr::Log2(a) | Expr::Log10(a)
| Expr::Sqrt(a) | Expr::Abs(a) => 1 + expr_cost(a),
Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b)
| Expr::Div(a, b) | Expr::Pow(a, b) | Expr::Atan2(a, b) => {
1 + expr_cost(a) + expr_cost(b)
}
}
}
fn expr_depth(e: &E) -> usize {
match e.as_ref() {
Expr::Sym(_) | Expr::Const(_) => 0,
Expr::Neg(a) | Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a)
| Expr::Asin(a) | Expr::Acos(a) | Expr::Atan(a)
| Expr::Sinh(a) | Expr::Cosh(a) | Expr::Tanh(a)
| Expr::Exp(a) | Expr::Ln(a) | Expr::Log2(a) | Expr::Log10(a)
| Expr::Sqrt(a) | Expr::Abs(a) => 1 + expr_depth(a),
Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b)
| Expr::Div(a, b) | Expr::Pow(a, b) | Expr::Atan2(a, b) => {
1 + expr_depth(a).max(expr_depth(b))
}
}
}
fn count_subexprs(e: &E, counts: &mut HashMap<E, usize>) {
*counts.entry(e.clone()).or_insert(0) += 1;
match e.as_ref() {
Expr::Sym(_) | Expr::Const(_) => {}
Expr::Neg(a) | Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a)
| Expr::Asin(a) | Expr::Acos(a) | Expr::Atan(a)
| Expr::Sinh(a) | Expr::Cosh(a) | Expr::Tanh(a)
| Expr::Exp(a) | Expr::Ln(a) | Expr::Log2(a) | Expr::Log10(a)
| Expr::Sqrt(a) | Expr::Abs(a) => count_subexprs(a, counts),
Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b)
| Expr::Div(a, b) | Expr::Pow(a, b) | Expr::Atan2(a, b) => {
count_subexprs(a, counts);
count_subexprs(b, counts);
}
}
}
pub fn replace_pub(e: &E, target: &E, replacement: &E) -> E {
replace(e, target, replacement)
}
fn replace(e: &E, target: &E, replacement: &E) -> E {
if e == target {
return replacement.clone();
}
match e.as_ref() {
Expr::Sym(_) | Expr::Const(_) => e.clone(),
Expr::Neg(a) => E::new(Expr::Neg(replace(a, target, replacement))),
Expr::Add(a, b) => E::new(Expr::Add(replace(a, target, replacement), replace(b, target, replacement))),
Expr::Sub(a, b) => E::new(Expr::Sub(replace(a, target, replacement), replace(b, target, replacement))),
Expr::Mul(_, _) => {
if matches!(target.as_ref(), Expr::Mul(_, _)) {
let (e_coeff, e_factors) = flatten_mul_factors(e);
let (t_coeff, t_factors) = flatten_mul_factors(target);
if t_coeff == 1.0 && t_factors.len() <= e_factors.len() {
let mut remaining = e_factors.clone();
let mut all_found = true;
for tf in &t_factors {
if let Some(pos) = remaining.iter().position(|f| f == tf) {
remaining.remove(pos);
} else {
all_found = false;
break;
}
}
if all_found {
remaining.push(replacement.clone());
let result = build_mul_from_factors(e_coeff, remaining);
return replace(&result, target, replacement);
}
}
}
let (a, b) = match e.as_ref() {
Expr::Mul(a, b) => (a, b),
_ => unreachable!(),
};
E::new(Expr::Mul(replace(a, target, replacement), replace(b, target, replacement)))
}
Expr::Div(a, b) => E::new(Expr::Div(replace(a, target, replacement), replace(b, target, replacement))),
Expr::Pow(a, b) => E::new(Expr::Pow(replace(a, target, replacement), replace(b, target, replacement))),
Expr::Atan2(a, b) => E::new(Expr::Atan2(replace(a, target, replacement), replace(b, target, replacement))),
Expr::Sin(a) => E::new(Expr::Sin(replace(a, target, replacement))),
Expr::Cos(a) => E::new(Expr::Cos(replace(a, target, replacement))),
Expr::Tan(a) => E::new(Expr::Tan(replace(a, target, replacement))),
Expr::Asin(a) => E::new(Expr::Asin(replace(a, target, replacement))),
Expr::Acos(a) => E::new(Expr::Acos(replace(a, target, replacement))),
Expr::Atan(a) => E::new(Expr::Atan(replace(a, target, replacement))),
Expr::Sinh(a) => E::new(Expr::Sinh(replace(a, target, replacement))),
Expr::Cosh(a) => E::new(Expr::Cosh(replace(a, target, replacement))),
Expr::Tanh(a) => E::new(Expr::Tanh(replace(a, target, replacement))),
Expr::Exp(a) => E::new(Expr::Exp(replace(a, target, replacement))),
Expr::Ln(a) => E::new(Expr::Ln(replace(a, target, replacement))),
Expr::Log2(a) => E::new(Expr::Log2(replace(a, target, replacement))),
Expr::Log10(a) => E::new(Expr::Log10(replace(a, target, replacement))),
Expr::Sqrt(a) => E::new(Expr::Sqrt(replace(a, target, replacement))),
Expr::Abs(a) => E::new(Expr::Abs(replace(a, target, replacement))),
}
}
fn flatten_mul_factors(e: &E) -> (f64, Vec<E>) {
match e.as_ref() {
Expr::Mul(a, b) => {
let (ca, mut fa) = flatten_mul_factors(a);
let (cb, fb) = flatten_mul_factors(b);
fa.extend(fb);
(ca * cb, fa)
}
Expr::Neg(a) => {
let (c, f) = flatten_mul_factors(a);
(-c, f)
}
Expr::Const(v) => (*v, vec![]),
_ => (1.0, vec![e.clone()]),
}
}
fn build_mul_from_factors(coeff: f64, factors: Vec<E>) -> E {
if factors.is_empty() {
return E::new(Expr::Const(coeff));
}
let mut iter = factors.into_iter();
let mut result = iter.next().unwrap();
for f in iter {
result = E::new(Expr::Mul(result, f));
}
if coeff == 1.0 {
result
} else if coeff == -1.0 {
E::new(Expr::Neg(result))
} else {
E::new(Expr::Mul(E::new(Expr::Const(coeff)), result))
}
}
pub fn cse(exprs: &[E]) -> (Vec<(String, E)>, Vec<E>) {
if exprs.is_empty() {
return (vec![], vec![]);
}
let mut results = exprs.to_vec();
let mut intermediates: Vec<(String, E)> = Vec::new();
let mut var_counter = 0usize;
loop {
let mut counts: HashMap<E, usize> = HashMap::new();
for r in &results {
count_subexprs(r, &mut counts);
}
for (_, expr) in &intermediates {
count_subexprs(expr, &mut counts);
}
let best = counts.into_iter()
.filter(|(e, count)| *count >= 2 && expr_cost(e) >= 1)
.max_by_key(|(e, count)| {
let cost = expr_cost(e);
let savings = (*count - 1) * cost;
(savings, expr_depth(e))
});
let (subexpr, _count) = match best {
Some(b) => b,
None => break,
};
let var_name = format!("__x{}", var_counter);
var_counter += 1;
let var_sym = symbol(&var_name);
for r in results.iter_mut() {
*r = replace(r, &subexpr, &var_sym);
}
for (_, expr) in intermediates.iter_mut() {
*expr = replace(expr, &subexpr, &var_sym);
}
intermediates.push((var_name, subexpr));
}
let mut divisor_counts: HashMap<E, usize> = HashMap::new();
for r in &results {
count_divisors(r, &mut divisor_counts);
}
for (_, expr) in &intermediates {
count_divisors(expr, &mut divisor_counts);
}
for (divisor, count) in divisor_counts {
if count >= 2 {
let var_name = format!("__x{}", var_counter);
var_counter += 1;
let var_sym = symbol(&var_name);
let recip = E::new(Expr::Div(E::new(Expr::Const(1.0)), divisor.clone()));
for r in results.iter_mut() {
*r = replace_divisor(r, &divisor, &var_sym);
}
for (_, expr) in intermediates.iter_mut() {
*expr = replace_divisor(expr, &divisor, &var_sym);
}
intermediates.push((var_name, recip));
}
}
let intermediates = topo_sort_intermediates(intermediates);
(intermediates, results)
}
fn count_divisors(e: &E, counts: &mut HashMap<E, usize>) {
match e.as_ref() {
Expr::Sym(_) | Expr::Const(_) => {}
Expr::Div(a, b) => {
*counts.entry(b.clone()).or_insert(0) += 1;
count_divisors(a, counts);
count_divisors(b, counts);
}
Expr::Neg(a) | Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a)
| Expr::Asin(a) | Expr::Acos(a) | Expr::Atan(a)
| Expr::Sinh(a) | Expr::Cosh(a) | Expr::Tanh(a)
| Expr::Exp(a) | Expr::Ln(a) | Expr::Log2(a) | Expr::Log10(a)
| Expr::Sqrt(a) | Expr::Abs(a) => count_divisors(a, counts),
Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b)
| Expr::Pow(a, b) | Expr::Atan2(a, b) => {
count_divisors(a, counts);
count_divisors(b, counts);
}
}
}
fn replace_divisor(e: &E, divisor: &E, replacement: &E) -> E {
match e.as_ref() {
Expr::Div(a, b) if b == divisor => {
let a2 = replace_divisor(a, divisor, replacement);
E::new(Expr::Mul(a2, replacement.clone()))
}
Expr::Sym(_) | Expr::Const(_) => e.clone(),
Expr::Neg(a) => E::new(Expr::Neg(replace_divisor(a, divisor, replacement))),
Expr::Add(a, b) => E::new(Expr::Add(replace_divisor(a, divisor, replacement), replace_divisor(b, divisor, replacement))),
Expr::Sub(a, b) => E::new(Expr::Sub(replace_divisor(a, divisor, replacement), replace_divisor(b, divisor, replacement))),
Expr::Mul(a, b) => E::new(Expr::Mul(replace_divisor(a, divisor, replacement), replace_divisor(b, divisor, replacement))),
Expr::Div(a, b) => E::new(Expr::Div(replace_divisor(a, divisor, replacement), replace_divisor(b, divisor, replacement))),
Expr::Pow(a, b) => E::new(Expr::Pow(replace_divisor(a, divisor, replacement), replace_divisor(b, divisor, replacement))),
Expr::Atan2(a, b) => E::new(Expr::Atan2(replace_divisor(a, divisor, replacement), replace_divisor(b, divisor, replacement))),
Expr::Sin(a) => E::new(Expr::Sin(replace_divisor(a, divisor, replacement))),
Expr::Cos(a) => E::new(Expr::Cos(replace_divisor(a, divisor, replacement))),
Expr::Tan(a) => E::new(Expr::Tan(replace_divisor(a, divisor, replacement))),
Expr::Asin(a) => E::new(Expr::Asin(replace_divisor(a, divisor, replacement))),
Expr::Acos(a) => E::new(Expr::Acos(replace_divisor(a, divisor, replacement))),
Expr::Atan(a) => E::new(Expr::Atan(replace_divisor(a, divisor, replacement))),
Expr::Sinh(a) => E::new(Expr::Sinh(replace_divisor(a, divisor, replacement))),
Expr::Cosh(a) => E::new(Expr::Cosh(replace_divisor(a, divisor, replacement))),
Expr::Tanh(a) => E::new(Expr::Tanh(replace_divisor(a, divisor, replacement))),
Expr::Exp(a) => E::new(Expr::Exp(replace_divisor(a, divisor, replacement))),
Expr::Ln(a) => E::new(Expr::Ln(replace_divisor(a, divisor, replacement))),
Expr::Log2(a) => E::new(Expr::Log2(replace_divisor(a, divisor, replacement))),
Expr::Log10(a) => E::new(Expr::Log10(replace_divisor(a, divisor, replacement))),
Expr::Sqrt(a) => E::new(Expr::Sqrt(replace_divisor(a, divisor, replacement))),
Expr::Abs(a) => E::new(Expr::Abs(replace_divisor(a, divisor, replacement))),
}
}
fn topo_sort_intermediates(intermediates: Vec<(String, E)>) -> Vec<(String, E)> {
use std::collections::HashSet;
let names: HashSet<String> = intermediates.iter().map(|(n, _)| n.clone()).collect();
let deps: Vec<HashSet<String>> = intermediates.iter().map(|(_, expr)| {
let vars = expr.free_vars();
vars.into_iter().filter(|v| names.contains(v)).collect()
}).collect();
let n = intermediates.len();
let name_to_idx: HashMap<String, usize> = intermediates.iter().enumerate()
.map(|(i, (n, _))| (n.clone(), i)).collect();
let mut in_degree = vec![0usize; n];
let mut dependents: Vec<Vec<usize>> = vec![vec![]; n];
for (i, dep_set) in deps.iter().enumerate() {
for dep_name in dep_set {
if let Some(&j) = name_to_idx.get(dep_name) {
in_degree[i] += 1;
dependents[j].push(i);
}
}
}
let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
let mut sorted = Vec::with_capacity(n);
while let Some(idx) = queue.pop() {
sorted.push(idx);
for &dep in &dependents[idx] {
in_degree[dep] -= 1;
if in_degree[dep] == 0 {
queue.push(dep);
}
}
}
sorted.into_iter().map(|i| intermediates[i].clone()).collect()
}