1use crate::grammar::*;
2use crate::utils::hash;
3
4use std::collections::HashMap;
5use std::rc::Rc;
6
7pub struct PatternMatch<E: Expression> {
12 map: HashMap<
13 u64, Rc<E>, >,
16}
17
18impl<E: Expression> Default for PatternMatch<E> {
19 fn default() -> Self {
20 Self {
21 map: HashMap::new(),
22 }
23 }
24}
25
26pub trait MatchRule<E: Expression> {
27 fn match_rule(rule: Rc<ExprPat>, target: Rc<E>) -> Option<PatternMatch<E>>;
33}
34
35impl MatchRule<Expr> for PatternMatch<Expr> {
36 fn match_rule(rule: Rc<ExprPat>, target: Rc<Expr>) -> Option<PatternMatch<Expr>> {
37 match (rule.as_ref(), target.as_ref()) {
38 (ExprPat::VarPat(_), Expr::Var(_))
40 | (ExprPat::ConstPat(_), Expr::Const(_))
41 | (ExprPat::AnyPat(_), _) => {
42 let mut replacements = PatternMatch::default();
43 replacements.insert(&rule, target);
44 Some(replacements)
45 }
46 (ExprPat::Const(a), Expr::Const(b)) => {
47 if (a - b).abs() > std::f64::EPSILON {
48 return None;
50 }
51 Some(PatternMatch::default())
53 }
54 (ExprPat::BinaryExpr(rule), Expr::BinaryExpr(expr)) => {
55 if rule.op != expr.op {
56 return None;
57 }
58 let replacements_lhs =
61 Self::match_rule(Rc::clone(&rule.lhs), Rc::clone(&expr.lhs))?;
62 let replacements_rhs =
63 Self::match_rule(Rc::clone(&rule.rhs), Rc::clone(&expr.rhs))?;
64 PatternMatch::try_merge(replacements_lhs, replacements_rhs)
65 }
66 (ExprPat::UnaryExpr(rule), Expr::UnaryExpr(expr)) => {
67 if rule.op != expr.op {
68 return None;
69 }
70 Self::match_rule(Rc::clone(&rule.rhs), Rc::clone(&expr.rhs))
73 }
74 (ExprPat::Parend(rule), Expr::Parend(expr)) => {
75 Self::match_rule(Rc::clone(rule), Rc::clone(expr))
76 }
77 (ExprPat::Bracketed(rule), Expr::Bracketed(expr)) => {
78 Self::match_rule(Rc::clone(rule), Rc::clone(expr))
79 }
80 _ => None,
81 }
82 }
83}
84
85impl MatchRule<ExprPat> for PatternMatch<ExprPat> {
86 fn match_rule(rule: Rc<ExprPat>, target: Rc<ExprPat>) -> Option<PatternMatch<ExprPat>> {
87 match (rule.as_ref(), target.as_ref()) {
88 (ExprPat::VarPat(_), ExprPat::VarPat(_))
89 | (ExprPat::ConstPat(_), ExprPat::ConstPat(_))
90 | (ExprPat::AnyPat(_), _) => {
91 let mut replacements = PatternMatch::default();
92 replacements.insert(&rule, target);
93 Some(replacements)
94 }
95 (ExprPat::Const(a), ExprPat::Const(b)) => {
96 if (a - b).abs() > std::f64::EPSILON {
97 return None;
98 }
99 Some(PatternMatch::default())
100 }
101 (ExprPat::BinaryExpr(rule), ExprPat::BinaryExpr(expr)) => {
102 if rule.op != expr.op {
103 return None;
104 }
105 let replacements_lhs =
106 Self::match_rule(Rc::clone(&rule.lhs), Rc::clone(&expr.lhs))?;
107 let replacements_rhs =
108 Self::match_rule(Rc::clone(&rule.rhs), Rc::clone(&expr.rhs))?;
109 PatternMatch::try_merge(replacements_lhs, replacements_rhs)
110 }
111 (ExprPat::UnaryExpr(rule), ExprPat::UnaryExpr(expr)) => {
112 if rule.op != expr.op {
113 return None;
114 }
115 Self::match_rule(Rc::clone(&rule.rhs), Rc::clone(&expr.rhs))
116 }
117 (ExprPat::Parend(rule), ExprPat::Parend(expr)) => {
118 Self::match_rule(Rc::clone(rule), Rc::clone(expr))
119 }
120 (ExprPat::Bracketed(rule), ExprPat::Bracketed(expr)) => {
121 Self::match_rule(Rc::clone(rule), Rc::clone(expr))
122 }
123 _ => None,
124 }
125 }
126}
127
128impl Transformer<Rc<ExprPat>, Rc<Expr>> for PatternMatch<Expr> {
129 fn transform(&self, item: Rc<ExprPat>) -> Rc<Expr> {
137 fn transform(
138 repls: &PatternMatch<Expr>,
139 item: Rc<ExprPat>,
140 cache: &mut HashMap<u64, Rc<Expr>>,
141 ) -> Rc<Expr> {
142 if let Some(result) = cache.get(&hash(item.as_ref())) {
143 return Rc::clone(result);
144 }
145
146 let transformed = match item.as_ref() {
147 ExprPat::VarPat(_) | ExprPat::ConstPat(_) | ExprPat::AnyPat(_) => {
148 match repls.map.get(&hash(&item)) {
149 Some(transformed) => Rc::clone(transformed),
150
151 None => unreachable!(),
155 }
156 }
157
158 ExprPat::Const(f) => Expr::Const(*f).into(),
159 ExprPat::BinaryExpr(binary_expr) => Expr::BinaryExpr(BinaryExpr {
160 op: binary_expr.op,
161 lhs: transform(repls, Rc::clone(&binary_expr.lhs), cache),
162 rhs: transform(repls, Rc::clone(&binary_expr.rhs), cache),
163 })
164 .into(),
165 ExprPat::UnaryExpr(unary_expr) => Expr::UnaryExpr(UnaryExpr {
166 op: unary_expr.op,
167 rhs: transform(repls, Rc::clone(&unary_expr.rhs), cache),
168 })
169 .into(),
170 ExprPat::Parend(expr) => {
171 let inner = transform(repls, Rc::clone(expr), cache);
172 Expr::Parend(inner).into()
173 }
174 ExprPat::Bracketed(expr) => {
175 let inner = transform(repls, Rc::clone(expr), cache);
176 Expr::Bracketed(inner).into()
177 }
178 };
179
180 let result = cache
181 .entry(hash(item.as_ref()))
182 .or_insert_with(|| transformed);
183 Rc::clone(result)
184 }
185
186 let mut cache = HashMap::new();
190 transform(self, item, &mut cache)
191 }
192}
193
194impl Transformer<Rc<ExprPat>, Rc<ExprPat>> for PatternMatch<ExprPat> {
195 fn transform(&self, item: Rc<ExprPat>) -> Rc<ExprPat> {
196 fn transform(
197 repls: &PatternMatch<ExprPat>,
198 item: Rc<ExprPat>,
199 cache: &mut HashMap<u64, Rc<ExprPat>>,
200 ) -> Rc<ExprPat> {
201 if let Some(result) = cache.get(&hash(item.as_ref())) {
202 return Rc::clone(result);
203 }
204
205 let transformed = match item.as_ref() {
206 ExprPat::VarPat(_) | ExprPat::ConstPat(_) | ExprPat::AnyPat(_) => {
207 match repls.map.get(&hash(&item)) {
208 Some(transformed) => Rc::clone(transformed),
209 None => unreachable!(),
210 }
211 }
212
213 ExprPat::Const(f) => ExprPat::Const(*f).into(),
214 ExprPat::BinaryExpr(binary_expr) => ExprPat::BinaryExpr(BinaryExpr {
215 op: binary_expr.op,
216 lhs: transform(repls, Rc::clone(&binary_expr.lhs), cache),
217 rhs: transform(repls, Rc::clone(&binary_expr.rhs), cache),
218 })
219 .into(),
220 ExprPat::UnaryExpr(unary_expr) => ExprPat::UnaryExpr(UnaryExpr {
221 op: unary_expr.op,
222 rhs: transform(repls, Rc::clone(&unary_expr.rhs), cache),
223 })
224 .into(),
225 ExprPat::Parend(expr) => {
226 let inner = transform(repls, Rc::clone(expr), cache);
227 ExprPat::Parend(inner).into()
228 }
229 ExprPat::Bracketed(expr) => {
230 let inner = transform(repls, Rc::clone(expr), cache);
231 ExprPat::Bracketed(inner).into()
232 }
233 };
234
235 let result = cache
236 .entry(hash(item.as_ref()))
237 .or_insert_with(|| transformed);
238 Rc::clone(result)
239 }
240
241 let mut cache = HashMap::new();
245 transform(self, item, &mut cache)
246 }
247}
248
249impl<E: Expression + Eq> PatternMatch<E> {
250 fn try_merge(left: PatternMatch<E>, right: PatternMatch<E>) -> Option<PatternMatch<E>> {
253 let mut replacements = left;
254 for (from, to_r) in right.map.into_iter() {
255 if let Some(to_l) = replacements.map.get(&from) {
256 if to_r != *to_l {
257 return None;
259 }
260 continue; }
262 replacements.map.insert(from, to_r);
264 }
265 Some(replacements)
266 }
267
268 fn insert(&mut self, k: &Rc<ExprPat>, v: Rc<E>) -> Option<Rc<E>> {
269 self.map.insert(hash(k.as_ref()), v)
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::{parse_expression, parse_expression_pattern, scan};
277
278 fn parse_rule(prog: &str) -> ExprPat {
279 let (expr, _) = parse_expression_pattern(scan(prog).tokens);
280 expr.as_ref().clone()
281 }
282
283 fn parse_expr(prog: &str) -> Expr {
284 match parse_expression(scan(prog).tokens) {
285 (Stmt::Expr(expr), _) => expr,
286 _ => unreachable!(),
287 }
288 }
289
290 mod replacements {
291 use super::*;
292
293 #[test]
294 fn try_merge() {
295 let a = Rc::new(ExprPat::VarPat("a".into()));
296 let b = Rc::new(ExprPat::VarPat("b".into()));
297 let c = Rc::new(ExprPat::VarPat("c".into()));
298
299 let mut left = PatternMatch::default();
300 left.insert(&a, Expr::Const(1.).into());
301 left.insert(&b, Expr::Const(2.).into());
302
303 let mut right = PatternMatch::default();
304 right.insert(&b, Expr::Const(2.).into());
305 right.insert(&c, Expr::Const(3.).into());
306
307 let merged = PatternMatch::try_merge(left, right).unwrap();
308 assert_eq!(merged.map.len(), 3);
309 assert_eq!(merged.map.get(&hash(&a)).unwrap().to_string(), "1");
310 assert_eq!(merged.map.get(&hash(&b)).unwrap().to_string(), "2");
311 assert_eq!(merged.map.get(&hash(&c)).unwrap().to_string(), "3");
312 }
313
314 #[test]
315 fn try_merge_overlapping_non_matching() {
316 let a = Rc::new(ExprPat::VarPat("a".into()));
317
318 let mut left = PatternMatch::default();
319 left.insert(&a, Expr::Const(1.).into());
320
321 let mut right = PatternMatch::default();
322 right.insert(&a, Expr::Const(2.).into());
323
324 let merged = PatternMatch::try_merge(left, right);
325 assert!(merged.is_none());
326 }
327
328 #[test]
329 fn transform_common_subexpression_elimination() {
330 let parsed_rule = Rc::new(parse_rule("#a * _b + #a * _b"));
331 let parsed_target = Rc::new(parse_expr("0 * 0 + 0 * 0"));
332
333 let repls =
334 PatternMatch::match_rule(Rc::clone(&parsed_rule), Rc::clone(&parsed_target))
335 .unwrap();
336 let transformed = repls.transform(Rc::clone(&parsed_rule));
337 let (l, r) = match transformed.as_ref() {
338 Expr::BinaryExpr(BinaryExpr { lhs, rhs, .. }) => (lhs, rhs),
339 _ => unreachable!(),
340 };
341 assert!(std::ptr::eq(l.as_ref(), r.as_ref())); let (ll, lr, rl, rr) = match (l.as_ref(), r.as_ref()) {
344 (
345 Expr::BinaryExpr(BinaryExpr {
346 lhs: ll, rhs: lr, ..
347 }),
348 Expr::BinaryExpr(BinaryExpr {
349 lhs: rl, rhs: rr, ..
350 }),
351 ) => (ll, lr, rl, rr),
352 _ => unreachable!(),
353 };
354 assert!(std::ptr::eq(ll.as_ref(), lr.as_ref())); assert!(std::ptr::eq(lr.as_ref(), rl.as_ref()));
356 assert!(std::ptr::eq(rl.as_ref(), rr.as_ref()));
357 }
358 }
359
360 mod match_rule {
361 use super::*;
362
363 macro_rules! match_rule_tests {
364 ($($name:ident: $rule:expr => $target:expr => $expected_repls:expr)*) => {
365 $(
366 #[test]
367 fn $name() {
368 let parsed_rule = parse_rule($rule);
369 let parsed_target = parse_expr($target);
370
371 let repls = PatternMatch::match_rule(parsed_rule.into(), parsed_target.into());
372 let (repls, expected_repls): (PatternMatch<Expr>, Vec<&str>) =
373 match (repls, $expected_repls) {
374 (None, expected_matches) => {
375 assert!(expected_matches.is_none());
376 return;
377 }
378 (Some(repl), expected_matches) => {
379 assert!(expected_matches.is_some());
380 (repl, expected_matches.unwrap())
381 }
382 };
383
384 let expected_repls = expected_repls
385 .into_iter()
386 .map(|m| m.split(": "))
387 .map(|mut i| (i.next().unwrap(), i.next().unwrap()))
388 .map(|(r, t)| (parse_rule(r), parse_expr(t)));
389
390 assert_eq!(expected_repls.len(), repls.map.len());
391
392 for (expected_pattern, expected_repl) in expected_repls {
393 assert_eq!(
394 expected_repl.to_string(),
395 repls.map.get(&hash(&expected_pattern)).unwrap().to_string()
396 );
397 }
398 }
399 )*
400 }
401 }
402
403 match_rule_tests! {
404 consts: "0" => "0" => Some(vec![])
405 consts_unmatched: "0" => "1" => None
406
407 variable_pattern: "$a" => "x" => Some(vec!["$a: x"])
408 variable_pattern_on_const: "$a" => "0" => None
409 variable_pattern_on_binary: "$a" => "x + 0" => None
410 variable_pattern_on_unary: "$a" => "+x" => None
411
412 const_pattern: "#a" => "1" => Some(vec!["#a: 1"])
413 const_pattern_on_var: "#a" => "x" => None
414 const_pattern_on_binary: "#a" => "1 + x" => None
415 const_pattern_on_unary: "#a" => "+1" => None
416
417 any_pattern_on_variable: "_a" => "x" => Some(vec!["_a: x"])
418 any_pattern_on_const: "_a" => "1" => Some(vec!["_a: 1"])
419 any_pattern_on_binary: "_a" => "1 + x" => Some(vec!["_a: 1 + x"])
420 any_pattern_on_unary: "_a" => "+(2)" => Some(vec!["_a: +(2)"])
421
422 binary_pattern: "$a + #b" => "x + 0" => Some(vec!["$a: x", "#b: 0"])
423 binary_pattern_wrong_op: "$a + #b" => "x - 0" => None
424 binary_pattern_partial: "$a + #b" => "x + y" => None
425
426 unary_pattern: "+$a" => "+x" => Some(vec!["$a: x"])
427 unary_pattern_wrong_op: "+$a" => "-x" => None
428 unary_pattern_partial: "+$a" => "+1" => None
429
430 parend: "($a + #b)" => "(x + 0)" => Some(vec!["$a: x", "#b: 0"])
431 parend_on_bracketed: "($a + #b)" => "[x + 0]" => None
432
433 bracketed: "[$a + #b]" => "[x + 0]" => Some(vec!["$a: x", "#b: 0"])
434 bracketed_on_parend: "[$a + #b]" => "(x + 0)" => None
435 }
436
437 #[test]
438 fn common_subexpression_elimination() {
439 let parsed_rule = parse_rule("#a * _b + _c * #d");
440 let parsed_target = parse_expr("0 * 0 + 0 * 0");
441 let l = match &parsed_target {
442 Expr::BinaryExpr(BinaryExpr { lhs, .. }) => Rc::clone(lhs),
443 _ => unreachable!(),
444 };
445 let ll = match l.as_ref() {
446 Expr::BinaryExpr(BinaryExpr { lhs, .. }) => lhs,
447 _ => unreachable!(),
448 };
449
450 let repls =
451 PatternMatch::match_rule(Rc::new(parsed_rule), Rc::new(parsed_target)).unwrap();
452 let zeros = repls.map.values().collect::<Vec<_>>();
453 assert!(std::ptr::eq(ll.as_ref(), zeros[0].as_ref()));
454 assert!(std::ptr::eq(zeros[0].as_ref(), zeros[1].as_ref()));
455 assert!(std::ptr::eq(zeros[1].as_ref(), zeros[2].as_ref()));
456 assert!(std::ptr::eq(zeros[2].as_ref(), zeros[3].as_ref()));
457 }
458 }
459}