use std::collections::HashMap;
use egg::{rewrite, CostFunction, Id, Language, RecExpr, Rewrite, Runner, Symbol};
use crate::expr::{ExprLang, Expression};
pub fn expand(expr: &Expression) -> Expression {
let expanded_pow = expand_powers(expr);
distribute_fully(&expanded_pow)
}
fn distribute_fully(expr: &Expression) -> Expression {
if expr.is_mul() {
let operands = expr.as_mul().expect("is_mul() was true");
let left = distribute_fully(&operands[0]);
let right = distribute_fully(&operands[1]);
distribute_product(&left, &right)
} else if expr.is_add() {
let operands = expr.as_add().expect("is_add() was true");
let left = distribute_fully(&operands[0]);
let right = distribute_fully(&operands[1]);
left + right
} else if expr.is_neg() {
let inner = expr.as_neg().expect("is_neg() was true");
-distribute_fully(&inner)
} else if expr.is_pow() {
let (base, exp) = expr.as_pow().expect("is_pow() was true");
let expanded_base = distribute_fully(&base);
expanded_base.pow(&exp)
} else {
expr.clone()
}
}
fn distribute_product(left: &Expression, right: &Expression) -> Expression {
let left_terms = collect_addends(left);
let right_terms = collect_addends(right);
let mut result_terms: Vec<Expression> = Vec::new();
for l in &left_terms {
for r in &right_terms {
let product = multiply_terms(l, r);
result_terms.push(product);
}
}
if result_terms.is_empty() {
Expression::zero()
} else {
let mut result = result_terms.remove(0);
for term in result_terms {
result = result + term;
}
result
}
}
fn collect_addends(expr: &Expression) -> Vec<Expression> {
if expr.is_add() {
let operands = expr.as_add().expect("is_add() was true");
let mut terms = collect_addends(&operands[0]);
terms.extend(collect_addends(&operands[1]));
terms
} else {
vec![expr.clone()]
}
}
fn multiply_terms(a: &Expression, b: &Expression) -> Expression {
let (a_neg, a_inner) = unwrap_neg(a);
let (b_neg, b_inner) = unwrap_neg(b);
let product = a_inner * b_inner;
if a_neg ^ b_neg {
-product
} else {
product
}
}
fn unwrap_neg(expr: &Expression) -> (bool, Expression) {
if expr.is_neg() {
let inner = expr.as_neg().expect("is_neg() was true");
let (inner_neg, inner_expr) = unwrap_neg(&inner);
(!inner_neg, inner_expr)
} else {
(false, expr.clone())
}
}
fn expand_powers(expr: &Expression) -> Expression {
if expr.is_pow() {
let (base, exp) = expr.as_pow().expect("is_pow() was true");
let expanded_base = expand_powers(&base);
if exp.is_number() {
if let Some(exp_val) = exp.to_f64() {
if (exp_val - 2.0).abs() < 1e-10 {
return expanded_base.clone() * expanded_base;
}
}
}
return expanded_base.pow(&exp);
}
if expr.is_add() {
let operands = expr.as_add().expect("is_add() was true");
let left = expand_powers(&operands[0]);
let right = expand_powers(&operands[1]);
return left + right;
}
if expr.is_mul() {
let operands = expr.as_mul().expect("is_mul() was true");
let left = expand_powers(&operands[0]);
let right = expand_powers(&operands[1]);
return left * right;
}
if expr.is_neg() {
let inner = expr.as_neg().expect("is_neg() was true");
return -expand_powers(&inner);
}
expr.clone()
}
pub fn simplify(expr: &Expression) -> Expression {
let rules = get_simplification_rules();
let runner = Runner::default()
.with_expr(expr.as_rec_expr())
.with_iter_limit(20)
.run(&rules);
let root = runner.roots[0];
let extractor = egg::Extractor::new(&runner.egraph, AstSize);
let (_, best) = extractor.find_best(root);
Expression::from_rec_expr(best)
}
pub fn substitute(expr: &Expression, var: &Expression, value: &Expression) -> Expression {
let var_name = match var.as_symbol() {
Some(name) => name.to_string(),
None => return expr.clone(), };
let rec_expr = expr.as_rec_expr();
let value_expr = value.as_rec_expr();
let mut new_expr = RecExpr::default();
let mut id_map: HashMap<usize, Id> = HashMap::new();
substitute_rec(
rec_expr,
rec_expr.as_ref().len() - 1,
&var_name,
value_expr,
&mut new_expr,
&mut id_map,
);
Expression::from_rec_expr(new_expr)
}
fn substitute_rec(
expr: &RecExpr<ExprLang>,
idx: usize,
var_name: &str,
value: &RecExpr<ExprLang>,
new_expr: &mut RecExpr<ExprLang>,
id_map: &mut HashMap<usize, Id>,
) -> Id {
if let Some(&new_id) = id_map.get(&idx) {
return new_id;
}
let node = &expr[Id::from(idx)];
if let ExprLang::Num(s) = node {
if s.as_str() == var_name {
let offset = new_expr.as_ref().len();
for (i, n) in value.as_ref().iter().enumerate() {
let mapped_node = n
.clone()
.map_children(|child_id| Id::from(usize::from(child_id) + offset));
new_expr.add(mapped_node);
}
let new_id = Id::from(new_expr.as_ref().len() - 1);
id_map.insert(idx, new_id);
return new_id;
}
}
let new_node = node.clone().map_children(|child_id| {
substitute_rec(
expr,
usize::from(child_id),
var_name,
value,
new_expr,
id_map,
)
});
let new_id = new_expr.add(new_node);
id_map.insert(idx, new_id);
new_id
}
struct AstSize;
impl CostFunction<ExprLang> for AstSize {
type Cost = usize;
fn cost<C>(&mut self, node: &ExprLang, mut costs: C) -> Self::Cost
where
C: FnMut(Id) -> Self::Cost,
{
let node_cost = match node {
ExprLang::Num(_) => 1,
_ => 3,
};
node.fold(node_cost, |sum, id| sum + costs(id))
}
}
#[allow(dead_code)]
struct ExpandedSize;
impl CostFunction<ExprLang> for ExpandedSize {
type Cost = usize;
fn cost<C>(&mut self, node: &ExprLang, mut costs: C) -> Self::Cost
where
C: FnMut(Id) -> Self::Cost,
{
let node_cost = match node {
ExprLang::Num(_) => 1,
ExprLang::Add(_) => 2,
ExprLang::Mul(_) => 4,
_ => 3,
};
node.fold(node_cost, |sum, id| sum + costs(id))
}
}
#[allow(dead_code)]
fn get_distribution_rules() -> Vec<Rewrite<ExprLang, ()>> {
vec![
rewrite!("distrib-left"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"),
rewrite!("distrib-right"; "(* (+ ?a ?b) ?c)" => "(+ (* ?a ?c) (* ?b ?c))"),
rewrite!("neg-mul-left"; "(* (neg ?a) ?b)" => "(neg (* ?a ?b))"),
rewrite!("neg-mul-right"; "(* ?a (neg ?b))" => "(neg (* ?a ?b))"),
rewrite!("neg-neg-mul"; "(* (neg ?a) (neg ?b))" => "(* ?a ?b)"),
rewrite!("neg-add"; "(neg (+ ?a ?b))" => "(+ (neg ?a) (neg ?b))"),
rewrite!("neg-neg"; "(neg (neg ?a))" => "?a"),
rewrite!("mul-assoc"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"),
rewrite!("mul-assoc-rev"; "(* (* ?a ?b) ?c)" => "(* ?a (* ?b ?c))"),
rewrite!("add-assoc"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"),
rewrite!("add-assoc-rev"; "(+ (+ ?a ?b) ?c)" => "(+ ?a (+ ?b ?c))"),
rewrite!("mul-comm"; "(* ?a ?b)" => "(* ?b ?a)"),
rewrite!("add-comm"; "(+ ?a ?b)" => "(+ ?b ?a)"),
rewrite!("add-zero"; "(+ ?a 0)" => "?a"),
rewrite!("zero-add"; "(+ 0 ?a)" => "?a"),
rewrite!("mul-one"; "(* ?a 1)" => "?a"),
rewrite!("one-mul"; "(* 1 ?a)" => "?a"),
rewrite!("mul-zero"; "(* ?a 0)" => "0"),
rewrite!("zero-mul"; "(* 0 ?a)" => "0"),
rewrite!("neg-zero"; "(neg 0)" => "0"),
]
}
fn get_simplification_rules() -> Vec<Rewrite<ExprLang, ()>> {
vec![
rewrite!("add-zero"; "(+ ?a 0)" => "?a"),
rewrite!("zero-add"; "(+ 0 ?a)" => "?a"),
rewrite!("mul-one"; "(* ?a 1)" => "?a"),
rewrite!("one-mul"; "(* 1 ?a)" => "?a"),
rewrite!("mul-zero"; "(* ?a 0)" => "0"),
rewrite!("zero-mul"; "(* 0 ?a)" => "0"),
rewrite!("neg-neg"; "(neg (neg ?a))" => "?a"),
rewrite!("pow-zero"; "(^ ?a 0)" => "1"),
rewrite!("pow-one"; "(^ ?a 1)" => "?a"),
rewrite!("add-comm"; "(+ ?a ?b)" => "(+ ?b ?a)"),
rewrite!("mul-comm"; "(* ?a ?b)" => "(* ?b ?a)"),
rewrite!("add-assoc"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"),
rewrite!("mul-assoc"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"),
rewrite!("distrib"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"),
rewrite!("exp-log"; "(exp (log ?a))" => "?a"),
rewrite!("log-exp"; "(log (exp ?a))" => "?a"),
rewrite!("sqrt-sq"; "(sqrt (^ ?a 2))" => "(abs ?a)"),
]
}
pub fn get_quantum_rules() -> Vec<Rewrite<ExprLang, ()>> {
vec![
rewrite!("comm-self"; "(comm ?a ?a)" => "0"),
rewrite!("comm-antisym"; "(comm ?a ?b)" => "(neg (comm ?b ?a))"),
rewrite!("comm-zero-left"; "(comm 0 ?a)" => "0"),
rewrite!("comm-zero-right"; "(comm ?a 0)" => "0"),
rewrite!("anticomm-self"; "(anticomm ?a ?a)" => "(* 2 ?a)"),
rewrite!("anticomm-sym"; "(anticomm ?a ?b)" => "(anticomm ?b ?a)"),
rewrite!("anticomm-zero"; "(anticomm 0 ?a)" => "?a"),
rewrite!("dagger-dagger"; "(dagger (dagger ?a))" => "?a"),
rewrite!("dagger-mul"; "(dagger (* ?a ?b))" => "(* (dagger ?b) (dagger ?a))"),
rewrite!("dagger-add"; "(dagger (+ ?a ?b))" => "(+ (dagger ?a) (dagger ?b))"),
rewrite!("dagger-zero"; "(dagger 0)" => "0"),
rewrite!("dagger-one"; "(dagger 1)" => "1"),
rewrite!("trace-add"; "(trace (+ ?a ?b))" => "(+ (trace ?a) (trace ?b))"),
rewrite!("trace-scale"; "(trace (* ?c ?a))" => "(* ?c (trace ?a))"),
rewrite!("trace-zero"; "(trace 0)" => "0"),
rewrite!("tensor-mul"; "(* (tensor ?a ?b) (tensor ?c ?d))" => "(tensor (* ?a ?c) (* ?b ?d))"),
rewrite!("tensor-one-right"; "(tensor ?a 1)" => "?a"),
rewrite!("tensor-one-left"; "(tensor 1 ?a)" => "?a"),
rewrite!("tensor-zero"; "(tensor ?a 0)" => "0"),
rewrite!("tensor-zero-left"; "(tensor 0 ?a)" => "0"),
rewrite!("det-one"; "(det 1)" => "1"),
rewrite!("transpose-transpose"; "(transpose (transpose ?a))" => "?a"),
rewrite!("transpose-mul"; "(transpose (* ?a ?b))" => "(* (transpose ?b) (transpose ?a))"),
rewrite!("transpose-add"; "(transpose (+ ?a ?b))" => "(+ (transpose ?a) (transpose ?b))"),
]
}
pub fn simplify_quantum(expr: &Expression) -> Expression {
let mut rules = get_simplification_rules();
rules.extend(get_quantum_rules());
let runner = Runner::default()
.with_expr(expr.as_rec_expr())
.with_iter_limit(30)
.run(&rules);
let root = runner.roots[0];
let extractor = egg::Extractor::new(&runner.egraph, AstSize);
let (_, best) = extractor.find_best(root);
Expression::from_rec_expr(best)
}
pub fn get_trig_rules() -> Vec<Rewrite<ExprLang, ()>> {
vec![
rewrite!("sin-zero"; "(sin 0)" => "0"),
rewrite!("cos-zero"; "(cos 0)" => "1"),
rewrite!("tan-zero"; "(tan 0)" => "0"),
rewrite!("exp-zero"; "(exp 0)" => "1"),
rewrite!("log-one"; "(log 1)" => "0"),
rewrite!("sin-neg"; "(sin (neg ?x))" => "(neg (sin ?x))"),
rewrite!("cos-neg"; "(cos (neg ?x))" => "(cos ?x)"),
rewrite!("tan-neg"; "(tan (neg ?x))" => "(neg (tan ?x))"),
rewrite!("exp-add"; "(exp (+ ?a ?b))" => "(* (exp ?a) (exp ?b))"),
rewrite!("log-mul"; "(log (* ?a ?b))" => "(+ (log ?a) (log ?b))"),
rewrite!("exp-log"; "(exp (log ?x))" => "?x"),
rewrite!("log-exp"; "(log (exp ?x))" => "?x"),
rewrite!("sqrt-sq"; "(^ (sqrt ?x) 2)" => "?x"),
rewrite!("sq-sqrt"; "(sqrt (^ ?x 2))" => "(abs ?x)"),
]
}
pub fn simplify_trig(expr: &Expression) -> Expression {
let mut rules = get_simplification_rules();
rules.extend(get_trig_rules());
let runner = Runner::default()
.with_expr(expr.as_rec_expr())
.with_iter_limit(30)
.run(&rules);
let root = runner.roots[0];
let extractor = egg::Extractor::new(&runner.egraph, AstSize);
let (_, best) = extractor.find_best(root);
Expression::from_rec_expr(best)
}
pub fn collect(expr: &Expression, var: &Expression) -> Expression {
let expanded = expand(expr);
simplify(&expanded)
}
pub fn factor(expr: &Expression) -> Expression {
let factor_rules = vec![
rewrite!("factor-left"; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"),
rewrite!("factor-right"; "(+ (* ?a ?c) (* ?b ?c))" => "(* (+ ?a ?b) ?c)"),
rewrite!("add-same"; "(+ ?a ?a)" => "(* 2 ?a)"),
rewrite!("mul-one"; "(* ?a 1)" => "?a"),
rewrite!("mul-zero"; "(* ?a 0)" => "0"),
];
let runner: Runner<ExprLang, ()> = Runner::default()
.with_expr(expr.as_rec_expr())
.with_iter_limit(20)
.run(&factor_rules);
let root = runner.roots[0];
let extractor = egg::Extractor::new(&runner.egraph, FactoredSize);
let (_, best) = extractor.find_best(root);
Expression::from_rec_expr(best)
}
struct FactoredSize;
impl CostFunction<ExprLang> for FactoredSize {
type Cost = usize;
fn cost<C>(&mut self, node: &ExprLang, mut costs: C) -> Self::Cost
where
C: FnMut(Id) -> Self::Cost,
{
let node_cost = match node {
ExprLang::Num(_) => 1,
ExprLang::Mul(_) => 2,
ExprLang::Add(_) => 4,
_ => 3,
};
node.fold(node_cost, |sum, id| sum + costs(id))
}
}
#[cfg(test)]
#[allow(clippy::redundant_clone)]
mod tests {
use super::*;
#[test]
fn test_simplify_add_zero() {
let x = Expression::symbol("x");
let zero = Expression::zero();
let expr = x + zero;
let simplified = simplify(&expr);
assert!(simplified.as_symbol().is_some());
}
#[test]
fn test_simplify_mul_one() {
let x = Expression::symbol("x");
let one = Expression::one();
let expr = x * one;
let simplified = simplify(&expr);
assert!(simplified.as_symbol().is_some());
}
#[test]
fn test_simplify_mul_zero() {
let x = Expression::symbol("x");
let zero = Expression::zero();
let expr = x * zero;
let simplified = simplify(&expr);
assert!(simplified.is_zero());
}
#[test]
fn test_substitute_simple() {
let x = Expression::symbol("x");
let y = Expression::symbol("y");
let two = Expression::int(2);
let expr = x.clone() + y;
let result = substitute(&expr, &x, &two);
let mut values = std::collections::HashMap::new();
values.insert("y".to_string(), 3.0);
let eval_result = result.eval(&values);
assert!(eval_result.is_ok());
assert!((eval_result.expect("eval") - 5.0).abs() < 1e-10);
}
#[test]
fn test_substitute_nested() {
let x = Expression::symbol("x");
let y = Expression::symbol("y");
let expr = x.clone() * x.clone();
let result = substitute(&expr, &x, &y);
let mut values = std::collections::HashMap::new();
values.insert("y".to_string(), 3.0);
let eval_result = result.eval(&values);
assert!(eval_result.is_ok());
assert!((eval_result.expect("eval") - 9.0).abs() < 1e-10);
}
#[test]
fn test_expand_distribution() {
let x = Expression::symbol("x");
let y = Expression::symbol("y");
let z = Expression::symbol("z");
let expr = x * (y + z);
let expanded = expand(&expr);
let mut values = std::collections::HashMap::new();
values.insert("x".to_string(), 2.0);
values.insert("y".to_string(), 3.0);
values.insert("z".to_string(), 4.0);
let orig_val = expr.eval(&values).expect("eval original");
let exp_val = expanded.eval(&values).expect("eval expanded");
assert!((orig_val - exp_val).abs() < 1e-10);
assert!((exp_val - 14.0).abs() < 1e-10); }
#[test]
fn test_factor_common_terms() {
let a = Expression::symbol("a");
let x = Expression::symbol("x");
let y = Expression::symbol("y");
let expr = a.clone() * x.clone() + a.clone() * y.clone();
let factored = factor(&expr);
let mut values = std::collections::HashMap::new();
values.insert("a".to_string(), 2.0);
values.insert("x".to_string(), 3.0);
values.insert("y".to_string(), 4.0);
let orig_val = expr.eval(&values).expect("eval original");
let fact_val = factored.eval(&values).expect("eval factored");
assert!((orig_val - fact_val).abs() < 1e-10);
assert!((fact_val - 14.0).abs() < 1e-10); }
#[test]
fn test_simplify_trig() {
let zero = Expression::zero();
let sin_zero = crate::ops::trig::sin(&zero);
let simplified = simplify_trig(&sin_zero);
let result = simplified.eval(&std::collections::HashMap::new());
assert!(result.is_ok());
assert!(result.expect("eval").abs() < 1e-10);
}
#[test]
fn test_simplify_quantum_dagger() {
let rules = get_quantum_rules();
assert!(!rules.is_empty());
assert!(rules.len() >= 15);
}
#[test]
fn test_collect() {
let x = Expression::symbol("x");
let expr = x.clone() + x.clone();
let collected = collect(&expr, &x);
let mut values = std::collections::HashMap::new();
values.insert("x".to_string(), 5.0);
let orig_val = expr.eval(&values).expect("eval original");
let coll_val = collected.eval(&values).expect("eval collected");
assert!((orig_val - coll_val).abs() < 1e-10);
assert!((coll_val - 10.0).abs() < 1e-10); }
#[test]
fn test_expand_simple_pow2() {
let a = Expression::symbol("a");
let two = Expression::from(2);
let expr = a.clone().pow(&two);
let expanded = expand(&expr);
let mut values = std::collections::HashMap::new();
values.insert("a".to_string(), 3.0);
let exp_val = expanded.eval(&values).expect("eval");
assert!((exp_val - 9.0).abs() < 1e-10);
}
#[test]
fn test_expand_binomial_squared() {
let a = Expression::symbol("a");
let b = Expression::symbol("b");
let two = Expression::from(2);
let expr = (a.clone() + b.clone()).pow(&two);
let expanded = expand(&expr);
for (a_val, b_val) in [(2.0, 3.0), (1.0, 1.0), (0.0, 5.0)] {
let mut values = std::collections::HashMap::new();
values.insert("a".to_string(), a_val);
values.insert("b".to_string(), b_val);
let orig_val = expr.eval(&values).expect("eval original");
let exp_val = expanded.eval(&values).expect("eval expanded");
assert!(
(orig_val - exp_val).abs() < 1e-10,
"Mismatch at a={a_val}, b={b_val}: orig={orig_val}, expanded={exp_val}"
);
let expected = (a_val + b_val).powi(2);
assert!(
(exp_val - expected).abs() < 1e-10,
"Unexpected value at a={a_val}, b={b_val}: got {exp_val}, expected {expected}"
);
}
}
#[test]
fn test_expand_polynomial_constraint() {
let x = Expression::symbol("x");
let y = Expression::symbol("y");
let z = Expression::symbol("z");
let one = Expression::from(1);
let two = Expression::from(2);
let expr = (x.clone() + y.clone() + z.clone() - one).pow(&two);
let expanded = expand(&expr);
for (x_val, y_val, z_val) in [
(0.0, 0.0, 0.0),
(1.0, 0.0, 0.0),
(1.0, 1.0, 0.0),
(0.0, 1.0, 1.0),
(1.0, 1.0, 1.0),
(0.5, 0.5, 0.0),
] {
let mut values = std::collections::HashMap::new();
values.insert("x".to_string(), x_val);
values.insert("y".to_string(), y_val);
values.insert("z".to_string(), z_val);
let orig_val = expr.eval(&values).expect("eval original");
let exp_val = expanded.eval(&values).expect("eval expanded");
assert!(
(orig_val - exp_val).abs() < 1e-10,
"Mismatch at x={x_val}, y={y_val}, z={z_val}: orig={orig_val}, expanded={exp_val}"
);
let expected = (x_val + y_val + z_val - 1.0).powi(2);
assert!(
(exp_val - expected).abs() < 1e-10,
"Unexpected value at x={x_val}, y={y_val}, z={z_val}: got {exp_val}, expected {expected}"
);
}
}
}