egg/macros.rs
1#[allow(unused_imports)]
2use crate::*;
3
4/** A macro to easily create a [`Language`].
5
6`define_language` derives `Debug`, `PartialEq`, `Eq`, `PartialOrd`, `Ord`,
7`Hash`, and `Clone` on the given `enum` so it can implement [`Language`].
8The macro also implements [`Display`] and [`FromOp`] for the `enum`
9based on either the data of variants or the provided strings.
10
11The final variant **must have a trailing comma**; this is due to limitations in
12macro parsing.
13
14The language discriminant will use the cases of the enum (the enum discriminant).
15
16See [`LanguageChildren`] for acceptable types of children `Id`s.
17
18Note that you can always implement [`Language`] yourself by just not using this
19macro.
20
21Presently, the macro does not support data variant with children, but that may
22be added later.
23
24# Example
25
26The following macro invocation shows the the accepted forms of variants:
27```
28# use egg::*;
29define_language! {
30 enum SimpleLanguage {
31 // string variant with no children
32 "pi" = Pi,
33
34 // string variants with an array of child `Id`s (any static size)
35 // any type that implements LanguageChildren may be used here
36 "+" = Add([Id; 2]),
37 "-" = Sub([Id; 2]),
38 "*" = Mul([Id; 2]),
39
40 // can also do a variable number of children in a boxed slice
41 // this will only match if the lengths are the same
42 "list" = List(Box<[Id]>),
43
44 // string variants with a single child `Id`
45 // note that this is distinct from `Sub`, even though it has the same
46 // string, because it has a different number of children
47 "-" = Neg(Id),
48
49 // data variants with a single field
50 // this field must implement `FromStr` and `Display`
51 Num(i32),
52 // language items are parsed in order, and we want symbol to
53 // be a fallback, so we put it last
54 Symbol(Symbol),
55 // This is the ultimate fallback, it will parse any operator (as a string)
56 // and any number of children.
57 // Note that if there were 0 children, the previous branch would have succeeded
58 Other(Symbol, Vec<Id>),
59 }
60}
61```
62
63It is also possible to define languages that are generic over some bounded type.
64You must use the `where`-like syntax below to specify the bounds on the type, they cannot go in the `enum` definition.
65You need at least the following bounds, since they are required by the [`Language`] trait.
66# Example
67```rust
68use egg::*;
69use std::{
70 fmt::{Debug, Display},
71 hash::Hash,
72 str::FromStr,
73};
74define_language! {
75 enum GenericLang<S, T> {
76 String(S),
77 Number(T),
78 "+" = Add([Id; 2]),
79 "-" = Sub([Id; 2]),
80 "/" = Div([Id; 2]),
81 "*" = Mult([Id; 2]),
82 }
83 where
84 S: Hash + Debug + Display + Clone + Eq + Ord + Hash + FromStr,
85 T: Hash + Debug + Display + Clone + Eq + Ord + Hash + FromStr,
86 // also required by the macro impl that parses S, T
87 <S as FromStr>::Err: Debug,
88 <T as FromStr>::Err: Debug,
89}
90```
91
92[`Display`]: std::fmt::Display
93**/
94#[macro_export]
95macro_rules! define_language {
96 ($(#[$meta:meta])* $vis:vis enum $name:ident
97 // annoying parsing hack to parse generic bounds https://stackoverflow.com/a/51580104
98 // $(<$gen:ident $(, $($gen2:ident),*)?>)?
99 $(<$($gen:ident),*>)?
100 { $($variants:tt)* }
101 $(where $($where:tt)*)?) => {
102 $crate::__define_language!(
103 $(#[$meta])* $vis enum $name [$($($gen),*)?] { $($variants)* }
104 [$($($where)*)?]
105 -> {} {} {} {} {} {}
106 );
107 };
108}
109
110#[doc(hidden)]
111#[macro_export]
112macro_rules! __define_language {
113 ($(#[$meta:meta])* $vis:vis enum $name:ident [$($gen:ident),*] {}
114 [$($where:tt)*]
115 ->
116 $decl:tt {$($matches:tt)*} $children:tt $children_mut:tt
117 $display:tt {$($from_op:tt)*}
118 ) => {
119 $(#[$meta])*
120 #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
121 $vis enum $name <$($gen),*> $decl
122
123 impl<$($gen),*> $crate::Language for $name <$($gen),*> where $($where)* {
124 type Discriminant = std::mem::Discriminant<Self>;
125
126 #[inline(always)]
127 fn discriminant(&self) -> Self::Discriminant {
128 std::mem::discriminant(self)
129 }
130
131 #[inline(always)]
132 fn matches(&self, other: &Self) -> bool {
133 ::std::mem::discriminant(self) == ::std::mem::discriminant(other) &&
134 match (self, other) { $($matches)* _ => false }
135 }
136
137 fn children(&self) -> &[$crate::Id] { match self $children }
138 fn children_mut(&mut self) -> &mut [$crate::Id] { match self $children_mut }
139 }
140
141 impl<$($gen),*> ::std::fmt::Display for $name <$($gen),*> where $($where)* {
142 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
143 // We need to pass `f` to the match expression for hygiene
144 // reasons.
145 match (self, f) $display
146 }
147 }
148
149 impl<$($gen),*> $crate::FromOp for $name <$($gen),*> where $($where)* {
150 type Error = $crate::FromOpError;
151
152 fn from_op(op: &str, children: ::std::vec::Vec<$crate::Id>) -> ::std::result::Result<Self, Self::Error> {
153 match (op, children) {
154 $($from_op)*
155 (op, children) => Err($crate::FromOpError::new(op, children)),
156 }
157 }
158 }
159 };
160
161 ($(#[$meta:meta])* $vis:vis enum $name:ident [$($gen:ident),*]
162 {
163 $string:literal = $variant:ident,
164 $($variants:tt)*
165 }
166 [$($where:tt)*]
167 ->
168 { $($decl:tt)* } { $($matches:tt)* } { $($children:tt)* } { $($children_mut:tt)* }
169 { $($display:tt)* } { $($from_op:tt)* }
170 ) => {
171 $crate::__define_language!(
172 $(#[$meta])* $vis enum $name [$($gen),*]
173 { $($variants)* }
174 [$($where)*]
175 ->
176 { $($decl)* $variant, }
177 { $($matches)* ($name::$variant, $name::$variant) => true, }
178 { $($children)* $name::$variant => &[], }
179 { $($children_mut)* $name::$variant => &mut [], }
180 { $($display)* ($name::$variant, f) => f.write_str($string), }
181 { $($from_op)* ($string, children) if children.is_empty() => Ok($name::$variant), }
182 );
183 };
184
185 ($(#[$meta:meta])* $vis:vis enum $name:ident [$($gen:ident),*]
186 {
187 $string:literal = $variant:ident ($ids:ty),
188 $($variants:tt)*
189 }
190 [$($where:tt)*]
191 ->
192 { $($decl:tt)* } { $($matches:tt)* } { $($children:tt)* } { $($children_mut:tt)* }
193 { $($display:tt)* } { $($from_op:tt)* }
194 ) => {
195 $crate::__define_language!(
196 $(#[$meta])* $vis enum $name [$($gen),*]
197 { $($variants)* }
198 [$($where)*]
199 ->
200 { $($decl)* $variant($ids), }
201 { $($matches)* ($name::$variant(l), $name::$variant(r)) => $crate::LanguageChildren::len(l) == $crate::LanguageChildren::len(r), }
202 { $($children)* $name::$variant(ids) => $crate::LanguageChildren::as_slice(ids), }
203 { $($children_mut)* $name::$variant(ids) => $crate::LanguageChildren::as_mut_slice(ids), }
204 { $($display)* ($name::$variant(..), f) => f.write_str($string), }
205 { $($from_op)* (op, children) if op == $string && <$ids as $crate::LanguageChildren>::can_be_length(children.len()) => {
206 let children = <$ids as $crate::LanguageChildren>::from_vec(children);
207 Ok($name::$variant(children))
208 },
209 }
210 );
211 };
212
213 ($(#[$meta:meta])* $vis:vis enum $name:ident [$($gen:ident),*]
214 {
215 $variant:ident ($data:ty),
216 $($variants:tt)*
217 }
218 [$($where:tt)*]
219 ->
220 { $($decl:tt)* } { $($matches:tt)* } { $($children:tt)* } { $($children_mut:tt)* }
221 { $($display:tt)* } { $($from_op:tt)* }
222 ) => {
223 $crate::__define_language!(
224 $(#[$meta])* $vis enum $name [$($gen),*]
225 { $($variants)* }
226 [$($where)*]
227 ->
228 { $($decl)* $variant($data), }
229 { $($matches)* ($name::$variant(data1), $name::$variant(data2)) => data1 == data2, }
230 { $($children)* $name::$variant(_data) => &[], }
231 { $($children_mut)* $name::$variant(_data) => &mut [], }
232 { $($display)* ($name::$variant(data), f) => ::std::fmt::Display::fmt(data, f), }
233 { $($from_op)* (op, children) if op.parse::<$data>().is_ok() && children.is_empty() => Ok($name::$variant(op.parse().unwrap())), }
234 );
235 };
236
237 ($(#[$meta:meta])* $vis:vis enum $name:ident [$($gen:ident)*]
238 {
239 $variant:ident ($data:ty, $ids:ty),
240 $($variants:tt)*
241 }
242 [$($where:tt)*]
243 ->
244 { $($decl:tt)* } { $($matches:tt)* } { $($children:tt)* } { $($children_mut:tt)* }
245 { $($display:tt)* } { $($from_op:tt)* }
246 ) => {
247 $crate::__define_language!(
248 $(#[$meta])* $vis enum $name [$($gen)*]
249 { $($variants)* }
250 [$($where)*]
251 ->
252 { $($decl)* $variant($data, $ids), }
253 { $($matches)* ($name::$variant(d1, l), $name::$variant(d2, r)) => d1 == d2 && $crate::LanguageChildren::len(l) == $crate::LanguageChildren::len(r), }
254 { $($children)* $name::$variant(_, ids) => $crate::LanguageChildren::as_slice(ids), }
255 { $($children_mut)* $name::$variant(_, ids) => $crate::LanguageChildren::as_mut_slice(ids), }
256 { $($display)* ($name::$variant(data, _), f) => ::std::fmt::Display::fmt(data, f), }
257 { $($from_op)* (op, children) if op.parse::<$data>().is_ok() && <$ids as $crate::LanguageChildren>::can_be_length(children.len()) => {
258 let data = op.parse::<$data>().unwrap();
259 let children = <$ids as $crate::LanguageChildren>::from_vec(children);
260 Ok($name::$variant(data, children))
261 },
262 }
263 );
264 };
265}
266
267/** A macro to easily make [`Rewrite`]s.
268
269The `rewrite!` macro greatly simplifies creating simple, purely
270syntactic rewrites while also allowing more complex ones.
271
272This panics if [`Rewrite::new`](Rewrite::new()) fails.
273
274The simplest form `rewrite!(a; b => c)` creates a [`Rewrite`]
275with name `a`, [`Searcher`] `b`, and [`Applier`] `c`.
276Note that in the `b` and `c` position, the macro only accepts a single
277token tree (see the [macros reference][macro] for more info).
278In short, that means you should pass in an identifier, literal, or
279something surrounded by parentheses or braces.
280
281If you pass in a literal to the `b` or `c` position, the macro will
282try to parse it as a [`Pattern`] which implements both [`Searcher`]
283and [`Applier`].
284
285The macro also accepts any number of `if <expr>` forms at the end,
286where the given expression should implement [`Condition`].
287For each of these, the macro will wrap the given applier in a
288[`ConditionalApplier`] with the given condition, with the first condition being
289the outermost, and the last condition being the innermost.
290
291# Example
292```
293# use egg::*;
294use std::borrow::Cow;
295use std::sync::Arc;
296define_language! {
297 enum SimpleLanguage {
298 Num(i32),
299 "+" = Add([Id; 2]),
300 "-" = Sub([Id; 2]),
301 "*" = Mul([Id; 2]),
302 "/" = Div([Id; 2]),
303 }
304}
305
306type EGraph = egg::EGraph<SimpleLanguage, ()>;
307
308let mut rules: Vec<Rewrite<SimpleLanguage, ()>> = vec![
309 rewrite!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"),
310 rewrite!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"),
311
312 rewrite!("mul-0"; "(* ?a 0)" => "0"),
313
314 rewrite!("silly"; "(* ?a 1)" => { MySillyApplier("foo") }),
315
316 rewrite!("something_conditional";
317 "(/ ?a ?b)" => "(* ?a (/ 1 ?b))"
318 if is_not_zero("?b")),
319];
320
321// rewrite! supports bidirectional rules too
322// it returns a Vec of length 2, so you need to concat
323rules.extend(vec![
324 rewrite!("add-0"; "(+ ?a 0)" <=> "?a"),
325 rewrite!("mul-1"; "(* ?a 1)" <=> "?a"),
326].concat());
327
328#[derive(Debug)]
329struct MySillyApplier(&'static str);
330impl Applier<SimpleLanguage, ()> for MySillyApplier {
331 fn apply_one(&self, _: &mut EGraph, _: Id, _: &Subst, _: Option<&PatternAst<SimpleLanguage>>, _: Symbol) -> Vec<Id> {
332 panic!()
333 }
334}
335
336// This returns a function that implements Condition
337fn is_not_zero(var: &'static str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
338 let var = var.parse().unwrap();
339 let zero = SimpleLanguage::Num(0);
340 // note this check is just an example,
341 // checking for the absence of 0 is insufficient since 0 could be merged in later
342 // see https://github.com/egraphs-good/egg/issues/297
343 move |egraph, _, subst| !egraph[subst[var]].nodes.contains(&zero)
344}
345```
346
347[macro]: https://doc.rust-lang.org/stable/reference/macros-by-example.html#metavariables
348**/
349#[macro_export]
350macro_rules! rewrite {
351 (
352 $name:expr;
353 $lhs:tt => $rhs:tt
354 $(if $cond:expr)*
355 ) => {{
356 let searcher = $crate::__rewrite!(@parse Pattern $lhs);
357 let core_applier = $crate::__rewrite!(@parse Pattern $rhs);
358 let applier = $crate::__rewrite!(@applier core_applier; $($cond,)*);
359 $crate::Rewrite::new($name.to_string(), searcher, applier).unwrap()
360 }};
361 (
362 $name:expr;
363 $lhs:tt <=> $rhs:tt
364 $(if $cond:expr)*
365 ) => {{
366 let name = $name;
367 let name2 = String::from(name.clone()) + "-rev";
368 vec![
369 $crate::rewrite!(name; $lhs => $rhs $(if $cond)*),
370 $crate::rewrite!(name2; $rhs => $lhs $(if $cond)*)
371 ]
372 }};
373}
374
375/** A macro to easily make [`Rewrite`]s using [`MultiPattern`]s.
376
377Similar to the [`rewrite!`] macro,
378this macro uses the form `multi_rewrite!(name; multipattern => multipattern)`.
379String literals will be parsed a [`MultiPattern`]s.
380
381**/
382#[macro_export]
383macro_rules! multi_rewrite {
384 // limited multipattern support
385 (
386 $name:expr;
387 $lhs:tt => $rhs:tt
388 ) => {{
389 let searcher = $crate::__rewrite!(@parse MultiPattern $lhs);
390 let applier = $crate::__rewrite!(@parse MultiPattern $rhs);
391 $crate::Rewrite::new($name.to_string(), searcher, applier).unwrap()
392 }};
393}
394
395#[doc(hidden)]
396#[macro_export]
397macro_rules! __rewrite {
398 (@parse $t:ident $rhs:literal) => {
399 $rhs.parse::<$crate::$t<_>>().unwrap()
400 };
401 (@parse $t:ident $rhs:expr) => { $rhs };
402 (@applier $applier:expr;) => { $applier };
403 (@applier $applier:expr; $cond:expr, $($conds:expr,)*) => {
404 $crate::ConditionalApplier {
405 condition: $cond,
406 applier: $crate::__rewrite!(@applier $applier; $($conds,)*)
407 }
408 };
409}
410
411#[cfg(test)]
412mod tests {
413
414 use crate::*;
415
416 define_language! {
417 enum Simple {
418 "+" = Add([Id; 2]),
419 "-" = Sub([Id; 2]),
420 "*" = Mul([Id; 2]),
421 "-" = Neg(Id),
422 "list" = List(Box<[Id]>),
423 "pi" = Pi,
424 Int(i32),
425 Var(Symbol),
426 }
427 }
428
429 #[test]
430 fn modify_children() {
431 let mut add = Simple::Add([0.into(), 0.into()]);
432 add.for_each_mut(|id| *id = 1.into());
433 assert_eq!(add, Simple::Add([1.into(), 1.into()]));
434 }
435
436 #[test]
437 fn some_rewrites() {
438 let mut rws: Vec<Rewrite<Simple, ()>> = vec![
439 // here it should parse the rhs
440 rewrite!("rule"; "cons" => "f"),
441 // here it should just accept the rhs without trying to parse
442 rewrite!("rule"; "f" => { "pat".parse::<Pattern<_>>().unwrap() }),
443 ];
444 rws.extend(rewrite!("two-way"; "foo" <=> "bar"));
445 }
446
447 #[test]
448 #[should_panic(expected = "refers to unbound var ?x")]
449 fn rewrite_simple_panic() {
450 let _: Rewrite<Simple, ()> = rewrite!("bad"; "?a" => "?x");
451 }
452
453 #[test]
454 #[should_panic(expected = "refers to unbound var ?x")]
455 fn rewrite_conditional_panic() {
456 let x: Pattern<Simple> = "?x".parse().unwrap();
457 let _: Rewrite<Simple, ()> = rewrite!(
458 "bad"; "?a" => "?a" if ConditionEqual::new(x.clone(), x)
459 );
460 }
461}