1use fmt::Formatter;
2use log::*;
3use std::borrow::Cow;
4use std::convert::TryInto;
5use std::fmt::{self, Display};
6use std::{convert::TryFrom, str::FromStr};
7
8use thiserror::Error;
9
10use crate::*;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
66pub struct Pattern<L> {
67 pub ast: PatternAst<L>,
69 program: machine::Program<L>,
70}
71
72pub type PatternAst<L> = RecExpr<ENodeOrVar<L>>;
75
76impl<L: Language> PatternAst<L> {
77 pub fn alpha_rename(&self) -> Self {
79 let mut vars = HashMap::<Var, Var>::default();
80 let mut new = PatternAst::default();
81
82 fn mkvar(i: usize) -> Var {
83 let vs = &["?x", "?y", "?z", "?w"];
84 match vs.get(i) {
85 Some(v) => v.parse().unwrap(),
86 None => format!("?v{}", i - vs.len()).parse().unwrap(),
87 }
88 }
89
90 for n in self {
91 new.add(match n {
92 ENodeOrVar::ENode(_) => n.clone(),
93 ENodeOrVar::Var(v) => {
94 let i = vars.len();
95 ENodeOrVar::Var(*vars.entry(*v).or_insert_with(|| mkvar(i)))
96 }
97 });
98 }
99
100 new
101 }
102}
103
104impl<L: Language> Pattern<L> {
105 pub fn new(ast: PatternAst<L>) -> Self {
107 let ast = ast.compact();
108 let program = machine::Program::compile_from_pat(&ast);
109 Pattern { ast, program }
110 }
111
112 pub fn vars(&self) -> Vec<Var> {
114 let mut vars = vec![];
115 for n in &self.ast {
116 if let ENodeOrVar::Var(v) = n {
117 if !vars.contains(v) {
118 vars.push(*v)
119 }
120 }
121 }
122 vars
123 }
124}
125
126impl<L: Language + Display> Pattern<L> {
127 pub fn pretty(&self, width: usize) -> String {
129 self.ast.pretty(width)
130 }
131}
132
133#[derive(Debug, Hash, PartialEq, Eq, Clone, PartialOrd, Ord)]
136pub enum ENodeOrVar<L> {
137 ENode(L),
139 Var(Var),
141}
142
143#[derive(Debug, Hash, PartialEq, Eq, Clone)]
145pub enum ENodeOrVarDiscriminant<L: Language> {
146 ENode(L::Discriminant),
147 Var(Var),
148}
149
150impl<L: Language> Language for ENodeOrVar<L> {
151 type Discriminant = ENodeOrVarDiscriminant<L>;
152
153 #[inline(always)]
154 fn discriminant(&self) -> Self::Discriminant {
155 match self {
156 ENodeOrVar::ENode(n) => ENodeOrVarDiscriminant::ENode(n.discriminant()),
157 ENodeOrVar::Var(v) => ENodeOrVarDiscriminant::Var(*v),
158 }
159 }
160
161 fn matches(&self, _other: &Self) -> bool {
162 panic!("Should never call this")
163 }
164
165 fn children(&self) -> &[Id] {
166 match self {
167 ENodeOrVar::ENode(n) => n.children(),
168 ENodeOrVar::Var(_) => &[],
169 }
170 }
171
172 fn children_mut(&mut self) -> &mut [Id] {
173 match self {
174 ENodeOrVar::ENode(n) => n.children_mut(),
175 ENodeOrVar::Var(_) => &mut [],
176 }
177 }
178}
179
180impl<L: Language + Display> Display for ENodeOrVar<L> {
181 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
182 match self {
183 Self::ENode(node) => Display::fmt(node, f),
184 Self::Var(var) => Display::fmt(var, f),
185 }
186 }
187}
188
189#[derive(Debug, Error)]
190pub enum ENodeOrVarParseError<E> {
191 #[error(transparent)]
192 BadVar(<Var as FromStr>::Err),
193
194 #[error("tried to parse pattern variable {0:?} as an operator")]
195 UnexpectedVar(String),
196
197 #[error(transparent)]
198 BadOp(E),
199}
200
201impl<L: FromOp> FromOp for ENodeOrVar<L> {
202 type Error = ENodeOrVarParseError<L::Error>;
203
204 fn from_op(op: &str, children: Vec<Id>) -> Result<Self, Self::Error> {
205 use ENodeOrVarParseError::*;
206
207 if op.starts_with('?') && op.len() > 1 {
208 if children.is_empty() {
209 op.parse().map(Self::Var).map_err(BadVar)
210 } else {
211 Err(UnexpectedVar(op.to_owned()))
212 }
213 } else {
214 L::from_op(op, children).map(Self::ENode).map_err(BadOp)
215 }
216 }
217}
218
219impl<L: FromOp> std::str::FromStr for Pattern<L> {
220 type Err = RecExprParseError<ENodeOrVarParseError<L::Error>>;
221
222 fn from_str(s: &str) -> Result<Self, Self::Err> {
223 PatternAst::from_str(s).map(Self::from)
224 }
225}
226
227impl<'a, L: Language> From<&'a [L]> for Pattern<L> {
228 fn from(expr: &'a [L]) -> Self {
229 let ast = expr.iter().cloned().map(ENodeOrVar::ENode).collect();
230 Self::new(ast)
231 }
232}
233
234impl<L: Language> From<RecExpr<L>> for Pattern<L> {
235 fn from(expr: RecExpr<L>) -> Self {
236 let ast = expr.into_iter().map(ENodeOrVar::ENode).collect();
237 Self::new(ast)
238 }
239}
240
241impl<L: Language> From<&RecExpr<L>> for Pattern<L> {
242 fn from(expr: &RecExpr<L>) -> Self {
243 Self::from(expr.as_ref())
244 }
245}
246
247impl<L: Language> From<PatternAst<L>> for Pattern<L> {
248 fn from(ast: PatternAst<L>) -> Self {
249 Self::new(ast)
250 }
251}
252
253impl<L: Language> TryFrom<PatternAst<L>> for RecExpr<L> {
254 type Error = Var;
255 fn try_from(ast: PatternAst<L>) -> Result<Self, Self::Error> {
256 ast.into_iter()
257 .map(|n| match n {
258 ENodeOrVar::ENode(n) => Ok(n),
259 ENodeOrVar::Var(v) => Err(v),
260 })
261 .collect()
262 }
263}
264
265impl<L: Language> TryFrom<Pattern<L>> for RecExpr<L> {
266 type Error = Var;
267 fn try_from(pat: Pattern<L>) -> Result<Self, Self::Error> {
268 pat.ast.try_into()
269 }
270}
271
272impl<L: Language + Display> Display for Pattern<L> {
273 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
274 Display::fmt(&self.ast, f)
275 }
276}
277
278#[derive(Debug)]
286pub struct SearchMatches<'a, L: Language> {
287 pub eclass: Id,
289 pub substs: Vec<Subst>,
291 pub ast: Option<Cow<'a, PatternAst<L>>>,
293}
294
295impl<L: Language, A: Analysis<L>> Searcher<L, A> for Pattern<L> {
296 fn get_pattern_ast(&self) -> Option<&PatternAst<L>> {
297 Some(&self.ast)
298 }
299
300 fn search_with_limit(&self, egraph: &EGraph<L, A>, limit: usize) -> Vec<SearchMatches<L>> {
301 match self.ast.last().unwrap() {
302 ENodeOrVar::ENode(e) => {
303 let key = e.discriminant();
304 match egraph.classes_for_op(&key) {
305 None => vec![],
306 Some(ids) => rewrite::search_eclasses_with_limit(self, egraph, ids, limit),
307 }
308 }
309 ENodeOrVar::Var(_) => rewrite::search_eclasses_with_limit(
310 self,
311 egraph,
312 egraph.classes().map(|e| e.id),
313 limit,
314 ),
315 }
316 }
317
318 fn search_eclass_with_limit(
319 &self,
320 egraph: &EGraph<L, A>,
321 eclass: Id,
322 limit: usize,
323 ) -> Option<SearchMatches<L>> {
324 let substs = self.program.run_with_limit(egraph, eclass, limit);
325 if substs.is_empty() {
326 None
327 } else {
328 let ast = Some(Cow::Borrowed(&self.ast));
329 Some(SearchMatches {
330 eclass,
331 substs,
332 ast,
333 })
334 }
335 }
336
337 fn vars(&self) -> Vec<Var> {
338 Pattern::vars(self)
339 }
340}
341
342impl<L, A> Applier<L, A> for Pattern<L>
343where
344 L: Language,
345 A: Analysis<L>,
346{
347 fn get_pattern_ast(&self) -> Option<&PatternAst<L>> {
348 Some(&self.ast)
349 }
350
351 fn apply_matches(
352 &self,
353 egraph: &mut EGraph<L, A>,
354 matches: &[SearchMatches<L>],
355 rule_name: Symbol,
356 ) -> Vec<Id> {
357 let mut added = vec![];
358 let mut id_buf = vec![0.into(); self.ast.len()];
359 for mat in matches {
360 let sast = mat.ast.as_ref().map(|cow| cow.as_ref());
361 for subst in &mat.substs {
362 let did_something;
363 let id;
364 if egraph.are_explanations_enabled() {
365 let (id_temp, did_something_temp) =
366 egraph.union_instantiations(sast.unwrap(), &self.ast, subst, rule_name);
367 did_something = did_something_temp;
368 id = id_temp;
369 } else {
370 id = apply_pat(&mut id_buf, &self.ast, egraph, subst);
371 did_something = egraph.union(id, mat.eclass);
372 }
373
374 if did_something {
375 added.push(id)
376 }
377 }
378 }
379 added
380 }
381
382 fn apply_one(
383 &self,
384 egraph: &mut EGraph<L, A>,
385 eclass: Id,
386 subst: &Subst,
387 searcher_ast: Option<&PatternAst<L>>,
388 rule_name: Symbol,
389 ) -> Vec<Id> {
390 let mut id_buf = vec![0.into(); self.ast.len()];
391 let id = apply_pat(&mut id_buf, &self.ast, egraph, subst);
392
393 if let Some(ast) = searcher_ast {
394 let (from, did_something) =
395 egraph.union_instantiations(ast, &self.ast, subst, rule_name);
396 if did_something {
397 vec![from]
398 } else {
399 vec![]
400 }
401 } else if egraph.union(eclass, id) {
402 vec![eclass]
403 } else {
404 vec![]
405 }
406 }
407
408 fn vars(&self) -> Vec<Var> {
409 Pattern::vars(self)
410 }
411}
412
413pub(crate) fn apply_pat<L: Language, A: Analysis<L>>(
414 ids: &mut [Id],
415 pat: &[ENodeOrVar<L>],
416 egraph: &mut EGraph<L, A>,
417 subst: &Subst,
418) -> Id {
419 debug_assert_eq!(pat.len(), ids.len());
420 trace!("apply_rec {:2?} {:?}", pat, subst);
421
422 for (i, pat_node) in pat.iter().enumerate() {
423 let id = match pat_node {
424 ENodeOrVar::Var(w) => subst[*w],
425 ENodeOrVar::ENode(e) => {
426 let n = e.clone().map_children(|child| ids[usize::from(child)]);
427 trace!("adding: {:?}", n);
428 egraph.add(n)
429 }
430 };
431 ids[i] = id;
432 }
433
434 *ids.last().unwrap()
435}
436
437#[cfg(test)]
438mod tests {
439
440 use crate::{SymbolLang as S, *};
441
442 type EGraph = crate::EGraph<S, ()>;
443
444 #[test]
445 fn simple_match() {
446 crate::init_logger();
447 let mut egraph = EGraph::default();
448
449 let (plus_id, _) = egraph.union_instantiations(
450 &"(+ x y)".parse().unwrap(),
451 &"(+ z w)".parse().unwrap(),
452 &Default::default(),
453 "union_plus".to_string(),
454 );
455 egraph.rebuild();
456
457 let commute_plus = rewrite!(
458 "commute_plus";
459 "(+ ?a ?b)" => "(+ ?b ?a)"
460 );
461
462 let matches = commute_plus.search(&egraph);
463 let n_matches: usize = matches.iter().map(|m| m.substs.len()).sum();
464 assert_eq!(n_matches, 2, "matches is wrong: {:#?}", matches);
465
466 let applications = commute_plus.apply(&mut egraph, &matches);
467 egraph.rebuild();
468 assert_eq!(applications.len(), 2);
469
470 let actual_substs: Vec<Subst> = matches.iter().flat_map(|m| m.substs.clone()).collect();
471
472 println!("Here are the substs!");
473 for m in &actual_substs {
474 println!("substs: {:?}", m);
475 }
476
477 egraph.dot().to_dot("target/simple-match.dot").unwrap();
478
479 use crate::extract::{AstSize, Extractor};
480
481 let ext = Extractor::new(&egraph, AstSize);
482 let (_, best) = ext.find_best(plus_id);
483 eprintln!("Best: {:#?}", best);
484 }
485
486 #[test]
487 fn nonlinear_patterns() {
488 crate::init_logger();
489 let mut egraph = EGraph::default();
490 egraph.add_expr(&"(f a a)".parse().unwrap());
491 egraph.add_expr(&"(f a (g a))))".parse().unwrap());
492 egraph.add_expr(&"(f a (g b))))".parse().unwrap());
493 egraph.add_expr(&"(h (foo a b) 0 1)".parse().unwrap());
494 egraph.add_expr(&"(h (foo a b) 1 0)".parse().unwrap());
495 egraph.add_expr(&"(h (foo a b) 0 0)".parse().unwrap());
496 egraph.rebuild();
497
498 let n_matches = |s: &str| s.parse::<Pattern<S>>().unwrap().n_matches(&egraph);
499
500 assert_eq!(n_matches("(f ?x ?y)"), 3);
501 assert_eq!(n_matches("(f ?x ?x)"), 1);
502 assert_eq!(n_matches("(f ?x (g ?y))))"), 2);
503 assert_eq!(n_matches("(f ?x (g ?x))))"), 1);
504 assert_eq!(n_matches("(h ?x 0 0)"), 1);
505 }
506
507 #[test]
508 fn search_with_limit() {
509 crate::init_logger();
510 let init_expr = &"(+ 1 (+ 2 (+ 3 (+ 4 (+ 5 6)))))".parse().unwrap();
511 let rules: Vec<Rewrite<_, ()>> = vec![
512 rewrite!("comm"; "(+ ?x ?y)" => "(+ ?y ?x)"),
513 rewrite!("assoc"; "(+ ?x (+ ?y ?z))" => "(+ (+ ?x ?y) ?z)"),
514 ];
515 let runner = Runner::default().with_expr(init_expr).run(&rules);
516 let egraph = &runner.egraph;
517
518 let len = |m: &Vec<SearchMatches<S>>| -> usize { m.iter().map(|m| m.substs.len()).sum() };
519
520 let pat = &"(+ ?x (+ ?y ?z))".parse::<Pattern<S>>().unwrap();
521 let m = pat.search(egraph);
522 let match_size = 2100;
523 assert_eq!(len(&m), match_size);
524
525 for limit in [1, 10, 100, 1000, 10000] {
526 let m = pat.search_with_limit(egraph, limit);
527 assert_eq!(len(&m), usize::min(limit, match_size));
528 }
529
530 let id = egraph.lookup_expr(init_expr).unwrap();
531 let m = pat.search_eclass(egraph, id).unwrap();
532 let match_size = 540;
533 assert_eq!(m.substs.len(), match_size);
534
535 for limit in [1, 10, 100, 1000] {
536 let m1 = pat.search_eclass_with_limit(egraph, id, limit).unwrap();
537 assert_eq!(m1.substs.len(), usize::min(limit, match_size));
538 }
539 }
540}