egg/
macros.rs

1#[allow(unused_imports)]
2use crate::*;
3
4/** A macro to easily create a [`Language`].
5
6`define_language` derives `Debug`, `PartialEq`, `Eq`, `PartialOrd`, `Ord`,
7`Hash`, and `Clone` on the given `enum` so it can implement [`Language`].
8The macro also implements [`Display`] and [`FromOp`] for the `enum`
9based on either the data of variants or the provided strings.
10
11The final variant **must have a trailing comma**; this is due to limitations in
12macro parsing.
13
14The language discriminant will use the cases of the enum (the enum discriminant).
15
16See [`LanguageChildren`] for acceptable types of children `Id`s.
17
18Note that you can always implement [`Language`] yourself by just not using this
19macro.
20
21Presently, the macro does not support data variant with children, but that may
22be added later.
23
24# Example
25
26The following macro invocation shows the the accepted forms of variants:
27```
28# use egg::*;
29define_language! {
30    enum SimpleLanguage {
31        // string variant with no children
32        "pi" = Pi,
33
34        // string variants with an array of child `Id`s (any static size)
35        // any type that implements LanguageChildren may be used here
36        "+" = Add([Id; 2]),
37        "-" = Sub([Id; 2]),
38        "*" = Mul([Id; 2]),
39
40        // can also do a variable number of children in a boxed slice
41        // this will only match if the lengths are the same
42        "list" = List(Box<[Id]>),
43
44        // string variants with a single child `Id`
45        // note that this is distinct from `Sub`, even though it has the same
46        // string, because it has a different number of children
47        "-"  = Neg(Id),
48
49        // data variants with a single field
50        // this field must implement `FromStr` and `Display`
51        Num(i32),
52        // language items are parsed in order, and we want symbol to
53        // be a fallback, so we put it last
54        Symbol(Symbol),
55        // This is the ultimate fallback, it will parse any operator (as a string)
56        // and any number of children.
57        // Note that if there were 0 children, the previous branch would have succeeded
58        Other(Symbol, Vec<Id>),
59    }
60}
61```
62
63It is also possible to define languages that are generic over some bounded type.
64You must use the `where`-like syntax below to specify the bounds on the type, they cannot go in the `enum` definition.
65You need at least the following bounds, since they are required by the [`Language`] trait.
66# Example
67```rust
68use egg::*;
69use std::{
70    fmt::{Debug, Display},
71    hash::Hash,
72    str::FromStr,
73};
74define_language! {
75    enum GenericLang<S, T> {
76        String(S),
77        Number(T),
78        "+" = Add([Id; 2]),
79        "-" = Sub([Id; 2]),
80        "/" = Div([Id; 2]),
81        "*" = Mult([Id; 2]),
82    }
83    where
84    S: Hash + Debug + Display + Clone + Eq + Ord + Hash + FromStr,
85    T: Hash + Debug + Display + Clone + Eq + Ord + Hash + FromStr,
86    // also required by the macro impl that parses S, T
87    <S as FromStr>::Err: Debug,
88    <T as FromStr>::Err: Debug,
89}
90```
91
92[`Display`]: std::fmt::Display
93**/
94#[macro_export]
95macro_rules! define_language {
96    ($(#[$meta:meta])* $vis:vis enum $name:ident
97     // annoying parsing hack to parse generic bounds https://stackoverflow.com/a/51580104
98    //  $(<$gen:ident $(, $($gen2:ident),*)?>)?
99     $(<$($gen:ident),*>)?
100     { $($variants:tt)* }
101     $(where $($where:tt)*)?) => {
102        $crate::__define_language!(
103            $(#[$meta])* $vis enum $name [$($($gen),*)?] { $($variants)* }
104            [$($($where)*)?]
105            -> {} {} {} {} {} {}
106        );
107    };
108}
109
110#[doc(hidden)]
111#[macro_export]
112macro_rules! __define_language {
113    ($(#[$meta:meta])* $vis:vis enum $name:ident [$($gen:ident),*] {}
114     [$($where:tt)*]
115     ->
116     $decl:tt {$($matches:tt)*} $children:tt $children_mut:tt
117     $display:tt {$($from_op:tt)*}
118    ) => {
119        $(#[$meta])*
120        #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
121        $vis enum $name <$($gen),*> $decl
122
123        impl<$($gen),*> $crate::Language for $name <$($gen),*> where $($where)* {
124            type Discriminant = std::mem::Discriminant<Self>;
125
126            #[inline(always)]
127            fn discriminant(&self) -> Self::Discriminant {
128                std::mem::discriminant(self)
129            }
130
131            #[inline(always)]
132            fn matches(&self, other: &Self) -> bool {
133                ::std::mem::discriminant(self) == ::std::mem::discriminant(other) &&
134                match (self, other) { $($matches)* _ => false }
135            }
136
137            fn children(&self) -> &[$crate::Id] { match self $children }
138            fn children_mut(&mut self) -> &mut [$crate::Id] { match self $children_mut }
139        }
140
141        impl<$($gen),*> ::std::fmt::Display for $name <$($gen),*> where $($where)* {
142            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
143                // We need to pass `f` to the match expression for hygiene
144                // reasons.
145                match (self, f) $display
146            }
147        }
148
149        impl<$($gen),*> $crate::FromOp for $name <$($gen),*> where $($where)* {
150            type Error = $crate::FromOpError;
151
152            fn from_op(op: &str, children: ::std::vec::Vec<$crate::Id>) -> ::std::result::Result<Self, Self::Error> {
153                match (op, children) {
154                    $($from_op)*
155                    (op, children) => Err($crate::FromOpError::new(op, children)),
156                }
157            }
158        }
159    };
160
161    ($(#[$meta:meta])* $vis:vis enum $name:ident [$($gen:ident),*]
162     {
163         $string:literal = $variant:ident,
164         $($variants:tt)*
165     }
166     [$($where:tt)*]
167     ->
168     { $($decl:tt)* } { $($matches:tt)* } { $($children:tt)* } { $($children_mut:tt)* }
169     { $($display:tt)* } { $($from_op:tt)* }
170    ) => {
171        $crate::__define_language!(
172            $(#[$meta])* $vis enum $name [$($gen),*]
173            { $($variants)* }
174            [$($where)*]
175             ->
176            { $($decl)*          $variant, }
177            { $($matches)*       ($name::$variant, $name::$variant) => true, }
178            { $($children)*      $name::$variant => &[], }
179            { $($children_mut)*  $name::$variant => &mut [], }
180            { $($display)*       ($name::$variant, f) => f.write_str($string), }
181            { $($from_op)*       ($string, children) if children.is_empty() => Ok($name::$variant), }
182        );
183    };
184
185    ($(#[$meta:meta])* $vis:vis enum $name:ident [$($gen:ident),*]
186     {
187         $string:literal = $variant:ident ($ids:ty),
188         $($variants:tt)*
189     }
190     [$($where:tt)*]
191     ->
192     { $($decl:tt)* } { $($matches:tt)* } { $($children:tt)* } { $($children_mut:tt)* }
193     { $($display:tt)* } { $($from_op:tt)* }
194    ) => {
195        $crate::__define_language!(
196            $(#[$meta])* $vis enum $name [$($gen),*]
197            { $($variants)* }
198            [$($where)*]
199            ->
200            { $($decl)*          $variant($ids), }
201            { $($matches)*       ($name::$variant(l), $name::$variant(r)) => $crate::LanguageChildren::len(l) == $crate::LanguageChildren::len(r), }
202            { $($children)*      $name::$variant(ids) => $crate::LanguageChildren::as_slice(ids), }
203            { $($children_mut)*  $name::$variant(ids) => $crate::LanguageChildren::as_mut_slice(ids), }
204            { $($display)*       ($name::$variant(..), f) => f.write_str($string), }
205            { $($from_op)*       (op, children) if op == $string && <$ids as $crate::LanguageChildren>::can_be_length(children.len()) => {
206                  let children = <$ids as $crate::LanguageChildren>::from_vec(children);
207                  Ok($name::$variant(children))
208              },
209            }
210        );
211    };
212
213    ($(#[$meta:meta])* $vis:vis enum $name:ident [$($gen:ident),*]
214     {
215         $variant:ident ($data:ty),
216         $($variants:tt)*
217     }
218     [$($where:tt)*]
219     ->
220     { $($decl:tt)* } { $($matches:tt)* } { $($children:tt)* } { $($children_mut:tt)* }
221     { $($display:tt)* } { $($from_op:tt)* }
222    ) => {
223        $crate::__define_language!(
224            $(#[$meta])* $vis enum $name [$($gen),*]
225            { $($variants)* }
226            [$($where)*]
227            ->
228            { $($decl)*          $variant($data), }
229            { $($matches)*       ($name::$variant(data1), $name::$variant(data2)) => data1 == data2, }
230            { $($children)*      $name::$variant(_data) => &[], }
231            { $($children_mut)*  $name::$variant(_data) => &mut [], }
232            { $($display)*       ($name::$variant(data), f) => ::std::fmt::Display::fmt(data, f), }
233            { $($from_op)*       (op, children) if op.parse::<$data>().is_ok() && children.is_empty() => Ok($name::$variant(op.parse().unwrap())), }
234        );
235    };
236
237    ($(#[$meta:meta])* $vis:vis enum $name:ident [$($gen:ident)*]
238     {
239         $variant:ident ($data:ty, $ids:ty),
240         $($variants:tt)*
241     }
242     [$($where:tt)*]
243     ->
244     { $($decl:tt)* } { $($matches:tt)* } { $($children:tt)* } { $($children_mut:tt)* }
245     { $($display:tt)* } { $($from_op:tt)* }
246    ) => {
247        $crate::__define_language!(
248            $(#[$meta])* $vis enum $name [$($gen)*]
249            { $($variants)* }
250            [$($where)*]
251            ->
252            { $($decl)*          $variant($data, $ids), }
253            { $($matches)*       ($name::$variant(d1, l), $name::$variant(d2, r)) => d1 == d2 && $crate::LanguageChildren::len(l) == $crate::LanguageChildren::len(r), }
254            { $($children)*      $name::$variant(_, ids) => $crate::LanguageChildren::as_slice(ids), }
255            { $($children_mut)*  $name::$variant(_, ids) => $crate::LanguageChildren::as_mut_slice(ids), }
256            { $($display)*       ($name::$variant(data, _), f) => ::std::fmt::Display::fmt(data, f), }
257            { $($from_op)*       (op, children) if op.parse::<$data>().is_ok() && <$ids as $crate::LanguageChildren>::can_be_length(children.len()) => {
258                  let data = op.parse::<$data>().unwrap();
259                  let children = <$ids as $crate::LanguageChildren>::from_vec(children);
260                  Ok($name::$variant(data, children))
261              },
262            }
263        );
264    };
265}
266
267/** A macro to easily make [`Rewrite`]s.
268
269The `rewrite!` macro greatly simplifies creating simple, purely
270syntactic rewrites while also allowing more complex ones.
271
272This panics if [`Rewrite::new`](Rewrite::new()) fails.
273
274The simplest form `rewrite!(a; b => c)` creates a [`Rewrite`]
275with name `a`, [`Searcher`] `b`, and [`Applier`] `c`.
276Note that in the `b` and `c` position, the macro only accepts a single
277token tree (see the [macros reference][macro] for more info).
278In short, that means you should pass in an identifier, literal, or
279something surrounded by parentheses or braces.
280
281If you pass in a literal to the `b` or `c` position, the macro will
282try to parse it as a [`Pattern`] which implements both [`Searcher`]
283and [`Applier`].
284
285The macro also accepts any number of `if <expr>` forms at the end,
286where the given expression should implement [`Condition`].
287For each of these, the macro will wrap the given applier in a
288[`ConditionalApplier`] with the given condition, with the first condition being
289the outermost, and the last condition being the innermost.
290
291# Example
292```
293# use egg::*;
294use std::borrow::Cow;
295use std::sync::Arc;
296define_language! {
297    enum SimpleLanguage {
298        Num(i32),
299        "+" = Add([Id; 2]),
300        "-" = Sub([Id; 2]),
301        "*" = Mul([Id; 2]),
302        "/" = Div([Id; 2]),
303    }
304}
305
306type EGraph = egg::EGraph<SimpleLanguage, ()>;
307
308let mut rules: Vec<Rewrite<SimpleLanguage, ()>> = vec![
309    rewrite!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"),
310    rewrite!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"),
311
312    rewrite!("mul-0"; "(* ?a 0)" => "0"),
313
314    rewrite!("silly"; "(* ?a 1)" => { MySillyApplier("foo") }),
315
316    rewrite!("something_conditional";
317             "(/ ?a ?b)" => "(* ?a (/ 1 ?b))"
318             if is_not_zero("?b")),
319];
320
321// rewrite! supports bidirectional rules too
322// it returns a Vec of length 2, so you need to concat
323rules.extend(vec![
324    rewrite!("add-0"; "(+ ?a 0)" <=> "?a"),
325    rewrite!("mul-1"; "(* ?a 1)" <=> "?a"),
326].concat());
327
328#[derive(Debug)]
329struct MySillyApplier(&'static str);
330impl Applier<SimpleLanguage, ()> for MySillyApplier {
331    fn apply_one(&self, _: &mut EGraph, _: Id, _: &Subst, _: Option<&PatternAst<SimpleLanguage>>, _: Symbol) -> Vec<Id> {
332        panic!()
333    }
334}
335
336// This returns a function that implements Condition
337fn is_not_zero(var: &'static str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
338    let var = var.parse().unwrap();
339    let zero = SimpleLanguage::Num(0);
340    // note this check is just an example,
341    // checking for the absence of 0 is insufficient since 0 could be merged in later
342    // see https://github.com/egraphs-good/egg/issues/297
343    move |egraph, _, subst| !egraph[subst[var]].nodes.contains(&zero)
344}
345```
346
347[macro]: https://doc.rust-lang.org/stable/reference/macros-by-example.html#metavariables
348**/
349#[macro_export]
350macro_rules! rewrite {
351    (
352        $name:expr;
353        $lhs:tt => $rhs:tt
354        $(if $cond:expr)*
355    )  => {{
356        let searcher = $crate::__rewrite!(@parse Pattern $lhs);
357        let core_applier = $crate::__rewrite!(@parse Pattern $rhs);
358        let applier = $crate::__rewrite!(@applier core_applier; $($cond,)*);
359        $crate::Rewrite::new($name.to_string(), searcher, applier).unwrap()
360    }};
361    (
362        $name:expr;
363        $lhs:tt <=> $rhs:tt
364        $(if $cond:expr)*
365    )  => {{
366        let name = $name;
367        let name2 = String::from(name.clone()) + "-rev";
368        vec![
369            $crate::rewrite!(name;  $lhs => $rhs $(if $cond)*),
370            $crate::rewrite!(name2; $rhs => $lhs $(if $cond)*)
371        ]
372    }};
373}
374
375/** A macro to easily make [`Rewrite`]s using [`MultiPattern`]s.
376
377Similar to the [`rewrite!`] macro,
378this macro uses the form `multi_rewrite!(name; multipattern => multipattern)`.
379String literals will be parsed a [`MultiPattern`]s.
380
381**/
382#[macro_export]
383macro_rules! multi_rewrite {
384    // limited multipattern support
385    (
386        $name:expr;
387        $lhs:tt => $rhs:tt
388    )  => {{
389        let searcher = $crate::__rewrite!(@parse MultiPattern $lhs);
390        let applier = $crate::__rewrite!(@parse MultiPattern $rhs);
391        $crate::Rewrite::new($name.to_string(), searcher, applier).unwrap()
392    }};
393}
394
395#[doc(hidden)]
396#[macro_export]
397macro_rules! __rewrite {
398    (@parse $t:ident $rhs:literal) => {
399        $rhs.parse::<$crate::$t<_>>().unwrap()
400    };
401    (@parse $t:ident $rhs:expr) => { $rhs };
402    (@applier $applier:expr;) => { $applier };
403    (@applier $applier:expr; $cond:expr, $($conds:expr,)*) => {
404        $crate::ConditionalApplier {
405            condition: $cond,
406            applier: $crate::__rewrite!(@applier $applier; $($conds,)*)
407        }
408    };
409}
410
411#[cfg(test)]
412mod tests {
413
414    use crate::*;
415
416    define_language! {
417        enum Simple {
418            "+" = Add([Id; 2]),
419            "-" = Sub([Id; 2]),
420            "*" = Mul([Id; 2]),
421            "-" = Neg(Id),
422            "list" = List(Box<[Id]>),
423            "pi" = Pi,
424            Int(i32),
425            Var(Symbol),
426        }
427    }
428
429    #[test]
430    fn modify_children() {
431        let mut add = Simple::Add([0.into(), 0.into()]);
432        add.for_each_mut(|id| *id = 1.into());
433        assert_eq!(add, Simple::Add([1.into(), 1.into()]));
434    }
435
436    #[test]
437    fn some_rewrites() {
438        let mut rws: Vec<Rewrite<Simple, ()>> = vec![
439            // here it should parse the rhs
440            rewrite!("rule"; "cons" => "f"),
441            // here it should just accept the rhs without trying to parse
442            rewrite!("rule"; "f" => { "pat".parse::<Pattern<_>>().unwrap() }),
443        ];
444        rws.extend(rewrite!("two-way"; "foo" <=> "bar"));
445    }
446
447    #[test]
448    #[should_panic(expected = "refers to unbound var ?x")]
449    fn rewrite_simple_panic() {
450        let _: Rewrite<Simple, ()> = rewrite!("bad"; "?a" => "?x");
451    }
452
453    #[test]
454    #[should_panic(expected = "refers to unbound var ?x")]
455    fn rewrite_conditional_panic() {
456        let x: Pattern<Simple> = "?x".parse().unwrap();
457        let _: Rewrite<Simple, ()> = rewrite!(
458            "bad"; "?a" => "?a" if ConditionEqual::new(x.clone(), x)
459        );
460    }
461}