egg/
pattern.rs

1use fmt::Formatter;
2use log::*;
3use std::borrow::Cow;
4use std::convert::TryInto;
5use std::fmt::{self, Display};
6use std::{convert::TryFrom, str::FromStr};
7
8use thiserror::Error;
9
10use crate::*;
11
12/// A pattern that can function as either a [`Searcher`] or [`Applier`].
13///
14/// A [`Pattern`] is essentially a for-all quantified expression with
15/// [`Var`]s as the variables (in the logical sense).
16///
17/// When creating a [`Rewrite`], the most common thing to use as either
18/// the left hand side (the [`Searcher`]) or the right hand side
19/// (the [`Applier`]) is a [`Pattern`].
20///
21/// As a [`Searcher`], a [`Pattern`] does the intuitive
22/// thing.
23/// Here is a somewhat verbose formal-ish statement:
24/// Searching for a pattern in an egraph yields substitutions
25/// ([`Subst`]s) _s_ such that, for any _s'_—where instead of
26/// mapping a variables to an eclass as _s_ does, _s'_ maps
27/// a variable to an arbitrary expression represented by that
28/// eclass—_p[s']_ (the pattern under substitution _s'_) is also
29/// represented by the egraph.
30///
31/// As an [`Applier`], a [`Pattern`] performs the given substitution
32/// and adds the result to the [`EGraph`].
33///
34/// Importantly, [`Pattern`] implements [`FromStr`] if the
35/// [`Language`] does.
36/// This is probably how you'll create most [`Pattern`]s.
37///
38/// ```
39/// use egg::*;
40/// define_language! {
41///     enum Math {
42///         Num(i32),
43///         "+" = Add([Id; 2]),
44///     }
45/// }
46///
47/// let mut egraph = EGraph::<Math, ()>::default();
48/// let a11 = egraph.add_expr(&"(+ 1 1)".parse().unwrap());
49/// let a22 = egraph.add_expr(&"(+ 2 2)".parse().unwrap());
50///
51/// // use Var syntax (leading question mark) to get a
52/// // variable in the Pattern
53/// let same_add: Pattern<Math> = "(+ ?a ?a)".parse().unwrap();
54///
55/// // Rebuild before searching
56/// egraph.rebuild();
57///
58/// // This is the search method from the Searcher trait
59/// let matches = same_add.search(&egraph);
60/// let matched_eclasses: Vec<Id> = matches.iter().map(|m| m.eclass).collect();
61/// assert_eq!(matched_eclasses, vec![a22, a11]);
62/// ```
63///
64/// [`FromStr`]: std::str::FromStr
65#[derive(Debug, Clone, PartialEq, Eq)]
66pub struct Pattern<L> {
67    /// The actual pattern as a [`RecExpr`]
68    pub ast: PatternAst<L>,
69    program: machine::Program<L>,
70}
71
72/// A [`RecExpr`] that represents a
73/// [`Pattern`].
74pub type PatternAst<L> = RecExpr<ENodeOrVar<L>>;
75
76impl<L: Language> PatternAst<L> {
77    /// Returns a new `PatternAst` with the variables renames canonically
78    pub fn alpha_rename(&self) -> Self {
79        let mut vars = HashMap::<Var, Var>::default();
80        let mut new = PatternAst::default();
81
82        fn mkvar(i: usize) -> Var {
83            let vs = &["?x", "?y", "?z", "?w"];
84            match vs.get(i) {
85                Some(v) => v.parse().unwrap(),
86                None => format!("?v{}", i - vs.len()).parse().unwrap(),
87            }
88        }
89
90        for n in self {
91            new.add(match n {
92                ENodeOrVar::ENode(_) => n.clone(),
93                ENodeOrVar::Var(v) => {
94                    let i = vars.len();
95                    ENodeOrVar::Var(*vars.entry(*v).or_insert_with(|| mkvar(i)))
96                }
97            });
98        }
99
100        new
101    }
102}
103
104impl<L: Language> Pattern<L> {
105    /// Creates a new pattern from the given pattern ast.
106    pub fn new(ast: PatternAst<L>) -> Self {
107        let ast = ast.compact();
108        let program = machine::Program::compile_from_pat(&ast);
109        Pattern { ast, program }
110    }
111
112    /// Returns a list of the [`Var`]s in this pattern.
113    pub fn vars(&self) -> Vec<Var> {
114        let mut vars = vec![];
115        for n in &self.ast {
116            if let ENodeOrVar::Var(v) = n {
117                if !vars.contains(v) {
118                    vars.push(*v)
119                }
120            }
121        }
122        vars
123    }
124}
125
126impl<L: Language + Display> Pattern<L> {
127    /// Pretty print this pattern as a sexp with the given width
128    pub fn pretty(&self, width: usize) -> String {
129        self.ast.pretty(width)
130    }
131}
132
133/// The language of [`Pattern`]s.
134///
135#[derive(Debug, Hash, PartialEq, Eq, Clone, PartialOrd, Ord)]
136pub enum ENodeOrVar<L> {
137    /// An enode from the underlying [`Language`]
138    ENode(L),
139    /// A pattern variable
140    Var(Var),
141}
142
143/// The discriminant for the language of [`Pattern`]s.
144#[derive(Debug, Hash, PartialEq, Eq, Clone)]
145pub enum ENodeOrVarDiscriminant<L: Language> {
146    ENode(L::Discriminant),
147    Var(Var),
148}
149
150impl<L: Language> Language for ENodeOrVar<L> {
151    type Discriminant = ENodeOrVarDiscriminant<L>;
152
153    #[inline(always)]
154    fn discriminant(&self) -> Self::Discriminant {
155        match self {
156            ENodeOrVar::ENode(n) => ENodeOrVarDiscriminant::ENode(n.discriminant()),
157            ENodeOrVar::Var(v) => ENodeOrVarDiscriminant::Var(*v),
158        }
159    }
160
161    fn matches(&self, _other: &Self) -> bool {
162        panic!("Should never call this")
163    }
164
165    fn children(&self) -> &[Id] {
166        match self {
167            ENodeOrVar::ENode(n) => n.children(),
168            ENodeOrVar::Var(_) => &[],
169        }
170    }
171
172    fn children_mut(&mut self) -> &mut [Id] {
173        match self {
174            ENodeOrVar::ENode(n) => n.children_mut(),
175            ENodeOrVar::Var(_) => &mut [],
176        }
177    }
178}
179
180impl<L: Language + Display> Display for ENodeOrVar<L> {
181    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
182        match self {
183            Self::ENode(node) => Display::fmt(node, f),
184            Self::Var(var) => Display::fmt(var, f),
185        }
186    }
187}
188
189#[derive(Debug, Error)]
190pub enum ENodeOrVarParseError<E> {
191    #[error(transparent)]
192    BadVar(<Var as FromStr>::Err),
193
194    #[error("tried to parse pattern variable {0:?} as an operator")]
195    UnexpectedVar(String),
196
197    #[error(transparent)]
198    BadOp(E),
199}
200
201impl<L: FromOp> FromOp for ENodeOrVar<L> {
202    type Error = ENodeOrVarParseError<L::Error>;
203
204    fn from_op(op: &str, children: Vec<Id>) -> Result<Self, Self::Error> {
205        use ENodeOrVarParseError::*;
206
207        if op.starts_with('?') && op.len() > 1 {
208            if children.is_empty() {
209                op.parse().map(Self::Var).map_err(BadVar)
210            } else {
211                Err(UnexpectedVar(op.to_owned()))
212            }
213        } else {
214            L::from_op(op, children).map(Self::ENode).map_err(BadOp)
215        }
216    }
217}
218
219impl<L: FromOp> std::str::FromStr for Pattern<L> {
220    type Err = RecExprParseError<ENodeOrVarParseError<L::Error>>;
221
222    fn from_str(s: &str) -> Result<Self, Self::Err> {
223        PatternAst::from_str(s).map(Self::from)
224    }
225}
226
227impl<'a, L: Language> From<&'a [L]> for Pattern<L> {
228    fn from(expr: &'a [L]) -> Self {
229        let ast = expr.iter().cloned().map(ENodeOrVar::ENode).collect();
230        Self::new(ast)
231    }
232}
233
234impl<L: Language> From<RecExpr<L>> for Pattern<L> {
235    fn from(expr: RecExpr<L>) -> Self {
236        let ast = expr.into_iter().map(ENodeOrVar::ENode).collect();
237        Self::new(ast)
238    }
239}
240
241impl<L: Language> From<&RecExpr<L>> for Pattern<L> {
242    fn from(expr: &RecExpr<L>) -> Self {
243        Self::from(expr.as_ref())
244    }
245}
246
247impl<L: Language> From<PatternAst<L>> for Pattern<L> {
248    fn from(ast: PatternAst<L>) -> Self {
249        Self::new(ast)
250    }
251}
252
253impl<L: Language> TryFrom<PatternAst<L>> for RecExpr<L> {
254    type Error = Var;
255    fn try_from(ast: PatternAst<L>) -> Result<Self, Self::Error> {
256        ast.into_iter()
257            .map(|n| match n {
258                ENodeOrVar::ENode(n) => Ok(n),
259                ENodeOrVar::Var(v) => Err(v),
260            })
261            .collect()
262    }
263}
264
265impl<L: Language> TryFrom<Pattern<L>> for RecExpr<L> {
266    type Error = Var;
267    fn try_from(pat: Pattern<L>) -> Result<Self, Self::Error> {
268        pat.ast.try_into()
269    }
270}
271
272impl<L: Language + Display> Display for Pattern<L> {
273    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
274        Display::fmt(&self.ast, f)
275    }
276}
277
278/// The result of searching a [`Searcher`] over one eclass.
279///
280/// Note that one [`SearchMatches`] can contain many found
281/// substitutions. So taking the length of a list of [`SearchMatches`]
282/// tells you how many eclasses something was matched in, _not_ how
283/// many matches were found total.
284///
285#[derive(Debug)]
286pub struct SearchMatches<'a, L: Language> {
287    /// The eclass id that these matches were found in.
288    pub eclass: Id,
289    /// The substitutions for each match.
290    pub substs: Vec<Subst>,
291    /// Optionally, an ast for the matches used in proof production.
292    pub ast: Option<Cow<'a, PatternAst<L>>>,
293}
294
295impl<L: Language, A: Analysis<L>> Searcher<L, A> for Pattern<L> {
296    fn get_pattern_ast(&self) -> Option<&PatternAst<L>> {
297        Some(&self.ast)
298    }
299
300    fn search_with_limit(&self, egraph: &EGraph<L, A>, limit: usize) -> Vec<SearchMatches<L>> {
301        match self.ast.last().unwrap() {
302            ENodeOrVar::ENode(e) => {
303                let key = e.discriminant();
304                match egraph.classes_for_op(&key) {
305                    None => vec![],
306                    Some(ids) => rewrite::search_eclasses_with_limit(self, egraph, ids, limit),
307                }
308            }
309            ENodeOrVar::Var(_) => rewrite::search_eclasses_with_limit(
310                self,
311                egraph,
312                egraph.classes().map(|e| e.id),
313                limit,
314            ),
315        }
316    }
317
318    fn search_eclass_with_limit(
319        &self,
320        egraph: &EGraph<L, A>,
321        eclass: Id,
322        limit: usize,
323    ) -> Option<SearchMatches<L>> {
324        let substs = self.program.run_with_limit(egraph, eclass, limit);
325        if substs.is_empty() {
326            None
327        } else {
328            let ast = Some(Cow::Borrowed(&self.ast));
329            Some(SearchMatches {
330                eclass,
331                substs,
332                ast,
333            })
334        }
335    }
336
337    fn vars(&self) -> Vec<Var> {
338        Pattern::vars(self)
339    }
340}
341
342impl<L, A> Applier<L, A> for Pattern<L>
343where
344    L: Language,
345    A: Analysis<L>,
346{
347    fn get_pattern_ast(&self) -> Option<&PatternAst<L>> {
348        Some(&self.ast)
349    }
350
351    fn apply_matches(
352        &self,
353        egraph: &mut EGraph<L, A>,
354        matches: &[SearchMatches<L>],
355        rule_name: Symbol,
356    ) -> Vec<Id> {
357        let mut added = vec![];
358        let mut id_buf = vec![0.into(); self.ast.len()];
359        for mat in matches {
360            let sast = mat.ast.as_ref().map(|cow| cow.as_ref());
361            for subst in &mat.substs {
362                let did_something;
363                let id;
364                if egraph.are_explanations_enabled() {
365                    let (id_temp, did_something_temp) =
366                        egraph.union_instantiations(sast.unwrap(), &self.ast, subst, rule_name);
367                    did_something = did_something_temp;
368                    id = id_temp;
369                } else {
370                    id = apply_pat(&mut id_buf, &self.ast, egraph, subst);
371                    did_something = egraph.union(id, mat.eclass);
372                }
373
374                if did_something {
375                    added.push(id)
376                }
377            }
378        }
379        added
380    }
381
382    fn apply_one(
383        &self,
384        egraph: &mut EGraph<L, A>,
385        eclass: Id,
386        subst: &Subst,
387        searcher_ast: Option<&PatternAst<L>>,
388        rule_name: Symbol,
389    ) -> Vec<Id> {
390        let mut id_buf = vec![0.into(); self.ast.len()];
391        let id = apply_pat(&mut id_buf, &self.ast, egraph, subst);
392
393        if let Some(ast) = searcher_ast {
394            let (from, did_something) =
395                egraph.union_instantiations(ast, &self.ast, subst, rule_name);
396            if did_something {
397                vec![from]
398            } else {
399                vec![]
400            }
401        } else if egraph.union(eclass, id) {
402            vec![eclass]
403        } else {
404            vec![]
405        }
406    }
407
408    fn vars(&self) -> Vec<Var> {
409        Pattern::vars(self)
410    }
411}
412
413pub(crate) fn apply_pat<L: Language, A: Analysis<L>>(
414    ids: &mut [Id],
415    pat: &[ENodeOrVar<L>],
416    egraph: &mut EGraph<L, A>,
417    subst: &Subst,
418) -> Id {
419    debug_assert_eq!(pat.len(), ids.len());
420    trace!("apply_rec {:2?} {:?}", pat, subst);
421
422    for (i, pat_node) in pat.iter().enumerate() {
423        let id = match pat_node {
424            ENodeOrVar::Var(w) => subst[*w],
425            ENodeOrVar::ENode(e) => {
426                let n = e.clone().map_children(|child| ids[usize::from(child)]);
427                trace!("adding: {:?}", n);
428                egraph.add(n)
429            }
430        };
431        ids[i] = id;
432    }
433
434    *ids.last().unwrap()
435}
436
437#[cfg(test)]
438mod tests {
439
440    use crate::{SymbolLang as S, *};
441
442    type EGraph = crate::EGraph<S, ()>;
443
444    #[test]
445    fn simple_match() {
446        crate::init_logger();
447        let mut egraph = EGraph::default();
448
449        let (plus_id, _) = egraph.union_instantiations(
450            &"(+ x y)".parse().unwrap(),
451            &"(+ z w)".parse().unwrap(),
452            &Default::default(),
453            "union_plus".to_string(),
454        );
455        egraph.rebuild();
456
457        let commute_plus = rewrite!(
458            "commute_plus";
459            "(+ ?a ?b)" => "(+ ?b ?a)"
460        );
461
462        let matches = commute_plus.search(&egraph);
463        let n_matches: usize = matches.iter().map(|m| m.substs.len()).sum();
464        assert_eq!(n_matches, 2, "matches is wrong: {:#?}", matches);
465
466        let applications = commute_plus.apply(&mut egraph, &matches);
467        egraph.rebuild();
468        assert_eq!(applications.len(), 2);
469
470        let actual_substs: Vec<Subst> = matches.iter().flat_map(|m| m.substs.clone()).collect();
471
472        println!("Here are the substs!");
473        for m in &actual_substs {
474            println!("substs: {:?}", m);
475        }
476
477        egraph.dot().to_dot("target/simple-match.dot").unwrap();
478
479        use crate::extract::{AstSize, Extractor};
480
481        let ext = Extractor::new(&egraph, AstSize);
482        let (_, best) = ext.find_best(plus_id);
483        eprintln!("Best: {:#?}", best);
484    }
485
486    #[test]
487    fn nonlinear_patterns() {
488        crate::init_logger();
489        let mut egraph = EGraph::default();
490        egraph.add_expr(&"(f a a)".parse().unwrap());
491        egraph.add_expr(&"(f a (g a))))".parse().unwrap());
492        egraph.add_expr(&"(f a (g b))))".parse().unwrap());
493        egraph.add_expr(&"(h (foo a b) 0 1)".parse().unwrap());
494        egraph.add_expr(&"(h (foo a b) 1 0)".parse().unwrap());
495        egraph.add_expr(&"(h (foo a b) 0 0)".parse().unwrap());
496        egraph.rebuild();
497
498        let n_matches = |s: &str| s.parse::<Pattern<S>>().unwrap().n_matches(&egraph);
499
500        assert_eq!(n_matches("(f ?x ?y)"), 3);
501        assert_eq!(n_matches("(f ?x ?x)"), 1);
502        assert_eq!(n_matches("(f ?x (g ?y))))"), 2);
503        assert_eq!(n_matches("(f ?x (g ?x))))"), 1);
504        assert_eq!(n_matches("(h ?x 0 0)"), 1);
505    }
506
507    #[test]
508    fn search_with_limit() {
509        crate::init_logger();
510        let init_expr = &"(+ 1 (+ 2 (+ 3 (+ 4 (+ 5 6)))))".parse().unwrap();
511        let rules: Vec<Rewrite<_, ()>> = vec![
512            rewrite!("comm"; "(+ ?x ?y)" => "(+ ?y ?x)"),
513            rewrite!("assoc"; "(+ ?x (+ ?y ?z))" => "(+ (+ ?x ?y) ?z)"),
514        ];
515        let runner = Runner::default().with_expr(init_expr).run(&rules);
516        let egraph = &runner.egraph;
517
518        let len = |m: &Vec<SearchMatches<S>>| -> usize { m.iter().map(|m| m.substs.len()).sum() };
519
520        let pat = &"(+ ?x (+ ?y ?z))".parse::<Pattern<S>>().unwrap();
521        let m = pat.search(egraph);
522        let match_size = 2100;
523        assert_eq!(len(&m), match_size);
524
525        for limit in [1, 10, 100, 1000, 10000] {
526            let m = pat.search_with_limit(egraph, limit);
527            assert_eq!(len(&m), usize::min(limit, match_size));
528        }
529
530        let id = egraph.lookup_expr(init_expr).unwrap();
531        let m = pat.search_eclass(egraph, id).unwrap();
532        let match_size = 540;
533        assert_eq!(m.substs.len(), match_size);
534
535        for limit in [1, 10, 100, 1000] {
536            let m1 = pat.search_eclass_with_limit(egraph, id, limit).unwrap();
537            assert_eq!(m1.substs.len(), usize::min(limit, match_size));
538        }
539    }
540}