use log::*;
use std::time::Instant;
use crate::{
egraph::{AddResult, EGraph, Metadata},
expr::{Expr, Id, Language, RecExpr},
util::HashMap,
};
#[derive(Debug, PartialEq, Clone)]
pub enum Pattern<L: Language> {
Expr(Expr<L, Pattern<L>>),
Wildcard(L::Wildcard),
}
impl<L: Language> Pattern<L> {
pub fn from_expr(e: &RecExpr<L>) -> Self {
Pattern::Expr(e.as_ref().map_children(|child| Pattern::from_expr(&child)))
}
}
#[derive(Debug, PartialEq, Clone)]
pub struct Rewrite<L: Language> {
pub name: String,
pub lhs: Pattern<L>,
pub rhs: Pattern<L>,
}
impl<L: Language> Rewrite<L> {
pub fn flip(&self) -> Self {
Rewrite {
name: format!("{}-flipped", self.name),
lhs: self.rhs.clone(),
rhs: self.lhs.clone(),
}
}
pub fn run<M: Metadata<L>>(&self, egraph: &mut EGraph<L, M>) {
debug!("Running rewrite '{}'", self.name);
let matches = self.lhs.search(&egraph);
debug!(
"Ran the rewrite '{}', found {} matches",
self.name,
matches.len()
);
let start = Instant::now();
for m in matches {
m.apply(&self.rhs, egraph);
}
let elapsed = start.elapsed();
debug!(
"Applied rewrite {} in {}.{:03}",
self.name,
elapsed.as_secs(),
elapsed.subsec_millis()
);
}
}
pub type WildMap<L> = HashMap<<L as Language>::Wildcard, Id>;
impl<L: Language> Pattern<L> {
pub fn search<M>(&self, egraph: &EGraph<L, M>) -> Vec<PatternMatches<L>> {
egraph
.classes()
.filter_map(|class| self.search_eclass(egraph, class.id))
.collect()
}
pub fn search_eclass<M>(&self, egraph: &EGraph<L, M>, eclass: Id) -> Option<PatternMatches<L>> {
let initial_mapping = HashMap::default();
let mappings = self.search_pat(0, initial_mapping, egraph, eclass);
if !mappings.is_empty() {
Some(PatternMatches { eclass, mappings })
} else {
None
}
}
fn search_pat<M>(
&self,
depth: usize,
mut var_mapping: WildMap<L>,
egraph: &EGraph<L, M>,
eclass: Id,
) -> Vec<WildMap<L>> {
let indent = " ".repeat(depth);
let pat_expr = match self {
Pattern::Wildcard(w) => {
match var_mapping.get(&w) {
None => {
var_mapping.insert(w.clone(), eclass);
}
Some(&prev_mapped_eclass) => {
if eclass != prev_mapped_eclass {
trace!("{} Failed to bind wildcard {:?}", indent, w);
return vec![];
}
}
}
trace!("{} Bound wildcard {:?} to {}", indent, w, eclass);
return vec![var_mapping];
}
Pattern::Expr(e) => e,
};
let mut new_mappings = Vec::new();
for e in egraph.get_eclass(eclass).iter() {
use Expr::*;
match (pat_expr, e) {
(Variable(pv), Variable(ev)) => {
if pv == ev {
new_mappings.push(var_mapping.clone())
}
}
(Constant(pc), Constant(ec)) => {
if pc == ec {
new_mappings.push(var_mapping.clone())
}
}
(Operator(po, pargs), Operator(eo, eargs)) => {
if po != eo {
continue;
}
if pat_expr.children().len() != e.children().len() {
panic!(
concat!(
"Different length children in pattern and expr\n",
" exp: {:?}\n",
" pat: {:?}"
),
pat_expr, e
);
}
let mut mappings1 = vec![];
let mut mappings2 = vec![var_mapping.clone()];
for (pa, ea) in pargs.iter().zip(eargs) {
std::mem::swap(&mut mappings1, &mut mappings2);
for m in mappings1.drain(..) {
mappings2.extend(pa.search_pat(depth + 1, m, egraph, *ea));
}
}
new_mappings.extend(mappings2);
}
_ => (),
}
}
trace!("{} Found {} mappings", indent, new_mappings.len());
new_mappings
}
}
#[derive(Debug)]
pub struct PatternMatches<L: Language> {
pub eclass: Id,
pub mappings: Vec<WildMap<L>>,
}
impl<L: Language> PatternMatches<L> {
pub fn apply<M: Metadata<L>>(
&self,
pattern: &Pattern<L>,
egraph: &mut EGraph<L, M>,
) -> Vec<Id> {
assert_ne!(self.mappings.len(), 0);
self.mappings
.iter()
.filter_map(|mapping| {
let before_size = egraph.total_size();
let pattern_root = self.apply_rec(0, pattern, egraph, mapping);
let leader = egraph.union(self.eclass, pattern_root.id);
if !pattern_root.was_there {
Some(leader)
} else {
let after_size = egraph.total_size();
assert_eq!(before_size, after_size);
None
}
})
.collect()
}
fn apply_rec<M: Metadata<L>>(
&self,
depth: usize,
pattern: &Pattern<L>,
egraph: &mut EGraph<L, M>,
mapping: &WildMap<L>,
) -> AddResult {
trace!("{}apply_rec {:2?}", " ".repeat(depth), pattern);
let result = match pattern {
Pattern::Wildcard(w) => AddResult {
was_there: true,
id: mapping[&w],
},
Pattern::Expr(e) => match e {
Expr::Constant(c) => egraph.add(Expr::Constant(c.clone())),
Expr::Variable(v) => egraph.add(Expr::Variable(v.clone())),
Expr::Operator(_, _) => {
let mut everything_was_there = true;
let n = e.clone().map_children(|arg| {
let add = self.apply_rec(depth + 1, &arg, egraph, mapping);
everything_was_there &= add.was_there;
add.id
});
trace!("{}adding: {:?}", " ".repeat(depth), n);
let mut op_add = egraph.add(n);
op_add.was_there &= everything_was_there;
op_add
}
},
};
trace!("{}result: {:?}", " ".repeat(depth), result);
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
expr::{
tests::{op, var, TestLang},
QuestionMarkName,
},
util::hashmap,
};
#[test]
fn simple_match() {
crate::init_logger();
let mut egraph = EGraph::<TestLang, ()>::default();
let x = egraph.add(var("x")).id;
let y = egraph.add(var("y")).id;
let plus = egraph.add(op("+", vec![x, y])).id;
let z = egraph.add(var("z")).id;
let w = egraph.add(var("w")).id;
let plus2 = egraph.add(op("+", vec![z, w])).id;
egraph.union(plus, plus2);
egraph.rebuild();
let a: QuestionMarkName = "?a".parse().unwrap();
let b: QuestionMarkName = "?b".parse().unwrap();
let commute_plus = crate::pattern::Rewrite {
name: "commute_plus".into(),
lhs: Pattern::Expr(op(
"+",
vec![Pattern::Wildcard(a.clone()), Pattern::Wildcard(b.clone())],
)),
rhs: Pattern::Expr(op(
"+",
vec![Pattern::Wildcard(b.clone()), Pattern::Wildcard(a.clone())],
)),
};
let eclass = egraph.just_find(plus);
let matches = commute_plus.lhs.search_eclass(&egraph, eclass).unwrap();
assert_eq!(matches.mappings.len(), 2);
let applications = matches.apply(&commute_plus.rhs, &mut egraph);
egraph.rebuild();
assert_eq!(applications.len(), 2);
let expected_mappings = vec![hashmap(&[(&a, x), (&b, y)]), hashmap(&[(&a, z), (&b, w)])];
if matches.mappings != expected_mappings {
let e0 = expected_mappings[0].clone();
let e1 = expected_mappings[1].clone();
assert_eq!(matches.mappings, vec![e1, e0])
}
info!("Here are the mappings!");
for m in &matches.mappings {
info!("mappings: {:?}", m);
}
egraph.dump_dot("simple-match.dot");
use crate::extract::Extractor;
let ext = Extractor::new(&egraph);
let best = ext.find_best(2);
eprintln!("Best: {:#?}", best.expr);
}
}