use std::collections::BTreeMap;
use std::collections::HashMap;
use std::sync::Arc;
use num_bigint::BigInt;
use num_rational::BigRational;
use num_traits::One;
use num_traits::Zero;
use ordered_float::OrderedFloat;
#[allow(unused_imports)]
use super::core::DAG_MANAGER;
#[allow(unused_imports)]
use super::core::DagManager;
#[allow(unused_imports)]
use super::core::DagNode;
#[allow(unused_imports)]
use super::core::DagOp;
#[allow(unused_imports)]
use super::core::Expr;
#[must_use]
pub fn simplify(expr: &Expr) -> Expr {
const MAX_ITERATIONS: usize = 1000;
let mut root_node = match DAG_MANAGER.get_or_create(expr) {
| Ok(node) => node,
| Err(_) => {
return expr.clone();
}, };
let mut iterations = 0;
loop {
let (simplified_root, changed) = bottom_up_simplify_pass(root_node.clone());
if !changed {
break Expr::Dag(simplified_root);
}
root_node = simplified_root;
iterations += 1;
if iterations >= MAX_ITERATIONS {
break Expr::Dag(root_node);
}
}
}
pub(crate) fn bottom_up_simplify_pass(root: Arc<DagNode>) -> (Arc<DagNode>, bool) {
const MAX_NODES_PER_PASS: usize = 10000;
let mut memo: HashMap<u64, Arc<DagNode>> = HashMap::new();
let mut work_stack: Vec<Arc<DagNode>> = vec![root.clone()];
let mut visited: HashMap<u64, bool> = HashMap::new();
let mut changed_in_pass = false;
let mut processed_nodes = 0;
while let Some(node) = work_stack.pop() {
if memo.contains_key(&node.hash) {
continue;
}
processed_nodes += 1;
if processed_nodes >= MAX_NODES_PER_PASS {
break;
}
let children_simplified = node
.children
.iter()
.all(|child| memo.contains_key(&child.hash));
if children_simplified {
let new_children: Vec<Arc<DagNode>> = node
.children
.iter()
.map(|child| {
match memo.get(&child.hash) {
| Some(child_node) => child_node.clone(),
| None => {
child.clone()
},
}
})
.collect();
let rebuilt_node =
match DAG_MANAGER.get_or_create_normalized(node.op.clone(), new_children) {
| Ok(node) => node,
| Err(_) => {
continue;
},
};
let simplified_node = apply_rules(&rebuilt_node);
if simplified_node.hash != node.hash {
changed_in_pass = true;
}
memo.insert(node.hash, simplified_node);
} else {
work_stack.push(node.clone());
if visited.insert(node.hash, true).is_none() {
for child in node.children.iter().rev() {
work_stack.push(child.clone());
}
}
}
}
let new_root = match memo.get(&root.hash) {
| Some(node) => node.clone(),
| None => root, };
(new_root, changed_in_pass)
}
pub(crate) fn apply_rules(node: &Arc<DagNode>) -> Arc<DagNode> {
if let Some(folded) = fold_constants(node) {
return folded;
}
match &node.op {
| DagOp::Add => {
if let Some(value) = apply_rules_add(node) {
return value;
}
},
| DagOp::Sub => {
if let Some(value) = apply_rules_sub(node) {
return value;
}
},
| DagOp::Mul => {
if let Some(value) = apply_rules_mul(node) {
return value;
}
},
| DagOp::Div => {
if let Some(value) = apply_rules_div(node) {
return value;
}
},
| DagOp::Neg => {
if let Some(value) = apply_rules_neg(node) {
return value;
}
},
| DagOp::Power => {
if let Some(value) = apply_rules_power(node) {
return value;
}
},
| DagOp::Log => {
if let Some(value) = apply_rules_log(node) {
return value;
}
},
| DagOp::LogBase => {
if let Some(value) = apply_rules_logbase(node) {
return value;
}
},
| DagOp::Exp => {
if let Some(value) = apply_rules_exp(node) {
return value;
}
},
| DagOp::Sin => {
if let Some(value) = apply_rules_sin(node) {
return value;
}
},
| DagOp::Cos => {
if let Some(value) = apply_rules_cos(node) {
return value;
}
},
| DagOp::Tan => {
if let Some(value) = apply_rules_tan(node) {
return value;
}
},
| DagOp::Sec => {
if let Some(value) = apply_rules_sec(node) {
return value;
}
},
| DagOp::Csc => {
if let Some(value) = apply_rules_csc(node) {
return value;
}
},
| DagOp::Cot => {
if let Some(value) = apply_rules_cot(node) {
return value;
}
},
| _ => {}, }
node.clone()
}
#[inline(always)]
#[allow(clippy::inline_always)]
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn apply_rules_cot(node: &Arc<DagNode>) -> Option<Arc<DagNode>> {
if node.children.is_empty() {
return Some(node.clone()); }
let arg = &node.children[0];
match DAG_MANAGER.get_or_create_normalized(DagOp::Cos, vec![arg.clone()]) {
| Ok(cos_x) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Sin, vec![arg.clone()]) {
| Ok(sin_x) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Div, vec![cos_x, sin_x]) {
| Ok(result) => Some(result),
| Err(_) => Some(node.clone()), }
},
| Err(_) => Some(node.clone()), }
},
| Err(_) => Some(node.clone()), }
}
#[inline(always)]
#[allow(clippy::inline_always)]
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn apply_rules_csc(node: &Arc<DagNode>) -> Option<Arc<DagNode>> {
if node.children.is_empty() {
return Some(node.clone()); }
let arg = &node.children[0];
match DAG_MANAGER.get_or_create(&Expr::Constant(1.0)) {
| Ok(one) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Sin, vec![arg.clone()]) {
| Ok(sin_x) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Div, vec![one, sin_x]) {
| Ok(result) => Some(result),
| Err(_) => Some(node.clone()), }
},
| Err(_) => Some(node.clone()), }
},
| Err(_) => Some(node.clone()), }
}
#[inline(always)]
#[allow(clippy::inline_always)]
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn apply_rules_sec(node: &Arc<DagNode>) -> Option<Arc<DagNode>> {
if node.children.is_empty() {
return Some(node.clone()); }
let arg = &node.children[0];
match DAG_MANAGER.get_or_create(&Expr::Constant(1.0)) {
| Ok(one) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Cos, vec![arg.clone()]) {
| Ok(cos_x) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Div, vec![one, cos_x]) {
| Ok(result) => Some(result),
| Err(_) => Some(node.clone()), }
},
| Err(_) => Some(node.clone()), }
},
| Err(_) => Some(node.clone()), }
}
#[inline(always)]
#[allow(clippy::inline_always)]
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn apply_rules_tan(node: &Arc<DagNode>) -> Option<Arc<DagNode>> {
if node.children.is_empty() {
return Some(node.clone()); }
let arg = &node.children[0];
if is_zero_node(arg) {
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(0.0)) {
| Ok(node) => node,
| Err(_) => node.clone(), });
}
match DAG_MANAGER.get_or_create_normalized(DagOp::Sin, vec![arg.clone()]) {
| Ok(sin_x) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Cos, vec![arg.clone()]) {
| Ok(cos_x) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Div, vec![sin_x, cos_x]) {
| Ok(result) => Some(result),
| Err(_) => Some(node.clone()), }
},
| Err(_) => Some(node.clone()), }
},
| Err(_) => Some(node.clone()), }
}
#[inline(always)]
#[allow(clippy::inline_always)]
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn apply_rules_cos(node: &Arc<DagNode>) -> Option<Arc<DagNode>> {
if node.children.is_empty() {
return Some(node.clone()); }
let arg = &node.children[0];
if is_zero_node(arg) {
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(1.0)) {
| Ok(node) => node,
| Err(_) => node.clone(), });
}
if is_pi_node(arg) {
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(-1.0)) {
| Ok(node) => node,
| Err(_) => node.clone(), });
}
if matches!(&arg.op, DagOp::Neg) {
if arg.children.is_empty() {
return Some(node.clone()); }
return Some(
match DAG_MANAGER.get_or_create_normalized(DagOp::Cos, vec![arg.children[0].clone()]) {
| Ok(result) => result,
| Err(_) => node.clone(), },
);
}
if matches!(&arg.op, DagOp::Add) {
if arg.children.len() >= 2 {
let a = &arg.children[0];
let b = &arg.children[1];
if is_pi_node(b) {
match DAG_MANAGER.get_or_create_normalized(DagOp::Cos, vec![a.clone()]) {
| Ok(cos_a) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Neg, vec![cos_a]) {
| Ok(result) => return Some(result),
| Err(_) => return Some(node.clone()),
}
},
| Err(_) => return Some(node.clone()),
}
}
match DAG_MANAGER.get_or_create_normalized(DagOp::Cos, vec![a.clone()]) {
| Ok(cos_a) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Cos, vec![b.clone()]) {
| Ok(cos_b) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Sin, vec![a.clone()])
{
| Ok(sin_a) => {
match DAG_MANAGER
.get_or_create_normalized(DagOp::Sin, vec![b.clone()])
{
| Ok(sin_b) => {
match DAG_MANAGER.get_or_create_normalized(
DagOp::Mul,
vec![cos_a, cos_b],
) {
| Ok(term1) => {
match DAG_MANAGER.get_or_create_normalized(
DagOp::Mul,
vec![sin_a, sin_b],
) {
| Ok(term2) => {
match DAG_MANAGER
.get_or_create_normalized(
DagOp::Sub,
vec![term1, term2],
) {
| Ok(result) => {
return Some(result);
},
| Err(_) => {
return Some(node.clone());
},
}
},
| Err(_) => return Some(node.clone()),
}
},
| Err(_) => return Some(node.clone()),
}
},
| Err(_) => return Some(node.clone()),
}
},
| Err(_) => return Some(node.clone()),
}
},
| Err(_) => return Some(node.clone()), }
},
| Err(_) => return Some(node.clone()), }
}
return Some(node.clone()); }
None
}
#[inline(always)]
#[allow(clippy::inline_always)]
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn apply_rules_sin(node: &Arc<DagNode>) -> Option<Arc<DagNode>> {
if node.children.is_empty() {
return Some(node.clone()); }
let arg = &node.children[0];
if is_zero_node(arg) {
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(0.0)) {
| Ok(node) => node,
| Err(_) => node.clone(), });
}
if is_pi_node(arg) {
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(0.0)) {
| Ok(node) => node,
| Err(_) => node.clone(), });
}
if matches!(&arg.op, DagOp::Neg) {
if arg.children.is_empty() {
return Some(node.clone()); }
match DAG_MANAGER.get_or_create_normalized(DagOp::Sin, vec![arg.children[0].clone()]) {
| Ok(new_sin) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Neg, vec![new_sin]) {
| Ok(result) => return Some(result),
| Err(_) => return Some(node.clone()), }
},
| Err(_) => {
return Some(node.clone());
}, }
}
if matches!(&arg.op, DagOp::Add) {
if arg.children.len() >= 2 {
let a = &arg.children[0];
let b = &arg.children[1];
if is_pi_node(b) {
match DAG_MANAGER.get_or_create_normalized(DagOp::Sin, vec![a.clone()]) {
| Ok(sin_a) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Neg, vec![sin_a]) {
| Ok(result) => return Some(result),
| Err(_) => return Some(node.clone()),
}
},
| Err(_) => return Some(node.clone()), }
}
match DAG_MANAGER.get_or_create_normalized(DagOp::Sin, vec![a.clone()]) {
| Ok(sin_a) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Cos, vec![b.clone()]) {
| Ok(cos_b) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Cos, vec![a.clone()])
{
| Ok(cos_a) => {
match DAG_MANAGER
.get_or_create_normalized(DagOp::Sin, vec![b.clone()])
{
| Ok(sin_b) => {
match DAG_MANAGER.get_or_create_normalized(
DagOp::Mul,
vec![sin_a, cos_b],
) {
| Ok(term1) => {
match DAG_MANAGER.get_or_create_normalized(
DagOp::Mul,
vec![cos_a, sin_b],
) {
| Ok(term2) => {
match DAG_MANAGER
.get_or_create_normalized(
DagOp::Add,
vec![term1, term2],
) {
| Ok(result) => {
return Some(result);
},
| Err(_) => {
return Some(node.clone());
},
}
},
| Err(_) => return Some(node.clone()),
}
},
| Err(_) => return Some(node.clone()),
}
},
| Err(_) => return Some(node.clone()),
}
},
| Err(_) => return Some(node.clone()),
}
},
| Err(_) => return Some(node.clone()), }
},
| Err(_) => return Some(node.clone()), }
}
return Some(node.clone()); }
None
}
#[inline(always)]
#[allow(clippy::inline_always)]
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn apply_rules_exp(node: &Arc<DagNode>) -> Option<Arc<DagNode>> {
if node.children.is_empty() {
return Some(node.clone()); }
let arg = &node.children[0];
if is_zero_node(arg) {
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(1.0)) {
| Ok(node) => node,
| Err(_) => node.clone(), });
}
if matches!(&arg.op, DagOp::Log) {
if arg.children.is_empty() {
return Some(node.clone()); }
return Some(arg.children[0].clone());
}
None
}
#[inline(always)]
#[allow(clippy::inline_always)]
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn apply_rules_logbase(node: &Arc<DagNode>) -> Option<Arc<DagNode>> {
if node.children.len() < 2 {
return Some(node.clone()); }
let base = &node.children[0];
let arg = &node.children[1];
if base.hash == arg.hash {
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(1.0)) {
| Ok(node) => node,
| Err(_) => node.clone(), });
}
match DAG_MANAGER.get_or_create_normalized(DagOp::Log, vec![arg.clone()]) {
| Ok(log_a) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Log, vec![base.clone()]) {
| Ok(log_b) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Div, vec![log_a, log_b]) {
| Ok(result) => Some(result),
| Err(_) => Some(node.clone()), }
},
| Err(_) => Some(node.clone()), }
},
| Err(_) => Some(node.clone()), }
}
#[inline(always)]
#[allow(clippy::inline_always)]
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn apply_rules_log(node: &Arc<DagNode>) -> Option<Arc<DagNode>> {
if node.children.is_empty() {
return Some(node.clone()); }
let arg = &node.children[0];
if is_one_node(arg) {
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(0.0)) {
| Ok(node) => node,
| Err(_) => node.clone(), });
}
if matches!(&arg.op, DagOp::E) {
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(1.0)) {
| Ok(node) => node,
| Err(_) => node.clone(), });
}
if matches!(&arg.op, DagOp::Exp) {
if arg.children.is_empty() {
return Some(node.clone()); }
return Some(arg.children[0].clone());
}
if matches!(&arg.op, DagOp::Mul) {
if arg.children.len() >= 2 {
let a = &arg.children[0];
let b = &arg.children[1];
match DAG_MANAGER.get_or_create_normalized(DagOp::Log, vec![a.clone()]) {
| Ok(log_a) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Log, vec![b.clone()]) {
| Ok(log_b) => {
match DAG_MANAGER
.get_or_create_normalized(DagOp::Add, vec![log_a, log_b])
{
| Ok(result) => return Some(result),
| Err(_) => return Some(node.clone()),
}
},
| Err(_) => return Some(node.clone()), }
},
| Err(_) => return Some(node.clone()), }
}
return Some(node.clone()); }
if matches!(&arg.op, DagOp::Power) {
if arg.children.len() >= 2 {
let b = arg.children[1].clone();
let log_a = &arg.children[0];
match DAG_MANAGER.get_or_create_normalized(DagOp::Log, vec![log_a.clone()]) {
| Ok(log_a_node) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Mul, vec![b, log_a_node]) {
| Ok(result) => return Some(result),
| Err(_) => return Some(node.clone()),
}
},
| Err(_) => return Some(node.clone()), }
}
return Some(node.clone()); }
None
}
#[inline(always)]
#[allow(clippy::inline_always)]
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn apply_rules_power(node: &Arc<DagNode>) -> Option<Arc<DagNode>> {
if node.children.len() < 2 {
return Some(node.clone()); }
let base = &node.children[0];
let exp = &node.children[1];
if is_one_node(exp) {
return Some(base.clone());
}
if is_zero_node(exp) {
if is_zero_node(base) || is_infinite_node(base) {
return Some(node.clone());
}
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(1.0)) {
| Ok(node) => node,
| Err(_) => node.clone(), });
}
if is_one_node(base) {
if is_infinite_node(exp) {
return Some(node.clone());
}
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(1.0)) {
| Ok(node) => node,
| Err(_) => node.clone(), });
}
if is_zero_node(base) {
match &exp.op {
| DagOp::Constant(c) if c.0 < 0.0 => {
return Some(match DAG_MANAGER.get_or_create(&Expr::Infinity) {
| Ok(node) => node,
| Err(_) => node.clone(),
});
},
| DagOp::BigInt(b) if *b < BigInt::zero() => {
return Some(match DAG_MANAGER.get_or_create(&Expr::Infinity) {
| Ok(node) => node,
| Err(_) => node.clone(),
});
},
| DagOp::Rational(r) if *r < BigRational::zero() => {
return Some(match DAG_MANAGER.get_or_create(&Expr::Infinity) {
| Ok(node) => node,
| Err(_) => node.clone(),
});
},
| _ => {},
}
}
if matches!(&base.op, DagOp::Variable(name) if name == "i")
&& (is_const_node(exp, 2.0) || matches!(&exp.op, DagOp::BigInt(b) if *b == BigInt::from(2)))
{
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(-1.0)) {
| Ok(node) => node,
| Err(_) => node.clone(),
});
}
if matches!(&base.op, DagOp::Sqrt)
&& (is_const_node(exp, 2.0) || matches!(&exp.op, DagOp::BigInt(b) if *b == BigInt::from(2)))
&& !base.children.is_empty()
{
return Some(base.children[0].clone());
}
if matches!(&base.op, DagOp::Power) && base.children.len() >= 2 {
match DAG_MANAGER
.get_or_create_normalized(DagOp::Mul, vec![base.children[1].clone(), exp.clone()])
{
| Ok(new_exp) => {
match DAG_MANAGER
.get_or_create_normalized(DagOp::Power, vec![base.children[0].clone(), new_exp])
{
| Ok(result) => return Some(result),
| Err(_) => return Some(node.clone()),
}
},
| Err(_) => {}, }
}
None
}
#[inline(always)]
#[allow(clippy::inline_always)]
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn apply_rules_neg(node: &Arc<DagNode>) -> Option<Arc<DagNode>> {
if node.children.is_empty() {
return Some(node.clone()); }
let inner = &node.children[0];
if matches!(&inner.op, DagOp::Neg) {
if inner.children.is_empty() {
return Some(node.clone()); }
return Some(inner.children[0].clone());
}
None
}
#[inline(always)]
#[allow(clippy::inline_always)]
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn apply_rules_div(node: &Arc<DagNode>) -> Option<Arc<DagNode>> {
if node.children.len() < 2 {
return Some(node.clone()); }
let lhs = &node.children[0];
let rhs = &node.children[1];
if is_one_node(rhs) {
return Some(lhs.clone());
}
if lhs.hash == rhs.hash {
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(1.0)) {
| Ok(node) => node,
| Err(_) => node.clone(), });
}
if is_zero_node(lhs) {
if is_zero_node(rhs) || is_infinite_node(rhs) {
return Some(node.clone());
}
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(0.0)) {
| Ok(node) => node,
| Err(_) => node.clone(),
});
}
match DAG_MANAGER.get_or_create(&Expr::BigInt(BigInt::from(-1))) {
| Ok(neg_one) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Power, vec![rhs.clone(), neg_one]) {
| Ok(rhs_pow_neg_one) => {
match DAG_MANAGER
.get_or_create_normalized(DagOp::Mul, vec![lhs.clone(), rhs_pow_neg_one])
{
| Ok(result) => Some(result),
| Err(_) => Some(node.clone()),
}
},
| Err(_) => Some(node.clone()), }
},
| Err(_) => Some(node.clone()), }
}
#[inline(always)]
#[allow(clippy::inline_always)]
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn apply_rules_mul(node: &Arc<DagNode>) -> Option<Arc<DagNode>> {
if node.children.len() < 2 {
return Some(node.clone()); }
let lhs = &node.children[0];
let rhs = &node.children[1];
if is_one_node(rhs) {
return Some(lhs.clone());
}
if is_one_node(lhs) {
return Some(rhs.clone());
}
if is_zero_node(rhs) {
if is_infinite_node(lhs) {
return Some(node.clone());
}
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(0.0)) {
| Ok(node) => node,
| Err(_) => node.clone(),
});
}
if is_zero_node(lhs) {
if is_infinite_node(rhs) {
return Some(node.clone());
}
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(0.0)) {
| Ok(node) => node,
| Err(_) => node.clone(),
});
}
if is_neg_one_node(rhs) {
return Some(
match DAG_MANAGER.get_or_create_normalized(DagOp::Neg, vec![lhs.clone()]) {
| Ok(node) => node,
| Err(_) => node.clone(), },
);
}
if is_neg_one_node(lhs) {
return Some(
match DAG_MANAGER.get_or_create_normalized(DagOp::Neg, vec![rhs.clone()]) {
| Ok(node) => node,
| Err(_) => node.clone(), },
);
}
if lhs.hash == rhs.hash {
match DAG_MANAGER.get_or_create(&Expr::Constant(2.0)) {
| Ok(two) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Power, vec![lhs.clone(), two]) {
| Ok(result) => return Some(result),
| Err(_) => {}, }
},
| Err(_) => {}, }
}
if matches!(&rhs.op, DagOp::Add) && rhs.children.len() >= 2 {
let a = lhs;
let b = &rhs.children[0];
let c = &rhs.children[1];
match DAG_MANAGER.get_or_create_normalized(DagOp::Mul, vec![a.clone(), b.clone()]) {
| Ok(ab) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Mul, vec![a.clone(), c.clone()]) {
| Ok(ac) => {
return Some(
match DAG_MANAGER.get_or_create_normalized(DagOp::Add, vec![ab, ac]) {
| Ok(result) => result,
| Err(_) => node.clone(), },
);
},
| Err(_) => {}, }
},
| Err(_) => {}, }
}
if matches!(&lhs.op, DagOp::Add) && lhs.children.len() >= 2 {
let a = &lhs.children[0];
let b = &lhs.children[1];
let c = rhs;
match DAG_MANAGER.get_or_create_normalized(DagOp::Mul, vec![a.clone(), c.clone()]) {
| Ok(ac) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Mul, vec![b.clone(), c.clone()]) {
| Ok(bc) => {
return Some(
match DAG_MANAGER.get_or_create_normalized(DagOp::Add, vec![ac, bc]) {
| Ok(result) => result,
| Err(_) => node.clone(), },
);
},
| Err(_) => {}, }
},
| Err(_) => {}, }
}
Some(simplify_mul(node))
}
#[inline(always)]
#[allow(clippy::inline_always)]
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn apply_rules_sub(node: &Arc<DagNode>) -> Option<Arc<DagNode>> {
if node.children.len() < 2 {
return Some(node.clone()); }
let lhs = &node.children[0];
let rhs = &node.children[1];
if is_zero_node(rhs) {
return Some(lhs.clone());
}
if matches!(&rhs.op, DagOp::Neg) && !rhs.children.is_empty() {
let b = &rhs.children[0];
return Some(
match DAG_MANAGER.get_or_create_normalized(DagOp::Add, vec![lhs.clone(), b.clone()]) {
| Ok(result) => result,
| Err(_) => node.clone(),
},
);
}
if lhs.hash == rhs.hash {
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(0.0)) {
| Ok(node) => node,
| Err(_) => node.clone(), });
}
if matches!(&lhs.op, DagOp::Mul) && matches!(&rhs.op, DagOp::Mul) {
let mut terms_lhs = Vec::new();
flatten_terms(lhs, &mut terms_lhs);
let mut terms_rhs = Vec::new();
flatten_terms(rhs, &mut terms_rhs);
let one_node_a = DAG_MANAGER
.get_or_create_normalized(
DagOp::Constant(OrderedFloat(1.0)), vec![], )
.unwrap_or_else(|_| node.clone());
let (a, b) = if lhs.children.len() < 2 || terms_lhs.len() < 2 {
(one_node_a, terms_lhs[0].clone())
} else {
(terms_lhs[0].clone(), terms_lhs[1].clone())
};
let one_node_c = DAG_MANAGER
.get_or_create_normalized(
DagOp::Constant(OrderedFloat(1.0)), vec![], )
.unwrap_or_else(|_| node.clone());
let (c, d) = if rhs.children.len() < 2 || terms_rhs.len() < 2 {
(one_node_c, terms_rhs[0].clone())
} else {
(terms_rhs[0].clone(), terms_rhs[1].clone())
};
if b.hash == d.hash {
return Some(
match DAG_MANAGER.get_or_create_normalized(DagOp::Sub, vec![a, c]) {
| Ok(result) => result,
| Err(_) => node.clone(),
},
);
}
}
Some(simplify_add(node))
}
#[inline(always)]
#[allow(clippy::inline_always)]
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn apply_rules_add(node: &Arc<DagNode>) -> Option<Arc<DagNode>> {
if node.children.len() < 2 {
return Some(node.clone()); }
let lhs = &node.children[0];
let rhs = &node.children[1];
if is_zero_node(rhs) {
return Some(lhs.clone());
}
if is_zero_node(lhs) {
return Some(rhs.clone());
}
if lhs.hash == rhs.hash {
match DAG_MANAGER.get_or_create(&Expr::Constant(2.0)) {
| Ok(two) => {
match DAG_MANAGER.get_or_create_normalized(DagOp::Mul, vec![two, lhs.clone()]) {
| Ok(result) => return Some(result),
| Err(_) => {}, }
},
| Err(_) => {}, }
}
if matches!((&lhs.op, &rhs.op), (DagOp::Mul, DagOp::Mul))
&& lhs.children.len() >= 2
&& rhs.children.len() >= 2
{
let a = &lhs.children[0];
let x1 = &lhs.children[1];
let b = &rhs.children[0];
let x2 = &rhs.children[1];
if x1.hash == x2.hash {
match DAG_MANAGER.get_or_create_normalized(DagOp::Add, vec![a.clone(), b.clone()]) {
| Ok(a_plus_b) => {
match DAG_MANAGER
.get_or_create_normalized(DagOp::Mul, vec![a_plus_b, x1.clone()])
{
| Ok(result) => return Some(result),
| Err(_) => {}, }
},
| Err(_) => {}, }
}
if a.hash == b.hash {
match DAG_MANAGER.get_or_create_normalized(DagOp::Add, vec![x1.clone(), x2.clone()]) {
| Ok(x_plus_y) => {
match DAG_MANAGER
.get_or_create_normalized(DagOp::Mul, vec![x_plus_y, a.clone()])
{
| Ok(result) => return Some(result),
| Err(_) => {}, }
},
| Err(_) => {}, }
}
}
if matches!((&lhs.op, &rhs.op), (DagOp::Power, DagOp::Power))
&& lhs.children.len() >= 2
&& rhs.children.len() >= 2
&& is_const_node(&lhs.children[1], 2.0)
&& is_const_node(&rhs.children[1], 2.0)
{
if matches!(
(&lhs.children[0].op, &rhs.children[0].op),
(DagOp::Sin, DagOp::Cos)
) && !lhs.children[0].children.is_empty()
&& !rhs.children[0].children.is_empty()
&& lhs.children[0].children[0].hash == rhs.children[0].children[0].hash
{
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(1.0)) {
| Ok(node) => node,
| Err(_) => node.clone(), });
}
if matches!(
(&lhs.children[0].op, &rhs.children[0].op),
(DagOp::Cos, DagOp::Sin)
) && !lhs.children[0].children.is_empty()
&& !rhs.children[0].children.is_empty()
&& lhs.children[0].children[0].hash == rhs.children[0].children[0].hash
{
return Some(match DAG_MANAGER.get_or_create(&Expr::Constant(1.0)) {
| Ok(node) => node,
| Err(_) => node.clone(), });
}
}
Some(simplify_add(node))
}
pub(crate) fn fold_constants(node: &Arc<DagNode>) -> Option<Arc<DagNode>> {
let children_values: Option<Vec<Expr>> = node.children.iter().map(get_numeric_value).collect();
if let Some(values) = children_values {
let result = match (&node.op, values.as_slice()) {
| (DagOp::Add, [a, b]) => Some(add_em(a, b)),
| (DagOp::Sub, [a, b]) => Some(sub_em(a, b)),
| (DagOp::Mul, [a, b]) => Some(mul_em(a, b)),
| (DagOp::Div, [a, b]) => div_em(a, b),
| (DagOp::Power, [Expr::Constant(a), Expr::Constant(b)]) => {
Some(Expr::Constant(a.powf(*b)))
},
| (DagOp::Neg, [a]) => Some(neg_em(a)),
| (DagOp::Sqrt, [a]) => {
match a.to_f64() {
| Some(val) if val >= 0.0 => {
let root = val.sqrt();
if (root.round() - root).abs() < 1e-12 {
Some(Expr::Constant(root.round()))
} else {
None
}
},
| _ => None,
}
},
| _ => None,
};
if let Some(value) = result {
return match DAG_MANAGER.get_or_create(&value) {
| Ok(node) => Some(node),
| Err(_) => None, };
}
}
None
}
#[inline]
pub(crate) fn get_numeric_value(node: &Arc<DagNode>) -> Option<Expr> {
match &node.op {
| DagOp::Constant(c) => Some(Expr::Constant(c.into_inner())),
| DagOp::BigInt(i) => Some(Expr::BigInt(i.clone())),
| DagOp::Rational(r) => Some(Expr::Rational(r.clone())),
| _ => None,
}
}
#[inline]
pub(crate) fn add_em(
a: &Expr,
b: &Expr,
) -> Expr {
match (a, b) {
| (Expr::Constant(va), Expr::Constant(vb)) => {
let result = va + vb;
if result.is_infinite() || result.is_nan() {
Expr::Constant(*va) } else {
Expr::Constant(result)
}
},
| (Expr::BigInt(ia), Expr::BigInt(ib)) => Expr::BigInt(ia + ib),
| (Expr::Rational(ra), Expr::Rational(rb)) => Expr::Rational(ra + rb),
| _ => {
match (a.to_f64(), b.to_f64()) {
| (Some(va), Some(vb)) => {
let result = va + vb;
if result.is_infinite() || result.is_nan() {
Expr::new_add(a, b) } else {
Expr::Constant(result)
}
},
| _ => Expr::new_add(a, b), // Return original expression if conversion fails
}
},
}
}
#[inline]
pub(crate) fn sub_em(
a: &Expr,
b: &Expr,
) -> Expr {
match (a, b) {
| (Expr::Constant(va), Expr::Constant(vb)) => {
let result = va - vb;
if result.is_infinite() || result.is_nan() {
Expr::Constant(*va) } else {
Expr::Constant(result)
}
},
| (Expr::BigInt(ia), Expr::BigInt(ib)) => Expr::BigInt(ia - ib),
| (Expr::Rational(ra), Expr::Rational(rb)) => Expr::Rational(ra - rb),
| _ => {
match (a.to_f64(), b.to_f64()) {
| (Some(va), Some(vb)) => {
let result = va - vb;
if result.is_infinite() || result.is_nan() {
Expr::new_sub(a, b) } else {
Expr::Constant(result)
}
},
| _ => Expr::new_sub(a, b), // Return original expression if conversion fails
}
},
}
}
#[inline]
pub(crate) fn mul_em(
a: &Expr,
b: &Expr,
) -> Expr {
match (a, b) {
| (Expr::Constant(va), Expr::Constant(vb)) => {
let result = va * vb;
if result.is_infinite() || result.is_nan() {
if (va.is_infinite() && is_zero_expr(b)) || (vb.is_infinite() && is_zero_expr(a)) {
Expr::Constant(0.0) } else {
Expr::Constant(*va) }
} else {
Expr::Constant(result)
}
},
| (Expr::BigInt(ia), Expr::BigInt(ib)) => Expr::BigInt(ia * ib),
| (Expr::Rational(ra), Expr::Rational(rb)) => Expr::Rational(ra * rb),
| _ => {
match (a.to_f64(), b.to_f64()) {
| (Some(va), Some(vb)) => {
let result = va * vb;
if result.is_infinite() || result.is_nan() {
Expr::new_mul(a, b) } else {
Expr::Constant(result)
}
},
| _ => Expr::new_mul(a, b), // Return original expression if conversion fails
}
},
}
}
#[inline]
pub(crate) fn div_em(
a: &Expr,
b: &Expr,
) -> Option<Expr> {
if is_zero_expr(b) {
if is_zero_expr(a) {
return None;
}
return Some(Expr::Infinity);
}
match (a, b) {
| (Expr::Constant(va), Expr::Constant(vb)) => {
let result = va / vb;
if result.is_infinite() {
Some(Expr::Infinity) } else if result.is_nan() {
None } else {
Some(Expr::Constant(result))
}
},
| (Expr::BigInt(ia), Expr::BigInt(ib)) => {
Some(Expr::Rational(BigRational::new(ia.clone(), ib.clone())))
},
| (Expr::Rational(ra), Expr::Rational(rb)) => Some(Expr::Rational(ra / rb)),
| _ => {
match (a.to_f64(), b.to_f64()) {
| (Some(va), Some(vb)) => {
let result = va / vb;
if result.is_infinite() {
Some(Expr::Infinity)
} else if result.is_nan() {
None } else {
Some(Expr::Constant(result))
}
},
| _ => Some(Expr::new_div(a, b)),
}
},
}
}
#[inline]
pub(crate) fn neg_em(a: &Expr) -> Expr {
match a {
| Expr::Constant(v) => Expr::Constant(-v),
| Expr::BigInt(i) => Expr::BigInt(-i),
| Expr::Rational(r) => Expr::Rational(-r),
| _ => unreachable!(),
}
}
#[inline]
pub(crate) fn is_numeric_node(node: &Arc<DagNode>) -> bool {
matches!(
&node.op,
DagOp::Constant(_) | DagOp::BigInt(_) | DagOp::Rational(_)
)
}
#[inline]
pub(crate) fn is_zero_expr(expr: &Expr) -> bool {
match expr {
| Expr::Constant(c) if *c == 0.0 => true,
| Expr::BigInt(i) if i.is_zero() => true,
| Expr::Rational(r) if r.is_zero() => true,
| _ => false, }
}
#[inline]
pub(crate) fn is_one_expr(expr: &Expr) -> bool {
match expr {
| Expr::Constant(c) if (*c - 1.0).abs() < f64::EPSILON => true,
| Expr::BigInt(i) if i.is_one() => true,
| Expr::Rational(r) if r.is_one() => true,
| _ => false, }
}
#[inline]
pub(crate) fn zero_node() -> Arc<DagNode> {
match DAG_MANAGER.get_or_create(&Expr::BigInt(BigInt::zero())) {
| Ok(node) => node,
| Err(_) => {
DAG_MANAGER
.get_or_create(&Expr::Constant(0.0))
.unwrap_or_else(|_| panic!("Failed to create zero node"))
},
}
}
#[inline]
#[allow(dead_code)]
pub(crate) fn one_node() -> Arc<DagNode> {
match DAG_MANAGER.get_or_create(&Expr::BigInt(BigInt::one())) {
| Ok(node) => node,
| Err(_) => {
DAG_MANAGER
.get_or_create(&Expr::Constant(1.0))
.unwrap_or_else(|_| panic!("Failed to create one node"))
},
}
}
#[inline]
pub(crate) fn is_const_node(
node: &Arc<DagNode>,
val: f64,
) -> bool {
matches!(&node.op, DagOp::Constant(c) if (c.into_inner() - val).abs() < f64::EPSILON)
}
#[inline]
pub(crate) fn is_zero_node(node: &Arc<DagNode>) -> bool {
match &node.op {
| DagOp::Constant(c) if c.is_zero() => true,
| DagOp::BigInt(i) if i.is_zero() => true,
| DagOp::Rational(r) if r.is_zero() => true,
| _ => false, }
}
pub(crate) fn is_infinite_node(node: &Arc<DagNode>) -> bool {
match &node.op {
| DagOp::Infinity | DagOp::NegativeInfinity => true,
| DagOp::Constant(c) => c.0.is_infinite(),
| _ => false,
}
}
#[inline]
pub(crate) fn is_one_node(node: &Arc<DagNode>) -> bool {
match &node.op {
| DagOp::Constant(c) if c.is_one() => true,
| DagOp::BigInt(i) if i.is_one() => true,
| DagOp::Rational(r) if r.is_one() => true,
| _ => false, }
}
#[inline]
pub(crate) fn is_neg_one_node(node: &Arc<DagNode>) -> bool {
matches!(&node.op, DagOp::Constant(c) if (c.into_inner() + 1.0).abs() < f64::EPSILON)
}
#[inline]
pub(crate) fn is_pi_node(node: &Arc<DagNode>) -> bool {
matches!(&node.op, DagOp::Pi)
}
#[inline]
pub(crate) fn flatten_mul_terms(
node: &Arc<DagNode>,
terms: &mut Vec<Arc<DagNode>>,
) {
if matches!(&node.op, DagOp::Mul) {
flatten_mul_terms(&node.children[0], terms);
flatten_mul_terms(&node.children[1], terms);
} else {
terms.push(node.clone());
}
}
pub(crate) fn simplify_mul(node: &Arc<DagNode>) -> Arc<DagNode> {
let mut factors = Vec::new();
flatten_mul_terms(node, &mut factors);
let mut exponents: BTreeMap<Arc<DagNode>, Expr> = BTreeMap::new(); let mut constant = Expr::BigInt(BigInt::one());
for factor in factors {
if let Some(val) = get_numeric_value(&factor) {
constant = mul_em(&constant, &val);
continue;
}
let (base_node, exponent_expr) = if matches!(&factor.op, DagOp::Power) {
if factor.children.len() < 2 {
continue; }
(
factor.children[0].clone(),
factor.children[1]
.to_expr()
.unwrap_or(Expr::BigInt(BigInt::one())),
)
} else {
(factor.clone(), Expr::BigInt(BigInt::one()))
};
let entry = exponents
.entry(base_node)
.or_insert(Expr::BigInt(BigInt::zero()));
*entry = add_em(entry, &exponent_expr);
}
let mut new_factors = Vec::new();
for (base, exponent) in exponents {
if is_zero_expr(&exponent) {
continue; }
if is_one_expr(&exponent) {
new_factors.push(base.clone()); } else {
match DAG_MANAGER.get_or_create(&exponent) {
| Ok(exp_node) => {
match DAG_MANAGER
.get_or_create_normalized(DagOp::Power, vec![base.clone(), exp_node])
{
| Ok(power_node) => new_factors.push(power_node),
| Err(_) => {
new_factors.push(base.clone());
},
}
},
| Err(_) => {
new_factors.push(base.clone());
},
}
}
}
if is_zero_expr(&constant) {
return zero_node(); }
if !is_one_expr(&constant) {
if let Ok(constant_node) = DAG_MANAGER.get_or_create(&constant) {
new_factors.insert(0, constant_node);
}
}
if new_factors.is_empty() {
return one_node();
}
new_factors.sort_by_key(|n| n.hash);
let mut result = new_factors[0].clone();
for factor in new_factors.iter().skip(1) {
result = match DAG_MANAGER
.get_or_create_normalized(DagOp::Mul, vec![result.clone(), factor.clone()])
{
| Ok(node) => node,
| Err(_) => {
break;
},
};
}
result
}
#[inline]
pub(crate) fn flatten_terms(
node: &Arc<DagNode>,
terms: &mut Vec<Arc<DagNode>>,
) {
match &node.op {
| DagOp::Add => {
flatten_terms(&node.children[0], terms);
flatten_terms(&node.children[1], terms);
},
| DagOp::Sub => {
if node.children.len() >= 2 {
flatten_terms(&node.children[0], terms);
match DAG_MANAGER
.get_or_create_normalized(DagOp::Neg, vec![node.children[1].clone()])
{
| Ok(neg_node) => terms.push(neg_node),
| Err(_) => {
terms.push(node.clone());
},
}
} else {
terms.push(node.clone());
}
},
| _ => {
terms.push(node.clone());
},
}
}
pub(crate) fn simplify_add(node: &Arc<DagNode>) -> Arc<DagNode> {
let mut terms = Vec::new();
flatten_terms(node, &mut terms);
let mut coeffs: BTreeMap<Arc<DagNode>, Expr> = BTreeMap::new(); let mut constant = Expr::BigInt(BigInt::zero());
for term in terms {
if let Some(val) = get_numeric_value(&term) {
constant = add_em(&constant, &val);
continue;
}
let simplified_term = if matches!(&term.op, DagOp::Mul) {
simplify_mul(&term)
} else {
term.clone()
};
let (coeff_expr, base_node) = if matches!(&simplified_term.op, DagOp::Neg) {
if simplified_term.children.is_empty() {
(Expr::BigInt(BigInt::one()), simplified_term.clone())
} else {
let child = &simplified_term.children[0];
if matches!(&child.op, DagOp::Mul) && child.children.len() >= 2 {
let c = &child.children[0];
let b = &child.children[1];
if is_numeric_node(c) {
get_numeric_value(c).map_or_else(
|| (Expr::Constant(-1.0), child.clone()),
|val| (neg_em(&val), b.clone()),
)
} else {
(Expr::Constant(-1.0), child.clone())
}
} else {
(Expr::Constant(-1.0), child.clone())
}
}
} else if matches!(&simplified_term.op, DagOp::Mul) {
if simplified_term.children.len() < 2 {
(Expr::BigInt(BigInt::one()), simplified_term.clone())
} else {
let c = &simplified_term.children[0];
let b = &simplified_term.children[1];
if is_numeric_node(c) {
get_numeric_value(c).map_or_else(
|| (Expr::BigInt(BigInt::one()), simplified_term.clone()),
|val| (val, b.clone()),
)
} else if is_numeric_node(b) {
get_numeric_value(b).map_or_else(
|| (Expr::BigInt(BigInt::one()), simplified_term.clone()),
|val| (val, c.clone()),
)
} else {
(Expr::BigInt(BigInt::one()), simplified_term.clone())
}
}
} else {
(Expr::BigInt(BigInt::one()), simplified_term.clone())
};
let entry = coeffs
.entry(base_node)
.or_insert(Expr::BigInt(BigInt::zero()));
*entry = add_em(entry, &coeff_expr);
}
let mut new_terms = Vec::new();
for (base, coeff) in coeffs {
if is_zero_expr(&coeff) {
continue; }
if is_one_expr(&coeff) {
new_terms.push(base.clone()); } else {
match DAG_MANAGER.get_or_create(&coeff) {
| Ok(coeff_node) => {
match DAG_MANAGER
.get_or_create_normalized(DagOp::Mul, vec![base.clone(), coeff_node])
{
| Ok(mul_node) => new_terms.push(mul_node),
| Err(_) => {
new_terms.push(base.clone());
},
}
},
| Err(_) => {
new_terms.push(base.clone());
},
}
}
}
if !is_zero_expr(&constant) {
if let Ok(constant_node) = DAG_MANAGER.get_or_create(&constant) {
new_terms.push(constant_node);
}
}
if new_terms.is_empty() {
return zero_node();
}
new_terms.sort_by_key(|n| n.hash);
let mut result = new_terms[0].clone();
for term in new_terms.iter().skip(1) {
result = match DAG_MANAGER
.get_or_create_normalized(DagOp::Add, vec![result.clone(), term.clone()])
{
| Ok(node) => node,
| Err(_) => {
break;
},
};
}
result
}
#[must_use]
pub fn pattern_match(
expr: &Expr,
pattern: &Expr,
) -> Option<HashMap<String, Expr>> {
let mut assignments = HashMap::new();
if pattern_match_recursive(expr, pattern, &mut assignments) {
Some(assignments)
} else {
None
}
}
pub(crate) fn pattern_match_recursive(
expr: &Expr,
pattern: &Expr,
assignments: &mut HashMap<String, Expr>,
) -> bool {
let expr_unwrapped = match expr {
| Expr::Dag(node) => node.to_expr().unwrap_or_else(|_| expr.clone()),
| _ => expr.clone(),
};
let pattern_unwrapped = match pattern {
| Expr::Dag(node) => node.to_expr().unwrap_or_else(|_| pattern.clone()),
| _ => pattern.clone(),
};
match (&expr_unwrapped, &pattern_unwrapped) {
| (_, Expr::Pattern(name)) => {
if let Some(existing) = assignments.get(name) {
return existing == expr;
}
assignments.insert(name.clone(), expr.clone());
true
},
| (Expr::Add(e1, e2), Expr::Add(p1, p2)) | (Expr::Mul(e1, e2), Expr::Mul(p1, p2)) => {
let original_assignments = assignments.clone();
if pattern_match_recursive(e1, p1, assignments)
&& pattern_match_recursive(e2, p2, assignments)
{
return true;
}
*assignments = original_assignments;
pattern_match_recursive(e1, p2, assignments)
&& pattern_match_recursive(e2, p1, assignments)
},
| (Expr::Sub(e1, e2), Expr::Sub(p1, p2))
| (Expr::Div(e1, e2), Expr::Div(p1, p2))
| (Expr::Power(e1, e2), Expr::Power(p1, p2)) => {
pattern_match_recursive(e1, p1, assignments)
&& pattern_match_recursive(e2, p2, assignments)
},
| (Expr::Sin(e), Expr::Sin(p))
| (Expr::Cos(e), Expr::Cos(p))
| (Expr::Tan(e), Expr::Tan(p))
| (Expr::Exp(e), Expr::Exp(p))
| (Expr::Log(e), Expr::Log(p))
| (Expr::Neg(e), Expr::Neg(p))
| (Expr::Abs(e), Expr::Abs(p))
| (Expr::Sqrt(e), Expr::Sqrt(p)) => pattern_match_recursive(e, p, assignments),
| (Expr::NaryList(s1, v1), Expr::NaryList(s2, v2)) => {
if s1 != s2 || v1.len() != v2.len() {
return false;
}
let original_assignments = assignments.clone();
for (e, p) in v1.iter().zip(v2.iter()) {
if !pattern_match_recursive(e, p, assignments) {
*assignments = original_assignments;
return false;
}
}
true
},
| (Expr::UnaryList(s1, e1), Expr::UnaryList(s2, p1)) => {
s1 == s2 && pattern_match_recursive(e1, p1, assignments)
},
| (Expr::BinaryList(s1, e1a, e1b), Expr::BinaryList(s2, p1a, p1b)) => {
s1 == s2
&& pattern_match_recursive(e1a, p1a, assignments)
&& pattern_match_recursive(e1b, p1b, assignments)
},
| _ => expr_unwrapped == pattern_unwrapped,
}
}
#[must_use]
pub fn substitute_patterns<S: std::hash::BuildHasher>(
template: &Expr,
assignments: &HashMap<String, Expr, S>,
) -> Expr {
let template_unwrapped = match template {
| Expr::Dag(node) => node.to_expr().unwrap_or_else(|_| template.clone()),
| _ => template.clone(),
};
match template_unwrapped {
| Expr::Pattern(name) => {
assignments
.get(&name)
.cloned()
.unwrap_or_else(|| template.clone())
},
| Expr::Add(a, b) => {
Expr::new_add(
substitute_patterns(&a, assignments),
substitute_patterns(&b, assignments),
)
},
| Expr::Sub(a, b) => {
Expr::new_sub(
substitute_patterns(&a, assignments),
substitute_patterns(&b, assignments),
)
},
| Expr::Mul(a, b) => {
Expr::new_mul(
substitute_patterns(&a, assignments),
substitute_patterns(&b, assignments),
)
},
| Expr::Div(a, b) => {
Expr::new_div(
substitute_patterns(&a, assignments),
substitute_patterns(&b, assignments),
)
},
| Expr::Power(b, e) => {
Expr::new_pow(
substitute_patterns(&b, assignments),
substitute_patterns(&e, assignments),
)
},
| Expr::Sin(a) => Expr::new_sin(substitute_patterns(&a, assignments)),
| Expr::Cos(a) => Expr::new_cos(substitute_patterns(&a, assignments)),
| Expr::Tan(a) => Expr::new_tan(substitute_patterns(&a, assignments)),
| Expr::Exp(a) => Expr::new_exp(substitute_patterns(&a, assignments)),
| Expr::Log(a) => Expr::new_log(substitute_patterns(&a, assignments)),
| Expr::Neg(a) => Expr::new_neg(substitute_patterns(&a, assignments)),
| Expr::Abs(a) => Expr::new_abs(substitute_patterns(&a, assignments)),
| Expr::Sqrt(a) => Expr::new_sqrt(substitute_patterns(&a, assignments)),
| Expr::NaryList(s, v) => {
Expr::NaryList(
s,
v.iter()
.map(|e| substitute_patterns(e, assignments))
.collect(),
)
},
| Expr::UnaryList(s, e) => {
Expr::UnaryList(s, Arc::new(substitute_patterns(&e, assignments)))
},
| Expr::BinaryList(s, a, b) => {
Expr::BinaryList(
s,
Arc::new(substitute_patterns(&a, assignments)),
Arc::new(substitute_patterns(&b, assignments)),
)
},
| _ => template.clone(),
}
}