Skip to main content

ast_grep_core/matcher/
pattern.rs

1use crate::language::Language;
2use crate::match_tree::{match_end_non_recursive, match_node_non_recursive, MatchStrictness};
3use crate::matcher::{kind_utils, KindMatcher, KindMatcherError, Matcher};
4use crate::meta_var::{MetaVarEnv, MetaVariable};
5use crate::source::SgNode;
6use crate::{Doc, Node, Root};
7
8use bit_set::BitSet;
9use thiserror::Error;
10
11use std::borrow::Cow;
12use std::collections::HashSet;
13
14#[derive(Clone)]
15pub struct Pattern {
16  pub node: PatternNode,
17  root_kind: Option<u16>,
18  pub strictness: MatchStrictness,
19}
20
21pub struct PatternBuilder<'a> {
22  selector: Option<&'a str>,
23  src: Cow<'a, str>,
24}
25
26impl PatternBuilder<'_> {
27  pub fn build<D, F>(&self, parse: F) -> Result<Pattern, PatternError>
28  where
29    F: FnOnce(&str) -> Result<D, String>,
30    D: Doc,
31  {
32    let doc = parse(&self.src).map_err(PatternError::Parse)?;
33    let root = Root::doc(doc);
34    if let Some(selector) = self.selector {
35      self.contextual(&root, selector)
36    } else {
37      self.single(&root)
38    }
39  }
40  fn single<D: Doc>(&self, root: &Root<D>) -> Result<Pattern, PatternError> {
41    let goal = root.root();
42    if goal.children().len() == 0 {
43      return Err(PatternError::NoContent(self.src.to_string()));
44    }
45    if !is_single_node(&goal.inner) {
46      return Err(PatternError::MultipleNode(self.src.to_string()));
47    }
48    let node = Pattern::single_matcher(root);
49    Ok(Pattern::from(node))
50  }
51
52  fn contextual<D: Doc>(&self, root: &Root<D>, selector: &str) -> Result<Pattern, PatternError> {
53    let goal = root.root();
54    let kind_matcher = KindMatcher::try_new(selector, root.lang().clone())?;
55    let Some(node) = goal.find(&kind_matcher) else {
56      return Err(PatternError::NoSelectorInContext {
57        context: self.src.to_string(),
58        selector: selector.into(),
59      });
60    };
61    Ok(Pattern {
62      root_kind: Some(node.kind_id()),
63      node: convert_node_to_pattern(node.get_node().clone()),
64      strictness: MatchStrictness::Smart,
65    })
66  }
67}
68
69pub struct DumpPattern<'p> {
70  pub is_meta_var: bool,
71  pub kind: Option<Cow<'static, str>>,
72  pub text: Cow<'p, str>,
73  pub children: Vec<DumpPattern<'p>>,
74}
75
76fn dump_pattern_impl<'p>(
77  pattern: &'p PatternNode,
78  strictness: &MatchStrictness,
79  to_kind_str: &impl Fn(u16) -> Option<Cow<'static, str>>,
80) -> Option<DumpPattern<'p>> {
81  match pattern {
82    PatternNode::MetaVar { meta_var } => {
83      let meta_var = match meta_var {
84        MetaVariable::Capture(name, _) => format!("${name}"),
85        MetaVariable::MultiCapture(name) => format!("$$${name}"),
86        MetaVariable::Multiple => "$$$".to_string(),
87        MetaVariable::Dropped(_) => "$_".to_string(),
88      };
89      Some(DumpPattern {
90        is_meta_var: true,
91        kind: Some("MetaVar".into()),
92        text: meta_var.into(),
93        children: vec![],
94      })
95    }
96    PatternNode::Terminal {
97      text,
98      kind_id,
99      is_named,
100    } => {
101      if !*is_named {
102        if matches!(
103          strictness,
104          MatchStrictness::Cst | MatchStrictness::Smart | MatchStrictness::Template
105        ) {
106          return Some(DumpPattern {
107            is_meta_var: false,
108            kind: None,
109            text: text.into(),
110            children: vec![],
111          });
112        }
113        return None;
114      }
115      let kind = if matches!(strictness, MatchStrictness::Template) {
116        None
117      } else {
118        Some(to_kind_str(*kind_id).unwrap_or("UNKNOWN".into()))
119      };
120      let text = if matches!(strictness, MatchStrictness::Signature) {
121        ""
122      } else {
123        text
124      };
125      Some(DumpPattern {
126        is_meta_var: false,
127        kind,
128        text: text.into(),
129        children: vec![],
130      })
131    }
132    PatternNode::Internal { kind_id, children } => {
133      let kind = if matches!(strictness, MatchStrictness::Template) {
134        Cow::Borrowed("(node)")
135      } else {
136        to_kind_str(*kind_id).unwrap_or_else(|| "UNKNOWN".into())
137      };
138      let children = children
139        .iter()
140        .filter_map(|n| dump_pattern_impl(n, strictness, to_kind_str))
141        .collect();
142      Some(DumpPattern {
143        is_meta_var: false,
144        kind: Some(kind),
145        text: Cow::Borrowed(""),
146        children,
147      })
148    }
149  }
150}
151
152#[derive(Clone)]
153pub enum PatternNode {
154  MetaVar {
155    meta_var: MetaVariable,
156  },
157  /// Node without children.
158  Terminal {
159    text: String,
160    is_named: bool,
161    kind_id: u16,
162  },
163  /// Non-Terminal Syntax Nodes are called Internal
164  Internal {
165    kind_id: u16,
166    children: Vec<PatternNode>,
167  },
168}
169
170impl PatternNode {
171  // for skipping trivial nodes in goal after ellipsis
172  pub fn is_trivial(&self) -> bool {
173    match self {
174      PatternNode::Terminal { is_named, .. } => !*is_named,
175      _ => false,
176    }
177  }
178
179  pub fn fixed_string(&self) -> Cow<'_, str> {
180    match &self {
181      PatternNode::Terminal { text, .. } => Cow::Borrowed(text),
182      PatternNode::MetaVar { .. } => Cow::Borrowed(""),
183      PatternNode::Internal { children, .. } => {
184        children
185          .iter()
186          .map(|n| n.fixed_string())
187          .fold(Cow::Borrowed(""), |longest, curr| {
188            if longest.len() >= curr.len() {
189              longest
190            } else {
191              curr
192            }
193          })
194      }
195    }
196  }
197}
198impl<'r, D: Doc> From<Node<'r, D>> for PatternNode {
199  fn from(node: Node<'r, D>) -> Self {
200    convert_node_to_pattern(node)
201  }
202}
203
204impl<'r, D: Doc> From<Node<'r, D>> for Pattern {
205  fn from(node: Node<'r, D>) -> Self {
206    Self {
207      node: convert_node_to_pattern(node),
208      root_kind: None,
209      strictness: MatchStrictness::Smart,
210    }
211  }
212}
213
214fn convert_node_to_pattern<D: Doc>(node: Node<'_, D>) -> PatternNode {
215  if let Some(meta_var) = extract_var_from_node(&node) {
216    PatternNode::MetaVar { meta_var }
217  } else if node.is_leaf() {
218    PatternNode::Terminal {
219      text: node.text().to_string(),
220      is_named: node.is_named(),
221      kind_id: node.kind_id(),
222    }
223  } else {
224    let children = node.children().filter_map(|n| {
225      if n.is_missing() {
226        None
227      } else {
228        Some(PatternNode::from(n))
229      }
230    });
231    PatternNode::Internal {
232      kind_id: node.kind_id(),
233      children: children.collect(),
234    }
235  }
236}
237
238fn extract_var_from_node<D: Doc>(goal: &Node<'_, D>) -> Option<MetaVariable> {
239  let key = goal.text();
240  goal.lang().extract_meta_var(&key)
241}
242
243#[derive(Debug, Error)]
244pub enum PatternError {
245  #[error("Fails to parse the pattern query: `{0}`")]
246  Parse(String),
247  #[error("No AST root is detected. Please check the pattern source `{0}`.")]
248  NoContent(String),
249  #[error("Multiple AST nodes are detected. Please check the pattern source `{0}`.")]
250  MultipleNode(String),
251  #[error(transparent)]
252  InvalidKind(#[from] KindMatcherError),
253  #[error("Fails to create Contextual pattern: selector `{selector}` matches no node in the context `{context}`.")]
254  NoSelectorInContext { context: String, selector: String },
255}
256
257#[inline]
258fn is_single_node<'r, N: SgNode<'r>>(n: &N) -> bool {
259  match n.children().len() {
260    1 => true,
261    2 => {
262      let c = n.child(1).expect("second child must exist");
263      // some language will have weird empty syntax node at the end
264      // see golang's `$A = 0` pattern test case
265      c.is_missing() || c.kind().is_empty()
266    }
267    _ => false,
268  }
269}
270impl Pattern {
271  pub fn dump(
272    &self,
273    kind_id_to_name: &impl Fn(u16) -> Option<Cow<'static, str>>,
274  ) -> Option<DumpPattern<'_>> {
275    dump_pattern_impl(&self.node, &self.strictness, kind_id_to_name)
276  }
277  pub fn has_error(&self) -> bool {
278    let kind = match &self.node {
279      PatternNode::Terminal { kind_id, .. } => *kind_id,
280      PatternNode::Internal { kind_id, .. } => *kind_id,
281      PatternNode::MetaVar { .. } => match self.root_kind {
282        Some(k) => k,
283        None => return false,
284      },
285    };
286    kind_utils::is_error_kind(kind)
287  }
288
289  pub fn fixed_string(&self) -> Cow<'_, str> {
290    self.node.fixed_string()
291  }
292
293  /// Get all defined variables in the pattern.
294  /// Used for validating rules and report undefined variables.
295  pub fn defined_vars(&self) -> HashSet<&str> {
296    let mut vars = HashSet::new();
297    collect_vars(&self.node, &mut vars);
298    vars
299  }
300}
301
302fn meta_var_name(meta_var: &MetaVariable) -> Option<&str> {
303  use MetaVariable as MV;
304  match meta_var {
305    MV::Capture(name, _) => Some(name),
306    MV::MultiCapture(name) => Some(name),
307    MV::Dropped(_) => None,
308    MV::Multiple => None,
309  }
310}
311
312fn collect_vars<'p>(p: &'p PatternNode, vars: &mut HashSet<&'p str>) {
313  match p {
314    PatternNode::MetaVar { meta_var, .. } => {
315      if let Some(name) = meta_var_name(meta_var) {
316        vars.insert(name);
317      }
318    }
319    PatternNode::Terminal { .. } => {
320      // collect nothing for terminal nodes!
321    }
322    PatternNode::Internal { children, .. } => {
323      for c in children {
324        collect_vars(c, vars);
325      }
326    }
327  }
328}
329
330impl Pattern {
331  pub fn try_new<L: Language>(src: &str, lang: L) -> Result<Self, PatternError> {
332    let processed = lang.pre_process_pattern(src);
333    let builder = PatternBuilder {
334      selector: None,
335      src: processed,
336    };
337    lang.build_pattern(&builder)
338  }
339
340  pub fn new<L: Language>(src: &str, lang: L) -> Self {
341    Self::try_new(src, lang).unwrap()
342  }
343
344  pub fn with_strictness(mut self, strictness: MatchStrictness) -> Self {
345    self.strictness = strictness;
346    self
347  }
348
349  pub fn contextual<L: Language>(
350    context: &str,
351    selector: &str,
352    lang: L,
353  ) -> Result<Self, PatternError> {
354    let processed = lang.pre_process_pattern(context);
355    let builder = PatternBuilder {
356      selector: Some(selector),
357      src: processed,
358    };
359    lang.build_pattern(&builder)
360  }
361  fn single_matcher<D: Doc>(root: &Root<D>) -> Node<'_, D> {
362    // debug_assert!(matches!(self.style, PatternStyle::Single));
363    let node = root.root();
364    let mut inner = node.inner;
365    while is_single_node(&inner) {
366      inner = inner.child(0).unwrap();
367    }
368    Node { inner, root }
369  }
370}
371
372impl Matcher for Pattern {
373  fn match_node_with_env<'tree, D: Doc>(
374    &self,
375    node: Node<'tree, D>,
376    env: &mut Cow<MetaVarEnv<'tree, D>>,
377  ) -> Option<Node<'tree, D>> {
378    if let Some(k) = self.root_kind {
379      if node.kind_id() != k {
380        return None;
381      }
382    }
383    // do not pollute the env if pattern does not match
384    let mut may_write = Cow::Borrowed(env.as_ref());
385    let node = match_node_non_recursive(self, node, &mut may_write)?;
386    if let Cow::Owned(map) = may_write {
387      // only change env when pattern matches
388      *env = Cow::Owned(map);
389    }
390    Some(node)
391  }
392
393  fn potential_kinds(&self) -> Option<bit_set::BitSet> {
394    // if strictness is Template, we can match any kind
395    if matches!(self.strictness, MatchStrictness::Template) {
396      return None;
397    }
398    let kind = match self.node {
399      PatternNode::Terminal { kind_id, .. } => kind_id,
400      PatternNode::MetaVar { .. } => self.root_kind?,
401      PatternNode::Internal { kind_id, .. } => {
402        if kind_utils::is_error_kind(kind_id) {
403          // error can match any kind
404          return None;
405        }
406        kind_id
407      }
408    };
409
410    let mut kinds = BitSet::new();
411    kinds.insert(kind.into());
412    Some(kinds)
413  }
414
415  fn get_match_len<D: Doc>(&self, node: Node<'_, D>) -> Option<usize> {
416    let start = node.range().start;
417    let end = match_end_non_recursive(self, node)?;
418    Some(end - start)
419  }
420}
421impl std::fmt::Debug for PatternNode {
422  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
423    match self {
424      Self::MetaVar { meta_var, .. } => write!(f, "{meta_var:?}"),
425      Self::Terminal { text, .. } => write!(f, "{text}"),
426      Self::Internal { children, .. } => write!(f, "{children:?}"),
427    }
428  }
429}
430
431impl std::fmt::Debug for Pattern {
432  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
433    write!(f, "{:?}", self.node)
434  }
435}
436
437#[cfg(test)]
438mod test {
439  use super::*;
440  use crate::language::Tsx;
441  use crate::matcher::MatcherExt;
442  use crate::meta_var::MetaVarEnv;
443  use crate::tree_sitter::StrDoc;
444  use std::collections::HashMap;
445
446  fn pattern_node(s: &str) -> Root<StrDoc<Tsx>> {
447    Root::str(s, Tsx)
448  }
449
450  fn test_match(s1: &str, s2: &str) {
451    let pattern = Pattern::new(s1, Tsx);
452    let cand = pattern_node(s2);
453    let cand = cand.root();
454    assert!(
455      pattern.find_node(cand.clone()).is_some(),
456      "goal: {:?}, candidate: {}",
457      pattern,
458      cand.get_inner_node().to_sexp(),
459    );
460  }
461  fn test_non_match(s1: &str, s2: &str) {
462    let pattern = Pattern::new(s1, Tsx);
463    let cand = pattern_node(s2);
464    let cand = cand.root();
465    assert!(
466      pattern.find_node(cand.clone()).is_none(),
467      "goal: {:?}, candidate: {}",
468      pattern,
469      cand.get_inner_node().to_sexp(),
470    );
471  }
472
473  #[test]
474  fn test_meta_variable() {
475    test_match("const a = $VALUE", "const a = 123");
476    test_match("const $VARIABLE = $VALUE", "const a = 123");
477    test_match("const $VARIABLE = $VALUE", "const a = 123");
478  }
479
480  #[test]
481  fn test_whitespace() {
482    test_match("function t() { }", "function t() {}");
483    test_match("function t() {}", "function t() {  }");
484  }
485
486  fn match_env(goal_str: &str, cand: &str) -> HashMap<String, String> {
487    let pattern = Pattern::new(goal_str, Tsx);
488    let cand = pattern_node(cand);
489    let cand = cand.root();
490    let nm = pattern.find_node(cand).unwrap();
491    HashMap::from(nm.get_env().clone())
492  }
493
494  #[test]
495  fn test_meta_variable_env() {
496    let env = match_env("const a = $VALUE", "const a = 123");
497    assert_eq!(env["VALUE"], "123");
498  }
499
500  #[test]
501  fn test_pattern_should_not_pollute_env() {
502    // gh issue #1164
503    let pattern = Pattern::new("const $A = 114", Tsx);
504    let cand = pattern_node("const a = 514");
505    let cand = cand.root().child(0).unwrap();
506    let map = MetaVarEnv::new();
507    let mut env = Cow::Borrowed(&map);
508    let nm = pattern.match_node_with_env(cand, &mut env);
509    assert!(nm.is_none());
510    assert!(env.get_match("A").is_none());
511    assert!(map.get_match("A").is_none());
512  }
513
514  #[test]
515  fn test_match_non_atomic() {
516    let env = match_env("const a = $VALUE", "const a = 5 + 3");
517    assert_eq!(env["VALUE"], "5 + 3");
518  }
519
520  #[test]
521  fn test_class_assignment() {
522    test_match("class $C { $MEMBER = $VAL}", "class A {a = 123}");
523    test_non_match("class $C { $MEMBER = $VAL; b = 123; }", "class A {a = 123}");
524    // test_match("a = 123", "class A {a = 123}");
525    test_non_match("a = 123", "class B {b = 123}");
526  }
527
528  #[test]
529  fn test_return() {
530    test_match("$A($B)", "return test(123)");
531  }
532
533  #[test]
534  fn test_contextual_pattern() {
535    let pattern =
536      Pattern::contextual("class A { $F = $I }", "public_field_definition", Tsx).expect("test");
537    let cand = pattern_node("class B { b = 123 }");
538    assert!(pattern.find_node(cand.root()).is_some());
539    let cand = pattern_node("let b = 123");
540    assert!(pattern.find_node(cand.root()).is_none());
541  }
542
543  #[test]
544  fn test_contextual_match_with_env() {
545    let pattern =
546      Pattern::contextual("class A { $F = $I }", "public_field_definition", Tsx).expect("test");
547    let cand = pattern_node("class B { b = 123 }");
548    let nm = pattern.find_node(cand.root()).expect("test");
549    let env = nm.get_env();
550    let env = HashMap::from(env.clone());
551    assert_eq!(env["F"], "b");
552    assert_eq!(env["I"], "123");
553  }
554
555  #[test]
556  fn test_contextual_unmatch_with_env() {
557    let pattern =
558      Pattern::contextual("class A { $F = $I }", "public_field_definition", Tsx).expect("test");
559    let cand = pattern_node("let b = 123");
560    let nm = pattern.find_node(cand.root());
561    assert!(nm.is_none());
562  }
563
564  fn get_kind(kind_str: &str) -> usize {
565    Tsx.kind_to_id(kind_str).into()
566  }
567
568  #[test]
569  fn test_pattern_potential_kinds() {
570    let pattern = Pattern::new("const a = 1", Tsx);
571    let kind = get_kind("lexical_declaration");
572    let kinds = pattern.potential_kinds().expect("should have kinds");
573    assert_eq!(kinds.len(), 1);
574    assert!(kinds.contains(kind));
575  }
576
577  #[test]
578  fn test_pattern_with_non_root_meta_var() {
579    let pattern = Pattern::new("const $A = $B", Tsx);
580    let kind = get_kind("lexical_declaration");
581    let kinds = pattern.potential_kinds().expect("should have kinds");
582    assert_eq!(kinds.len(), 1);
583    assert!(kinds.contains(kind));
584  }
585
586  #[test]
587  fn test_bare_wildcard() {
588    let pattern = Pattern::new("$A", Tsx);
589    // wildcard should match anything, so kinds should be None
590    assert!(pattern.potential_kinds().is_none());
591  }
592
593  #[test]
594  fn test_contextual_potential_kinds() {
595    let pattern =
596      Pattern::contextual("class A { $F = $I }", "public_field_definition", Tsx).expect("test");
597    let kind = get_kind("public_field_definition");
598    let kinds = pattern.potential_kinds().expect("should have kinds");
599    assert_eq!(kinds.len(), 1);
600    assert!(kinds.contains(kind));
601  }
602
603  #[test]
604  fn test_contextual_wildcard() {
605    let pattern = Pattern::contextual("class A { $F }", "property_identifier", Tsx).expect("test");
606    let kind = get_kind("property_identifier");
607    let kinds = pattern.potential_kinds().expect("should have kinds");
608    assert_eq!(kinds.len(), 1);
609    assert!(kinds.contains(kind));
610  }
611
612  #[test]
613  #[ignore]
614  fn test_multi_node_pattern() {
615    let pattern = Pattern::new("a;b;c;", Tsx);
616    let kinds = pattern.potential_kinds().expect("should have kinds");
617    assert_eq!(kinds.len(), 1);
618    test_match("a;b;c", "a;b;c;");
619  }
620
621  #[test]
622  #[ignore]
623  fn test_multi_node_meta_var() {
624    let env = match_env("a;$B;c", "a;b;c");
625    assert_eq!(env["B"], "b");
626    let env = match_env("a;$B;c", "a;1+2+3;c");
627    assert_eq!(env["B"], "1+2+3");
628  }
629
630  #[test]
631  #[ignore]
632  fn test_pattern_size() {
633    assert_eq!(std::mem::size_of::<Pattern>(), 40);
634  }
635
636  #[test]
637  fn test_error_kind() {
638    let ret = Pattern::contextual("a", "property_identifier", Tsx);
639    assert!(ret.is_err());
640    let ret = Pattern::new("123+", Tsx);
641    assert!(ret.has_error());
642  }
643
644  #[test]
645  fn test_bare_wildcard_in_context() {
646    let pattern = Pattern::contextual("class A { $F }", "property_identifier", Tsx).expect("test");
647    let cand = pattern_node("let b = 123");
648    // it should not match
649    assert!(pattern.find_node(cand.root()).is_none());
650  }
651
652  #[test]
653  fn test_pattern_fixed_string() {
654    let pattern = Pattern::new("class A { $F }", Tsx);
655    assert_eq!(pattern.fixed_string(), "class");
656    let pattern = Pattern::contextual("class A { $F }", "property_identifier", Tsx).expect("test");
657    assert!(pattern.fixed_string().is_empty());
658  }
659
660  #[test]
661  fn test_pattern_error() {
662    let pattern = Pattern::try_new("", Tsx);
663    assert!(matches!(pattern, Err(PatternError::NoContent(_))));
664    let pattern = Pattern::try_new("12  3344", Tsx);
665    assert!(matches!(pattern, Err(PatternError::MultipleNode(_))));
666  }
667
668  #[test]
669  fn test_debug_pattern() {
670    let pattern = Pattern::new("var $A = 1", Tsx);
671    assert_eq!(
672      format!("{pattern:?}"),
673      "[var, [Capture(\"A\", true), =, 1]]"
674    );
675  }
676
677  fn defined_vars(s: &str) -> Vec<String> {
678    let pattern = Pattern::new(s, Tsx);
679    let mut vars: Vec<_> = pattern
680      .defined_vars()
681      .into_iter()
682      .map(String::from)
683      .collect();
684    vars.sort();
685    vars
686  }
687
688  #[test]
689  fn test_extract_meta_var_from_pattern() {
690    let vars = defined_vars("var $A = 1");
691    assert_eq!(vars, ["A"]);
692  }
693
694  #[test]
695  fn test_extract_complex_meta_var() {
696    let vars = defined_vars("function $FUNC($$$ARGS): $RET { $$$BODY }");
697    assert_eq!(vars, ["ARGS", "BODY", "FUNC", "RET"]);
698  }
699
700  #[test]
701  fn test_extract_duplicate_meta_var() {
702    let vars = defined_vars("var $A = $A");
703    assert_eq!(vars, ["A"]);
704  }
705
706  #[test]
707  fn test_contextual_pattern_vars() {
708    let pattern = Pattern::contextual("<div ref={$A}/>", "jsx_attribute", Tsx).expect("correct");
709    assert_eq!(pattern.defined_vars(), ["A"].into_iter().collect());
710  }
711
712  #[test]
713  fn test_gh_1087() {
714    test_match("($P) => $F($P)", "(x) => bar(x)");
715  }
716
717  #[test]
718  fn test_template_pattern_have_no_kinds() {
719    let pattern = Pattern::new("$A = $B", Tsx).with_strictness(MatchStrictness::Template);
720    assert!(pattern.potential_kinds().is_none());
721    let pattern = Pattern::contextual("{a: b}", "pair", Tsx)
722      .expect("should create template pattern")
723      .with_strictness(MatchStrictness::Template);
724    assert!(pattern.potential_kinds().is_none());
725  }
726}