use crate::*;
use std::{
borrow::BorrowMut,
fmt::{self, Debug, Display},
marker::PhantomData,
};
#[cfg(feature = "serde-1")]
use serde::{Deserialize, Serialize};
use log::*;
#[derive(Clone)]
#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
pub struct EGraph<L: Language, N: Analysis<L>> {
pub analysis: N,
pub(crate) explain: Option<Explain<L>>,
unionfind: UnionFind,
nodes: Vec<L>,
#[cfg_attr(feature = "serde-1", serde(with = "vectorize"))]
memo: HashMap<L, Id>,
pending: Vec<Id>,
analysis_pending: UniqueQueue<Id>,
#[cfg_attr(
feature = "serde-1",
serde(bound(
serialize = "N::Data: Serialize",
deserialize = "N::Data: for<'a> Deserialize<'a>",
))
)]
pub(crate) classes: HashMap<Id, EClass<L, N::Data>>,
#[cfg_attr(feature = "serde-1", serde(skip))]
#[cfg_attr(feature = "serde-1", serde(default = "default_classes_by_op"))]
classes_by_op: HashMap<L::Discriminant, HashSet<Id>>,
#[cfg_attr(feature = "serde-1", serde(skip))]
pub clean: bool,
}
#[cfg(feature = "serde-1")]
fn default_classes_by_op<K>() -> HashMap<K, HashSet<Id>> {
HashMap::default()
}
impl<L: Language, N: Analysis<L> + Default> Default for EGraph<L, N> {
fn default() -> Self {
Self::new(N::default())
}
}
impl<L: Language, N: Analysis<L>> Debug for EGraph<L, N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("EGraph")
.field("memo", &self.memo)
.field("classes", &self.classes)
.finish()
}
}
impl<L: Language, N: Analysis<L>> EGraph<L, N> {
pub fn new(analysis: N) -> Self {
Self {
analysis,
classes: Default::default(),
unionfind: Default::default(),
nodes: Default::default(),
clean: false,
explain: None,
pending: Default::default(),
memo: Default::default(),
analysis_pending: Default::default(),
classes_by_op: Default::default(),
}
}
pub fn classes(&self) -> impl ExactSizeIterator<Item = &EClass<L, N::Data>> {
self.classes.values()
}
pub fn classes_mut(&mut self) -> impl ExactSizeIterator<Item = &mut EClass<L, N::Data>> {
self.classes.values_mut()
}
pub fn classes_for_op(
&self,
op: &L::Discriminant,
) -> Option<impl ExactSizeIterator<Item = Id> + '_> {
self.classes_by_op.get(op).map(|s| s.iter().copied())
}
pub fn nodes(&self) -> &[L] {
&self.nodes
}
pub fn is_empty(&self) -> bool {
self.memo.is_empty()
}
pub fn total_size(&self) -> usize {
self.memo.len()
}
pub fn total_number_of_nodes(&self) -> usize {
self.classes().map(|c| c.len()).sum()
}
pub fn number_of_classes(&self) -> usize {
self.classes.len()
}
pub fn with_explanations_enabled(mut self) -> Self {
if self.explain.is_some() {
return self;
}
if self.total_size() > 0 {
panic!("Need to set explanations enabled before adding any expressions to the egraph.");
}
self.explain = Some(Explain::new());
self
}
pub fn without_explanation_length_optimization(mut self) -> Self {
if let Some(explain) = &mut self.explain {
explain.optimize_explanation_lengths = false;
self
} else {
panic!("Need to set explanations enabled before setting length optimization.");
}
}
pub fn with_explanation_length_optimization(mut self) -> Self {
if let Some(explain) = &mut self.explain {
explain.optimize_explanation_lengths = true;
self
} else {
panic!("Need to set explanations enabled before setting length optimization.");
}
}
pub fn copy_without_unions(&self, analysis: N) -> Self {
if self.explain.is_none() {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get a copied egraph without unions");
}
let mut egraph = Self::new(analysis);
for node in &self.nodes {
egraph.add(node.clone());
}
egraph
}
pub fn egraph_union(&mut self, other: &EGraph<L, N>) {
let right_unions = other.get_union_equalities();
for (left, right, why) in right_unions {
self.union_instantiations(
&other.id_to_pattern(left, &Default::default()).0.ast,
&other.id_to_pattern(right, &Default::default()).0.ast,
&Default::default(),
why,
);
}
self.rebuild();
}
fn from_enodes(enodes: Vec<(L, Id)>, analysis: N) -> Self {
let mut egraph = Self::new(analysis);
let mut ids: HashMap<Id, Id> = Default::default();
loop {
let mut did_something = false;
for (enode, id) in &enodes {
let valid = enode.children().iter().all(|c| ids.contains_key(c));
if !valid {
continue;
}
let mut enode = enode.clone().map_children(|c| ids[&c]);
if egraph.lookup(&mut enode).is_some() {
continue;
}
let added = egraph.add(enode);
if let Some(existing) = ids.get(id) {
egraph.union(*existing, added);
} else {
ids.insert(*id, added);
}
did_something = true;
}
if !did_something {
break;
}
}
egraph
}
pub fn egraph_intersect(&self, other: &EGraph<L, N>, analysis: N) -> EGraph<L, N> {
let mut product_map: HashMap<(Id, Id), Id> = Default::default();
let mut enodes = vec![];
for class1 in self.classes() {
for class2 in other.classes() {
self.intersect_classes(other, &mut enodes, class1.id, class2.id, &mut product_map);
}
}
Self::from_enodes(enodes, analysis)
}
fn get_product_id(class1: Id, class2: Id, product_map: &mut HashMap<(Id, Id), Id>) -> Id {
if let Some(id) = product_map.get(&(class1, class2)) {
*id
} else {
let id = Id::from(product_map.len());
product_map.insert((class1, class2), id);
id
}
}
fn intersect_classes(
&self,
other: &EGraph<L, N>,
res: &mut Vec<(L, Id)>,
class1: Id,
class2: Id,
product_map: &mut HashMap<(Id, Id), Id>,
) {
let res_id = Self::get_product_id(class1, class2, product_map);
for node1 in &self.classes[&class1].nodes {
for node2 in &other.classes[&class2].nodes {
if node1.matches(node2) {
let children1 = node1.children();
let children2 = node2.children();
let mut new_node = node1.clone();
let children = new_node.children_mut();
for (i, (child1, child2)) in children1.iter().zip(children2.iter()).enumerate()
{
let prod = Self::get_product_id(
self.find(*child1),
other.find(*child2),
product_map,
);
children[i] = prod;
}
res.push((new_node, res_id));
}
}
}
}
pub fn id_to_expr(&self, id: Id) -> RecExpr<L> {
let mut res = Default::default();
let mut cache = Default::default();
self.id_to_expr_internal(&mut res, id, &mut cache);
res
}
fn id_to_expr_internal(
&self,
res: &mut RecExpr<L>,
node_id: Id,
cache: &mut HashMap<Id, Id>,
) -> Id {
if let Some(existing) = cache.get(&node_id) {
return *existing;
}
let new_node = self
.id_to_node(node_id)
.clone()
.map_children(|child| self.id_to_expr_internal(res, child, cache));
let res_id = res.add(new_node);
cache.insert(node_id, res_id);
res_id
}
pub fn id_to_node(&self, id: Id) -> &L {
&self.nodes[usize::from(id)]
}
pub fn id_to_pattern(&self, id: Id, substitutions: &HashMap<Id, Id>) -> (Pattern<L>, Subst) {
let mut res = Default::default();
let mut subst = Default::default();
let mut cache = Default::default();
self.id_to_pattern_internal(&mut res, id, substitutions, &mut subst, &mut cache);
(Pattern::new(res), subst)
}
fn id_to_pattern_internal(
&self,
res: &mut PatternAst<L>,
node_id: Id,
var_substitutions: &HashMap<Id, Id>,
subst: &mut Subst,
cache: &mut HashMap<Id, Id>,
) -> Id {
if let Some(existing) = cache.get(&node_id) {
return *existing;
}
let res_id = if let Some(existing) = var_substitutions.get(&node_id) {
let var = format!("?{}", node_id).parse().unwrap();
subst.insert(var, *existing);
res.add(ENodeOrVar::Var(var))
} else {
let new_node = self.id_to_node(node_id).clone().map_children(|child| {
self.id_to_pattern_internal(res, child, var_substitutions, subst, cache)
});
res.add(ENodeOrVar::ENode(new_node))
};
cache.insert(node_id, res_id);
res_id
}
pub fn get_union_equalities(&self) -> UnionEqualities {
if let Some(explain) = &self.explain {
explain.get_union_equalities()
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get union equalities");
}
}
pub fn with_explanations_disabled(mut self) -> Self {
self.explain = None;
self
}
pub fn are_explanations_enabled(&self) -> bool {
self.explain.is_some()
}
pub fn get_num_congr(&mut self) -> usize {
if let Some(explain) = &mut self.explain {
explain
.with_nodes(&self.nodes)
.get_num_congr::<N>(&self.classes, &self.unionfind)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
}
pub fn get_explanation_num_nodes(&mut self) -> usize {
if let Some(explain) = &mut self.explain {
explain.with_nodes(&self.nodes).get_num_nodes()
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
}
pub fn explain_equivalence(
&mut self,
left_expr: &RecExpr<L>,
right_expr: &RecExpr<L>,
) -> Explanation<L> {
let left = self.add_expr_uncanonical(left_expr);
let right = self.add_expr_uncanonical(right_expr);
self.explain_id_equivalence(left, right)
}
pub fn explain_id_equivalence(&mut self, left: Id, right: Id) -> Explanation<L> {
if self.find(left) != self.find(right) {
panic!(
"Tried to explain equivalence between non-equal terms {:?} and {:?}",
self.id_to_expr(left),
self.id_to_expr(right)
);
}
if let Some(explain) = &mut self.explain {
explain.with_nodes(&self.nodes).explain_equivalence::<N>(
left,
right,
&mut self.unionfind,
&self.classes,
)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
}
pub fn explain_matches(
&mut self,
left_expr: &RecExpr<L>,
right_pattern: &PatternAst<L>,
subst: &Subst,
) -> Explanation<L> {
let left = self.add_expr_uncanonical(left_expr);
let right = self.add_instantiation_noncanonical(right_pattern, subst);
if self.find(left) != self.find(right) {
panic!(
"Tried to explain equivalence between non-equal terms {:?} and {:?}",
left_expr, right_pattern
);
}
if let Some(explain) = &mut self.explain {
explain.with_nodes(&self.nodes).explain_equivalence::<N>(
left,
right,
&mut self.unionfind,
&self.classes,
)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.");
}
}
pub fn find(&self, id: Id) -> Id {
self.unionfind.find(id)
}
fn find_mut(&mut self, id: Id) -> Id {
self.unionfind.find_mut(id)
}
pub fn dot(&self) -> Dot<L, N> {
Dot {
egraph: self,
config: vec![],
use_anchors: true,
}
}
}
pub trait LanguageMapper<L, A>
where
L: Language,
A: Analysis<L>,
{
type L2: Language;
type A2: Analysis<Self::L2>;
fn map_node(&self, node: L) -> Self::L2;
fn map_discriminant(
&self,
discriminant: L::Discriminant,
) -> <Self::L2 as Language>::Discriminant;
fn map_analysis(&self, analysis: A) -> Self::A2;
fn map_data(&self, data: A::Data) -> <Self::A2 as Analysis<Self::L2>>::Data;
fn map_eclass(
&self,
src_eclass: EClass<L, A::Data>,
) -> EClass<Self::L2, <Self::A2 as Analysis<Self::L2>>::Data> {
EClass {
id: src_eclass.id,
nodes: src_eclass
.nodes
.into_iter()
.map(|l| self.map_node(l))
.collect(),
data: self.map_data(src_eclass.data),
parents: src_eclass.parents,
}
}
fn map_egraph(&self, src_egraph: EGraph<L, A>) -> EGraph<Self::L2, Self::A2> {
let kv_map = |(k, v): (L, Id)| (self.map_node(k), v);
EGraph {
analysis: self.map_analysis(src_egraph.analysis),
explain: None,
unionfind: src_egraph.unionfind,
memo: src_egraph.memo.into_iter().map(kv_map).collect(),
pending: src_egraph.pending,
nodes: src_egraph
.nodes
.into_iter()
.map(|x| self.map_node(x))
.collect(),
analysis_pending: src_egraph.analysis_pending,
classes: src_egraph
.classes
.into_iter()
.map(|(id, eclass)| (id, self.map_eclass(eclass)))
.collect(),
classes_by_op: src_egraph
.classes_by_op
.into_iter()
.map(|(k, v)| (self.map_discriminant(k), v))
.collect(),
clean: src_egraph.clean,
}
}
}
pub struct SimpleLanguageMapper<L2, A2> {
_phantom: PhantomData<(L2, A2)>,
}
impl<L, A> Default for SimpleLanguageMapper<L, A> {
fn default() -> Self {
SimpleLanguageMapper {
_phantom: PhantomData,
}
}
}
impl<L, A, L2, A2> LanguageMapper<L, A> for SimpleLanguageMapper<L2, A2>
where
L: Language,
A: Analysis<L>,
L2: Language + From<L>,
A2: Analysis<L2> + From<A>,
<L2 as Language>::Discriminant: From<<L as Language>::Discriminant>,
<A2 as Analysis<L2>>::Data: From<<A as Analysis<L>>::Data>,
{
type L2 = L2;
type A2 = A2;
fn map_node(&self, node: L) -> Self::L2 {
node.into()
}
fn map_discriminant(
&self,
discriminant: <L as Language>::Discriminant,
) -> <Self::L2 as Language>::Discriminant {
discriminant.into()
}
fn map_analysis(&self, analysis: A) -> Self::A2 {
analysis.into()
}
fn map_data(&self, data: <A as Analysis<L>>::Data) -> <Self::A2 as Analysis<Self::L2>>::Data {
data.into()
}
}
impl<L: Language, N: Analysis<L>> std::ops::Index<Id> for EGraph<L, N> {
type Output = EClass<L, N::Data>;
fn index(&self, id: Id) -> &Self::Output {
let id = self.find(id);
self.classes
.get(&id)
.unwrap_or_else(|| panic!("Invalid id {}", id))
}
}
impl<L: Language, N: Analysis<L>> std::ops::IndexMut<Id> for EGraph<L, N> {
fn index_mut(&mut self, id: Id) -> &mut Self::Output {
let id = self.find_mut(id);
self.classes
.get_mut(&id)
.unwrap_or_else(|| panic!("Invalid id {}", id))
}
}
impl<L: Language, N: Analysis<L>> EGraph<L, N> {
pub fn add_expr(&mut self, expr: &RecExpr<L>) -> Id {
let id = self.add_expr_uncanonical(expr);
self.find(id)
}
pub fn add_expr_uncanonical(&mut self, expr: &RecExpr<L>) -> Id {
let mut new_ids = Vec::with_capacity(expr.len());
let mut new_node_q = Vec::with_capacity(expr.len());
for node in expr {
let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]);
let size_before = self.unionfind.size();
let next_id = self.add_uncanonical(new_node);
if self.unionfind.size() > size_before {
new_node_q.push(true);
} else {
new_node_q.push(false);
}
new_ids.push(next_id);
}
*new_ids.last().unwrap()
}
pub fn add_instantiation(&mut self, pat: &PatternAst<L>, subst: &Subst) -> Id {
let id = self.add_instantiation_noncanonical(pat, subst);
self.find(id)
}
fn add_instantiation_noncanonical(&mut self, pat: &PatternAst<L>, subst: &Subst) -> Id {
let mut new_ids = Vec::with_capacity(pat.len());
let mut new_node_q = Vec::with_capacity(pat.len());
for node in pat {
match node {
ENodeOrVar::Var(var) => {
let id = self.find(subst[*var]);
new_ids.push(id);
new_node_q.push(false);
}
ENodeOrVar::ENode(node) => {
let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]);
let size_before = self.unionfind.size();
let next_id = self.add_uncanonical(new_node);
if self.unionfind.size() > size_before {
new_node_q.push(true);
} else {
new_node_q.push(false);
}
new_ids.push(next_id);
}
}
}
*new_ids.last().unwrap()
}
pub fn lookup<B>(&self, enode: B) -> Option<Id>
where
B: BorrowMut<L>,
{
self.lookup_internal(enode).map(|id| self.find(id))
}
fn lookup_internal<B>(&self, mut enode: B) -> Option<Id>
where
B: BorrowMut<L>,
{
let enode = enode.borrow_mut();
enode.update_children(|id| self.find(id));
self.memo.get(enode).copied()
}
pub fn lookup_expr(&self, expr: &RecExpr<L>) -> Option<Id> {
self.lookup_expr_ids(expr)
.and_then(|ids| ids.last().copied())
}
pub fn lookup_expr_ids(&self, expr: &RecExpr<L>) -> Option<Vec<Id>> {
let mut new_ids = Vec::with_capacity(expr.len());
for node in expr {
let node = node.clone().map_children(|i| new_ids[usize::from(i)]);
let id = self.lookup(node)?;
new_ids.push(id)
}
Some(new_ids)
}
pub fn add(&mut self, enode: L) -> Id {
let id = self.add_uncanonical(enode);
self.find(id)
}
pub fn add_uncanonical(&mut self, mut enode: L) -> Id {
let original = enode.clone();
if let Some(existing_id) = self.lookup_internal(&mut enode) {
let id = self.find(existing_id);
if let Some(explain) = self.explain.as_mut() {
if let Some(existing_explain) = explain.uncanon_memo.get(&original) {
*existing_explain
} else {
let new_id = self.unionfind.make_set();
explain.add(original.clone(), new_id);
debug_assert_eq!(Id::from(self.nodes.len()), new_id);
self.nodes.push(original);
self.unionfind.union(id, new_id);
explain.union(existing_id, new_id, Justification::Congruence);
new_id
}
} else {
existing_id
}
} else {
let id = self.make_new_eclass(enode, original.clone());
if let Some(explain) = self.explain.as_mut() {
explain.add(original, id);
}
N::modify(self, id);
self.clean = false;
id
}
}
fn make_new_eclass(&mut self, enode: L, original: L) -> Id {
let id = self.unionfind.make_set();
log::trace!(" ...adding to {}", id);
let class = EClass {
id,
nodes: vec![enode.clone()],
data: N::make(self, &original, id),
parents: Default::default(),
};
debug_assert_eq!(Id::from(self.nodes.len()), id);
self.nodes.push(original);
enode.for_each(|child| {
self[child].parents.push(id);
});
self.pending.push(id);
self.classes.insert(id, class);
assert!(self.memo.insert(enode, id).is_none());
id
}
pub fn equivs(&self, expr1: &RecExpr<L>, expr2: &RecExpr<L>) -> Vec<Id> {
let pat1 = Pattern::from(expr1);
let pat2 = Pattern::from(expr2);
let matches1 = pat1.search(self);
trace!("Matches1: {:?}", matches1);
let matches2 = pat2.search(self);
trace!("Matches2: {:?}", matches2);
let mut equiv_eclasses = Vec::new();
for m1 in &matches1 {
for m2 in &matches2 {
if self.find(m1.eclass) == self.find(m2.eclass) {
equiv_eclasses.push(m1.eclass)
}
}
}
equiv_eclasses
}
pub fn union_instantiations(
&mut self,
from_pat: &PatternAst<L>,
to_pat: &PatternAst<L>,
subst: &Subst,
rule_name: impl Into<Symbol>,
) -> (Id, bool) {
let id1 = self.add_instantiation_noncanonical(from_pat, subst);
let id2 = self.add_instantiation_noncanonical(to_pat, subst);
let did_union = self.perform_union(id1, id2, Some(Justification::Rule(rule_name.into())));
(self.find(id1), did_union)
}
pub fn union_trusted(&mut self, from: Id, to: Id, reason: impl Into<Symbol>) -> bool {
self.perform_union(from, to, Some(Justification::Rule(reason.into())))
}
#[track_caller]
pub fn union(&mut self, id1: Id, id2: Id) -> bool {
if self.explain.is_some() {
let caller = std::panic::Location::caller();
self.union_trusted(id1, id2, caller.to_string())
} else {
self.perform_union(id1, id2, None)
}
}
fn perform_union(&mut self, enode_id1: Id, enode_id2: Id, rule: Option<Justification>) -> bool {
N::pre_union(self, enode_id1, enode_id2, &rule);
self.clean = false;
let mut id1 = self.find_mut(enode_id1);
let mut id2 = self.find_mut(enode_id2);
if id1 == id2 {
if let Some(Justification::Rule(_)) = rule {
if let Some(explain) = &mut self.explain {
explain.alternate_rewrite(enode_id1, enode_id2, rule.unwrap());
}
}
return false;
}
let class1_parents = self.classes[&id1].parents.len();
let class2_parents = self.classes[&id2].parents.len();
if class1_parents < class2_parents {
std::mem::swap(&mut id1, &mut id2);
}
if let Some(explain) = &mut self.explain {
explain.union(enode_id1, enode_id2, rule.unwrap());
}
self.unionfind.union(id1, id2);
assert_ne!(id1, id2);
let class2 = self.classes.remove(&id2).unwrap();
let class1 = self.classes.get_mut(&id1).unwrap();
assert_eq!(id1, class1.id);
self.pending.extend(class2.parents.iter().copied());
let did_merge = self.analysis.merge(&mut class1.data, class2.data);
if did_merge.0 {
self.analysis_pending.extend(class1.parents.iter().copied());
}
if did_merge.1 {
self.analysis_pending.extend(class2.parents.iter().copied());
}
concat_vecs(&mut class1.nodes, class2.nodes);
concat_vecs(&mut class1.parents, class2.parents);
N::modify(self, id1);
true
}
pub fn set_analysis_data(&mut self, id: Id, new_data: N::Data) {
let id = self.find_mut(id);
let class = self.classes.get_mut(&id).unwrap();
class.data = new_data;
self.analysis_pending.extend(class.parents.iter().copied());
N::modify(self, id)
}
pub fn dump(&self) -> impl Debug + '_ {
EGraphDump(self)
}
}
impl<L: Language + Display, N: Analysis<L>> EGraph<L, N> {
pub fn check_goals(&self, id: Id, goals: &[Pattern<L>]) {
let (cost, best) = Extractor::new(self, AstSize).find_best(id);
println!("End ({}): {}", cost, best.pretty(80));
for (i, goal) in goals.iter().enumerate() {
println!("Trying to prove goal {}: {}", i, goal.pretty(40));
let matches = goal.search_eclass(self, id);
if matches.is_none() {
let best = Extractor::new(self, AstSize).find_best(id).1;
panic!(
"Could not prove goal {}:\n\
{}\n\
Best thing found:\n\
{}",
i,
goal.pretty(40),
best.pretty(40),
);
}
}
}
}
impl<L: Language, N: Analysis<L>> EGraph<L, N> {
#[inline(never)]
fn rebuild_classes(&mut self) -> usize {
let mut classes_by_op = std::mem::take(&mut self.classes_by_op);
classes_by_op.values_mut().for_each(|ids| ids.clear());
let mut trimmed = 0;
let uf = &mut self.unionfind;
for class in self.classes.values_mut() {
let old_len = class.len();
class
.nodes
.iter_mut()
.for_each(|n| n.update_children(|id| uf.find_mut(id)));
class.nodes.sort_unstable();
class.nodes.dedup();
trimmed += old_len - class.nodes.len();
let mut add = |n: &L| {
classes_by_op
.entry(n.discriminant())
.or_default()
.insert(class.id)
};
let mut nodes = class.nodes.iter();
if let Some(mut prev) = nodes.next() {
add(prev);
for n in nodes {
if !prev.matches(n) {
add(n);
prev = n;
}
}
}
}
#[cfg(debug_assertions)]
for ids in classes_by_op.values_mut() {
let unique: HashSet<Id> = ids.iter().copied().collect();
assert_eq!(ids.len(), unique.len());
}
self.classes_by_op = classes_by_op;
trimmed
}
#[inline(never)]
fn check_memo(&self) -> bool {
let mut test_memo = HashMap::default();
for (&id, class) in self.classes.iter() {
assert_eq!(class.id, id);
for node in &class.nodes {
if let Some(old) = test_memo.insert(node, id) {
assert_eq!(
self.find(old),
self.find(id),
"Found unexpected equivalence for {:?}\n{:?}\nvs\n{:?}",
node,
self[self.find(id)].nodes,
self[self.find(old)].nodes,
);
}
}
}
for (n, e) in test_memo {
assert_eq!(e, self.find(e));
assert_eq!(
Some(e),
self.memo.get(n).map(|id| self.find(*id)),
"Entry for {:?} at {} in test_memo was incorrect",
n,
e
);
}
true
}
#[inline(never)]
fn process_unions(&mut self) -> usize {
let mut n_unions = 0;
while !self.pending.is_empty() || !self.analysis_pending.is_empty() {
while let Some(class) = self.pending.pop() {
let mut node = self.nodes[usize::from(class)].clone();
node.update_children(|id| self.find_mut(id));
if let Some(memo_class) = self.memo.insert(node, class) {
let did_something =
self.perform_union(memo_class, class, Some(Justification::Congruence));
n_unions += did_something as usize;
}
}
while let Some(class_id) = self.analysis_pending.pop() {
let node = self.nodes[usize::from(class_id)].clone();
let class_id = self.find_mut(class_id);
let node_data = N::remake(self, &node, class_id);
let class = self.classes.get_mut(&class_id).unwrap();
let did_merge = self.analysis.merge(&mut class.data, node_data);
if did_merge.0 {
self.analysis_pending.extend(class.parents.iter().copied());
N::modify(self, class_id)
}
}
}
assert!(self.pending.is_empty());
assert!(self.analysis_pending.is_empty());
n_unions
}
pub fn rebuild(&mut self) -> usize {
let old_hc_size = self.memo.len();
let old_n_eclasses = self.number_of_classes();
let start = Instant::now();
let n_unions = self.process_unions();
let trimmed_nodes = self.rebuild_classes();
let elapsed = start.elapsed();
info!(
concat!(
"REBUILT! in {}.{:03}s\n",
" Old: hc size {}, eclasses: {}\n",
" New: hc size {}, eclasses: {}\n",
" unions: {}, trimmed nodes: {}"
),
elapsed.as_secs(),
elapsed.subsec_millis(),
old_hc_size,
old_n_eclasses,
self.memo.len(),
self.number_of_classes(),
n_unions,
trimmed_nodes,
);
debug_assert!(self.check_memo());
self.clean = true;
n_unions
}
pub(crate) fn check_each_explain(&mut self, rules: &[&Rewrite<L, N>]) -> bool {
if let Some(explain) = &mut self.explain {
explain.with_nodes(&self.nodes).check_each_explain(rules)
} else {
panic!("Can't check explain when explanations are off");
}
}
}
struct EGraphDump<'a, L: Language, N: Analysis<L>>(&'a EGraph<L, N>);
impl<'a, L: Language, N: Analysis<L>> Debug for EGraphDump<'a, L, N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut ids: Vec<Id> = self.0.classes().map(|c| c.id).collect();
ids.sort();
for id in ids {
let mut nodes = self.0[id].nodes.clone();
nodes.sort();
writeln!(f, "{} ({:?}): {:?}", id, self.0[id].data, nodes)?
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn simple_add() {
use SymbolLang as S;
crate::init_logger();
let mut egraph = EGraph::<S, ()>::default();
let x = egraph.add(S::leaf("x"));
let x2 = egraph.add(S::leaf("x"));
let _plus = egraph.add(S::new("+", vec![x, x2]));
egraph.union_instantiations(
&"x".parse().unwrap(),
&"y".parse().unwrap(),
&Default::default(),
"union x and y".to_string(),
);
egraph.rebuild();
}
#[cfg(all(feature = "serde-1", feature = "serde_json"))]
#[test]
fn test_serde() {
fn ser(_: &impl Serialize) {}
fn de<'a>(_: &impl Deserialize<'a>) {}
let mut egraph = EGraph::<SymbolLang, ()>::default();
egraph.add_expr(&"(foo bar baz)".parse().unwrap());
ser(&egraph);
de(&egraph);
let json_rep = serde_json::to_string_pretty(&egraph).unwrap();
println!("{}", json_rep);
}
}