egg/
rewrite.rs

1use pattern::apply_pat;
2use std::fmt::{self, Debug, Display};
3use std::sync::Arc;
4
5use crate::*;
6
7/// A rewrite that searches for the lefthand side and applies the righthand side.
8///
9/// The [`rewrite!`] macro is the easiest way to create rewrites.
10///
11/// A [`Rewrite`] consists principally of a [`Searcher`] (the lefthand
12/// side) and an [`Applier`] (the righthand side).
13/// It additionally stores a name used to refer to the rewrite and a
14/// long name used for debugging.
15///
16#[derive(Clone)]
17#[non_exhaustive]
18pub struct Rewrite<L, N> {
19    /// The name of the rewrite.
20    pub name: Symbol,
21    /// The searcher (left-hand side) of the rewrite.
22    pub searcher: Arc<dyn Searcher<L, N> + Sync + Send>,
23    /// The applier (right-hand side) of the rewrite.
24    pub applier: Arc<dyn Applier<L, N> + Sync + Send>,
25}
26
27impl<L, N> Debug for Rewrite<L, N>
28where
29    L: Language + Display + 'static,
30    N: Analysis<L> + 'static,
31{
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        let mut d = f.debug_struct("Rewrite");
34        d.field("name", &self.name);
35
36        // if let Some(pat) = Any::downcast_ref::<dyn Pattern<L>>(&self.searcher) {
37        if let Some(pat) = self.searcher.get_pattern_ast() {
38            d.field("searcher", &DisplayAsDebug(pat));
39        } else {
40            d.field("searcher", &"<< searcher >>");
41        }
42
43        if let Some(pat) = self.applier.get_pattern_ast() {
44            d.field("applier", &DisplayAsDebug(pat));
45        } else {
46            d.field("applier", &"<< applier >>");
47        }
48
49        d.finish()
50    }
51}
52
53impl<L: Language, N: Analysis<L>> Rewrite<L, N> {
54    /// Create a new [`Rewrite`]. You typically want to use the
55    /// [`rewrite!`] macro instead.
56    ///
57    pub fn new(
58        name: impl Into<Symbol>,
59        searcher: impl Searcher<L, N> + Send + Sync + 'static,
60        applier: impl Applier<L, N> + Send + Sync + 'static,
61    ) -> Result<Self, String> {
62        let name = name.into();
63        let searcher = Arc::new(searcher);
64        let applier = Arc::new(applier);
65
66        let bound_vars = searcher.vars();
67        for v in applier.vars() {
68            if !bound_vars.contains(&v) {
69                return Err(format!("Rewrite {} refers to unbound var {}", name, v));
70            }
71        }
72
73        Ok(Self {
74            name,
75            searcher,
76            applier,
77        })
78    }
79
80    /// Call [`search`] on the [`Searcher`].
81    ///
82    /// [`search`]: Searcher::search()
83    pub fn search(&self, egraph: &EGraph<L, N>) -> Vec<SearchMatches<L>> {
84        self.searcher.search(egraph)
85    }
86
87    /// Call [`search_with_limit`] on the [`Searcher`].
88    ///
89    /// [`search_with_limit`]: Searcher::search_with_limit()
90    pub fn search_with_limit(&self, egraph: &EGraph<L, N>, limit: usize) -> Vec<SearchMatches<L>> {
91        self.searcher.search_with_limit(egraph, limit)
92    }
93
94    /// Call [`apply_matches`] on the [`Applier`].
95    ///
96    /// [`apply_matches`]: Applier::apply_matches()
97    pub fn apply(&self, egraph: &mut EGraph<L, N>, matches: &[SearchMatches<L>]) -> Vec<Id> {
98        self.applier.apply_matches(egraph, matches, self.name)
99    }
100
101    /// This `run` is for testing use only. You should use things
102    /// from the `egg::run` module
103    #[cfg(test)]
104    pub(crate) fn run(&self, egraph: &mut EGraph<L, N>) -> Vec<Id> {
105        let start = crate::util::Instant::now();
106
107        let matches = self.search(egraph);
108        log::debug!("Found rewrite {} {} times", self.name, matches.len());
109
110        let ids = self.apply(egraph, &matches);
111        let elapsed = start.elapsed();
112        log::debug!(
113            "Applied rewrite {} {} times in {}.{:03}",
114            self.name,
115            ids.len(),
116            elapsed.as_secs(),
117            elapsed.subsec_millis()
118        );
119
120        egraph.rebuild();
121        ids
122    }
123}
124
125/// Searches the given list of e-classes with a limit.
126pub(crate) fn search_eclasses_with_limit<'a, I, S, L, N>(
127    searcher: &'a S,
128    egraph: &EGraph<L, N>,
129    eclasses: I,
130    mut limit: usize,
131) -> Vec<SearchMatches<'a, L>>
132where
133    L: Language,
134    N: Analysis<L>,
135    S: Searcher<L, N> + ?Sized,
136    I: IntoIterator<Item = Id>,
137{
138    let mut ms = vec![];
139    for eclass in eclasses {
140        if limit == 0 {
141            break;
142        }
143        match searcher.search_eclass_with_limit(egraph, eclass, limit) {
144            None => continue,
145            Some(m) => {
146                let len = m.substs.len();
147                assert!(len <= limit);
148                limit -= len;
149                ms.push(m);
150            }
151        }
152    }
153    ms
154}
155
156/// The lefthand side of a [`Rewrite`].
157///
158/// A [`Searcher`] is something that can search the egraph and find
159/// matching substitutions.
160/// Right now the only significant [`Searcher`] is [`Pattern`].
161///
162pub trait Searcher<L, N>
163where
164    L: Language,
165    N: Analysis<L>,
166{
167    /// Search one eclass, returning None if no matches can be found.
168    /// This should not return a SearchMatches with no substs.
169    fn search_eclass(&self, egraph: &EGraph<L, N>, eclass: Id) -> Option<SearchMatches<L>> {
170        self.search_eclass_with_limit(egraph, eclass, usize::MAX)
171    }
172
173    /// Similar to [`search_eclass`], but return at most `limit` many matches.
174    ///
175    /// Implementation of [`Searcher`] should implement
176    /// [`search_eclass_with_limit`].
177    ///
178    /// [`search_eclass`]: Searcher::search_eclass
179    /// [`search_eclass_with_limit`]: Searcher::search_eclass_with_limit
180    fn search_eclass_with_limit(
181        &self,
182        egraph: &EGraph<L, N>,
183        eclass: Id,
184        limit: usize,
185    ) -> Option<SearchMatches<L>>;
186
187    /// Search the whole [`EGraph`], returning a list of all the
188    /// [`SearchMatches`] where something was found.
189    /// This just calls [`Searcher::search_with_limit`] with a big limit.
190    fn search(&self, egraph: &EGraph<L, N>) -> Vec<SearchMatches<L>> {
191        self.search_with_limit(egraph, usize::MAX)
192    }
193
194    /// Similar to [`search`], but return at most `limit` many matches.
195    ///
196    /// [`search`]: Searcher::search
197    fn search_with_limit(&self, egraph: &EGraph<L, N>, limit: usize) -> Vec<SearchMatches<L>> {
198        search_eclasses_with_limit(self, egraph, egraph.classes().map(|e| e.id), limit)
199    }
200
201    /// Returns the number of matches in the e-graph
202    fn n_matches(&self, egraph: &EGraph<L, N>) -> usize {
203        self.search(egraph).iter().map(|m| m.substs.len()).sum()
204    }
205
206    /// For patterns, return the ast directly as a reference
207    fn get_pattern_ast(&self) -> Option<&PatternAst<L>> {
208        None
209    }
210
211    /// Returns a list of the variables bound by this Searcher
212    fn vars(&self) -> Vec<Var>;
213}
214
215/// The righthand side of a [`Rewrite`].
216///
217/// An [`Applier`] is anything that can do something with a
218/// substitution ([`Subst`]). This allows you to implement rewrites
219/// that determine when and how to respond to a match using custom
220/// logic, including access to the [`Analysis`] data of an [`EClass`].
221///
222/// Notably, [`Pattern`] implements [`Applier`], which suffices in
223/// most cases.
224/// Additionally, `egg` provides [`ConditionalApplier`] to stack
225/// [`Condition`]s onto an [`Applier`], which in many cases can save
226/// you from having to implement your own applier.
227///
228/// # Example
229/// ```
230/// use egg::{rewrite as rw, *};
231/// use std::sync::Arc;
232///
233/// define_language! {
234///     enum Math {
235///         Num(i32),
236///         "+" = Add([Id; 2]),
237///         "*" = Mul([Id; 2]),
238///         Symbol(Symbol),
239///     }
240/// }
241///
242/// type EGraph = egg::EGraph<Math, MinSize>;
243///
244/// // Our metadata in this case will be size of the smallest
245/// // represented expression in the eclass.
246/// #[derive(Default)]
247/// struct MinSize;
248/// impl Analysis<Math> for MinSize {
249///     type Data = usize;
250///     fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge {
251///         merge_min(to, from)
252///     }
253///     fn make(egraph: &mut EGraph, enode: &Math, _id: Id) -> Self::Data {
254///         let get_size = |i: Id| egraph[i].data;
255///         AstSize.cost(enode, get_size)
256///     }
257/// }
258///
259/// let rules = &[
260///     rw!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"),
261///     rw!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"),
262///     rw!("add-0"; "(+ ?a 0)" => "?a"),
263///     rw!("mul-0"; "(* ?a 0)" => "0"),
264///     rw!("mul-1"; "(* ?a 1)" => "?a"),
265///     // the rewrite macro parses the rhs as a single token tree, so
266///     // we wrap it in braces (parens work too).
267///     rw!("funky"; "(+ ?a (* ?b ?c))" => { Funky {
268///         a: "?a".parse().unwrap(),
269///         b: "?b".parse().unwrap(),
270///         c: "?c".parse().unwrap(),
271///         ast: "(+ (+ ?a 0) (* (+ ?b 0) (+ ?c 0)))".parse().unwrap(),
272///     }}),
273/// ];
274///
275/// #[derive(Debug, Clone, PartialEq, Eq)]
276/// struct Funky {
277///     a: Var,
278///     b: Var,
279///     c: Var,
280///     ast: PatternAst<Math>,
281/// }
282///
283/// impl Applier<Math, MinSize> for Funky {
284///
285///     fn apply_one(&self, egraph: &mut EGraph, matched_id: Id, subst: &Subst, searcher_pattern: Option<&PatternAst<Math>>, rule_name: Symbol) -> Vec<Id> {
286///         let a: Id = subst[self.a];
287///         // In a custom Applier, you can inspect the analysis data,
288///         // which is powerful combination!
289///         let size_of_a = egraph[a].data;
290///         if size_of_a > 50 {
291///             println!("Too big! Not doing anything");
292///             vec![]
293///         } else {
294///             // we're going to manually add:
295///             // (+ (+ ?a 0) (* (+ ?b 0) (+ ?c 0)))
296///             // to be unified with the original:
297///             // (+    ?a    (*    ?b       ?c   ))
298///             let b: Id = subst[self.b];
299///             let c: Id = subst[self.c];
300///             let zero = egraph.add(Math::Num(0));
301///             let a0 = egraph.add(Math::Add([a, zero]));
302///             let b0 = egraph.add(Math::Add([b, zero]));
303///             let c0 = egraph.add(Math::Add([c, zero]));
304///             let b0c0 = egraph.add(Math::Mul([b0, c0]));
305///             let a0b0c0 = egraph.add(Math::Add([a0, b0c0]));
306///             // Don't forget to union the new node with the matched node!
307///             if egraph.union(matched_id, a0b0c0) {
308///                 vec![a0b0c0]
309///             } else {
310///                 vec![]
311///             }
312///         }
313///     }
314/// }
315///
316/// let start = "(+ x (* y z))".parse().unwrap();
317/// Runner::default().with_expr(&start).run(rules);
318/// ```
319pub trait Applier<L, N>
320where
321    L: Language,
322    N: Analysis<L>,
323{
324    /// Apply many substitutions.
325    ///
326    /// This method should call [`apply_one`] for each match.
327    ///
328    /// It returns the ids resulting from the calls to [`apply_one`].
329    /// The default implementation does this and should suffice for
330    /// most use cases.
331    ///
332    /// [`apply_one`]: Applier::apply_one()
333    fn apply_matches(
334        &self,
335        egraph: &mut EGraph<L, N>,
336        matches: &[SearchMatches<L>],
337        rule_name: Symbol,
338    ) -> Vec<Id> {
339        let mut added = vec![];
340        for mat in matches {
341            let ast = if egraph.are_explanations_enabled() {
342                mat.ast.as_ref().map(|cow| cow.as_ref())
343            } else {
344                None
345            };
346            for subst in &mat.substs {
347                let ids = self.apply_one(egraph, mat.eclass, subst, ast, rule_name);
348                added.extend(ids)
349            }
350        }
351        added
352    }
353
354    /// For patterns, get the ast directly as a reference.
355    fn get_pattern_ast(&self) -> Option<&PatternAst<L>> {
356        None
357    }
358
359    /// Apply a single substitution.
360    ///
361    /// An [`Applier`] should add things and union them with `eclass`.
362    /// Appliers can also inspect the eclass if necessary using the
363    /// `eclass` parameter.
364    ///
365    /// This should return a list of [`Id`]s of eclasses that
366    /// were changed. There can be zero, one, or many.
367    /// When explanations mode is enabled, a [`PatternAst`] for
368    /// the searcher is provided.
369    ///
370    /// [`apply_matches`]: Applier::apply_matches()
371    fn apply_one(
372        &self,
373        egraph: &mut EGraph<L, N>,
374        eclass: Id,
375        subst: &Subst,
376        searcher_ast: Option<&PatternAst<L>>,
377        rule_name: Symbol,
378    ) -> Vec<Id>;
379
380    /// Returns a list of variables that this Applier assumes are bound.
381    ///
382    /// `egg` will check that the corresponding `Searcher` binds those
383    /// variables.
384    /// By default this return an empty `Vec`, which basically turns off the
385    /// checking.
386    fn vars(&self) -> Vec<Var> {
387        vec![]
388    }
389}
390
391/// An [`Applier`] that checks a [`Condition`] before applying.
392///
393/// A [`ConditionalApplier`] simply calls [`check`] on the
394/// [`Condition`] before calling [`apply_one`] on the inner
395/// [`Applier`].
396///
397/// See the [`rewrite!`] macro documentation for an example.
398///
399/// [`apply_one`]: Applier::apply_one()
400/// [`check`]: Condition::check()
401#[derive(Clone, Debug, PartialEq, Eq)]
402pub struct ConditionalApplier<C, A> {
403    /// The [`Condition`] to [`check`] before calling [`apply_one`] on
404    /// `applier`.
405    ///
406    /// [`apply_one`]: Applier::apply_one()
407    /// [`check`]: Condition::check()
408    pub condition: C,
409    /// The inner [`Applier`] to call once `condition` passes.
410    ///
411    pub applier: A,
412}
413
414impl<C, A, N, L> Applier<L, N> for ConditionalApplier<C, A>
415where
416    L: Language,
417    C: Condition<L, N>,
418    A: Applier<L, N>,
419    N: Analysis<L>,
420{
421    fn get_pattern_ast(&self) -> Option<&PatternAst<L>> {
422        self.applier.get_pattern_ast()
423    }
424
425    fn apply_one(
426        &self,
427        egraph: &mut EGraph<L, N>,
428        eclass: Id,
429        subst: &Subst,
430        searcher_ast: Option<&PatternAst<L>>,
431        rule_name: Symbol,
432    ) -> Vec<Id> {
433        if self.condition.check(egraph, eclass, subst) {
434            self.applier
435                .apply_one(egraph, eclass, subst, searcher_ast, rule_name)
436        } else {
437            vec![]
438        }
439    }
440
441    fn vars(&self) -> Vec<Var> {
442        let mut vars = self.applier.vars();
443        vars.extend(self.condition.vars());
444        vars
445    }
446}
447
448/// A condition to check in a [`ConditionalApplier`].
449///
450/// See the [`ConditionalApplier`] docs.
451///
452/// Notably, any function ([`Fn`]) that doesn't mutate other state
453/// and matches the signature of [`check`] implements [`Condition`].
454///
455/// [`check`]: Condition::check()
456/// [`Fn`]: std::ops::Fn
457pub trait Condition<L, N>
458where
459    L: Language,
460    N: Analysis<L>,
461{
462    /// Check a condition.
463    ///
464    /// `eclass` is the eclass [`Id`] where the match (`subst`) occured.
465    /// If this is true, then the [`ConditionalApplier`] will fire.
466    ///
467    fn check(&self, egraph: &mut EGraph<L, N>, eclass: Id, subst: &Subst) -> bool;
468
469    /// Returns a list of variables that this Condition assumes are bound.
470    ///
471    /// `egg` will check that the corresponding `Searcher` binds those
472    /// variables.
473    /// By default this return an empty `Vec`, which basically turns off the
474    /// checking.
475    fn vars(&self) -> Vec<Var> {
476        vec![]
477    }
478}
479
480impl<L, F, N> Condition<L, N> for F
481where
482    L: Language,
483    N: Analysis<L>,
484    F: Fn(&mut EGraph<L, N>, Id, &Subst) -> bool,
485{
486    fn check(&self, egraph: &mut EGraph<L, N>, eclass: Id, subst: &Subst) -> bool {
487        self(egraph, eclass, subst)
488    }
489}
490
491/// A [`Condition`] that checks if two terms are equivalent.
492///
493/// This condition adds its two [`Pattern`] to the egraph and passes
494/// if and only if they are equivalent (in the same eclass).
495///
496#[derive(Debug, Clone, PartialEq, Eq)]
497pub struct ConditionEqual<L> {
498    p1: Pattern<L>,
499    p2: Pattern<L>,
500}
501
502impl<L: Language> ConditionEqual<L> {
503    /// Create a new [`ConditionEqual`] condition given two patterns.
504    pub fn new(p1: Pattern<L>, p2: Pattern<L>) -> Self {
505        ConditionEqual { p1, p2 }
506    }
507}
508
509impl<L: FromOp> ConditionEqual<L> {
510    /// Create a ConditionEqual by parsing two pattern strings.
511    ///
512    /// This panics if the parsing fails.
513    pub fn parse(a1: &str, a2: &str) -> Self {
514        Self {
515            p1: a1.parse().unwrap(),
516            p2: a2.parse().unwrap(),
517        }
518    }
519}
520
521impl<L, N> Condition<L, N> for ConditionEqual<L>
522where
523    L: Language,
524    N: Analysis<L>,
525{
526    fn check(&self, egraph: &mut EGraph<L, N>, _eclass: Id, subst: &Subst) -> bool {
527        let mut id_buf_1 = vec![0.into(); self.p1.ast.len()];
528        let mut id_buf_2 = vec![0.into(); self.p2.ast.len()];
529        let a1 = apply_pat(&mut id_buf_1, &self.p1.ast, egraph, subst);
530        let a2 = apply_pat(&mut id_buf_2, &self.p2.ast, egraph, subst);
531        a1 == a2
532    }
533
534    fn vars(&self) -> Vec<Var> {
535        let mut vars = self.p1.vars();
536        vars.extend(self.p2.vars());
537        vars
538    }
539}
540
541#[cfg(test)]
542mod tests {
543
544    use crate::{SymbolLang as S, *};
545    use std::str::FromStr;
546
547    type EGraph = crate::EGraph<S, ()>;
548
549    #[test]
550    fn conditional_rewrite() {
551        crate::init_logger();
552        let mut egraph = EGraph::default();
553
554        let x = egraph.add(S::leaf("x"));
555        let y = egraph.add(S::leaf("2"));
556        let mul = egraph.add(S::new("*", vec![x, y]));
557
558        let true_pat = Pattern::from_str("TRUE").unwrap();
559        egraph.add(S::leaf("TRUE"));
560
561        let pow2b = Pattern::from_str("(is-power2 ?b)").unwrap();
562        let mul_to_shift = rewrite!(
563            "mul_to_shift";
564            "(* ?a ?b)" => "(>> ?a (log2 ?b))"
565            if ConditionEqual::new(pow2b, true_pat)
566        );
567
568        println!("rewrite shouldn't do anything yet");
569        egraph.rebuild();
570        let apps = mul_to_shift.run(&mut egraph);
571        assert!(apps.is_empty());
572
573        println!("Add the needed equality");
574        egraph.union_instantiations(
575            &"(is-power2 2)".parse().unwrap(),
576            &"TRUE".parse().unwrap(),
577            &Default::default(),
578            "direct-union".to_string(),
579        );
580
581        println!("Should fire now");
582        egraph.rebuild();
583        let apps = mul_to_shift.run(&mut egraph);
584        assert_eq!(apps, vec![egraph.find(mul)]);
585    }
586
587    #[test]
588    fn fn_rewrite() {
589        crate::init_logger();
590        let mut egraph = EGraph::default();
591
592        let start = RecExpr::from_str("(+ x y)").unwrap();
593        let goal = RecExpr::from_str("xy").unwrap();
594
595        let root = egraph.add_expr(&start);
596
597        fn get(egraph: &EGraph, id: Id) -> Symbol {
598            egraph[id].nodes[0].op
599        }
600
601        #[derive(Debug)]
602        struct Appender {
603            _rhs: PatternAst<S>,
604        }
605
606        impl Applier<SymbolLang, ()> for Appender {
607            fn apply_one(
608                &self,
609                egraph: &mut EGraph,
610                eclass: Id,
611                subst: &Subst,
612                searcher_ast: Option<&PatternAst<SymbolLang>>,
613                rule_name: Symbol,
614            ) -> Vec<Id> {
615                let a: Var = "?a".parse().unwrap();
616                let b: Var = "?b".parse().unwrap();
617                let a = get(egraph, subst[a]);
618                let b = get(egraph, subst[b]);
619                let s = format!("{}{}", a, b);
620                if let Some(ast) = searcher_ast {
621                    let (id, did_something) = egraph.union_instantiations(
622                        ast,
623                        &PatternAst::from_str(&s).unwrap(),
624                        subst,
625                        rule_name,
626                    );
627                    if did_something {
628                        vec![id]
629                    } else {
630                        vec![]
631                    }
632                } else {
633                    let added = egraph.add(S::leaf(&s));
634                    if egraph.union(added, eclass) {
635                        vec![eclass]
636                    } else {
637                        vec![]
638                    }
639                }
640            }
641        }
642
643        let fold_add = rewrite!(
644            "fold_add"; "(+ ?a ?b)" => { Appender { _rhs: "?a".parse().unwrap()}}
645        );
646
647        egraph.rebuild();
648        fold_add.run(&mut egraph);
649        assert_eq!(egraph.equivs(&start, &goal), vec![egraph.find(root)]);
650    }
651}