ast_grep_core/matcher/
pattern.rs

1use crate::language::Language;
2use crate::match_tree::{match_end_non_recursive, match_node_non_recursive, MatchStrictness};
3use crate::matcher::{kind_utils, KindMatcher, KindMatcherError, Matcher};
4use crate::meta_var::{MetaVarEnv, MetaVariable};
5use crate::source::SgNode;
6use crate::{Doc, Node, Root};
7
8use bit_set::BitSet;
9use thiserror::Error;
10
11use std::borrow::Cow;
12use std::collections::HashSet;
13
14#[derive(Clone)]
15pub struct Pattern {
16  pub node: PatternNode,
17  root_kind: Option<u16>,
18  pub strictness: MatchStrictness,
19}
20
21pub struct PatternBuilder<'a> {
22  selector: Option<&'a str>,
23  src: Cow<'a, str>,
24}
25
26impl PatternBuilder<'_> {
27  pub fn build<D, F>(&self, parse: F) -> Result<Pattern, PatternError>
28  where
29    F: FnOnce(&str) -> Result<D, String>,
30    D: Doc,
31  {
32    let doc = parse(&self.src).map_err(PatternError::Parse)?;
33    let root = Root::doc(doc);
34    if let Some(selector) = self.selector {
35      self.contextual(&root, selector)
36    } else {
37      self.single(&root)
38    }
39  }
40  fn single<D: Doc>(&self, root: &Root<D>) -> Result<Pattern, PatternError> {
41    let goal = root.root();
42    if goal.children().len() == 0 {
43      return Err(PatternError::NoContent(self.src.to_string()));
44    }
45    if !is_single_node(&goal.inner) {
46      return Err(PatternError::MultipleNode(self.src.to_string()));
47    }
48    let node = Pattern::single_matcher(root);
49    Ok(Pattern::from(node))
50  }
51
52  fn contextual<D: Doc>(&self, root: &Root<D>, selector: &str) -> Result<Pattern, PatternError> {
53    let goal = root.root();
54    let kind_matcher = KindMatcher::try_new(selector, root.lang().clone())?;
55    let Some(node) = goal.find(&kind_matcher) else {
56      return Err(PatternError::NoSelectorInContext {
57        context: self.src.to_string(),
58        selector: selector.into(),
59      });
60    };
61    Ok(Pattern {
62      root_kind: Some(node.kind_id()),
63      node: convert_node_to_pattern(node.get_node().clone()),
64      strictness: MatchStrictness::Smart,
65    })
66  }
67}
68
69#[derive(Clone)]
70pub enum PatternNode {
71  MetaVar {
72    meta_var: MetaVariable,
73  },
74  /// Node without children.
75  Terminal {
76    text: String,
77    is_named: bool,
78    kind_id: u16,
79  },
80  /// Non-Terminal Syntax Nodes are called Internal
81  Internal {
82    kind_id: u16,
83    children: Vec<PatternNode>,
84  },
85}
86
87impl PatternNode {
88  // for skipping trivial nodes in goal after ellipsis
89  pub fn is_trivial(&self) -> bool {
90    match self {
91      PatternNode::Terminal { is_named, .. } => !*is_named,
92      _ => false,
93    }
94  }
95
96  pub fn fixed_string(&self) -> Cow<'_, str> {
97    match &self {
98      PatternNode::Terminal { text, .. } => Cow::Borrowed(text),
99      PatternNode::MetaVar { .. } => Cow::Borrowed(""),
100      PatternNode::Internal { children, .. } => {
101        children
102          .iter()
103          .map(|n| n.fixed_string())
104          .fold(Cow::Borrowed(""), |longest, curr| {
105            if longest.len() >= curr.len() {
106              longest
107            } else {
108              curr
109            }
110          })
111      }
112    }
113  }
114}
115impl<'r, D: Doc> From<Node<'r, D>> for PatternNode {
116  fn from(node: Node<'r, D>) -> Self {
117    convert_node_to_pattern(node)
118  }
119}
120
121impl<'r, D: Doc> From<Node<'r, D>> for Pattern {
122  fn from(node: Node<'r, D>) -> Self {
123    Self {
124      node: convert_node_to_pattern(node),
125      root_kind: None,
126      strictness: MatchStrictness::Smart,
127    }
128  }
129}
130
131fn convert_node_to_pattern<D: Doc>(node: Node<'_, D>) -> PatternNode {
132  if let Some(meta_var) = extract_var_from_node(&node) {
133    PatternNode::MetaVar { meta_var }
134  } else if node.is_leaf() {
135    PatternNode::Terminal {
136      text: node.text().to_string(),
137      is_named: node.is_named(),
138      kind_id: node.kind_id(),
139    }
140  } else {
141    let children = node.children().filter_map(|n| {
142      if n.is_missing() {
143        None
144      } else {
145        Some(PatternNode::from(n))
146      }
147    });
148    PatternNode::Internal {
149      kind_id: node.kind_id(),
150      children: children.collect(),
151    }
152  }
153}
154
155fn extract_var_from_node<D: Doc>(goal: &Node<'_, D>) -> Option<MetaVariable> {
156  let key = goal.text();
157  goal.lang().extract_meta_var(&key)
158}
159
160#[derive(Debug, Error)]
161pub enum PatternError {
162  #[error("Fails to parse the pattern query: `{0}`")]
163  Parse(String),
164  #[error("No AST root is detected. Please check the pattern source `{0}`.")]
165  NoContent(String),
166  #[error("Multiple AST nodes are detected. Please check the pattern source `{0}`.")]
167  MultipleNode(String),
168  #[error(transparent)]
169  InvalidKind(#[from] KindMatcherError),
170  #[error("Fails to create Contextual pattern: selector `{selector}` matches no node in the context `{context}`.")]
171  NoSelectorInContext { context: String, selector: String },
172}
173
174#[inline]
175fn is_single_node<'r, N: SgNode<'r>>(n: &N) -> bool {
176  match n.children().len() {
177    1 => true,
178    2 => {
179      let c = n.child(1).expect("second child must exist");
180      // some language will have weird empty syntax node at the end
181      // see golang's `$A = 0` pattern test case
182      c.is_missing() || c.kind().is_empty()
183    }
184    _ => false,
185  }
186}
187impl Pattern {
188  pub fn has_error(&self) -> bool {
189    let kind = match &self.node {
190      PatternNode::Terminal { kind_id, .. } => *kind_id,
191      PatternNode::Internal { kind_id, .. } => *kind_id,
192      PatternNode::MetaVar { .. } => match self.root_kind {
193        Some(k) => k,
194        None => return false,
195      },
196    };
197    kind_utils::is_error_kind(kind)
198  }
199
200  pub fn fixed_string(&self) -> Cow<'_, str> {
201    self.node.fixed_string()
202  }
203
204  /// Get all defined variables in the pattern.
205  /// Used for validating rules and report undefined variables.
206  pub fn defined_vars(&self) -> HashSet<&str> {
207    let mut vars = HashSet::new();
208    collect_vars(&self.node, &mut vars);
209    vars
210  }
211}
212
213fn meta_var_name(meta_var: &MetaVariable) -> Option<&str> {
214  use MetaVariable as MV;
215  match meta_var {
216    MV::Capture(name, _) => Some(name),
217    MV::MultiCapture(name) => Some(name),
218    MV::Dropped(_) => None,
219    MV::Multiple => None,
220  }
221}
222
223fn collect_vars<'p>(p: &'p PatternNode, vars: &mut HashSet<&'p str>) {
224  match p {
225    PatternNode::MetaVar { meta_var, .. } => {
226      if let Some(name) = meta_var_name(meta_var) {
227        vars.insert(name);
228      }
229    }
230    PatternNode::Terminal { .. } => {
231      // collect nothing for terminal nodes!
232    }
233    PatternNode::Internal { children, .. } => {
234      for c in children {
235        collect_vars(c, vars);
236      }
237    }
238  }
239}
240
241impl Pattern {
242  pub fn try_new<L: Language>(src: &str, lang: L) -> Result<Self, PatternError> {
243    let processed = lang.pre_process_pattern(src);
244    let builder = PatternBuilder {
245      selector: None,
246      src: processed,
247    };
248    lang.build_pattern(&builder)
249  }
250
251  pub fn new<L: Language>(src: &str, lang: L) -> Self {
252    Self::try_new(src, lang).unwrap()
253  }
254
255  pub fn with_strictness(mut self, strictness: MatchStrictness) -> Self {
256    self.strictness = strictness;
257    self
258  }
259
260  pub fn contextual<L: Language>(
261    context: &str,
262    selector: &str,
263    lang: L,
264  ) -> Result<Self, PatternError> {
265    let processed = lang.pre_process_pattern(context);
266    let builder = PatternBuilder {
267      selector: Some(selector),
268      src: processed,
269    };
270    lang.build_pattern(&builder)
271  }
272  fn single_matcher<D: Doc>(root: &Root<D>) -> Node<'_, D> {
273    // debug_assert!(matches!(self.style, PatternStyle::Single));
274    let node = root.root();
275    let mut inner = node.inner;
276    while is_single_node(&inner) {
277      inner = inner.child(0).unwrap();
278    }
279    Node { inner, root }
280  }
281}
282
283impl Matcher for Pattern {
284  fn match_node_with_env<'tree, D: Doc>(
285    &self,
286    node: Node<'tree, D>,
287    env: &mut Cow<MetaVarEnv<'tree, D>>,
288  ) -> Option<Node<'tree, D>> {
289    if let Some(k) = self.root_kind {
290      if node.kind_id() != k {
291        return None;
292      }
293    }
294    // do not pollute the env if pattern does not match
295    let mut may_write = Cow::Borrowed(env.as_ref());
296    let node = match_node_non_recursive(self, node, &mut may_write)?;
297    if let Cow::Owned(map) = may_write {
298      // only change env when pattern matches
299      *env = Cow::Owned(map);
300    }
301    Some(node)
302  }
303
304  fn potential_kinds(&self) -> Option<bit_set::BitSet> {
305    // if strictness is Template, we can match any kind
306    if matches!(self.strictness, MatchStrictness::Template) {
307      return None;
308    }
309    let kind = match self.node {
310      PatternNode::Terminal { kind_id, .. } => kind_id,
311      PatternNode::MetaVar { .. } => self.root_kind?,
312      PatternNode::Internal { kind_id, .. } => {
313        if kind_utils::is_error_kind(kind_id) {
314          // error can match any kind
315          return None;
316        }
317        kind_id
318      }
319    };
320
321    let mut kinds = BitSet::new();
322    kinds.insert(kind.into());
323    Some(kinds)
324  }
325
326  fn get_match_len<D: Doc>(&self, node: Node<'_, D>) -> Option<usize> {
327    let start = node.range().start;
328    let end = match_end_non_recursive(self, node)?;
329    Some(end - start)
330  }
331}
332impl std::fmt::Debug for PatternNode {
333  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334    match self {
335      Self::MetaVar { meta_var, .. } => write!(f, "{meta_var:?}"),
336      Self::Terminal { text, .. } => write!(f, "{text}"),
337      Self::Internal { children, .. } => write!(f, "{children:?}"),
338    }
339  }
340}
341
342impl std::fmt::Debug for Pattern {
343  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344    write!(f, "{:?}", self.node)
345  }
346}
347
348#[cfg(test)]
349mod test {
350  use super::*;
351  use crate::language::Tsx;
352  use crate::matcher::MatcherExt;
353  use crate::meta_var::MetaVarEnv;
354  use crate::tree_sitter::StrDoc;
355  use std::collections::HashMap;
356
357  fn pattern_node(s: &str) -> Root<StrDoc<Tsx>> {
358    Root::str(s, Tsx)
359  }
360
361  fn test_match(s1: &str, s2: &str) {
362    let pattern = Pattern::new(s1, Tsx);
363    let cand = pattern_node(s2);
364    let cand = cand.root();
365    assert!(
366      pattern.find_node(cand.clone()).is_some(),
367      "goal: {:?}, candidate: {}",
368      pattern,
369      cand.get_inner_node().to_sexp(),
370    );
371  }
372  fn test_non_match(s1: &str, s2: &str) {
373    let pattern = Pattern::new(s1, Tsx);
374    let cand = pattern_node(s2);
375    let cand = cand.root();
376    assert!(
377      pattern.find_node(cand.clone()).is_none(),
378      "goal: {:?}, candidate: {}",
379      pattern,
380      cand.get_inner_node().to_sexp(),
381    );
382  }
383
384  #[test]
385  fn test_meta_variable() {
386    test_match("const a = $VALUE", "const a = 123");
387    test_match("const $VARIABLE = $VALUE", "const a = 123");
388    test_match("const $VARIABLE = $VALUE", "const a = 123");
389  }
390
391  #[test]
392  fn test_whitespace() {
393    test_match("function t() { }", "function t() {}");
394    test_match("function t() {}", "function t() {  }");
395  }
396
397  fn match_env(goal_str: &str, cand: &str) -> HashMap<String, String> {
398    let pattern = Pattern::new(goal_str, Tsx);
399    let cand = pattern_node(cand);
400    let cand = cand.root();
401    let nm = pattern.find_node(cand).unwrap();
402    HashMap::from(nm.get_env().clone())
403  }
404
405  #[test]
406  fn test_meta_variable_env() {
407    let env = match_env("const a = $VALUE", "const a = 123");
408    assert_eq!(env["VALUE"], "123");
409  }
410
411  #[test]
412  fn test_pattern_should_not_pollute_env() {
413    // gh issue #1164
414    let pattern = Pattern::new("const $A = 114", Tsx);
415    let cand = pattern_node("const a = 514");
416    let cand = cand.root().child(0).unwrap();
417    let map = MetaVarEnv::new();
418    let mut env = Cow::Borrowed(&map);
419    let nm = pattern.match_node_with_env(cand, &mut env);
420    assert!(nm.is_none());
421    assert!(env.get_match("A").is_none());
422    assert!(map.get_match("A").is_none());
423  }
424
425  #[test]
426  fn test_match_non_atomic() {
427    let env = match_env("const a = $VALUE", "const a = 5 + 3");
428    assert_eq!(env["VALUE"], "5 + 3");
429  }
430
431  #[test]
432  fn test_class_assignment() {
433    test_match("class $C { $MEMBER = $VAL}", "class A {a = 123}");
434    test_non_match("class $C { $MEMBER = $VAL; b = 123; }", "class A {a = 123}");
435    // test_match("a = 123", "class A {a = 123}");
436    test_non_match("a = 123", "class B {b = 123}");
437  }
438
439  #[test]
440  fn test_return() {
441    test_match("$A($B)", "return test(123)");
442  }
443
444  #[test]
445  fn test_contextual_pattern() {
446    let pattern =
447      Pattern::contextual("class A { $F = $I }", "public_field_definition", Tsx).expect("test");
448    let cand = pattern_node("class B { b = 123 }");
449    assert!(pattern.find_node(cand.root()).is_some());
450    let cand = pattern_node("let b = 123");
451    assert!(pattern.find_node(cand.root()).is_none());
452  }
453
454  #[test]
455  fn test_contextual_match_with_env() {
456    let pattern =
457      Pattern::contextual("class A { $F = $I }", "public_field_definition", Tsx).expect("test");
458    let cand = pattern_node("class B { b = 123 }");
459    let nm = pattern.find_node(cand.root()).expect("test");
460    let env = nm.get_env();
461    let env = HashMap::from(env.clone());
462    assert_eq!(env["F"], "b");
463    assert_eq!(env["I"], "123");
464  }
465
466  #[test]
467  fn test_contextual_unmatch_with_env() {
468    let pattern =
469      Pattern::contextual("class A { $F = $I }", "public_field_definition", Tsx).expect("test");
470    let cand = pattern_node("let b = 123");
471    let nm = pattern.find_node(cand.root());
472    assert!(nm.is_none());
473  }
474
475  fn get_kind(kind_str: &str) -> usize {
476    Tsx.kind_to_id(kind_str).into()
477  }
478
479  #[test]
480  fn test_pattern_potential_kinds() {
481    let pattern = Pattern::new("const a = 1", Tsx);
482    let kind = get_kind("lexical_declaration");
483    let kinds = pattern.potential_kinds().expect("should have kinds");
484    assert_eq!(kinds.len(), 1);
485    assert!(kinds.contains(kind));
486  }
487
488  #[test]
489  fn test_pattern_with_non_root_meta_var() {
490    let pattern = Pattern::new("const $A = $B", Tsx);
491    let kind = get_kind("lexical_declaration");
492    let kinds = pattern.potential_kinds().expect("should have kinds");
493    assert_eq!(kinds.len(), 1);
494    assert!(kinds.contains(kind));
495  }
496
497  #[test]
498  fn test_bare_wildcard() {
499    let pattern = Pattern::new("$A", Tsx);
500    // wildcard should match anything, so kinds should be None
501    assert!(pattern.potential_kinds().is_none());
502  }
503
504  #[test]
505  fn test_contextual_potential_kinds() {
506    let pattern =
507      Pattern::contextual("class A { $F = $I }", "public_field_definition", Tsx).expect("test");
508    let kind = get_kind("public_field_definition");
509    let kinds = pattern.potential_kinds().expect("should have kinds");
510    assert_eq!(kinds.len(), 1);
511    assert!(kinds.contains(kind));
512  }
513
514  #[test]
515  fn test_contextual_wildcard() {
516    let pattern = Pattern::contextual("class A { $F }", "property_identifier", Tsx).expect("test");
517    let kind = get_kind("property_identifier");
518    let kinds = pattern.potential_kinds().expect("should have kinds");
519    assert_eq!(kinds.len(), 1);
520    assert!(kinds.contains(kind));
521  }
522
523  #[test]
524  #[ignore]
525  fn test_multi_node_pattern() {
526    let pattern = Pattern::new("a;b;c;", Tsx);
527    let kinds = pattern.potential_kinds().expect("should have kinds");
528    assert_eq!(kinds.len(), 1);
529    test_match("a;b;c", "a;b;c;");
530  }
531
532  #[test]
533  #[ignore]
534  fn test_multi_node_meta_var() {
535    let env = match_env("a;$B;c", "a;b;c");
536    assert_eq!(env["B"], "b");
537    let env = match_env("a;$B;c", "a;1+2+3;c");
538    assert_eq!(env["B"], "1+2+3");
539  }
540
541  #[test]
542  #[ignore]
543  fn test_pattern_size() {
544    assert_eq!(std::mem::size_of::<Pattern>(), 40);
545  }
546
547  #[test]
548  fn test_error_kind() {
549    let ret = Pattern::contextual("a", "property_identifier", Tsx);
550    assert!(ret.is_err());
551    let ret = Pattern::new("123+", Tsx);
552    assert!(ret.has_error());
553  }
554
555  #[test]
556  fn test_bare_wildcard_in_context() {
557    let pattern = Pattern::contextual("class A { $F }", "property_identifier", Tsx).expect("test");
558    let cand = pattern_node("let b = 123");
559    // it should not match
560    assert!(pattern.find_node(cand.root()).is_none());
561  }
562
563  #[test]
564  fn test_pattern_fixed_string() {
565    let pattern = Pattern::new("class A { $F }", Tsx);
566    assert_eq!(pattern.fixed_string(), "class");
567    let pattern = Pattern::contextual("class A { $F }", "property_identifier", Tsx).expect("test");
568    assert!(pattern.fixed_string().is_empty());
569  }
570
571  #[test]
572  fn test_pattern_error() {
573    let pattern = Pattern::try_new("", Tsx);
574    assert!(matches!(pattern, Err(PatternError::NoContent(_))));
575    let pattern = Pattern::try_new("12  3344", Tsx);
576    assert!(matches!(pattern, Err(PatternError::MultipleNode(_))));
577  }
578
579  #[test]
580  fn test_debug_pattern() {
581    let pattern = Pattern::new("var $A = 1", Tsx);
582    assert_eq!(
583      format!("{pattern:?}"),
584      "[var, [Capture(\"A\", true), =, 1]]"
585    );
586  }
587
588  fn defined_vars(s: &str) -> Vec<String> {
589    let pattern = Pattern::new(s, Tsx);
590    let mut vars: Vec<_> = pattern
591      .defined_vars()
592      .into_iter()
593      .map(String::from)
594      .collect();
595    vars.sort();
596    vars
597  }
598
599  #[test]
600  fn test_extract_meta_var_from_pattern() {
601    let vars = defined_vars("var $A = 1");
602    assert_eq!(vars, ["A"]);
603  }
604
605  #[test]
606  fn test_extract_complex_meta_var() {
607    let vars = defined_vars("function $FUNC($$$ARGS): $RET { $$$BODY }");
608    assert_eq!(vars, ["ARGS", "BODY", "FUNC", "RET"]);
609  }
610
611  #[test]
612  fn test_extract_duplicate_meta_var() {
613    let vars = defined_vars("var $A = $A");
614    assert_eq!(vars, ["A"]);
615  }
616
617  #[test]
618  fn test_contextual_pattern_vars() {
619    let pattern = Pattern::contextual("<div ref={$A}/>", "jsx_attribute", Tsx).expect("correct");
620    assert_eq!(pattern.defined_vars(), ["A"].into_iter().collect());
621  }
622
623  #[test]
624  fn test_gh_1087() {
625    test_match("($P) => $F($P)", "(x) => bar(x)");
626  }
627
628  #[test]
629  fn test_template_pattern_have_no_kinds() {
630    let pattern = Pattern::new("$A = $B", Tsx).with_strictness(MatchStrictness::Template);
631    assert!(pattern.potential_kinds().is_none());
632    let pattern = Pattern::contextual("{a: b}", "pair", Tsx)
633      .expect("should create template pattern")
634      .with_strictness(MatchStrictness::Template);
635    assert!(pattern.potential_kinds().is_none());
636  }
637}