egg/rewrite.rs
1use pattern::apply_pat;
2use std::fmt::{self, Debug, Display};
3use std::sync::Arc;
4
5use crate::*;
6
7/// A rewrite that searches for the lefthand side and applies the righthand side.
8///
9/// The [`rewrite!`] macro is the easiest way to create rewrites.
10///
11/// A [`Rewrite`] consists principally of a [`Searcher`] (the lefthand
12/// side) and an [`Applier`] (the righthand side).
13/// It additionally stores a name used to refer to the rewrite and a
14/// long name used for debugging.
15///
16#[derive(Clone)]
17#[non_exhaustive]
18pub struct Rewrite<L, N> {
19 /// The name of the rewrite.
20 pub name: Symbol,
21 /// The searcher (left-hand side) of the rewrite.
22 pub searcher: Arc<dyn Searcher<L, N> + Sync + Send>,
23 /// The applier (right-hand side) of the rewrite.
24 pub applier: Arc<dyn Applier<L, N> + Sync + Send>,
25}
26
27impl<L, N> Debug for Rewrite<L, N>
28where
29 L: Language + Display + 'static,
30 N: Analysis<L> + 'static,
31{
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 let mut d = f.debug_struct("Rewrite");
34 d.field("name", &self.name);
35
36 // if let Some(pat) = Any::downcast_ref::<dyn Pattern<L>>(&self.searcher) {
37 if let Some(pat) = self.searcher.get_pattern_ast() {
38 d.field("searcher", &DisplayAsDebug(pat));
39 } else {
40 d.field("searcher", &"<< searcher >>");
41 }
42
43 if let Some(pat) = self.applier.get_pattern_ast() {
44 d.field("applier", &DisplayAsDebug(pat));
45 } else {
46 d.field("applier", &"<< applier >>");
47 }
48
49 d.finish()
50 }
51}
52
53impl<L: Language, N: Analysis<L>> Rewrite<L, N> {
54 /// Create a new [`Rewrite`]. You typically want to use the
55 /// [`rewrite!`] macro instead.
56 ///
57 pub fn new(
58 name: impl Into<Symbol>,
59 searcher: impl Searcher<L, N> + Send + Sync + 'static,
60 applier: impl Applier<L, N> + Send + Sync + 'static,
61 ) -> Result<Self, String> {
62 let name = name.into();
63 let searcher = Arc::new(searcher);
64 let applier = Arc::new(applier);
65
66 let bound_vars = searcher.vars();
67 for v in applier.vars() {
68 if !bound_vars.contains(&v) {
69 return Err(format!("Rewrite {} refers to unbound var {}", name, v));
70 }
71 }
72
73 Ok(Self {
74 name,
75 searcher,
76 applier,
77 })
78 }
79
80 /// Call [`search`] on the [`Searcher`].
81 ///
82 /// [`search`]: Searcher::search()
83 pub fn search(&self, egraph: &EGraph<L, N>) -> Vec<SearchMatches<L>> {
84 self.searcher.search(egraph)
85 }
86
87 /// Call [`search_with_limit`] on the [`Searcher`].
88 ///
89 /// [`search_with_limit`]: Searcher::search_with_limit()
90 pub fn search_with_limit(&self, egraph: &EGraph<L, N>, limit: usize) -> Vec<SearchMatches<L>> {
91 self.searcher.search_with_limit(egraph, limit)
92 }
93
94 /// Call [`apply_matches`] on the [`Applier`].
95 ///
96 /// [`apply_matches`]: Applier::apply_matches()
97 pub fn apply(&self, egraph: &mut EGraph<L, N>, matches: &[SearchMatches<L>]) -> Vec<Id> {
98 self.applier.apply_matches(egraph, matches, self.name)
99 }
100
101 /// This `run` is for testing use only. You should use things
102 /// from the `egg::run` module
103 #[cfg(test)]
104 pub(crate) fn run(&self, egraph: &mut EGraph<L, N>) -> Vec<Id> {
105 let start = crate::util::Instant::now();
106
107 let matches = self.search(egraph);
108 log::debug!("Found rewrite {} {} times", self.name, matches.len());
109
110 let ids = self.apply(egraph, &matches);
111 let elapsed = start.elapsed();
112 log::debug!(
113 "Applied rewrite {} {} times in {}.{:03}",
114 self.name,
115 ids.len(),
116 elapsed.as_secs(),
117 elapsed.subsec_millis()
118 );
119
120 egraph.rebuild();
121 ids
122 }
123}
124
125/// Searches the given list of e-classes with a limit.
126pub(crate) fn search_eclasses_with_limit<'a, I, S, L, N>(
127 searcher: &'a S,
128 egraph: &EGraph<L, N>,
129 eclasses: I,
130 mut limit: usize,
131) -> Vec<SearchMatches<'a, L>>
132where
133 L: Language,
134 N: Analysis<L>,
135 S: Searcher<L, N> + ?Sized,
136 I: IntoIterator<Item = Id>,
137{
138 let mut ms = vec![];
139 for eclass in eclasses {
140 if limit == 0 {
141 break;
142 }
143 match searcher.search_eclass_with_limit(egraph, eclass, limit) {
144 None => continue,
145 Some(m) => {
146 let len = m.substs.len();
147 assert!(len <= limit);
148 limit -= len;
149 ms.push(m);
150 }
151 }
152 }
153 ms
154}
155
156/// The lefthand side of a [`Rewrite`].
157///
158/// A [`Searcher`] is something that can search the egraph and find
159/// matching substitutions.
160/// Right now the only significant [`Searcher`] is [`Pattern`].
161///
162pub trait Searcher<L, N>
163where
164 L: Language,
165 N: Analysis<L>,
166{
167 /// Search one eclass, returning None if no matches can be found.
168 /// This should not return a SearchMatches with no substs.
169 fn search_eclass(&self, egraph: &EGraph<L, N>, eclass: Id) -> Option<SearchMatches<L>> {
170 self.search_eclass_with_limit(egraph, eclass, usize::MAX)
171 }
172
173 /// Similar to [`search_eclass`], but return at most `limit` many matches.
174 ///
175 /// Implementation of [`Searcher`] should implement
176 /// [`search_eclass_with_limit`].
177 ///
178 /// [`search_eclass`]: Searcher::search_eclass
179 /// [`search_eclass_with_limit`]: Searcher::search_eclass_with_limit
180 fn search_eclass_with_limit(
181 &self,
182 egraph: &EGraph<L, N>,
183 eclass: Id,
184 limit: usize,
185 ) -> Option<SearchMatches<L>>;
186
187 /// Search the whole [`EGraph`], returning a list of all the
188 /// [`SearchMatches`] where something was found.
189 /// This just calls [`Searcher::search_with_limit`] with a big limit.
190 fn search(&self, egraph: &EGraph<L, N>) -> Vec<SearchMatches<L>> {
191 self.search_with_limit(egraph, usize::MAX)
192 }
193
194 /// Similar to [`search`], but return at most `limit` many matches.
195 ///
196 /// [`search`]: Searcher::search
197 fn search_with_limit(&self, egraph: &EGraph<L, N>, limit: usize) -> Vec<SearchMatches<L>> {
198 search_eclasses_with_limit(self, egraph, egraph.classes().map(|e| e.id), limit)
199 }
200
201 /// Returns the number of matches in the e-graph
202 fn n_matches(&self, egraph: &EGraph<L, N>) -> usize {
203 self.search(egraph).iter().map(|m| m.substs.len()).sum()
204 }
205
206 /// For patterns, return the ast directly as a reference
207 fn get_pattern_ast(&self) -> Option<&PatternAst<L>> {
208 None
209 }
210
211 /// Returns a list of the variables bound by this Searcher
212 fn vars(&self) -> Vec<Var>;
213}
214
215/// The righthand side of a [`Rewrite`].
216///
217/// An [`Applier`] is anything that can do something with a
218/// substitution ([`Subst`]). This allows you to implement rewrites
219/// that determine when and how to respond to a match using custom
220/// logic, including access to the [`Analysis`] data of an [`EClass`].
221///
222/// Notably, [`Pattern`] implements [`Applier`], which suffices in
223/// most cases.
224/// Additionally, `egg` provides [`ConditionalApplier`] to stack
225/// [`Condition`]s onto an [`Applier`], which in many cases can save
226/// you from having to implement your own applier.
227///
228/// # Example
229/// ```
230/// use egg::{rewrite as rw, *};
231/// use std::sync::Arc;
232///
233/// define_language! {
234/// enum Math {
235/// Num(i32),
236/// "+" = Add([Id; 2]),
237/// "*" = Mul([Id; 2]),
238/// Symbol(Symbol),
239/// }
240/// }
241///
242/// type EGraph = egg::EGraph<Math, MinSize>;
243///
244/// // Our metadata in this case will be size of the smallest
245/// // represented expression in the eclass.
246/// #[derive(Default)]
247/// struct MinSize;
248/// impl Analysis<Math> for MinSize {
249/// type Data = usize;
250/// fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge {
251/// merge_min(to, from)
252/// }
253/// fn make(egraph: &mut EGraph, enode: &Math, _id: Id) -> Self::Data {
254/// let get_size = |i: Id| egraph[i].data;
255/// AstSize.cost(enode, get_size)
256/// }
257/// }
258///
259/// let rules = &[
260/// rw!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"),
261/// rw!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"),
262/// rw!("add-0"; "(+ ?a 0)" => "?a"),
263/// rw!("mul-0"; "(* ?a 0)" => "0"),
264/// rw!("mul-1"; "(* ?a 1)" => "?a"),
265/// // the rewrite macro parses the rhs as a single token tree, so
266/// // we wrap it in braces (parens work too).
267/// rw!("funky"; "(+ ?a (* ?b ?c))" => { Funky {
268/// a: "?a".parse().unwrap(),
269/// b: "?b".parse().unwrap(),
270/// c: "?c".parse().unwrap(),
271/// ast: "(+ (+ ?a 0) (* (+ ?b 0) (+ ?c 0)))".parse().unwrap(),
272/// }}),
273/// ];
274///
275/// #[derive(Debug, Clone, PartialEq, Eq)]
276/// struct Funky {
277/// a: Var,
278/// b: Var,
279/// c: Var,
280/// ast: PatternAst<Math>,
281/// }
282///
283/// impl Applier<Math, MinSize> for Funky {
284///
285/// fn apply_one(&self, egraph: &mut EGraph, matched_id: Id, subst: &Subst, searcher_pattern: Option<&PatternAst<Math>>, rule_name: Symbol) -> Vec<Id> {
286/// let a: Id = subst[self.a];
287/// // In a custom Applier, you can inspect the analysis data,
288/// // which is powerful combination!
289/// let size_of_a = egraph[a].data;
290/// if size_of_a > 50 {
291/// println!("Too big! Not doing anything");
292/// vec![]
293/// } else {
294/// // we're going to manually add:
295/// // (+ (+ ?a 0) (* (+ ?b 0) (+ ?c 0)))
296/// // to be unified with the original:
297/// // (+ ?a (* ?b ?c ))
298/// let b: Id = subst[self.b];
299/// let c: Id = subst[self.c];
300/// let zero = egraph.add(Math::Num(0));
301/// let a0 = egraph.add(Math::Add([a, zero]));
302/// let b0 = egraph.add(Math::Add([b, zero]));
303/// let c0 = egraph.add(Math::Add([c, zero]));
304/// let b0c0 = egraph.add(Math::Mul([b0, c0]));
305/// let a0b0c0 = egraph.add(Math::Add([a0, b0c0]));
306/// // Don't forget to union the new node with the matched node!
307/// if egraph.union(matched_id, a0b0c0) {
308/// vec![a0b0c0]
309/// } else {
310/// vec![]
311/// }
312/// }
313/// }
314/// }
315///
316/// let start = "(+ x (* y z))".parse().unwrap();
317/// Runner::default().with_expr(&start).run(rules);
318/// ```
319pub trait Applier<L, N>
320where
321 L: Language,
322 N: Analysis<L>,
323{
324 /// Apply many substitutions.
325 ///
326 /// This method should call [`apply_one`] for each match.
327 ///
328 /// It returns the ids resulting from the calls to [`apply_one`].
329 /// The default implementation does this and should suffice for
330 /// most use cases.
331 ///
332 /// [`apply_one`]: Applier::apply_one()
333 fn apply_matches(
334 &self,
335 egraph: &mut EGraph<L, N>,
336 matches: &[SearchMatches<L>],
337 rule_name: Symbol,
338 ) -> Vec<Id> {
339 let mut added = vec![];
340 for mat in matches {
341 let ast = if egraph.are_explanations_enabled() {
342 mat.ast.as_ref().map(|cow| cow.as_ref())
343 } else {
344 None
345 };
346 for subst in &mat.substs {
347 let ids = self.apply_one(egraph, mat.eclass, subst, ast, rule_name);
348 added.extend(ids)
349 }
350 }
351 added
352 }
353
354 /// For patterns, get the ast directly as a reference.
355 fn get_pattern_ast(&self) -> Option<&PatternAst<L>> {
356 None
357 }
358
359 /// Apply a single substitution.
360 ///
361 /// An [`Applier`] should add things and union them with `eclass`.
362 /// Appliers can also inspect the eclass if necessary using the
363 /// `eclass` parameter.
364 ///
365 /// This should return a list of [`Id`]s of eclasses that
366 /// were changed. There can be zero, one, or many.
367 /// When explanations mode is enabled, a [`PatternAst`] for
368 /// the searcher is provided.
369 ///
370 /// [`apply_matches`]: Applier::apply_matches()
371 fn apply_one(
372 &self,
373 egraph: &mut EGraph<L, N>,
374 eclass: Id,
375 subst: &Subst,
376 searcher_ast: Option<&PatternAst<L>>,
377 rule_name: Symbol,
378 ) -> Vec<Id>;
379
380 /// Returns a list of variables that this Applier assumes are bound.
381 ///
382 /// `egg` will check that the corresponding `Searcher` binds those
383 /// variables.
384 /// By default this return an empty `Vec`, which basically turns off the
385 /// checking.
386 fn vars(&self) -> Vec<Var> {
387 vec![]
388 }
389}
390
391/// An [`Applier`] that checks a [`Condition`] before applying.
392///
393/// A [`ConditionalApplier`] simply calls [`check`] on the
394/// [`Condition`] before calling [`apply_one`] on the inner
395/// [`Applier`].
396///
397/// See the [`rewrite!`] macro documentation for an example.
398///
399/// [`apply_one`]: Applier::apply_one()
400/// [`check`]: Condition::check()
401#[derive(Clone, Debug, PartialEq, Eq)]
402pub struct ConditionalApplier<C, A> {
403 /// The [`Condition`] to [`check`] before calling [`apply_one`] on
404 /// `applier`.
405 ///
406 /// [`apply_one`]: Applier::apply_one()
407 /// [`check`]: Condition::check()
408 pub condition: C,
409 /// The inner [`Applier`] to call once `condition` passes.
410 ///
411 pub applier: A,
412}
413
414impl<C, A, N, L> Applier<L, N> for ConditionalApplier<C, A>
415where
416 L: Language,
417 C: Condition<L, N>,
418 A: Applier<L, N>,
419 N: Analysis<L>,
420{
421 fn get_pattern_ast(&self) -> Option<&PatternAst<L>> {
422 self.applier.get_pattern_ast()
423 }
424
425 fn apply_one(
426 &self,
427 egraph: &mut EGraph<L, N>,
428 eclass: Id,
429 subst: &Subst,
430 searcher_ast: Option<&PatternAst<L>>,
431 rule_name: Symbol,
432 ) -> Vec<Id> {
433 if self.condition.check(egraph, eclass, subst) {
434 self.applier
435 .apply_one(egraph, eclass, subst, searcher_ast, rule_name)
436 } else {
437 vec![]
438 }
439 }
440
441 fn vars(&self) -> Vec<Var> {
442 let mut vars = self.applier.vars();
443 vars.extend(self.condition.vars());
444 vars
445 }
446}
447
448/// A condition to check in a [`ConditionalApplier`].
449///
450/// See the [`ConditionalApplier`] docs.
451///
452/// Notably, any function ([`Fn`]) that doesn't mutate other state
453/// and matches the signature of [`check`] implements [`Condition`].
454///
455/// [`check`]: Condition::check()
456/// [`Fn`]: std::ops::Fn
457pub trait Condition<L, N>
458where
459 L: Language,
460 N: Analysis<L>,
461{
462 /// Check a condition.
463 ///
464 /// `eclass` is the eclass [`Id`] where the match (`subst`) occured.
465 /// If this is true, then the [`ConditionalApplier`] will fire.
466 ///
467 fn check(&self, egraph: &mut EGraph<L, N>, eclass: Id, subst: &Subst) -> bool;
468
469 /// Returns a list of variables that this Condition assumes are bound.
470 ///
471 /// `egg` will check that the corresponding `Searcher` binds those
472 /// variables.
473 /// By default this return an empty `Vec`, which basically turns off the
474 /// checking.
475 fn vars(&self) -> Vec<Var> {
476 vec![]
477 }
478}
479
480impl<L, F, N> Condition<L, N> for F
481where
482 L: Language,
483 N: Analysis<L>,
484 F: Fn(&mut EGraph<L, N>, Id, &Subst) -> bool,
485{
486 fn check(&self, egraph: &mut EGraph<L, N>, eclass: Id, subst: &Subst) -> bool {
487 self(egraph, eclass, subst)
488 }
489}
490
491/// A [`Condition`] that checks if two terms are equivalent.
492///
493/// This condition adds its two [`Pattern`] to the egraph and passes
494/// if and only if they are equivalent (in the same eclass).
495///
496#[derive(Debug, Clone, PartialEq, Eq)]
497pub struct ConditionEqual<L> {
498 p1: Pattern<L>,
499 p2: Pattern<L>,
500}
501
502impl<L: Language> ConditionEqual<L> {
503 /// Create a new [`ConditionEqual`] condition given two patterns.
504 pub fn new(p1: Pattern<L>, p2: Pattern<L>) -> Self {
505 ConditionEqual { p1, p2 }
506 }
507}
508
509impl<L: FromOp> ConditionEqual<L> {
510 /// Create a ConditionEqual by parsing two pattern strings.
511 ///
512 /// This panics if the parsing fails.
513 pub fn parse(a1: &str, a2: &str) -> Self {
514 Self {
515 p1: a1.parse().unwrap(),
516 p2: a2.parse().unwrap(),
517 }
518 }
519}
520
521impl<L, N> Condition<L, N> for ConditionEqual<L>
522where
523 L: Language,
524 N: Analysis<L>,
525{
526 fn check(&self, egraph: &mut EGraph<L, N>, _eclass: Id, subst: &Subst) -> bool {
527 let mut id_buf_1 = vec![0.into(); self.p1.ast.len()];
528 let mut id_buf_2 = vec![0.into(); self.p2.ast.len()];
529 let a1 = apply_pat(&mut id_buf_1, &self.p1.ast, egraph, subst);
530 let a2 = apply_pat(&mut id_buf_2, &self.p2.ast, egraph, subst);
531 a1 == a2
532 }
533
534 fn vars(&self) -> Vec<Var> {
535 let mut vars = self.p1.vars();
536 vars.extend(self.p2.vars());
537 vars
538 }
539}
540
541#[cfg(test)]
542mod tests {
543
544 use crate::{SymbolLang as S, *};
545 use std::str::FromStr;
546
547 type EGraph = crate::EGraph<S, ()>;
548
549 #[test]
550 fn conditional_rewrite() {
551 crate::init_logger();
552 let mut egraph = EGraph::default();
553
554 let x = egraph.add(S::leaf("x"));
555 let y = egraph.add(S::leaf("2"));
556 let mul = egraph.add(S::new("*", vec![x, y]));
557
558 let true_pat = Pattern::from_str("TRUE").unwrap();
559 egraph.add(S::leaf("TRUE"));
560
561 let pow2b = Pattern::from_str("(is-power2 ?b)").unwrap();
562 let mul_to_shift = rewrite!(
563 "mul_to_shift";
564 "(* ?a ?b)" => "(>> ?a (log2 ?b))"
565 if ConditionEqual::new(pow2b, true_pat)
566 );
567
568 println!("rewrite shouldn't do anything yet");
569 egraph.rebuild();
570 let apps = mul_to_shift.run(&mut egraph);
571 assert!(apps.is_empty());
572
573 println!("Add the needed equality");
574 egraph.union_instantiations(
575 &"(is-power2 2)".parse().unwrap(),
576 &"TRUE".parse().unwrap(),
577 &Default::default(),
578 "direct-union".to_string(),
579 );
580
581 println!("Should fire now");
582 egraph.rebuild();
583 let apps = mul_to_shift.run(&mut egraph);
584 assert_eq!(apps, vec![egraph.find(mul)]);
585 }
586
587 #[test]
588 fn fn_rewrite() {
589 crate::init_logger();
590 let mut egraph = EGraph::default();
591
592 let start = RecExpr::from_str("(+ x y)").unwrap();
593 let goal = RecExpr::from_str("xy").unwrap();
594
595 let root = egraph.add_expr(&start);
596
597 fn get(egraph: &EGraph, id: Id) -> Symbol {
598 egraph[id].nodes[0].op
599 }
600
601 #[derive(Debug)]
602 struct Appender {
603 _rhs: PatternAst<S>,
604 }
605
606 impl Applier<SymbolLang, ()> for Appender {
607 fn apply_one(
608 &self,
609 egraph: &mut EGraph,
610 eclass: Id,
611 subst: &Subst,
612 searcher_ast: Option<&PatternAst<SymbolLang>>,
613 rule_name: Symbol,
614 ) -> Vec<Id> {
615 let a: Var = "?a".parse().unwrap();
616 let b: Var = "?b".parse().unwrap();
617 let a = get(egraph, subst[a]);
618 let b = get(egraph, subst[b]);
619 let s = format!("{}{}", a, b);
620 if let Some(ast) = searcher_ast {
621 let (id, did_something) = egraph.union_instantiations(
622 ast,
623 &PatternAst::from_str(&s).unwrap(),
624 subst,
625 rule_name,
626 );
627 if did_something {
628 vec![id]
629 } else {
630 vec![]
631 }
632 } else {
633 let added = egraph.add(S::leaf(&s));
634 if egraph.union(added, eclass) {
635 vec![eclass]
636 } else {
637 vec![]
638 }
639 }
640 }
641 }
642
643 let fold_add = rewrite!(
644 "fold_add"; "(+ ?a ?b)" => { Appender { _rhs: "?a".parse().unwrap()}}
645 );
646
647 egraph.rebuild();
648 fold_add.run(&mut egraph);
649 assert_eq!(egraph.equivs(&start, &goal), vec![egraph.find(root)]);
650 }
651}