use crate::Symbol;
use crate::{
util::pretty_print, Analysis, EClass, EGraph, ENodeOrVar, FromOp, HashMap, HashSet, Id,
Language, Pattern, PatternAst, RecExpr, Rewrite, Subst, UnionFind, Var,
};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, VecDeque};
use std::fmt::{self, Debug, Display, Formatter};
use std::rc::Rc;
use symbolic_expressions::Sexp;
const CONGRUENCE_LIMIT: usize = 2;
const GREEDY_NUM_ITERS: usize = 2;
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
pub enum Justification {
Rule(Symbol),
Congruence,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
struct Connection {
next: Id,
current: Id,
justification: Justification,
is_rewrite_forward: bool,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
struct ExplainNode<L: Language> {
node: L,
neighbors: Vec<Connection>,
parent_connection: Connection,
existance_node: Id,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
pub struct Explain<L: Language> {
explainfind: Vec<ExplainNode<L>>,
#[cfg_attr(feature = "serde-1", serde(with = "vectorize"))]
pub uncanon_memo: HashMap<L, Id>,
pub optimize_explanation_lengths: bool,
shortest_explanation_memo: HashMap<(Id, Id), (usize, Id)>,
}
#[derive(Default)]
struct DistanceMemo {
parent_distance: Vec<(Id, usize)>,
common_ancestor: HashMap<(Id, Id), Id>,
tree_depth: HashMap<Id, usize>,
}
pub type TreeExplanation<L> = Vec<Rc<TreeTerm<L>>>;
pub type FlatExplanation<L> = Vec<FlatTerm<L>>;
pub type UnionEqualities = Vec<(Id, Id, Symbol)>;
type ExplainCache<L> = HashMap<(Id, Id), Rc<TreeTerm<L>>>;
type NodeExplanationCache<L> = HashMap<Id, Rc<TreeTerm<L>>>;
pub struct Explanation<L: Language> {
pub explanation_trees: TreeExplanation<L>,
flat_explanation: Option<FlatExplanation<L>>,
}
impl<L: Language + Display + FromOp> Display for Explanation<L> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let s = self.get_sexp().to_string();
f.write_str(&s)
}
}
impl<L: Language + Display + FromOp> Explanation<L> {
pub fn get_flat_string(&mut self) -> String {
self.get_flat_strings().join("\n")
}
pub fn get_string(&self) -> String {
self.to_string()
}
pub fn get_string_with_let(&self) -> String {
let mut s = "".to_string();
pretty_print(&mut s, &self.get_sexp_with_let(), 100, 0).unwrap();
s
}
pub fn get_flat_strings(&mut self) -> Vec<String> {
self.make_flat_explanation()
.iter()
.map(|e| e.to_string())
.collect()
}
fn get_sexp(&self) -> Sexp {
let mut items = vec![Sexp::String("Explanation".to_string())];
for e in self.explanation_trees.iter() {
items.push(e.get_sexp());
}
Sexp::List(items)
}
pub fn get_tree_size(&self) -> usize {
let mut seen = Default::default();
let mut seen_adjacent = Default::default();
let mut sum = 0;
for e in self.explanation_trees.iter() {
sum += self.tree_size(&mut seen, &mut seen_adjacent, e);
}
sum
}
fn tree_size(
&self,
seen: &mut HashSet<*const TreeTerm<L>>,
seen_adjacent: &mut HashSet<(Id, Id)>,
current: &Rc<TreeTerm<L>>,
) -> usize {
if !seen.insert(&**current as *const TreeTerm<L>) {
return 0;
}
let mut my_size = 0;
if current.forward_rule.is_some() {
my_size += 1;
}
if current.backward_rule.is_some() {
my_size += 1;
}
assert!(my_size <= 1);
if my_size == 1 {
if !seen_adjacent.insert((current.current, current.last)) {
return 0;
} else {
seen_adjacent.insert((current.last, current.current));
}
}
for child_proof in ¤t.child_proofs {
for child in child_proof {
my_size += self.tree_size(seen, seen_adjacent, child);
}
}
my_size
}
fn get_sexp_with_let(&self) -> Sexp {
let mut shared: HashSet<*const TreeTerm<L>> = Default::default();
let mut to_let_bind = vec![];
for term in &self.explanation_trees {
self.find_to_let_bind(term.clone(), &mut shared, &mut to_let_bind);
}
let mut bindings: HashMap<*const TreeTerm<L>, Sexp> = Default::default();
let mut generated_bindings: Vec<(Sexp, Sexp)> = Default::default();
for to_bind in to_let_bind {
if bindings.get(&(&*to_bind as *const TreeTerm<L>)).is_none() {
let name = Sexp::String("v_".to_string() + &generated_bindings.len().to_string());
let ast = to_bind.get_sexp_with_bindings(&bindings);
generated_bindings.push((name.clone(), ast));
bindings.insert(&*to_bind as *const TreeTerm<L>, name);
}
}
let mut items = vec![Sexp::String("Explanation".to_string())];
for e in self.explanation_trees.iter() {
if let Some(existing) = bindings.get(&(&**e as *const TreeTerm<L>)) {
items.push(existing.clone());
} else {
items.push(e.get_sexp_with_bindings(&bindings));
}
}
let mut result = Sexp::List(items);
for (name, expr) in generated_bindings.into_iter().rev() {
let let_expr = Sexp::List(vec![name, expr]);
result = Sexp::List(vec![Sexp::String("let".to_string()), let_expr, result]);
}
result
}
fn find_to_let_bind(
&self,
term: Rc<TreeTerm<L>>,
shared: &mut HashSet<*const TreeTerm<L>>,
to_let_bind: &mut Vec<Rc<TreeTerm<L>>>,
) {
if !term.child_proofs.is_empty() {
if shared.insert(&*term as *const TreeTerm<L>) {
for proof in &term.child_proofs {
for child in proof {
self.find_to_let_bind(child.clone(), shared, to_let_bind);
}
}
} else {
to_let_bind.push(term);
}
}
}
}
impl<L: Language> Explanation<L> {
pub fn new(explanation_trees: TreeExplanation<L>) -> Explanation<L> {
Explanation {
explanation_trees,
flat_explanation: None,
}
}
pub fn make_flat_explanation(&mut self) -> &FlatExplanation<L> {
if self.flat_explanation.is_some() {
return self.flat_explanation.as_ref().unwrap();
} else {
self.flat_explanation = Some(TreeTerm::flatten_proof(&self.explanation_trees));
self.flat_explanation.as_ref().unwrap()
}
}
pub fn check_proof<'a, R, N: Analysis<L>>(&mut self, rules: R)
where
R: IntoIterator<Item = &'a Rewrite<L, N>>,
L: 'a,
N: 'a,
{
let rules: Vec<&Rewrite<L, N>> = rules.into_iter().collect();
let rule_table = Explain::make_rule_table(rules.as_slice());
self.make_flat_explanation();
let flat_explanation = self.flat_explanation.as_ref().unwrap();
assert!(!flat_explanation[0].has_rewrite_forward());
assert!(!flat_explanation[0].has_rewrite_backward());
for i in 0..flat_explanation.len() - 1 {
let current = &flat_explanation[i];
let next = &flat_explanation[i + 1];
let has_forward = next.has_rewrite_forward();
let has_backward = next.has_rewrite_backward();
assert!(has_forward ^ has_backward);
if has_forward {
assert!(self.check_rewrite_at(current, next, &rule_table, true));
} else {
assert!(self.check_rewrite_at(current, next, &rule_table, false));
}
}
}
fn check_rewrite_at<N: Analysis<L>>(
&self,
current: &FlatTerm<L>,
next: &FlatTerm<L>,
table: &HashMap<Symbol, &Rewrite<L, N>>,
is_forward: bool,
) -> bool {
if is_forward && next.forward_rule.is_some() {
let rule_name = next.forward_rule.as_ref().unwrap();
if let Some(rule) = table.get(rule_name) {
Explanation::check_rewrite(current, next, rule)
} else {
true
}
} else if !is_forward && next.backward_rule.is_some() {
let rule_name = next.backward_rule.as_ref().unwrap();
if let Some(rule) = table.get(rule_name) {
Explanation::check_rewrite(next, current, rule)
} else {
true
}
} else {
for (left, right) in current.children.iter().zip(next.children.iter()) {
if !self.check_rewrite_at(left, right, table, is_forward) {
return false;
}
}
true
}
}
fn check_rewrite<'a, N: Analysis<L>>(
current: &'a FlatTerm<L>,
next: &'a FlatTerm<L>,
rewrite: &Rewrite<L, N>,
) -> bool {
if let Some(lhs) = rewrite.searcher.get_pattern_ast() {
if let Some(rhs) = rewrite.applier.get_pattern_ast() {
let rewritten = current.rewrite(lhs, rhs);
if &rewritten != next {
return false;
}
}
}
true
}
}
#[derive(Debug, Clone)]
pub struct TreeTerm<L: Language> {
pub node: L,
pub backward_rule: Option<Symbol>,
pub forward_rule: Option<Symbol>,
pub child_proofs: Vec<TreeExplanation<L>>,
last: Id,
current: Id,
}
impl<L: Language> TreeTerm<L> {
pub fn new(node: L, child_proofs: Vec<TreeExplanation<L>>) -> TreeTerm<L> {
TreeTerm {
node,
backward_rule: None,
forward_rule: None,
child_proofs,
current: Id::from(0),
last: Id::from(0),
}
}
fn flatten_proof(proof: &[Rc<TreeTerm<L>>]) -> FlatExplanation<L> {
let mut flat_proof: FlatExplanation<L> = vec![];
for tree in proof {
let mut explanation = tree.flatten_explanation();
if !flat_proof.is_empty()
&& !explanation[0].has_rewrite_forward()
&& !explanation[0].has_rewrite_backward()
{
let last = flat_proof.pop().unwrap();
explanation[0].combine_rewrites(&last);
}
flat_proof.extend(explanation);
}
flat_proof
}
pub fn get_initial_flat_term(&self) -> FlatTerm<L> {
FlatTerm {
node: self.node.clone(),
backward_rule: self.backward_rule,
forward_rule: self.forward_rule,
children: self
.child_proofs
.iter()
.map(|child_proof| child_proof[0].get_initial_flat_term())
.collect(),
}
}
pub fn get_last_flat_term(&self) -> FlatTerm<L> {
FlatTerm {
node: self.node.clone(),
backward_rule: self.backward_rule,
forward_rule: self.forward_rule,
children: self
.child_proofs
.iter()
.map(|child_proof| child_proof[child_proof.len() - 1].get_last_flat_term())
.collect(),
}
}
pub fn flatten_explanation(&self) -> FlatExplanation<L> {
let mut proof = vec![];
let mut child_proofs = vec![];
let mut representative_terms = vec![];
for child_explanation in &self.child_proofs {
let flat_proof = TreeTerm::flatten_proof(child_explanation);
representative_terms.push(flat_proof[0].remove_rewrites());
child_proofs.push(flat_proof);
}
proof.push(FlatTerm::new(
self.node.clone(),
representative_terms.clone(),
));
for (i, child_proof) in child_proofs.iter().enumerate() {
proof.last_mut().unwrap().children[i] = child_proof[0].clone();
for child in child_proof.iter().skip(1) {
let mut children = vec![];
for (j, rep_term) in representative_terms.iter().enumerate() {
if j == i {
children.push(child.clone());
} else {
children.push(rep_term.clone());
}
}
proof.push(FlatTerm::new(self.node.clone(), children));
}
representative_terms[i] = child_proof.last().unwrap().remove_rewrites();
}
proof[0].backward_rule = self.backward_rule;
proof[0].forward_rule = self.forward_rule;
proof
}
}
#[derive(Debug, Clone, Eq)]
pub struct FlatTerm<L: Language> {
pub node: L,
pub backward_rule: Option<Symbol>,
pub forward_rule: Option<Symbol>,
pub children: FlatExplanation<L>,
}
impl<L: Language + Display + FromOp> Display for FlatTerm<L> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let s = self.get_sexp().to_string();
write!(f, "{}", s)
}
}
impl<L: Language> PartialEq for FlatTerm<L> {
fn eq(&self, other: &FlatTerm<L>) -> bool {
if !self.node.matches(&other.node) {
return false;
}
for (child1, child2) in self.children.iter().zip(other.children.iter()) {
if !child1.eq(child2) {
return false;
}
}
true
}
}
impl<L: Language> FlatTerm<L> {
pub fn remove_rewrites(&self) -> FlatTerm<L> {
FlatTerm::new(
self.node.clone(),
self.children
.iter()
.map(|child| child.remove_rewrites())
.collect(),
)
}
fn combine_rewrites(&mut self, other: &FlatTerm<L>) {
if other.forward_rule.is_some() {
assert!(self.forward_rule.is_none());
self.forward_rule = other.forward_rule;
}
if other.backward_rule.is_some() {
assert!(self.backward_rule.is_none());
self.backward_rule = other.backward_rule;
}
for (left, right) in self.children.iter_mut().zip(other.children.iter()) {
left.combine_rewrites(right);
}
}
}
impl<L: Language> Default for Explain<L> {
fn default() -> Self {
Self::new()
}
}
impl<L: Language + Display + FromOp> FlatTerm<L> {
pub fn get_string(&self) -> String {
self.get_sexp().to_string()
}
fn get_sexp(&self) -> Sexp {
let op = Sexp::String(self.node.to_string());
let mut expr = if self.node.is_leaf() {
op
} else {
let mut vec = vec![op];
for child in &self.children {
vec.push(child.get_sexp());
}
Sexp::List(vec)
};
if let Some(rule_name) = &self.backward_rule {
expr = Sexp::List(vec![
Sexp::String("Rewrite<=".to_string()),
Sexp::String((*rule_name).to_string()),
expr,
]);
}
if let Some(rule_name) = &self.forward_rule {
expr = Sexp::List(vec![
Sexp::String("Rewrite=>".to_string()),
Sexp::String((*rule_name).to_string()),
expr,
]);
}
expr
}
pub fn get_recexpr(&self) -> RecExpr<L> {
self.remove_rewrites().to_string().parse().unwrap()
}
}
impl<L: Language + Display + FromOp> Display for TreeTerm<L> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut buf = String::new();
let width = 80;
pretty_print(&mut buf, &self.get_sexp(), width, 1).unwrap();
write!(f, "{}", buf)
}
}
impl<L: Language + Display + FromOp> TreeTerm<L> {
fn get_sexp(&self) -> Sexp {
self.get_sexp_with_bindings(&Default::default())
}
fn get_sexp_with_bindings(&self, bindings: &HashMap<*const TreeTerm<L>, Sexp>) -> Sexp {
let op = Sexp::String(self.node.to_string());
let mut expr = if self.node.is_leaf() {
op
} else {
let mut vec = vec![op];
for child in &self.child_proofs {
assert!(!child.is_empty());
if child.len() == 1 {
if let Some(existing) = bindings.get(&(&*child[0] as *const TreeTerm<L>)) {
vec.push(existing.clone());
} else {
vec.push(child[0].get_sexp_with_bindings(bindings));
}
} else {
let mut child_expressions = vec![Sexp::String("Explanation".to_string())];
for child_explanation in child.iter() {
if let Some(existing) =
bindings.get(&(&**child_explanation as *const TreeTerm<L>))
{
child_expressions.push(existing.clone());
} else {
child_expressions
.push(child_explanation.get_sexp_with_bindings(bindings));
}
}
vec.push(Sexp::List(child_expressions));
}
}
Sexp::List(vec)
};
if let Some(rule_name) = &self.backward_rule {
expr = Sexp::List(vec![
Sexp::String("Rewrite<=".to_string()),
Sexp::String((*rule_name).to_string()),
expr,
]);
}
if let Some(rule_name) = &self.forward_rule {
expr = Sexp::List(vec![
Sexp::String("Rewrite=>".to_string()),
Sexp::String((*rule_name).to_string()),
expr,
]);
}
expr
}
}
impl<L: Language> FlatTerm<L> {
pub fn new(node: L, children: FlatExplanation<L>) -> FlatTerm<L> {
FlatTerm {
node,
backward_rule: None,
forward_rule: None,
children,
}
}
pub fn rewrite(&self, lhs: &PatternAst<L>, rhs: &PatternAst<L>) -> FlatTerm<L> {
let lhs_nodes = lhs.as_ref();
let rhs_nodes = rhs.as_ref();
let mut bindings = Default::default();
self.make_bindings(lhs_nodes, lhs_nodes.len() - 1, &mut bindings);
FlatTerm::from_pattern(rhs_nodes, rhs_nodes.len() - 1, &bindings)
}
pub fn has_rewrite_forward(&self) -> bool {
self.forward_rule.is_some()
|| self
.children
.iter()
.any(|child| child.has_rewrite_forward())
}
pub fn has_rewrite_backward(&self) -> bool {
self.backward_rule.is_some()
|| self
.children
.iter()
.any(|child| child.has_rewrite_backward())
}
fn from_pattern(
pattern: &[ENodeOrVar<L>],
location: usize,
bindings: &HashMap<Var, &FlatTerm<L>>,
) -> FlatTerm<L> {
match &pattern[location] {
ENodeOrVar::Var(var) => (*bindings.get(var).unwrap()).clone(),
ENodeOrVar::ENode(node) => {
let children = node.fold(vec![], |mut acc, child| {
acc.push(FlatTerm::from_pattern(
pattern,
usize::from(child),
bindings,
));
acc
});
FlatTerm::new(node.clone(), children)
}
}
}
fn make_bindings<'a>(
&'a self,
pattern: &[ENodeOrVar<L>],
location: usize,
bindings: &mut HashMap<Var, &'a FlatTerm<L>>,
) {
match &pattern[location] {
ENodeOrVar::Var(var) => {
if let Some(existing) = bindings.get(var) {
if existing != &self {
panic!(
"Invalid proof: binding for variable {:?} does not match between {:?} \n and \n {:?}",
var, existing, self);
}
} else {
bindings.insert(*var, self);
}
}
ENodeOrVar::ENode(node) => {
assert!(node.matches(&self.node));
let mut counter = 0;
node.for_each(|child| {
self.children[counter].make_bindings(pattern, usize::from(child), bindings);
counter += 1;
});
}
}
}
}
#[derive(Copy, Clone, Eq, PartialEq)]
struct HeapState<I> {
cost: usize,
item: I,
}
impl<I: Eq + PartialEq> Ord for HeapState<I> {
fn cmp(&self, other: &Self) -> Ordering {
other
.cost
.cmp(&self.cost)
.then_with(|| self.cost.cmp(&other.cost))
}
}
impl<I: Eq + PartialEq> PartialOrd for HeapState<I> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<L: Language> Explain<L> {
fn node_to_explanation(
&self,
node_id: Id,
cache: &mut NodeExplanationCache<L>,
) -> Rc<TreeTerm<L>> {
if let Some(existing) = cache.get(&node_id) {
existing.clone()
} else {
let node = self.explainfind[usize::from(node_id)].node.clone();
let children = node.fold(vec![], |mut sofar, child| {
sofar.push(vec![self.node_to_explanation(child, cache)]);
sofar
});
let res = Rc::new(TreeTerm::new(node, children));
cache.insert(node_id, res.clone());
res
}
}
pub(crate) fn node_to_recexpr(&self, node_id: Id) -> RecExpr<L> {
let mut res = Default::default();
let mut cache = Default::default();
self.node_to_recexpr_internal(&mut res, node_id, &mut cache);
res
}
fn node_to_recexpr_internal(
&self,
res: &mut RecExpr<L>,
node_id: Id,
cache: &mut HashMap<Id, Id>,
) {
let new_node = self.explainfind[usize::from(node_id)]
.node
.clone()
.map_children(|child| {
if let Some(existing) = cache.get(&child) {
*existing
} else {
self.node_to_recexpr_internal(res, child, cache);
Id::from(res.as_ref().len() - 1)
}
});
res.add(new_node);
}
pub(crate) fn node_to_pattern(
&self,
node_id: Id,
substitutions: &HashMap<Id, Id>,
) -> (Pattern<L>, Subst) {
let mut res = Default::default();
let mut subst = Default::default();
let mut cache = Default::default();
self.node_to_pattern_internal(&mut res, node_id, substitutions, &mut subst, &mut cache);
(Pattern::new(res), subst)
}
fn node_to_pattern_internal(
&self,
res: &mut PatternAst<L>,
node_id: Id,
var_substitutions: &HashMap<Id, Id>,
subst: &mut Subst,
cache: &mut HashMap<Id, Id>,
) {
if let Some(existing) = var_substitutions.get(&node_id) {
let var = format!("?{}", node_id).parse().unwrap();
res.add(ENodeOrVar::Var(var));
subst.insert(var, *existing);
} else {
let new_node = self.explainfind[usize::from(node_id)]
.node
.clone()
.map_children(|child| {
if let Some(existing) = cache.get(&child) {
*existing
} else {
self.node_to_pattern_internal(res, child, var_substitutions, subst, cache);
Id::from(res.as_ref().len() - 1)
}
});
res.add(ENodeOrVar::ENode(new_node));
}
}
fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm<L> {
let node = self.explainfind[usize::from(node_id)].node.clone();
let children = node.fold(vec![], |mut sofar, child| {
sofar.push(self.node_to_flat_explanation(child));
sofar
});
FlatTerm::new(node, children)
}
fn make_rule_table<'a, N: Analysis<L>>(
rules: &[&'a Rewrite<L, N>],
) -> HashMap<Symbol, &'a Rewrite<L, N>> {
let mut table: HashMap<Symbol, &'a Rewrite<L, N>> = Default::default();
for r in rules {
table.insert(r.name, r);
}
table
}
pub fn check_each_explain<N: Analysis<L>>(&self, rules: &[&Rewrite<L, N>]) -> bool {
let rule_table = Explain::make_rule_table(rules);
for i in 0..self.explainfind.len() {
let explain_node = &self.explainfind[i];
let mut existance = i;
let mut seen_existance: HashSet<usize> = Default::default();
loop {
seen_existance.insert(existance);
let next = usize::from(self.explainfind[existance].existance_node);
if existance == next {
break;
}
existance = next;
if seen_existance.contains(&existance) {
panic!("Cycle in existance!");
}
}
if explain_node.parent_connection.next != Id::from(i) {
let mut current_explanation = self.node_to_flat_explanation(Id::from(i));
let mut next_explanation =
self.node_to_flat_explanation(explain_node.parent_connection.next);
if let Justification::Rule(rule_name) =
&explain_node.parent_connection.justification
{
if let Some(rule) = rule_table.get(rule_name) {
if !explain_node.parent_connection.is_rewrite_forward {
std::mem::swap(&mut current_explanation, &mut next_explanation);
}
if !Explanation::check_rewrite(
¤t_explanation,
&next_explanation,
rule,
) {
return false;
}
}
}
}
}
true
}
pub fn new() -> Self {
Explain {
explainfind: vec![],
uncanon_memo: Default::default(),
shortest_explanation_memo: Default::default(),
optimize_explanation_lengths: true,
}
}
pub(crate) fn set_existance_reason(&mut self, node: Id, existance_node: Id) {
self.explainfind[usize::from(node)].existance_node = existance_node;
}
pub(crate) fn add(&mut self, node: L, set: Id, existance_node: Id) -> Id {
assert_eq!(self.explainfind.len(), usize::from(set));
self.uncanon_memo.insert(node.clone(), set);
self.explainfind.push(ExplainNode {
node,
neighbors: vec![],
parent_connection: Connection {
justification: Justification::Congruence,
is_rewrite_forward: false,
next: set,
current: set,
},
existance_node,
});
set
}
fn make_leader(&mut self, node: Id) {
let next = self.explainfind[usize::from(node)].parent_connection.next;
if next != node {
self.make_leader(next);
let node_connection = &self.explainfind[usize::from(node)].parent_connection;
let pconnection = Connection {
justification: node_connection.justification.clone(),
is_rewrite_forward: !node_connection.is_rewrite_forward,
next: node,
current: next,
};
self.explainfind[usize::from(next)].parent_connection = pconnection;
}
}
pub(crate) fn alternate_rewrite(&mut self, node1: Id, node2: Id, justification: Justification) {
if node1 == node2 {
return;
}
if let Some((cost, _)) = self.shortest_explanation_memo.get(&(node1, node2)) {
if cost <= &1 {
return;
}
}
let lconnection = Connection {
justification: justification.clone(),
is_rewrite_forward: true,
next: node2,
current: node1,
};
let rconnection = Connection {
justification,
is_rewrite_forward: false,
next: node1,
current: node2,
};
self.explainfind[usize::from(node1)]
.neighbors
.push(lconnection);
self.explainfind[usize::from(node2)]
.neighbors
.push(rconnection);
self.shortest_explanation_memo
.insert((node1, node2), (1, node2));
self.shortest_explanation_memo
.insert((node2, node1), (1, node1));
}
pub(crate) fn union(
&mut self,
node1: Id,
node2: Id,
justification: Justification,
new_rhs: bool,
) {
if let Justification::Congruence = justification {
assert!(self.explainfind[usize::from(node1)]
.node
.matches(&self.explainfind[usize::from(node2)].node));
}
if new_rhs {
self.set_existance_reason(node2, node1)
}
self.make_leader(node1);
self.explainfind[usize::from(node1)].parent_connection.next = node2;
if let Justification::Rule(_) = justification {
self.shortest_explanation_memo
.insert((node1, node2), (1, node2));
self.shortest_explanation_memo
.insert((node2, node1), (1, node1));
}
let pconnection = Connection {
justification: justification.clone(),
is_rewrite_forward: true,
next: node2,
current: node1,
};
let other_pconnection = Connection {
justification,
is_rewrite_forward: false,
next: node1,
current: node2,
};
self.explainfind[usize::from(node1)]
.neighbors
.push(pconnection.clone());
self.explainfind[usize::from(node2)]
.neighbors
.push(other_pconnection);
self.explainfind[usize::from(node1)].parent_connection = pconnection;
}
pub(crate) fn get_union_equalities(&self) -> UnionEqualities {
let mut equalities = vec![];
for node in &self.explainfind {
for neighbor in &node.neighbors {
if neighbor.is_rewrite_forward {
if let Justification::Rule(r) = neighbor.justification {
equalities.push((neighbor.current, neighbor.next, r));
}
}
}
}
equalities
}
pub(crate) fn populate_enodes<N: Analysis<L>>(&self, mut egraph: EGraph<L, N>) -> EGraph<L, N> {
for i in 0..self.explainfind.len() {
let node = &self.explainfind[i];
egraph.add(node.node.clone());
}
egraph
}
pub(crate) fn explain_equivalence<N: Analysis<L>>(
&mut self,
left: Id,
right: Id,
unionfind: &mut UnionFind,
classes: &HashMap<Id, EClass<L, N::Data>>,
) -> Explanation<L> {
if self.optimize_explanation_lengths {
self.calculate_shortest_explanations::<N>(left, right, classes, unionfind);
}
let mut cache = Default::default();
let mut enode_cache = Default::default();
Explanation::new(self.explain_enodes(left, right, &mut cache, &mut enode_cache, false))
}
pub(crate) fn explain_existance(&mut self, left: Id) -> Explanation<L> {
let mut cache = Default::default();
let mut enode_cache = Default::default();
Explanation::new(self.explain_enode_existance(
left,
self.node_to_explanation(left, &mut enode_cache),
&mut cache,
&mut enode_cache,
))
}
fn common_ancestor(&self, mut left: Id, mut right: Id) -> Id {
let mut seen_left: HashSet<Id> = Default::default();
let mut seen_right: HashSet<Id> = Default::default();
loop {
seen_left.insert(left);
if seen_right.contains(&left) {
return left;
}
seen_right.insert(right);
if seen_left.contains(&right) {
return right;
}
let next_left = self.explainfind[usize::from(left)].parent_connection.next;
let next_right = self.explainfind[usize::from(right)].parent_connection.next;
assert!(next_left != left || next_right != right);
left = next_left;
right = next_right;
}
}
fn get_connections(&self, mut node: Id, ancestor: Id) -> Vec<Connection> {
if node == ancestor {
return vec![];
}
let mut nodes = vec![];
loop {
let next = self.explainfind[usize::from(node)].parent_connection.next;
nodes.push(
self.explainfind[usize::from(node)]
.parent_connection
.clone(),
);
if next == ancestor {
return nodes;
}
assert!(next != node);
node = next;
}
}
fn get_path_unoptimized(&self, left: Id, right: Id) -> (Vec<Connection>, Vec<Connection>) {
let ancestor = self.common_ancestor(left, right);
let left_connections = self.get_connections(left, ancestor);
let right_connections = self.get_connections(right, ancestor);
(left_connections, right_connections)
}
fn get_neighbor(&self, current: Id, next: Id) -> Connection {
for neighbor in &self.explainfind[usize::from(current)].neighbors {
if neighbor.next == next {
if let Justification::Rule(_) = neighbor.justification {
return neighbor.clone();
}
}
}
Connection {
justification: Justification::Congruence,
current,
next,
is_rewrite_forward: true,
}
}
fn get_path(&self, mut left: Id, right: Id) -> (Vec<Connection>, Vec<Connection>) {
let mut left_connections = vec![];
loop {
if left == right {
return (left_connections, vec![]);
}
if let Some((_, next)) = self.shortest_explanation_memo.get(&(left, right)) {
left_connections.push(self.get_neighbor(left, *next));
left = *next;
} else {
break;
}
}
let (restleft, right_connections) = self.get_path_unoptimized(left, right);
left_connections.extend(restleft);
(left_connections, right_connections)
}
fn explain_enode_existance(
&self,
node: Id,
rest_of_proof: Rc<TreeTerm<L>>,
cache: &mut ExplainCache<L>,
enode_cache: &mut NodeExplanationCache<L>,
) -> TreeExplanation<L> {
let graphnode = &self.explainfind[usize::from(node)];
let existance = graphnode.existance_node;
let existance_node = &self.explainfind[usize::from(existance)];
if existance == node {
return vec![self.node_to_explanation(node, enode_cache), rest_of_proof];
}
if graphnode.parent_connection.next == existance
|| existance_node.parent_connection.next == node
{
let mut connection = graphnode.parent_connection.clone();
if graphnode.parent_connection.next == existance {
connection.is_rewrite_forward = !connection.is_rewrite_forward;
std::mem::swap(&mut connection.next, &mut connection.current);
}
return self.explain_enode_existance(
existance,
self.explain_adjacent(connection, cache, enode_cache, false),
cache,
enode_cache,
);
}
let mut new_rest_of_proof = (*self.node_to_explanation(existance, enode_cache)).clone();
let mut index_of_child = 0;
let mut found = false;
existance_node.node.for_each(|child| {
if found {
return;
}
if child == node {
found = true;
} else {
index_of_child += 1;
}
});
assert!(found);
new_rest_of_proof.child_proofs[index_of_child].push(rest_of_proof);
self.explain_enode_existance(existance, Rc::new(new_rest_of_proof), cache, enode_cache)
}
fn explain_enodes(
&self,
left: Id,
right: Id,
cache: &mut ExplainCache<L>,
node_explanation_cache: &mut NodeExplanationCache<L>,
use_unoptimized: bool,
) -> TreeExplanation<L> {
let mut proof = vec![self.node_to_explanation(left, node_explanation_cache)];
let (left_connections, right_connections) = if use_unoptimized {
self.get_path_unoptimized(left, right)
} else {
self.get_path(left, right)
};
for (i, connection) in left_connections
.iter()
.chain(right_connections.iter().rev())
.enumerate()
{
let mut connection = connection.clone();
if i >= left_connections.len() {
connection.is_rewrite_forward = !connection.is_rewrite_forward;
std::mem::swap(&mut connection.next, &mut connection.current);
}
proof.push(self.explain_adjacent(
connection,
cache,
node_explanation_cache,
use_unoptimized,
));
}
proof
}
fn explain_adjacent(
&self,
connection: Connection,
cache: &mut ExplainCache<L>,
node_explanation_cache: &mut NodeExplanationCache<L>,
use_unoptimized: bool,
) -> Rc<TreeTerm<L>> {
let fingerprint = (connection.current, connection.next);
if let Some(answer) = cache.get(&fingerprint) {
return answer.clone();
}
let term = match connection.justification {
Justification::Rule(name) => {
let mut rewritten =
(*self.node_to_explanation(connection.next, node_explanation_cache)).clone();
if connection.is_rewrite_forward {
rewritten.forward_rule = Some(name);
} else {
rewritten.backward_rule = Some(name);
}
rewritten.current = connection.next;
rewritten.last = connection.current;
Rc::new(rewritten)
}
Justification::Congruence => {
let current_node = &self.explainfind[usize::from(connection.current)].node;
let next_node = &self.explainfind[usize::from(connection.next)].node;
assert!(current_node.matches(next_node));
let mut subproofs = vec![];
for (left_child, right_child) in current_node
.children()
.iter()
.zip(next_node.children().iter())
{
subproofs.push(self.explain_enodes(
*left_child,
*right_child,
cache,
node_explanation_cache,
use_unoptimized,
));
}
Rc::new(TreeTerm::new(current_node.clone(), subproofs))
}
};
cache.insert(fingerprint, term.clone());
term
}
fn find_all_enodes(&self, eclass: Id) -> HashSet<Id> {
let mut enodes = HashSet::default();
let mut todo = vec![eclass];
while !todo.is_empty() {
let current = todo.pop().unwrap();
if enodes.insert(current) {
for neighbor in &self.explainfind[usize::from(current)].neighbors {
todo.push(neighbor.next);
}
}
}
enodes
}
fn add_tree_depths(&self, node: Id, depths: &mut HashMap<Id, usize>) -> usize {
if depths.get(&node).is_none() {
let parent = self.parent(node);
let depth = if parent == node {
0
} else {
self.add_tree_depths(parent, depths) + 1
};
depths.insert(node, depth);
}
return *depths.get(&node).unwrap();
}
fn calculate_tree_depths(&self) -> HashMap<Id, usize> {
let mut depths = HashMap::default();
for i in 0..self.explainfind.len() {
self.add_tree_depths(Id::from(i), &mut depths);
}
depths
}
fn replace_distance(&mut self, current: Id, next: Id, right: Id, distance: usize) {
self.shortest_explanation_memo
.insert((current, right), (distance, next));
}
fn populate_path_length(
&mut self,
right: Id,
left_connections: &[Connection],
distance_memo: &mut DistanceMemo,
target_cost: usize,
) {
self.shortest_explanation_memo
.insert((right, right), (0, right));
let mut last_cost = 0;
for connection in left_connections.iter().rev() {
let next = connection.next;
let current = connection.current;
let next_cost = self
.shortest_explanation_memo
.get(&(next, right))
.unwrap()
.0;
let dist = self.connection_distance(connection, distance_memo);
last_cost = dist + next_cost;
self.replace_distance(current, next, right, next_cost + dist);
}
assert!(last_cost <= target_cost);
}
fn distance_between(&mut self, left: Id, right: Id, distance_memo: &mut DistanceMemo) -> usize {
if left == right {
return 0;
}
let ancestor = if let Some(a) = distance_memo.common_ancestor.get(&(left, right)) {
*a
} else {
self.common_ancestor(left, right)
};
self.calculate_parent_distance(left, ancestor, distance_memo);
self.calculate_parent_distance(right, ancestor, distance_memo);
let a = self.calculate_parent_distance(ancestor, Id::from(usize::MAX), distance_memo);
let b = self.calculate_parent_distance(left, Id::from(usize::MAX), distance_memo);
let c = self.calculate_parent_distance(right, Id::from(usize::MAX), distance_memo);
assert!(
distance_memo.parent_distance[usize::from(ancestor)].0
== distance_memo.parent_distance[usize::from(left)].0
);
assert!(
distance_memo.parent_distance[usize::from(ancestor)].0
== distance_memo.parent_distance[usize::from(right)].0
);
match b.checked_add(c) {
Some(added) => added
.checked_sub(a.checked_mul(2).unwrap_or(0))
.unwrap_or(usize::MAX),
None => usize::MAX,
}
}
fn congruence_distance(
&mut self,
current: Id,
next: Id,
distance_memo: &mut DistanceMemo,
) -> usize {
let current_node = self.explainfind[usize::from(current)].node.clone();
let next_node = self.explainfind[usize::from(next)].node.clone();
let mut cost: usize = 0;
for (left_child, right_child) in current_node
.children()
.iter()
.zip(next_node.children().iter())
{
cost = cost.saturating_add(self.distance_between(
*left_child,
*right_child,
distance_memo,
));
}
cost
}
fn connection_distance(
&mut self,
connection: &Connection,
distance_memo: &mut DistanceMemo,
) -> usize {
match connection.justification {
Justification::Congruence => {
self.congruence_distance(connection.current, connection.next, distance_memo)
}
Justification::Rule(_) => 1,
}
}
fn calculate_parent_distance(
&mut self,
enode: Id,
ancestor: Id,
distance_memo: &mut DistanceMemo,
) -> usize {
loop {
let parent = distance_memo.parent_distance[usize::from(enode)].0;
let dist = distance_memo.parent_distance[usize::from(enode)].1;
if self.parent(parent) == parent {
break;
}
let parent_parent = distance_memo.parent_distance[usize::from(parent)].0;
if parent_parent != parent {
let new_dist =
dist.saturating_add(distance_memo.parent_distance[usize::from(parent)].1);
distance_memo.parent_distance[usize::from(enode)] = (parent_parent, new_dist);
} else {
if ancestor == Id::from(usize::MAX) {
break;
}
if distance_memo.tree_depth.get(&parent).unwrap()
<= distance_memo.tree_depth.get(&ancestor).unwrap()
{
break;
}
let connection = &self.explainfind[usize::from(parent)].parent_connection;
let current = connection.current;
let next = connection.next;
let cost = match connection.justification {
Justification::Congruence => {
self.congruence_distance(current, next, distance_memo)
}
Justification::Rule(_) => 1,
};
distance_memo.parent_distance[usize::from(parent)] = (self.parent(parent), cost);
}
}
distance_memo.parent_distance[usize::from(enode)].1
}
fn find_congruence_neighbors<N: Analysis<L>>(
&self,
classes: &HashMap<Id, EClass<L, N::Data>>,
congruence_neighbors: &mut [Vec<Id>],
unionfind: &UnionFind,
) {
let mut counter = 0;
for node in &self.explainfind {
if let Justification::Congruence = node.parent_connection.justification {
let current = node.parent_connection.current;
let next = node.parent_connection.next;
congruence_neighbors[usize::from(current)].push(next);
congruence_neighbors[usize::from(next)].push(current);
counter += 1;
}
}
'outer: for eclass in classes.keys() {
let enodes = self.find_all_enodes(*eclass);
let mut cannon_enodes: HashMap<L, Vec<Id>> = Default::default();
for enode in &enodes {
let cannon = self.explainfind[usize::from(*enode)]
.node
.clone()
.map_children(|child| unionfind.find(child));
if let Some(others) = cannon_enodes.get_mut(&cannon) {
for other in others.iter() {
congruence_neighbors[usize::from(*enode)].push(*other);
congruence_neighbors[usize::from(*other)].push(*enode);
}
counter += 1;
others.push(*enode);
} else {
counter += 1;
cannon_enodes.insert(cannon, vec![*enode]);
}
if counter > CONGRUENCE_LIMIT * self.explainfind.len() {
break 'outer;
}
}
}
}
pub fn get_num_congr<N: Analysis<L>>(
&self,
classes: &HashMap<Id, EClass<L, N::Data>>,
unionfind: &UnionFind,
) -> usize {
let mut congruence_neighbors = vec![vec![]; self.explainfind.len()];
self.find_congruence_neighbors::<N>(classes, &mut congruence_neighbors, unionfind);
let mut count = 0;
for v in congruence_neighbors {
count += v.len();
}
count / 2
}
pub fn get_num_nodes(&self) -> usize {
self.explainfind.len()
}
fn shortest_path_modulo_congruence(
&mut self,
start: Id,
end: Id,
congruence_neighbors: &[Vec<Id>],
distance_memo: &mut DistanceMemo,
) -> Option<(Vec<Connection>, Vec<Connection>)> {
let mut todo = BinaryHeap::new();
todo.push(HeapState {
cost: 0,
item: Connection {
current: start,
next: start,
justification: Justification::Congruence,
is_rewrite_forward: true,
},
});
let mut last = HashMap::default();
let mut path_cost = HashMap::default();
'outer: loop {
if todo.is_empty() {
break 'outer;
}
let state = todo.pop().unwrap();
let connection = state.item;
let cost_so_far = state.cost;
let current = connection.next;
if last.get(¤t).is_some() {
continue 'outer;
} else {
last.insert(current, connection);
path_cost.insert(current, cost_so_far);
}
if current == end {
break;
}
for neighbor in &self.explainfind[usize::from(current)].neighbors {
if let Justification::Rule(_) = neighbor.justification {
let neighbor_cost = cost_so_far.saturating_add(1);
todo.push(HeapState {
item: neighbor.clone(),
cost: neighbor_cost,
});
}
}
for other in congruence_neighbors[usize::from(current)].iter() {
let next = other;
let distance = self.congruence_distance(current, *next, distance_memo);
let next_cost = cost_so_far.saturating_add(distance);
todo.push(HeapState {
item: Connection {
current,
next: *next,
justification: Justification::Congruence,
is_rewrite_forward: true,
},
cost: next_cost,
});
}
}
let total_cost = path_cost.get(&end);
let left_connections;
let mut right_connections = vec![];
if *total_cost.unwrap() == self.distance_between(start, end, distance_memo) {
let (a_left_connections, a_right_connections) = self.get_path_unoptimized(start, end);
left_connections = a_left_connections;
right_connections = a_right_connections;
} else {
let mut current = end;
let mut connections = vec![];
while current != start {
let prev = last.get(¤t);
if let Some(prev_connection) = prev {
connections.push(prev_connection.clone());
current = prev_connection.current;
} else {
break;
}
}
connections.reverse();
self.populate_path_length(
end,
&connections,
distance_memo,
*path_cost.get(&end).unwrap(),
);
left_connections = connections;
}
Some((left_connections, right_connections))
}
fn greedy_short_explanations(
&mut self,
start: Id,
end: Id,
congruence_neighbors: &[Vec<Id>],
distance_memo: &mut DistanceMemo,
mut fuel: usize,
) {
let mut todo_congruence = VecDeque::new();
todo_congruence.push_back((start, end));
while !todo_congruence.is_empty() {
let (start, end) = todo_congruence.pop_front().unwrap();
let eclass_size = self.find_all_enodes(start).len();
if fuel < eclass_size {
continue;
}
fuel = fuel.saturating_sub(eclass_size);
let (left_connections, right_connections) = self
.shortest_path_modulo_congruence(start, end, congruence_neighbors, distance_memo)
.unwrap();
for (i, connection) in left_connections
.iter()
.chain(right_connections.iter().rev())
.enumerate()
{
let mut next = connection.next;
let mut current = connection.current;
if i >= left_connections.len() {
std::mem::swap(&mut next, &mut current);
}
if let Justification::Congruence = connection.justification {
let current_node = self.explainfind[usize::from(current)].node.clone();
let next_node = self.explainfind[usize::from(next)].node.clone();
for (left_child, right_child) in current_node
.children()
.iter()
.zip(next_node.children().iter())
{
todo_congruence.push_back((*left_child, *right_child));
}
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn tarjan_ocla(
&self,
enode: Id,
children: &HashMap<Id, Vec<Id>>,
common_ancestor_queries: &HashMap<Id, Vec<Id>>,
black_set: &mut HashSet<Id>,
unionfind: &mut UnionFind,
ancestor: &mut Vec<Id>,
common_ancestor: &mut HashMap<(Id, Id), Id>,
) {
ancestor[usize::from(enode)] = enode;
for child in children[&enode].iter() {
self.tarjan_ocla(
*child,
children,
common_ancestor_queries,
black_set,
unionfind,
ancestor,
common_ancestor,
);
unionfind.union(enode, *child);
ancestor[usize::from(unionfind.find(enode))] = enode;
}
if common_ancestor_queries.get(&enode).is_some() {
black_set.insert(enode);
for other in common_ancestor_queries.get(&enode).unwrap() {
if black_set.contains(other) {
let ancestor = ancestor[usize::from(unionfind.find(*other))];
common_ancestor.insert((enode, *other), ancestor);
common_ancestor.insert((*other, enode), ancestor);
}
}
}
}
fn parent(&self, enode: Id) -> Id {
self.explainfind[usize::from(enode)].parent_connection.next
}
fn calculate_common_ancestor<N: Analysis<L>>(
&self,
classes: &HashMap<Id, EClass<L, N::Data>>,
congruence_neighbors: &[Vec<Id>],
) -> HashMap<(Id, Id), Id> {
let mut common_ancestor_queries = HashMap::default();
for (s_int, others) in congruence_neighbors.iter().enumerate() {
let start = &Id::from(s_int);
for other in others {
for (left, right) in self.explainfind[usize::from(*start)]
.node
.children()
.iter()
.zip(self.explainfind[usize::from(*other)].node.children().iter())
{
if left != right {
if common_ancestor_queries.get(start).is_none() {
common_ancestor_queries.insert(*start, vec![]);
}
if common_ancestor_queries.get(other).is_none() {
common_ancestor_queries.insert(*other, vec![]);
}
common_ancestor_queries.get_mut(start).unwrap().push(*other);
common_ancestor_queries.get_mut(other).unwrap().push(*start);
}
}
}
}
let mut common_ancestor = HashMap::default();
let mut unionfind = UnionFind::default();
let mut ancestor = vec![];
for i in 0..self.explainfind.len() {
unionfind.make_set();
ancestor.push(Id::from(i));
}
for (eclass, _) in classes.iter() {
let enodes = self.find_all_enodes(*eclass);
let mut children: HashMap<Id, Vec<Id>> = HashMap::default();
for enode in &enodes {
children.insert(*enode, vec![]);
}
for enode in &enodes {
if self.parent(*enode) != *enode {
children.get_mut(&self.parent(*enode)).unwrap().push(*enode);
}
}
let mut black_set = HashSet::default();
let mut parent = *enodes.iter().next().unwrap();
while parent != self.parent(parent) {
parent = self.parent(parent);
}
self.tarjan_ocla(
parent,
&children,
&common_ancestor_queries,
&mut black_set,
&mut unionfind,
&mut ancestor,
&mut common_ancestor,
);
}
common_ancestor
}
fn calculate_shortest_explanations<N: Analysis<L>>(
&mut self,
start: Id,
end: Id,
classes: &HashMap<Id, EClass<L, N::Data>>,
unionfind: &UnionFind,
) {
let mut congruence_neighbors = vec![vec![]; self.explainfind.len()];
self.find_congruence_neighbors::<N>(classes, &mut congruence_neighbors, unionfind);
let mut parent_distance = vec![(Id::from(0), 0); self.explainfind.len()];
for (i, entry) in parent_distance.iter_mut().enumerate() {
entry.0 = Id::from(i);
}
let mut distance_memo = DistanceMemo {
parent_distance,
common_ancestor: self.calculate_common_ancestor::<N>(classes, &congruence_neighbors),
tree_depth: self.calculate_tree_depths(),
};
let fuel = GREEDY_NUM_ITERS * self.explainfind.len();
self.greedy_short_explanations(start, end, &congruence_neighbors, &mut distance_memo, fuel);
}
}
#[cfg(test)]
mod tests {
use super::super::*;
#[test]
fn simple_explain() {
use SymbolLang as S;
crate::init_logger();
let mut egraph = EGraph::<S, ()>::default().with_explanations_enabled();
let fa = "(f a)".parse().unwrap();
let fb = "(f b)".parse().unwrap();
egraph.add_expr(&fa);
egraph.add_expr(&fb);
egraph.add_expr(&"c".parse().unwrap());
egraph.add_expr(&"d".parse().unwrap());
egraph.union_instantiations(
&"a".parse().unwrap(),
&"c".parse().unwrap(),
&Default::default(),
"ac".to_string(),
);
egraph.union_instantiations(
&"c".parse().unwrap(),
&"d".parse().unwrap(),
&Default::default(),
"cd".to_string(),
);
egraph.union_instantiations(
&"d".parse().unwrap(),
&"b".parse().unwrap(),
&Default::default(),
"db".to_string(),
);
egraph.rebuild();
assert_eq!(egraph.add_expr(&fa), egraph.add_expr(&fb));
assert_eq!(
egraph
.explain_equivalence(&fa, &fb)
.get_flat_strings()
.len(),
4
);
assert_eq!(
egraph
.explain_equivalence(&fa, &fb)
.get_flat_strings()
.len(),
4
);
assert_eq!(
egraph
.explain_equivalence(&fa, &fb)
.get_flat_strings()
.len(),
4
);
egraph.union_instantiations(
&"(f a)".parse().unwrap(),
&"g".parse().unwrap(),
&Default::default(),
"fag".to_string(),
);
egraph.union_instantiations(
&"g".parse().unwrap(),
&"(f b)".parse().unwrap(),
&Default::default(),
"gfb".to_string(),
);
egraph.rebuild();
egraph = egraph.without_explanation_length_optimization();
assert_eq!(
egraph
.explain_equivalence(&fa, &fb)
.get_flat_strings()
.len(),
4
);
egraph = egraph.with_explanation_length_optimization();
assert_eq!(
egraph
.explain_equivalence(&fa, &fb)
.get_flat_strings()
.len(),
3
);
assert_eq!(
egraph
.explain_equivalence(&fa, &fb)
.get_flat_strings()
.len(),
3
);
egraph.dot().to_dot("target/foo.dot").unwrap();
}
}