Skip to main content

libslide/evaluator_rules/
pattern_match.rs

1use crate::grammar::*;
2use crate::utils::hash;
3
4use std::collections::HashMap;
5use std::rc::Rc;
6
7/// Represents pattern-matched replacements betwen a rule and a target expression.
8///
9/// The rhs of a rule may be transfomed with an instance of `PatternMatch` to obtain the result of a
10/// rule applied on a target expression.
11pub struct PatternMatch<E: Expression> {
12    map: HashMap<
13        u64,   // pointer to rule pattern, like #a
14        Rc<E>, // target expr,             like 10
15    >,
16}
17
18impl<E: Expression> Default for PatternMatch<E> {
19    fn default() -> Self {
20        Self {
21            map: HashMap::new(),
22        }
23    }
24}
25
26pub trait MatchRule<E: Expression> {
27    /// Pattern matches a rule template against an expression. If successful, the results of the
28    /// matching are returned as a mapping of rule to target expressions replacements.
29    ///
30    /// A sucessful pattern matching is one that matches the target expression wholly, abiding the
31    /// expression pattern matching rules.
32    fn match_rule(rule: Rc<ExprPat>, target: Rc<E>) -> Option<PatternMatch<E>>;
33}
34
35impl MatchRule<Expr> for PatternMatch<Expr> {
36    fn match_rule(rule: Rc<ExprPat>, target: Rc<Expr>) -> Option<PatternMatch<Expr>> {
37        match (rule.as_ref(), target.as_ref()) {
38            // The happiest path -- if a pattern matches an expression, return replacements for it!
39            (ExprPat::VarPat(_), Expr::Var(_))
40            | (ExprPat::ConstPat(_), Expr::Const(_))
41            | (ExprPat::AnyPat(_), _) => {
42                let mut replacements = PatternMatch::default();
43                replacements.insert(&rule, target);
44                Some(replacements)
45            }
46            (ExprPat::Const(a), Expr::Const(b)) => {
47                if (a - b).abs() > std::f64::EPSILON {
48                    // Constants don't match; rule can't be applied.
49                    return None;
50                }
51                // Constant values are... constant, so there is no need to replace them.
52                Some(PatternMatch::default())
53            }
54            (ExprPat::BinaryExpr(rule), Expr::BinaryExpr(expr)) => {
55                if rule.op != expr.op {
56                    return None;
57                }
58                // Expressions are of the same type; match the rest of the expression by recursing on
59                // the arguments.
60                let replacements_lhs =
61                    Self::match_rule(Rc::clone(&rule.lhs), Rc::clone(&expr.lhs))?;
62                let replacements_rhs =
63                    Self::match_rule(Rc::clone(&rule.rhs), Rc::clone(&expr.rhs))?;
64                PatternMatch::try_merge(replacements_lhs, replacements_rhs)
65            }
66            (ExprPat::UnaryExpr(rule), Expr::UnaryExpr(expr)) => {
67                if rule.op != expr.op {
68                    return None;
69                }
70                // Expressions are of the same type; match the rest of the expression by recursing on
71                // the argument.
72                Self::match_rule(Rc::clone(&rule.rhs), Rc::clone(&expr.rhs))
73            }
74            (ExprPat::Parend(rule), Expr::Parend(expr)) => {
75                Self::match_rule(Rc::clone(rule), Rc::clone(expr))
76            }
77            (ExprPat::Bracketed(rule), Expr::Bracketed(expr)) => {
78                Self::match_rule(Rc::clone(rule), Rc::clone(expr))
79            }
80            _ => None,
81        }
82    }
83}
84
85impl MatchRule<ExprPat> for PatternMatch<ExprPat> {
86    fn match_rule(rule: Rc<ExprPat>, target: Rc<ExprPat>) -> Option<PatternMatch<ExprPat>> {
87        match (rule.as_ref(), target.as_ref()) {
88            (ExprPat::VarPat(_), ExprPat::VarPat(_))
89            | (ExprPat::ConstPat(_), ExprPat::ConstPat(_))
90            | (ExprPat::AnyPat(_), _) => {
91                let mut replacements = PatternMatch::default();
92                replacements.insert(&rule, target);
93                Some(replacements)
94            }
95            (ExprPat::Const(a), ExprPat::Const(b)) => {
96                if (a - b).abs() > std::f64::EPSILON {
97                    return None;
98                }
99                Some(PatternMatch::default())
100            }
101            (ExprPat::BinaryExpr(rule), ExprPat::BinaryExpr(expr)) => {
102                if rule.op != expr.op {
103                    return None;
104                }
105                let replacements_lhs =
106                    Self::match_rule(Rc::clone(&rule.lhs), Rc::clone(&expr.lhs))?;
107                let replacements_rhs =
108                    Self::match_rule(Rc::clone(&rule.rhs), Rc::clone(&expr.rhs))?;
109                PatternMatch::try_merge(replacements_lhs, replacements_rhs)
110            }
111            (ExprPat::UnaryExpr(rule), ExprPat::UnaryExpr(expr)) => {
112                if rule.op != expr.op {
113                    return None;
114                }
115                Self::match_rule(Rc::clone(&rule.rhs), Rc::clone(&expr.rhs))
116            }
117            (ExprPat::Parend(rule), ExprPat::Parend(expr)) => {
118                Self::match_rule(Rc::clone(rule), Rc::clone(expr))
119            }
120            (ExprPat::Bracketed(rule), ExprPat::Bracketed(expr)) => {
121                Self::match_rule(Rc::clone(rule), Rc::clone(expr))
122            }
123            _ => None,
124        }
125    }
126}
127
128impl Transformer<Rc<ExprPat>, Rc<Expr>> for PatternMatch<Expr> {
129    /// Transforms a pattern expression into an expression by replacing patterns with target
130    /// expressions known by the [`PatternMatch`].
131    ///
132    /// This transformation can be used to apply a rule on an expression by transforming the RHS
133    /// using patterns matched between the LHS of the rule and the target expression.
134    ///
135    /// [`PatternMatch`]: PatternMatch
136    fn transform(&self, item: Rc<ExprPat>) -> Rc<Expr> {
137        fn transform(
138            repls: &PatternMatch<Expr>,
139            item: Rc<ExprPat>,
140            cache: &mut HashMap<u64, Rc<Expr>>,
141        ) -> Rc<Expr> {
142            if let Some(result) = cache.get(&hash(item.as_ref())) {
143                return Rc::clone(result);
144            }
145
146            let transformed = match item.as_ref() {
147                ExprPat::VarPat(_) | ExprPat::ConstPat(_) | ExprPat::AnyPat(_) => {
148                    match repls.map.get(&hash(&item)) {
149                        Some(transformed) => Rc::clone(transformed),
150
151                        // A pattern can only be transformed into an expression if it has an
152                        // expression replacement. Patterns are be validated before transformation,
153                        // so this branch should never be hit.
154                        None => unreachable!(),
155                    }
156                }
157
158                ExprPat::Const(f) => Expr::Const(*f).into(),
159                ExprPat::BinaryExpr(binary_expr) => Expr::BinaryExpr(BinaryExpr {
160                    op: binary_expr.op,
161                    lhs: transform(repls, Rc::clone(&binary_expr.lhs), cache),
162                    rhs: transform(repls, Rc::clone(&binary_expr.rhs), cache),
163                })
164                .into(),
165                ExprPat::UnaryExpr(unary_expr) => Expr::UnaryExpr(UnaryExpr {
166                    op: unary_expr.op,
167                    rhs: transform(repls, Rc::clone(&unary_expr.rhs), cache),
168                })
169                .into(),
170                ExprPat::Parend(expr) => {
171                    let inner = transform(repls, Rc::clone(expr), cache);
172                    Expr::Parend(inner).into()
173                }
174                ExprPat::Bracketed(expr) => {
175                    let inner = transform(repls, Rc::clone(expr), cache);
176                    Expr::Bracketed(inner).into()
177                }
178            };
179
180            let result = cache
181                .entry(hash(item.as_ref()))
182                .or_insert_with(|| transformed);
183            Rc::clone(result)
184        }
185
186        // Expr pointer -> transformed expression. Assumes that transient expressions of the same
187        // value are reference counters pointing to the same underlying expression. This is done
188        // via common subexpression elimination during parsing.
189        let mut cache = HashMap::new();
190        transform(self, item, &mut cache)
191    }
192}
193
194impl Transformer<Rc<ExprPat>, Rc<ExprPat>> for PatternMatch<ExprPat> {
195    fn transform(&self, item: Rc<ExprPat>) -> Rc<ExprPat> {
196        fn transform(
197            repls: &PatternMatch<ExprPat>,
198            item: Rc<ExprPat>,
199            cache: &mut HashMap<u64, Rc<ExprPat>>,
200        ) -> Rc<ExprPat> {
201            if let Some(result) = cache.get(&hash(item.as_ref())) {
202                return Rc::clone(result);
203            }
204
205            let transformed = match item.as_ref() {
206                ExprPat::VarPat(_) | ExprPat::ConstPat(_) | ExprPat::AnyPat(_) => {
207                    match repls.map.get(&hash(&item)) {
208                        Some(transformed) => Rc::clone(transformed),
209                        None => unreachable!(),
210                    }
211                }
212
213                ExprPat::Const(f) => ExprPat::Const(*f).into(),
214                ExprPat::BinaryExpr(binary_expr) => ExprPat::BinaryExpr(BinaryExpr {
215                    op: binary_expr.op,
216                    lhs: transform(repls, Rc::clone(&binary_expr.lhs), cache),
217                    rhs: transform(repls, Rc::clone(&binary_expr.rhs), cache),
218                })
219                .into(),
220                ExprPat::UnaryExpr(unary_expr) => ExprPat::UnaryExpr(UnaryExpr {
221                    op: unary_expr.op,
222                    rhs: transform(repls, Rc::clone(&unary_expr.rhs), cache),
223                })
224                .into(),
225                ExprPat::Parend(expr) => {
226                    let inner = transform(repls, Rc::clone(expr), cache);
227                    ExprPat::Parend(inner).into()
228                }
229                ExprPat::Bracketed(expr) => {
230                    let inner = transform(repls, Rc::clone(expr), cache);
231                    ExprPat::Bracketed(inner).into()
232                }
233            };
234
235            let result = cache
236                .entry(hash(item.as_ref()))
237                .or_insert_with(|| transformed);
238            Rc::clone(result)
239        }
240
241        // ExprPat pointer -> transformed expression. Assumes that transient expressions of the same
242        // value are reference counters pointing to the same underlying expression. This is done
243        // via common subexpression elimination during parsing.
244        let mut cache = HashMap::new();
245        transform(self, item, &mut cache)
246    }
247}
248
249impl<E: Expression + Eq> PatternMatch<E> {
250    /// Merges two `PatternMatch`. If the `PatternMatch` are of incompatible state (i.e. contain
251    /// different mappings), merging fails and nothing is returned.
252    fn try_merge(left: PatternMatch<E>, right: PatternMatch<E>) -> Option<PatternMatch<E>> {
253        let mut replacements = left;
254        for (from, to_r) in right.map.into_iter() {
255            if let Some(to_l) = replacements.map.get(&from) {
256                if to_r != *to_l {
257                    // Replacement already exists and its value does not match exactly; bail out.
258                    return None;
259                }
260                continue; // no need to insert replacement again
261            }
262            // Replacement is new, add it.
263            replacements.map.insert(from, to_r);
264        }
265        Some(replacements)
266    }
267
268    fn insert(&mut self, k: &Rc<ExprPat>, v: Rc<E>) -> Option<Rc<E>> {
269        self.map.insert(hash(k.as_ref()), v)
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::{parse_expression, parse_expression_pattern, scan};
277
278    fn parse_rule(prog: &str) -> ExprPat {
279        let (expr, _) = parse_expression_pattern(scan(prog).tokens);
280        expr.as_ref().clone()
281    }
282
283    fn parse_expr(prog: &str) -> Expr {
284        match parse_expression(scan(prog).tokens) {
285            (Stmt::Expr(expr), _) => expr,
286            _ => unreachable!(),
287        }
288    }
289
290    mod replacements {
291        use super::*;
292
293        #[test]
294        fn try_merge() {
295            let a = Rc::new(ExprPat::VarPat("a".into()));
296            let b = Rc::new(ExprPat::VarPat("b".into()));
297            let c = Rc::new(ExprPat::VarPat("c".into()));
298
299            let mut left = PatternMatch::default();
300            left.insert(&a, Expr::Const(1.).into());
301            left.insert(&b, Expr::Const(2.).into());
302
303            let mut right = PatternMatch::default();
304            right.insert(&b, Expr::Const(2.).into());
305            right.insert(&c, Expr::Const(3.).into());
306
307            let merged = PatternMatch::try_merge(left, right).unwrap();
308            assert_eq!(merged.map.len(), 3);
309            assert_eq!(merged.map.get(&hash(&a)).unwrap().to_string(), "1");
310            assert_eq!(merged.map.get(&hash(&b)).unwrap().to_string(), "2");
311            assert_eq!(merged.map.get(&hash(&c)).unwrap().to_string(), "3");
312        }
313
314        #[test]
315        fn try_merge_overlapping_non_matching() {
316            let a = Rc::new(ExprPat::VarPat("a".into()));
317
318            let mut left = PatternMatch::default();
319            left.insert(&a, Expr::Const(1.).into());
320
321            let mut right = PatternMatch::default();
322            right.insert(&a, Expr::Const(2.).into());
323
324            let merged = PatternMatch::try_merge(left, right);
325            assert!(merged.is_none());
326        }
327
328        #[test]
329        fn transform_common_subexpression_elimination() {
330            let parsed_rule = Rc::new(parse_rule("#a * _b + #a * _b"));
331            let parsed_target = Rc::new(parse_expr("0 * 0 + 0 * 0"));
332
333            let repls =
334                PatternMatch::match_rule(Rc::clone(&parsed_rule), Rc::clone(&parsed_target))
335                    .unwrap();
336            let transformed = repls.transform(Rc::clone(&parsed_rule));
337            let (l, r) = match transformed.as_ref() {
338                Expr::BinaryExpr(BinaryExpr { lhs, rhs, .. }) => (lhs, rhs),
339                _ => unreachable!(),
340            };
341            assert!(std::ptr::eq(l.as_ref(), r.as_ref())); // #a * _b
342
343            let (ll, lr, rl, rr) = match (l.as_ref(), r.as_ref()) {
344                (
345                    Expr::BinaryExpr(BinaryExpr {
346                        lhs: ll, rhs: lr, ..
347                    }),
348                    Expr::BinaryExpr(BinaryExpr {
349                        lhs: rl, rhs: rr, ..
350                    }),
351                ) => (ll, lr, rl, rr),
352                _ => unreachable!(),
353            };
354            assert!(std::ptr::eq(ll.as_ref(), lr.as_ref())); // check 0s
355            assert!(std::ptr::eq(lr.as_ref(), rl.as_ref()));
356            assert!(std::ptr::eq(rl.as_ref(), rr.as_ref()));
357        }
358    }
359
360    mod match_rule {
361        use super::*;
362
363        macro_rules! match_rule_tests {
364            ($($name:ident: $rule:expr => $target:expr => $expected_repls:expr)*) => {
365            $(
366                #[test]
367                fn $name() {
368                    let parsed_rule = parse_rule($rule);
369                    let parsed_target = parse_expr($target);
370
371                    let repls = PatternMatch::match_rule(parsed_rule.into(), parsed_target.into());
372                    let (repls, expected_repls): (PatternMatch<Expr>, Vec<&str>) =
373                        match (repls, $expected_repls) {
374                            (None, expected_matches) => {
375                                assert!(expected_matches.is_none());
376                                return;
377                            }
378                            (Some(repl), expected_matches) => {
379                                assert!(expected_matches.is_some());
380                                (repl, expected_matches.unwrap())
381                            }
382                        };
383
384                    let expected_repls = expected_repls
385                        .into_iter()
386                        .map(|m| m.split(": "))
387                        .map(|mut i| (i.next().unwrap(), i.next().unwrap()))
388                        .map(|(r, t)| (parse_rule(r), parse_expr(t)));
389
390                    assert_eq!(expected_repls.len(), repls.map.len());
391
392                    for (expected_pattern, expected_repl) in expected_repls {
393                        assert_eq!(
394                            expected_repl.to_string(),
395                            repls.map.get(&hash(&expected_pattern)).unwrap().to_string()
396                        );
397                    }
398                }
399            )*
400            }
401        }
402
403        match_rule_tests! {
404            consts:                     "0" => "0" => Some(vec![])
405            consts_unmatched:           "0" => "1" => None
406
407            variable_pattern:           "$a" => "x"     => Some(vec!["$a: x"])
408            variable_pattern_on_const:  "$a" => "0"     => None
409            variable_pattern_on_binary: "$a" => "x + 0" => None
410            variable_pattern_on_unary:  "$a" => "+x"    => None
411
412            const_pattern:              "#a" => "1"     => Some(vec!["#a: 1"])
413            const_pattern_on_var:       "#a" => "x"     => None
414            const_pattern_on_binary:    "#a" => "1 + x" => None
415            const_pattern_on_unary:     "#a" => "+1"    => None
416
417            any_pattern_on_variable:    "_a" => "x"     => Some(vec!["_a: x"])
418            any_pattern_on_const:       "_a" => "1"     => Some(vec!["_a: 1"])
419            any_pattern_on_binary:      "_a" => "1 + x" => Some(vec!["_a: 1 + x"])
420            any_pattern_on_unary:       "_a" => "+(2)"  => Some(vec!["_a: +(2)"])
421
422            binary_pattern:             "$a + #b" => "x + 0" => Some(vec!["$a: x", "#b: 0"])
423            binary_pattern_wrong_op:    "$a + #b" => "x - 0" => None
424            binary_pattern_partial:     "$a + #b" => "x + y" => None
425
426            unary_pattern:              "+$a" => "+x" => Some(vec!["$a: x"])
427            unary_pattern_wrong_op:     "+$a" => "-x" => None
428            unary_pattern_partial:      "+$a" => "+1" => None
429
430            parend:                     "($a + #b)" => "(x + 0)" => Some(vec!["$a: x", "#b: 0"])
431            parend_on_bracketed:           "($a + #b)" => "[x + 0]" => None
432
433            bracketed:                     "[$a + #b]" => "[x + 0]" => Some(vec!["$a: x", "#b: 0"])
434            bracketed_on_parend:           "[$a + #b]" => "(x + 0)" => None
435        }
436
437        #[test]
438        fn common_subexpression_elimination() {
439            let parsed_rule = parse_rule("#a * _b + _c * #d");
440            let parsed_target = parse_expr("0 * 0 + 0 * 0");
441            let l = match &parsed_target {
442                Expr::BinaryExpr(BinaryExpr { lhs, .. }) => Rc::clone(lhs),
443                _ => unreachable!(),
444            };
445            let ll = match l.as_ref() {
446                Expr::BinaryExpr(BinaryExpr { lhs, .. }) => lhs,
447                _ => unreachable!(),
448            };
449
450            let repls =
451                PatternMatch::match_rule(Rc::new(parsed_rule), Rc::new(parsed_target)).unwrap();
452            let zeros = repls.map.values().collect::<Vec<_>>();
453            assert!(std::ptr::eq(ll.as_ref(), zeros[0].as_ref()));
454            assert!(std::ptr::eq(zeros[0].as_ref(), zeros[1].as_ref()));
455            assert!(std::ptr::eq(zeros[1].as_ref(), zeros[2].as_ref()));
456            assert!(std::ptr::eq(zeros[2].as_ref(), zeros[3].as_ref()));
457        }
458    }
459}