use pattern::apply_pat;
use std::fmt::{self, Debug, Display};
use std::sync::Arc;
use crate::*;
#[derive(Clone)]
#[non_exhaustive]
pub struct Rewrite<L, N> {
pub name: Symbol,
pub searcher: Arc<dyn Searcher<L, N> + Sync + Send>,
pub applier: Arc<dyn Applier<L, N> + Sync + Send>,
}
impl<L, N> Debug for Rewrite<L, N>
where
L: Language + Display + 'static,
N: Analysis<L> + 'static,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut d = f.debug_struct("Rewrite");
d.field("name", &self.name);
if let Some(pat) = self.searcher.get_pattern_ast() {
d.field("searcher", &DisplayAsDebug(pat));
} else {
d.field("searcher", &"<< searcher >>");
}
if let Some(pat) = self.applier.get_pattern_ast() {
d.field("applier", &DisplayAsDebug(pat));
} else {
d.field("applier", &"<< applier >>");
}
d.finish()
}
}
impl<L: Language, N: Analysis<L>> Rewrite<L, N> {
pub fn new(
name: impl Into<Symbol>,
searcher: impl Searcher<L, N> + Send + Sync + 'static,
applier: impl Applier<L, N> + Send + Sync + 'static,
) -> Result<Self, String> {
let name = name.into();
let searcher = Arc::new(searcher);
let applier = Arc::new(applier);
let bound_vars = searcher.vars();
for v in applier.vars() {
if !bound_vars.contains(&v) {
return Err(format!("Rewrite {} refers to unbound var {}", name, v));
}
}
Ok(Self {
name,
searcher,
applier,
})
}
pub fn search(&self, egraph: &EGraph<L, N>) -> Vec<SearchMatches<L>> {
self.searcher.search(egraph)
}
pub fn search_with_limit(&self, egraph: &EGraph<L, N>, limit: usize) -> Vec<SearchMatches<L>> {
self.searcher.search_with_limit(egraph, limit)
}
pub fn apply(&self, egraph: &mut EGraph<L, N>, matches: &[SearchMatches<L>]) -> Vec<Id> {
self.applier.apply_matches(egraph, matches, self.name)
}
#[cfg(test)]
pub(crate) fn run(&self, egraph: &mut EGraph<L, N>) -> Vec<Id> {
let start = crate::util::Instant::now();
let matches = self.search(egraph);
log::debug!("Found rewrite {} {} times", self.name, matches.len());
let ids = self.apply(egraph, &matches);
let elapsed = start.elapsed();
log::debug!(
"Applied rewrite {} {} times in {}.{:03}",
self.name,
ids.len(),
elapsed.as_secs(),
elapsed.subsec_millis()
);
egraph.rebuild();
ids
}
}
pub(crate) fn search_eclasses_with_limit<'a, I, S, L, N>(
searcher: &'a S,
egraph: &EGraph<L, N>,
eclasses: I,
mut limit: usize,
) -> Vec<SearchMatches<'a, L>>
where
L: Language,
N: Analysis<L>,
S: Searcher<L, N> + ?Sized,
I: IntoIterator<Item = Id>,
{
let mut ms = vec![];
for eclass in eclasses {
if limit == 0 {
break;
}
match searcher.search_eclass_with_limit(egraph, eclass, limit) {
None => continue,
Some(m) => {
let len = m.substs.len();
assert!(len <= limit);
limit -= len;
ms.push(m);
}
}
}
ms
}
pub trait Searcher<L, N>
where
L: Language,
N: Analysis<L>,
{
fn search_eclass(&self, egraph: &EGraph<L, N>, eclass: Id) -> Option<SearchMatches<L>> {
self.search_eclass_with_limit(egraph, eclass, usize::MAX)
}
fn search_eclass_with_limit(
&self,
egraph: &EGraph<L, N>,
eclass: Id,
limit: usize,
) -> Option<SearchMatches<L>>;
fn search(&self, egraph: &EGraph<L, N>) -> Vec<SearchMatches<L>> {
self.search_with_limit(egraph, usize::MAX)
}
fn search_with_limit(&self, egraph: &EGraph<L, N>, limit: usize) -> Vec<SearchMatches<L>> {
search_eclasses_with_limit(self, egraph, egraph.classes().map(|e| e.id), limit)
}
fn n_matches(&self, egraph: &EGraph<L, N>) -> usize {
self.search(egraph).iter().map(|m| m.substs.len()).sum()
}
fn get_pattern_ast(&self) -> Option<&PatternAst<L>> {
None
}
fn vars(&self) -> Vec<Var>;
}
pub trait Applier<L, N>
where
L: Language,
N: Analysis<L>,
{
fn apply_matches(
&self,
egraph: &mut EGraph<L, N>,
matches: &[SearchMatches<L>],
rule_name: Symbol,
) -> Vec<Id> {
let mut added = vec![];
for mat in matches {
let ast = if egraph.are_explanations_enabled() {
mat.ast.as_ref().map(|cow| cow.as_ref())
} else {
None
};
for subst in &mat.substs {
let ids = self.apply_one(egraph, mat.eclass, subst, ast, rule_name);
added.extend(ids)
}
}
added
}
fn get_pattern_ast(&self) -> Option<&PatternAst<L>> {
None
}
fn apply_one(
&self,
egraph: &mut EGraph<L, N>,
eclass: Id,
subst: &Subst,
searcher_ast: Option<&PatternAst<L>>,
rule_name: Symbol,
) -> Vec<Id>;
fn vars(&self) -> Vec<Var> {
vec![]
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ConditionalApplier<C, A> {
pub condition: C,
pub applier: A,
}
impl<C, A, N, L> Applier<L, N> for ConditionalApplier<C, A>
where
L: Language,
C: Condition<L, N>,
A: Applier<L, N>,
N: Analysis<L>,
{
fn get_pattern_ast(&self) -> Option<&PatternAst<L>> {
self.applier.get_pattern_ast()
}
fn apply_one(
&self,
egraph: &mut EGraph<L, N>,
eclass: Id,
subst: &Subst,
searcher_ast: Option<&PatternAst<L>>,
rule_name: Symbol,
) -> Vec<Id> {
if self.condition.check(egraph, eclass, subst) {
self.applier
.apply_one(egraph, eclass, subst, searcher_ast, rule_name)
} else {
vec![]
}
}
fn vars(&self) -> Vec<Var> {
let mut vars = self.applier.vars();
vars.extend(self.condition.vars());
vars
}
}
pub trait Condition<L, N>
where
L: Language,
N: Analysis<L>,
{
fn check(&self, egraph: &mut EGraph<L, N>, eclass: Id, subst: &Subst) -> bool;
fn vars(&self) -> Vec<Var> {
vec![]
}
}
impl<L, F, N> Condition<L, N> for F
where
L: Language,
N: Analysis<L>,
F: Fn(&mut EGraph<L, N>, Id, &Subst) -> bool,
{
fn check(&self, egraph: &mut EGraph<L, N>, eclass: Id, subst: &Subst) -> bool {
self(egraph, eclass, subst)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ConditionEqual<L> {
p1: Pattern<L>,
p2: Pattern<L>,
}
impl<L: Language> ConditionEqual<L> {
pub fn new(p1: Pattern<L>, p2: Pattern<L>) -> Self {
ConditionEqual { p1, p2 }
}
}
impl<L: FromOp> ConditionEqual<L> {
pub fn parse(a1: &str, a2: &str) -> Self {
Self {
p1: a1.parse().unwrap(),
p2: a2.parse().unwrap(),
}
}
}
impl<L, N> Condition<L, N> for ConditionEqual<L>
where
L: Language,
N: Analysis<L>,
{
fn check(&self, egraph: &mut EGraph<L, N>, _eclass: Id, subst: &Subst) -> bool {
let mut id_buf_1 = vec![0.into(); self.p1.ast.len()];
let mut id_buf_2 = vec![0.into(); self.p2.ast.len()];
let a1 = apply_pat(&mut id_buf_1, &self.p1.ast, egraph, subst);
let a2 = apply_pat(&mut id_buf_2, &self.p2.ast, egraph, subst);
a1 == a2
}
fn vars(&self) -> Vec<Var> {
let mut vars = self.p1.vars();
vars.extend(self.p2.vars());
vars
}
}
#[cfg(test)]
mod tests {
use crate::{SymbolLang as S, *};
use std::str::FromStr;
type EGraph = crate::EGraph<S, ()>;
#[test]
fn conditional_rewrite() {
crate::init_logger();
let mut egraph = EGraph::default();
let x = egraph.add(S::leaf("x"));
let y = egraph.add(S::leaf("2"));
let mul = egraph.add(S::new("*", vec![x, y]));
let true_pat = Pattern::from_str("TRUE").unwrap();
egraph.add(S::leaf("TRUE"));
let pow2b = Pattern::from_str("(is-power2 ?b)").unwrap();
let mul_to_shift = rewrite!(
"mul_to_shift";
"(* ?a ?b)" => "(>> ?a (log2 ?b))"
if ConditionEqual::new(pow2b, true_pat)
);
println!("rewrite shouldn't do anything yet");
egraph.rebuild();
let apps = mul_to_shift.run(&mut egraph);
assert!(apps.is_empty());
println!("Add the needed equality");
egraph.union_instantiations(
&"(is-power2 2)".parse().unwrap(),
&"TRUE".parse().unwrap(),
&Default::default(),
"direct-union".to_string(),
);
println!("Should fire now");
egraph.rebuild();
let apps = mul_to_shift.run(&mut egraph);
assert_eq!(apps, vec![egraph.find(mul)]);
}
#[test]
fn fn_rewrite() {
crate::init_logger();
let mut egraph = EGraph::default();
let start = RecExpr::from_str("(+ x y)").unwrap();
let goal = RecExpr::from_str("xy").unwrap();
let root = egraph.add_expr(&start);
fn get(egraph: &EGraph, id: Id) -> Symbol {
egraph[id].nodes[0].op
}
#[derive(Debug)]
struct Appender {
_rhs: PatternAst<S>,
}
impl Applier<SymbolLang, ()> for Appender {
fn apply_one(
&self,
egraph: &mut EGraph,
eclass: Id,
subst: &Subst,
searcher_ast: Option<&PatternAst<SymbolLang>>,
rule_name: Symbol,
) -> Vec<Id> {
let a: Var = "?a".parse().unwrap();
let b: Var = "?b".parse().unwrap();
let a = get(egraph, subst[a]);
let b = get(egraph, subst[b]);
let s = format!("{}{}", a, b);
if let Some(ast) = searcher_ast {
let (id, did_something) = egraph.union_instantiations(
ast,
&PatternAst::from_str(&s).unwrap(),
subst,
rule_name,
);
if did_something {
vec![id]
} else {
vec![]
}
} else {
let added = egraph.add(S::leaf(&s));
if egraph.union(added, eclass) {
vec![eclass]
} else {
vec![]
}
}
}
}
let fold_add = rewrite!(
"fold_add"; "(+ ?a ?b)" => { Appender { _rhs: "?a".parse().unwrap()}}
);
egraph.rebuild();
fold_add.run(&mut egraph);
assert_eq!(egraph.equivs(&start, &goal), vec![egraph.find(root)]);
}
}