Macro egg::rewrite

source ·
macro_rules! rewrite {
    (
        $name:expr;
        $lhs:tt => $rhs:tt
        $(if $cond:expr)*
    ) => { ... };
    (
        $name:expr;
        $lhs:tt <=> $rhs:tt
        $(if $cond:expr)*
    ) => { ... };
}
Expand description

A macro to easily make Rewrites.

The rewrite! macro greatly simplifies creating simple, purely syntactic rewrites while also allowing more complex ones.

This panics if Rewrite::new fails.

The simplest form rewrite!(a; b => c) creates a Rewrite with name a, Searcher b, and Applier c. Note that in the b and c position, the macro only accepts a single token tree (see the macros reference for more info). In short, that means you should pass in an identifier, literal, or something surrounded by parentheses or braces.

If you pass in a literal to the b or c position, the macro will try to parse it as a Pattern which implements both Searcher and Applier.

The macro also accepts any number of if <expr> forms at the end, where the given expression should implement Condition. For each of these, the macro will wrap the given applier in a ConditionalApplier with the given condition, with the first condition being the outermost, and the last condition being the innermost.

Example

use std::borrow::Cow;
use std::sync::Arc;
define_language! {
    enum SimpleLanguage {
        Num(i32),
        "+" = Add([Id; 2]),
        "-" = Sub([Id; 2]),
        "*" = Mul([Id; 2]),
        "/" = Div([Id; 2]),
    }
}

type EGraph = egg::EGraph<SimpleLanguage, ()>;

let mut rules: Vec<Rewrite<SimpleLanguage, ()>> = vec![
    rewrite!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"),
    rewrite!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"),

    rewrite!("mul-0"; "(* ?a 0)" => "0"),

    rewrite!("silly"; "(* ?a 1)" => { MySillyApplier("foo") }),

    rewrite!("something_conditional";
             "(/ ?a ?b)" => "(* ?a (/ 1 ?b))"
             if is_not_zero("?b")),
];

// rewrite! supports bidirectional rules too
// it returns a Vec of length 2, so you need to concat
rules.extend(vec![
    rewrite!("add-0"; "(+ ?a 0)" <=> "?a"),
    rewrite!("mul-1"; "(* ?a 1)" <=> "?a"),
].concat());

#[derive(Debug)]
struct MySillyApplier(&'static str);
impl Applier<SimpleLanguage, ()> for MySillyApplier {
    fn apply_one(&self, _: &mut EGraph, _: Id, _: &Subst, _: Option<&PatternAst<SimpleLanguage>>, _: Symbol) -> Vec<Id> {
        panic!()
    }
}

// This returns a function that implements Condition
fn is_not_zero(var: &'static str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
    let var = var.parse().unwrap();
    let zero = SimpleLanguage::Num(0);
    move |egraph, _, subst| !egraph[subst[var]].nodes.contains(&zero)
}