use std::fmt;
use std::rc::Rc;
use crate::{Analysis, EGraph, Id, Language, Pattern, SearchMatches, Subst, Var};
#[derive(Clone)]
#[non_exhaustive]
pub struct Rewrite<L, N> {
name: String,
long_name: String,
searcher: Rc<dyn Searcher<L, N>>,
applier: Rc<dyn Applier<L, N>>,
}
impl<L, N> fmt::Debug for Rewrite<L, N>
where
L: Language + 'static,
N: 'static,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct DisplayAsDebug<T>(T);
impl<T: fmt::Display> fmt::Debug for DisplayAsDebug<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
use std::any::Any;
let mut d = f.debug_struct("Rewrite");
d.field("name", &self.name)
.field("long_name", &self.long_name);
if let Some(pat) = Any::downcast_ref::<Pattern<L>>(&self.searcher) {
d.field("searcher", &DisplayAsDebug(pat));
} else {
d.field("searcher", &"<< searcher >>");
}
if let Some(pat) = Any::downcast_ref::<Pattern<L>>(&self.applier) {
d.field("applier", &DisplayAsDebug(pat));
} else {
d.field("applier", &"<< applier >>");
}
d.finish()
}
}
impl<L, N> Rewrite<L, N> {
pub fn name(&self) -> &str {
&self.name
}
pub fn long_name(&self) -> &str {
&self.long_name
}
}
impl<L: Language, N: Analysis<L>> Rewrite<L, N> {
pub fn new(
name: impl Into<String>,
long_name: impl Into<String>,
searcher: impl Searcher<L, N> + 'static,
applier: impl Applier<L, N> + 'static,
) -> Result<Self, String> {
let name = name.into();
let long_name = long_name.into();
let searcher = Rc::new(searcher);
let applier = Rc::new(applier);
let bound_vars = searcher.vars();
for v in applier.vars() {
if !bound_vars.contains(&v) {
return Err(format!("Rewrite {} refers to unbound var {}", name, v));
}
}
Ok(Self {
name,
long_name,
searcher,
applier,
})
}
pub fn search(&self, egraph: &EGraph<L, N>) -> Vec<SearchMatches> {
self.searcher.search(egraph)
}
pub fn apply(&self, egraph: &mut EGraph<L, N>, matches: &[SearchMatches]) -> Vec<Id> {
self.applier.apply_matches(egraph, matches)
}
#[cfg(test)]
pub(crate) fn run(&self, egraph: &mut EGraph<L, N>) -> Vec<Id> {
let start = instant::Instant::now();
let matches = self.search(egraph);
log::debug!("Found rewrite {} {} times", self.name, matches.len());
let ids = self.apply(egraph, &matches);
let elapsed = start.elapsed();
log::debug!(
"Applied rewrite {} {} times in {}.{:03}",
self.name,
ids.len(),
elapsed.as_secs(),
elapsed.subsec_millis()
);
egraph.rebuild();
ids
}
}
pub trait Searcher<L, N>
where
L: Language,
N: Analysis<L>,
{
fn search_eclass(&self, egraph: &EGraph<L, N>, eclass: Id) -> Option<SearchMatches>;
fn search(&self, egraph: &EGraph<L, N>) -> Vec<SearchMatches> {
egraph
.classes()
.filter_map(|e| self.search_eclass(egraph, e.id))
.collect()
}
fn vars(&self) -> Vec<Var>;
}
pub trait Applier<L, N>
where
L: Language,
N: Analysis<L>,
{
fn apply_matches(&self, egraph: &mut EGraph<L, N>, matches: &[SearchMatches]) -> Vec<Id> {
let mut added = vec![];
for mat in matches {
for subst in &mat.substs {
let ids = self
.apply_one(egraph, mat.eclass, subst)
.into_iter()
.filter_map(|id| {
let (to, did_something) = egraph.union(id, mat.eclass);
if did_something {
Some(to)
} else {
None
}
});
added.extend(ids)
}
}
added
}
fn apply_one(&self, egraph: &mut EGraph<L, N>, eclass: Id, subst: &Subst) -> Vec<Id>;
fn vars(&self) -> Vec<Var> {
vec![]
}
}
#[derive(Clone, Debug)]
pub struct ConditionalApplier<C, A> {
pub condition: C,
pub applier: A,
}
impl<C, A, N, L> Applier<L, N> for ConditionalApplier<C, A>
where
L: Language,
C: Condition<L, N>,
A: Applier<L, N>,
N: Analysis<L>,
{
fn apply_one(&self, egraph: &mut EGraph<L, N>, eclass: Id, subst: &Subst) -> Vec<Id> {
if self.condition.check(egraph, eclass, subst) {
self.applier.apply_one(egraph, eclass, subst)
} else {
vec![]
}
}
fn vars(&self) -> Vec<Var> {
let mut vars = self.applier.vars();
vars.extend(self.condition.vars());
vars
}
}
pub trait Condition<L, N>
where
L: Language,
N: Analysis<L>,
{
fn check(&self, egraph: &mut EGraph<L, N>, eclass: Id, subst: &Subst) -> bool;
fn vars(&self) -> Vec<Var> {
vec![]
}
}
impl<L, F, N> Condition<L, N> for F
where
L: Language,
N: Analysis<L>,
F: Fn(&mut EGraph<L, N>, Id, &Subst) -> bool,
{
fn check(&self, egraph: &mut EGraph<L, N>, eclass: Id, subst: &Subst) -> bool {
self(egraph, eclass, subst)
}
}
pub struct ConditionEqual<A1, A2>(pub A1, pub A2);
impl<L: Language> ConditionEqual<Pattern<L>, Pattern<L>> {
pub fn parse(a1: &str, a2: &str) -> Self {
Self(a1.parse().unwrap(), a2.parse().unwrap())
}
}
impl<L, N, A1, A2> Condition<L, N> for ConditionEqual<A1, A2>
where
L: Language,
N: Analysis<L>,
A1: Applier<L, N>,
A2: Applier<L, N>,
{
fn check(&self, egraph: &mut EGraph<L, N>, eclass: Id, subst: &Subst) -> bool {
let a1 = self.0.apply_one(egraph, eclass, subst);
let a2 = self.1.apply_one(egraph, eclass, subst);
assert_eq!(a1.len(), 1);
assert_eq!(a2.len(), 1);
a1[0] == a2[0]
}
fn vars(&self) -> Vec<Var> {
let mut vars = self.0.vars();
vars.extend(self.1.vars());
vars
}
}
#[cfg(test)]
mod tests {
use crate::{SymbolLang as S, *};
use std::str::FromStr;
type EGraph = crate::EGraph<S, ()>;
#[test]
fn conditional_rewrite() {
crate::init_logger();
let mut egraph = EGraph::default();
let x = egraph.add(S::leaf("x"));
let y = egraph.add(S::leaf("2"));
let mul = egraph.add(S::new("*", vec![x, y]));
let true_pat = Pattern::from_str("TRUE").unwrap();
let true_id = egraph.add(S::leaf("TRUE"));
let pow2b = Pattern::from_str("(is-power2 ?b)").unwrap();
let mul_to_shift = rewrite!(
"mul_to_shift";
"(* ?a ?b)" => "(>> ?a (log2 ?b))"
if ConditionEqual(pow2b, true_pat)
);
println!("rewrite shouldn't do anything yet");
egraph.rebuild();
let apps = mul_to_shift.run(&mut egraph);
assert!(apps.is_empty());
println!("Add the needed equality");
let two_ispow2 = egraph.add(S::new("is-power2", vec![y]));
egraph.union(two_ispow2, true_id);
println!("Should fire now");
egraph.rebuild();
let apps = mul_to_shift.run(&mut egraph);
assert_eq!(apps, vec![egraph.find(mul)]);
}
#[test]
fn fn_rewrite() {
crate::init_logger();
let mut egraph = EGraph::default();
let start = RecExpr::from_str("(+ x y)").unwrap();
let goal = RecExpr::from_str("xy").unwrap();
let root = egraph.add_expr(&start);
fn get(egraph: &EGraph, id: Id) -> Symbol {
egraph[id].nodes[0].op
}
#[derive(Debug)]
struct Appender;
impl Applier<SymbolLang, ()> for Appender {
fn apply_one(&self, egraph: &mut EGraph, _eclass: Id, subst: &Subst) -> Vec<Id> {
let a: Var = "?a".parse().unwrap();
let b: Var = "?b".parse().unwrap();
let a = get(&egraph, subst[a]);
let b = get(&egraph, subst[b]);
let s = format!("{}{}", a, b);
vec![egraph.add(S::leaf(&s))]
}
}
let fold_add = rewrite!(
"fold_add"; "(+ ?a ?b)" => { Appender }
);
egraph.rebuild();
fold_add.run(&mut egraph);
assert_eq!(egraph.equivs(&start, &goal), vec![egraph.find(root)]);
}
}