#![allow(clippy::match_same_arms)]
use std::collections::HashMap;
use serde::Deserialize;
use serde::Serialize;
use crate::symbolic::calculus::substitute;
use crate::symbolic::core::Expr;
use crate::symbolic::polynomial::contains_var;
use crate::symbolic::simplify_dag::pattern_match;
use crate::symbolic::simplify_dag::substitute_patterns;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RewriteRule {
pub lhs: Expr,
pub rhs: Expr,
}
pub(crate) fn is_greater(
e1: &Expr,
e2: &Expr,
) -> bool {
complexity(e1) > complexity(e2)
}
#[must_use]
pub fn apply_rules_to_normal_form(
expr: &Expr,
rules: &[RewriteRule],
) -> Expr {
let mut current_expr = expr.clone();
let mut changed = true;
while changed {
changed = false;
let (next_expr, applied) = apply_rules_once(¤t_expr, rules);
if applied {
current_expr = next_expr;
changed = true;
}
}
current_expr
}
pub(crate) fn apply_rules_once(
expr: &Expr,
rules: &[RewriteRule],
) -> (Expr, bool) {
for rule in rules {
if let Some(assignments) = pattern_match(expr, &rule.lhs) {
return (substitute_patterns(&rule.rhs, &assignments), true);
}
}
match expr {
| Expr::Dag(node) => {
return apply_rules_once(
&node.to_expr().expect(
"Apply rules \
once",
),
rules,
);
},
| Expr::Add(a, b) => {
let (na, ca) = apply_rules_once(a, rules);
if ca {
return (Expr::new_add(na, b.clone()), true);
}
let (nb, cb) = apply_rules_once(b, rules);
if cb {
return (Expr::new_add(a.clone(), nb), true);
}
},
| Expr::Mul(a, b) => {
let (na, ca) = apply_rules_once(a, rules);
if ca {
return (Expr::new_mul(na, b.clone()), true);
}
let (nb, cb) = apply_rules_once(b, rules);
if cb {
return (Expr::new_mul(a.clone(), nb), true);
}
},
| _ => {},
}
(expr.clone(), false)
}
pub fn knuth_bendix(equations: &[Expr]) -> Result<Vec<RewriteRule>, String> {
let mut rules: Vec<RewriteRule> = Vec::new();
for eq in equations {
if let Expr::Eq(lhs, rhs) = eq {
if is_greater(lhs, rhs) {
rules.push(RewriteRule {
lhs: lhs.as_ref().clone(),
rhs: rhs.as_ref().clone(),
});
} else if is_greater(rhs, lhs) {
rules.push(RewriteRule {
lhs: rhs.as_ref().clone(),
rhs: lhs.as_ref().clone(),
});
}
} else {
return Err("Input must be a list \
of equations \
(Expr::Eq)."
.to_string());
}
}
let mut i = 0;
while i < rules.len() {
let mut j = 0;
while j <= i {
let (rule1, rule2) = (&rules[i].clone(), &rules[j].clone());
let critical_pairs = find_critical_pairs(rule1, rule2);
for (t1, t2) in critical_pairs {
let n1 = apply_rules_to_normal_form(&t1, &rules);
let n2 = apply_rules_to_normal_form(&t2, &rules);
if n1 != n2 {
let new_rule = if is_greater(&n1, &n2) {
RewriteRule { lhs: n1, rhs: n2 }
} else {
RewriteRule { lhs: n2, rhs: n1 }
};
if new_rule.lhs != new_rule.rhs && !rules.iter().any(|r| r.lhs == new_rule.lhs)
{
rules.push(new_rule);
i = 0;
j = 0;
}
}
}
j += 1;
}
i += 1;
}
Ok(rules)
}
pub(crate) fn find_critical_pairs(
r1: &RewriteRule,
r2: &RewriteRule,
) -> Vec<(Expr, Expr)> {
let mut pairs = Vec::new();
let mut sub_expressions = Vec::new();
r1.lhs.pre_order_walk(&mut |sub_expr| {
sub_expressions.push(sub_expr.clone());
});
for sub_expr in &sub_expressions {
if let Some(subst) = unify(sub_expr, &r2.lhs) {
let t1 = substitute(&r1.lhs, &sub_expr.to_string(), &r2.rhs);
let t1_subst = substitute_patterns(&t1, &subst);
let t2 = substitute_patterns(&r1.rhs, &subst);
if t1_subst != t2 {
pairs.push((t1_subst, t2));
}
}
}
pairs
}
pub(crate) fn unify(
e1: &Expr,
e2: &Expr,
) -> Option<HashMap<String, Expr>> {
let mut subst = HashMap::new();
if unify_recursive(e1, e2, &mut subst) {
Some(subst)
} else {
None
}
}
pub(crate) fn unify_recursive(
e1: &Expr,
e2: &Expr,
subst: &mut HashMap<String, Expr>,
) -> bool {
match (e1, e2) {
| (Expr::Pattern(p), _) => {
if let Some(val) = subst.get(p) {
return val == e2;
}
if contains_var(e2, p) {
return false;
}
subst.insert(p.clone(), e2.clone());
true
},
| (_, Expr::Pattern(p)) => {
if let Some(val) = subst.get(p) {
return val == e1;
}
if contains_var(e1, p) {
return false;
}
subst.insert(p.clone(), e1.clone());
true
},
| (Expr::Add(a1, b1), Expr::Add(a2, b2)) | (Expr::Mul(a1, b1), Expr::Mul(a2, b2)) => {
let original_subst = subst.clone();
if unify_recursive(a1, a2, subst) && unify_recursive(b1, b2, subst) {
true
} else {
*subst = original_subst;
unify_recursive(a1, b2, subst) && unify_recursive(b1, a2, subst)
}
},
| (Expr::Sub(a1, b1), Expr::Sub(a2, b2))
| (Expr::Div(a1, b1), Expr::Div(a2, b2))
| (Expr::Power(a1, b1), Expr::Power(a2, b2)) => {
unify_recursive(a1, a2, subst) && unify_recursive(b1, b2, subst)
},
| (Expr::Sin(a1), Expr::Sin(a2))
| (Expr::Cos(a1), Expr::Cos(a2))
| (Expr::Tan(a1), Expr::Tan(a2))
| (Expr::Log(a1), Expr::Log(a2))
| (Expr::Exp(a1), Expr::Exp(a2))
| (Expr::Neg(a1), Expr::Neg(a2)) => unify_recursive(a1, a2, subst),
| _ => e1 == e2,
}
}
pub(crate) fn complexity(expr: &Expr) -> usize {
match expr {
| Expr::Dag(node) => complexity(&node.to_expr().expect("Complexity")),
| Expr::Add(a, b) | Expr::Mul(a, b) | Expr::Sub(a, b) | Expr::Div(a, b) => {
complexity(a) + complexity(b) + 1
},
| Expr::Power(b, e) => complexity(b) + complexity(e) + 2,
| Expr::Sin(a)
| Expr::Cos(a)
| Expr::Tan(a)
| Expr::Log(a)
| Expr::Exp(a)
| Expr::Neg(a) => complexity(a) + 1,
| Expr::UnaryList(_, a) => complexity(a) + 1,
| Expr::BinaryList(_, a, b) => complexity(a) + complexity(b) + 1,
| Expr::NaryList(_, v) => v.iter().map(complexity).sum::<usize>() + 1,
| _ => 1,
}
}