use std::collections::HashMap;
use std::{
borrow::BorrowMut,
fmt::{self, Debug},
};
use indexmap::IndexMap;
use log::*;
use crate::{
Analysis, AstSize, Dot, EClass, Extractor, Id, Language, Pattern, RecExpr, Searcher, UnionFind,
};
#[derive(Clone)]
pub struct EGraph<L: Language, N: Analysis<L>> {
pub analysis: N,
memo: HashMap<L, Id>,
unionfind: UnionFind,
classes: SparseVec<EClass<L, N::Data>>,
dirty_unions: Vec<Id>,
repairs_since_rebuild: usize,
pub(crate) classes_by_op: IndexMap<std::mem::Discriminant<L>, indexmap::IndexSet<Id>>,
}
type SparseVec<T> = Vec<Option<Box<T>>>;
impl<L: Language, N: Analysis<L> + Default> Default for EGraph<L, N> {
fn default() -> Self {
Self::new(N::default())
}
}
impl<L: Language, N: Analysis<L>> Debug for EGraph<L, N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("EGraph")
.field("memo", &self.memo)
.field("classes", &self.classes)
.finish()
}
}
impl<L: Language, N: Analysis<L>> EGraph<L, N> {
pub fn new(analysis: N) -> Self {
Self {
analysis,
memo: Default::default(),
classes: Default::default(),
unionfind: Default::default(),
dirty_unions: Default::default(),
classes_by_op: IndexMap::default(),
repairs_since_rebuild: 0,
}
}
pub fn classes(&self) -> impl Iterator<Item = &EClass<L, N::Data>> {
self.classes
.iter()
.filter_map(Option::as_ref)
.map(AsRef::as_ref)
}
pub fn classes_mut(&mut self) -> impl Iterator<Item = &mut EClass<L, N::Data>> {
self.classes
.iter_mut()
.filter_map(Option::as_mut)
.map(AsMut::as_mut)
}
pub fn is_empty(&self) -> bool {
self.memo.is_empty()
}
pub fn total_size(&self) -> usize {
self.memo.len()
}
pub fn total_number_of_nodes(&self) -> usize {
self.classes().map(|c| c.len()).sum()
}
pub fn number_of_classes(&self) -> usize {
self.classes().count()
}
pub fn find(&self, id: Id) -> Id {
self.unionfind.find(id)
}
pub fn dot(&self) -> Dot<L, N> {
Dot { egraph: self }
}
}
impl<L: Language, N: Analysis<L>> std::ops::Index<Id> for EGraph<L, N> {
type Output = EClass<L, N::Data>;
fn index(&self, id: Id) -> &Self::Output {
let id = self.find(id);
self.classes[usize::from(id)]
.as_ref()
.unwrap_or_else(|| panic!("Invalid id {}", id))
}
}
impl<L: Language, N: Analysis<L>> std::ops::IndexMut<Id> for EGraph<L, N> {
fn index_mut(&mut self, id: Id) -> &mut Self::Output {
let id = self.find(id);
self.classes[usize::from(id)]
.as_mut()
.unwrap_or_else(|| panic!("Invalid id {}", id))
}
}
impl<L: Language, N: Analysis<L>> EGraph<L, N> {
pub fn add_expr(&mut self, expr: &RecExpr<L>) -> Id {
self.add_expr_rec(expr.as_ref())
}
fn add_expr_rec(&mut self, expr: &[L]) -> Id {
log::trace!("Adding expr {:?}", expr);
let e = expr.last().unwrap().clone().map_children(|i| {
let child = &expr[..usize::from(i) + 1];
self.add_expr_rec(child)
});
let id = self.add(e);
log::trace!("Added!! expr {:?}", expr);
id
}
pub fn lookup<B>(&self, mut enode: B) -> Option<Id>
where
B: BorrowMut<L>,
{
let enode = enode.borrow_mut();
enode.update_children(|id| self.find(id));
let id = self.memo.get(enode);
id.map(|&id| self.find(id))
}
pub fn add(&mut self, mut enode: L) -> Id {
self.lookup(&mut enode).unwrap_or_else(|| {
let id = self.unionfind.make_set();
log::trace!(" ...adding to {}", id);
let class = Box::new(EClass {
id,
nodes: vec![enode.clone()],
data: N::make(self, &enode),
parents: Default::default(),
});
enode.for_each(|child| {
let tup = (enode.clone(), id);
self[child].parents.push(tup);
});
assert_eq!(self.classes.len(), usize::from(id));
self.classes.push(Some(class));
assert!(self.memo.insert(enode, id).is_none());
N::modify(self, id);
id
})
}
pub fn equivs(&self, expr1: &RecExpr<L>, expr2: &RecExpr<L>) -> Vec<Id> {
let matches1 = Pattern::from(expr1.as_ref()).search(self);
trace!("Matches1: {:?}", matches1);
let matches2 = Pattern::from(expr2.as_ref()).search(self);
trace!("Matches2: {:?}", matches2);
let mut equiv_eclasses = Vec::new();
for m1 in &matches1 {
for m2 in &matches2 {
if self.find(m1.eclass) == self.find(m2.eclass) {
equiv_eclasses.push(m1.eclass)
}
}
}
equiv_eclasses
}
pub fn check_goals(&self, id: Id, goals: &[Pattern<L>]) {
let (cost, best) = Extractor::new(self, AstSize).find_best(id);
println!("End ({}): {}", cost, best.pretty(80));
for (i, goal) in goals.iter().enumerate() {
println!("Trying to prove goal {}: {}", i, goal.pretty(40));
let matches = goal.search_eclass(&self, id);
if matches.is_none() {
let best = Extractor::new(&self, AstSize).find_best(id).1;
panic!(
"Could not prove goal {}:\n{}\nBest thing found:\n{}",
i,
goal.pretty(40),
best.pretty(40),
);
}
}
}
#[inline]
fn union_impl(&mut self, id1: Id, id2: Id) -> (Id, bool) {
fn concat<T>(to: &mut Vec<T>, mut from: Vec<T>) {
if to.len() < from.len() {
std::mem::swap(to, &mut from)
}
to.extend(from);
}
let (to, from) = self.unionfind.union(id1, id2);
debug_assert_eq!(to, self.find(id1));
debug_assert_eq!(to, self.find(id2));
if to != from {
self.dirty_unions.push(to);
let from_class = self.classes[usize::from(from)].take().unwrap();
let to_class = self.classes[usize::from(to)].as_mut().unwrap();
self.analysis.merge(&mut to_class.data, from_class.data);
concat(&mut to_class.nodes, from_class.nodes);
concat(&mut to_class.parents, from_class.parents);
N::modify(self, to);
}
(to, to != from)
}
pub fn union(&mut self, id1: Id, id2: Id) -> (Id, bool) {
let union = self.union_impl(id1, id2);
if union.1 && cfg!(feature = "upward-merging") {
self.process_unions();
}
union
}
pub fn dump<'a>(&'a self) -> impl Debug + 'a {
EGraphDump(self)
}
}
impl<L: Language, N: Analysis<L>> EGraph<L, N> {
#[inline(never)]
fn rebuild_classes(&mut self) -> usize {
let mut classes_by_op = std::mem::take(&mut self.classes_by_op);
classes_by_op.values_mut().for_each(|ids| ids.clear());
let mut trimmed = 0;
let uf = &self.unionfind;
for class in self.classes.iter_mut().filter_map(Option::as_mut) {
let old_len = class.len();
class
.nodes
.iter_mut()
.for_each(|n| n.update_children(|id| uf.find(id)));
class.nodes.sort_unstable();
class.nodes.dedup();
trimmed += old_len - class.nodes.len();
let mut add = |n: &L| {
#[allow(clippy::mem_discriminant_non_enum)]
classes_by_op
.entry(std::mem::discriminant(&n))
.or_default()
.insert(class.id)
};
let mut nodes = class.nodes.iter();
if let Some(mut prev) = nodes.next() {
add(prev);
for n in nodes {
if !prev.matches(n) {
add(n);
prev = n;
}
}
}
}
#[cfg(debug_assertions)]
for ids in classes_by_op.values_mut() {
let unique: indexmap::IndexSet<Id> = ids.iter().copied().collect();
assert_eq!(ids.len(), unique.len());
}
self.classes_by_op = classes_by_op;
trimmed
}
#[inline(never)]
fn check_memo(&self) -> bool {
let mut test_memo = IndexMap::new();
for (id, class) in self.classes.iter().enumerate() {
let id = Id::from(id);
let class = match class.as_ref() {
Some(class) => class,
None => continue,
};
assert_eq!(class.id, id);
for node in &class.nodes {
if let Some(old) = test_memo.insert(node, id) {
assert_eq!(
self.find(old),
self.find(id),
"Found unexpected equivalence for {:?}\n{:?}\nvs\n{:?}",
node,
self[self.find(id)].nodes,
self[self.find(old)].nodes,
);
}
}
}
for (n, e) in test_memo {
assert_eq!(e, self.find(e));
assert_eq!(
Some(e),
self.memo.get(n).map(|id| self.find(*id)),
"Entry for {:?} at {} in test_memo was incorrect",
n,
e
);
}
true
}
#[inline(never)]
fn process_unions(&mut self) {
let mut to_union = vec![];
while !self.dirty_unions.is_empty() {
let mut todo = std::mem::take(&mut self.dirty_unions);
todo.iter_mut().for_each(|id| *id = self.find(*id));
if cfg!(not(feature = "upward-merging")) {
todo.sort_unstable();
todo.dedup();
}
assert!(!todo.is_empty());
for id in todo {
self.repairs_since_rebuild += 1;
let mut parents = std::mem::take(&mut self[id].parents);
for (n, _e) in &parents {
self.memo.remove(n);
}
parents.iter_mut().for_each(|(n, id)| {
n.update_children(|child| self.find(child));
*id = self.find(*id);
});
parents.sort_unstable();
parents.dedup_by(|(n1, e1), (n2, e2)| {
n1 == n2 && {
to_union.push((*e1, *e2));
true
}
});
for (n, e) in &parents {
if let Some(old) = self.memo.insert(n.clone(), *e) {
to_union.push((old, *e));
}
}
self.propagate_metadata(&parents);
self[id].parents = parents;
N::modify(self, id);
}
for (id1, id2) in to_union.drain(..) {
let (to, did_something) = self.union_impl(id1, id2);
if did_something {
self.dirty_unions.push(to);
}
}
}
assert!(self.dirty_unions.is_empty());
assert!(to_union.is_empty());
}
pub fn rebuild(&mut self) -> usize {
let old_hc_size = self.memo.len();
let old_n_eclasses = self.number_of_classes();
let start = instant::Instant::now();
self.process_unions();
let n_unions = std::mem::take(&mut self.repairs_since_rebuild);
let trimmed_nodes = self.rebuild_classes();
let elapsed = start.elapsed();
info!(
concat!(
"REBUILT! in {}.{:03}s\n",
" Old: hc size {}, eclasses: {}\n",
" New: hc size {}, eclasses: {}\n",
" unions: {}, trimmed nodes: {}"
),
elapsed.as_secs(),
elapsed.subsec_millis(),
old_hc_size,
old_n_eclasses,
self.memo.len(),
self.number_of_classes(),
n_unions,
trimmed_nodes,
);
debug_assert!(self.check_memo());
n_unions
}
#[inline(never)]
fn propagate_metadata(&mut self, parents: &[(L, Id)]) {
for (n, e) in parents {
let e = self.find(*e);
let node_data = N::make(self, n);
let class = self.classes[usize::from(e)].as_mut().unwrap();
if self.analysis.merge(&mut class.data, node_data) {
let e_parents = std::mem::take(&mut class.parents);
self.propagate_metadata(&e_parents);
self[e].parents = e_parents;
N::modify(self, e)
}
}
}
}
struct EGraphDump<'a, L: Language, N: Analysis<L>>(&'a EGraph<L, N>);
impl<'a, L: Language, N: Analysis<L>> Debug for EGraphDump<'a, L, N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut ids: Vec<Id> = self.0.classes().map(|c| c.id).collect();
ids.sort();
for id in ids {
let mut nodes = self.0[id].nodes.clone();
nodes.sort();
writeln!(f, "{}: {:?}", id, nodes)?
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::*;
#[test]
fn simple_add() {
use SymbolLang as S;
crate::init_logger();
let mut egraph = EGraph::<S, ()>::default();
let x = egraph.add(S::leaf("x"));
let x2 = egraph.add(S::leaf("x"));
let _plus = egraph.add(S::new("+", vec![x, x2]));
let y = egraph.add(S::leaf("y"));
egraph.union(x, y);
egraph.rebuild();
egraph.dot().to_dot("target/foo.dot").unwrap();
assert_eq!(2 + 2, 4);
}
}