1use crate::language::Language;
2use crate::match_tree::{match_end_non_recursive, match_node_non_recursive, MatchStrictness};
3use crate::matcher::{kind_utils, KindMatcher, KindMatcherError, Matcher};
4use crate::meta_var::{MetaVarEnv, MetaVariable};
5use crate::source::SgNode;
6use crate::{Doc, Node, Root};
7
8use bit_set::BitSet;
9use thiserror::Error;
10
11use std::borrow::Cow;
12use std::collections::HashSet;
13
14#[derive(Clone)]
15pub struct Pattern {
16 pub node: PatternNode,
17 root_kind: Option<u16>,
18 pub strictness: MatchStrictness,
19}
20
21pub struct PatternBuilder<'a> {
22 selector: Option<&'a str>,
23 src: Cow<'a, str>,
24}
25
26impl PatternBuilder<'_> {
27 pub fn build<D, F>(&self, parse: F) -> Result<Pattern, PatternError>
28 where
29 F: FnOnce(&str) -> Result<D, String>,
30 D: Doc,
31 {
32 let doc = parse(&self.src).map_err(PatternError::Parse)?;
33 let root = Root::doc(doc);
34 if let Some(selector) = self.selector {
35 self.contextual(&root, selector)
36 } else {
37 self.single(&root)
38 }
39 }
40 fn single<D: Doc>(&self, root: &Root<D>) -> Result<Pattern, PatternError> {
41 let goal = root.root();
42 if goal.children().len() == 0 {
43 return Err(PatternError::NoContent(self.src.to_string()));
44 }
45 if !is_single_node(&goal.inner) {
46 return Err(PatternError::MultipleNode(self.src.to_string()));
47 }
48 let node = Pattern::single_matcher(root);
49 Ok(Pattern::from(node))
50 }
51
52 fn contextual<D: Doc>(&self, root: &Root<D>, selector: &str) -> Result<Pattern, PatternError> {
53 let goal = root.root();
54 let kind_matcher = KindMatcher::try_new(selector, root.lang().clone())?;
55 let Some(node) = goal.find(&kind_matcher) else {
56 return Err(PatternError::NoSelectorInContext {
57 context: self.src.to_string(),
58 selector: selector.into(),
59 });
60 };
61 Ok(Pattern {
62 root_kind: Some(node.kind_id()),
63 node: convert_node_to_pattern(node.get_node().clone()),
64 strictness: MatchStrictness::Smart,
65 })
66 }
67}
68
69#[derive(Clone)]
70pub enum PatternNode {
71 MetaVar {
72 meta_var: MetaVariable,
73 },
74 Terminal {
76 text: String,
77 is_named: bool,
78 kind_id: u16,
79 },
80 Internal {
82 kind_id: u16,
83 children: Vec<PatternNode>,
84 },
85}
86
87impl PatternNode {
88 pub fn is_trivial(&self) -> bool {
90 match self {
91 PatternNode::Terminal { is_named, .. } => !*is_named,
92 _ => false,
93 }
94 }
95
96 pub fn fixed_string(&self) -> Cow<'_, str> {
97 match &self {
98 PatternNode::Terminal { text, .. } => Cow::Borrowed(text),
99 PatternNode::MetaVar { .. } => Cow::Borrowed(""),
100 PatternNode::Internal { children, .. } => {
101 children
102 .iter()
103 .map(|n| n.fixed_string())
104 .fold(Cow::Borrowed(""), |longest, curr| {
105 if longest.len() >= curr.len() {
106 longest
107 } else {
108 curr
109 }
110 })
111 }
112 }
113 }
114}
115impl<'r, D: Doc> From<Node<'r, D>> for PatternNode {
116 fn from(node: Node<'r, D>) -> Self {
117 convert_node_to_pattern(node)
118 }
119}
120
121impl<'r, D: Doc> From<Node<'r, D>> for Pattern {
122 fn from(node: Node<'r, D>) -> Self {
123 Self {
124 node: convert_node_to_pattern(node),
125 root_kind: None,
126 strictness: MatchStrictness::Smart,
127 }
128 }
129}
130
131fn convert_node_to_pattern<D: Doc>(node: Node<'_, D>) -> PatternNode {
132 if let Some(meta_var) = extract_var_from_node(&node) {
133 PatternNode::MetaVar { meta_var }
134 } else if node.is_leaf() {
135 PatternNode::Terminal {
136 text: node.text().to_string(),
137 is_named: node.is_named(),
138 kind_id: node.kind_id(),
139 }
140 } else {
141 let children = node.children().filter_map(|n| {
142 if n.is_missing() {
143 None
144 } else {
145 Some(PatternNode::from(n))
146 }
147 });
148 PatternNode::Internal {
149 kind_id: node.kind_id(),
150 children: children.collect(),
151 }
152 }
153}
154
155fn extract_var_from_node<D: Doc>(goal: &Node<'_, D>) -> Option<MetaVariable> {
156 let key = goal.text();
157 goal.lang().extract_meta_var(&key)
158}
159
160#[derive(Debug, Error)]
161pub enum PatternError {
162 #[error("Fails to parse the pattern query: `{0}`")]
163 Parse(String),
164 #[error("No AST root is detected. Please check the pattern source `{0}`.")]
165 NoContent(String),
166 #[error("Multiple AST nodes are detected. Please check the pattern source `{0}`.")]
167 MultipleNode(String),
168 #[error(transparent)]
169 InvalidKind(#[from] KindMatcherError),
170 #[error("Fails to create Contextual pattern: selector `{selector}` matches no node in the context `{context}`.")]
171 NoSelectorInContext { context: String, selector: String },
172}
173
174#[inline]
175fn is_single_node<'r, N: SgNode<'r>>(n: &N) -> bool {
176 match n.children().len() {
177 1 => true,
178 2 => {
179 let c = n.child(1).expect("second child must exist");
180 c.is_missing() || c.kind().is_empty()
183 }
184 _ => false,
185 }
186}
187impl Pattern {
188 pub fn has_error(&self) -> bool {
189 let kind = match &self.node {
190 PatternNode::Terminal { kind_id, .. } => *kind_id,
191 PatternNode::Internal { kind_id, .. } => *kind_id,
192 PatternNode::MetaVar { .. } => match self.root_kind {
193 Some(k) => k,
194 None => return false,
195 },
196 };
197 kind_utils::is_error_kind(kind)
198 }
199
200 pub fn fixed_string(&self) -> Cow<'_, str> {
201 self.node.fixed_string()
202 }
203
204 pub fn defined_vars(&self) -> HashSet<&str> {
207 let mut vars = HashSet::new();
208 collect_vars(&self.node, &mut vars);
209 vars
210 }
211}
212
213fn meta_var_name(meta_var: &MetaVariable) -> Option<&str> {
214 use MetaVariable as MV;
215 match meta_var {
216 MV::Capture(name, _) => Some(name),
217 MV::MultiCapture(name) => Some(name),
218 MV::Dropped(_) => None,
219 MV::Multiple => None,
220 }
221}
222
223fn collect_vars<'p>(p: &'p PatternNode, vars: &mut HashSet<&'p str>) {
224 match p {
225 PatternNode::MetaVar { meta_var, .. } => {
226 if let Some(name) = meta_var_name(meta_var) {
227 vars.insert(name);
228 }
229 }
230 PatternNode::Terminal { .. } => {
231 }
233 PatternNode::Internal { children, .. } => {
234 for c in children {
235 collect_vars(c, vars);
236 }
237 }
238 }
239}
240
241impl Pattern {
242 pub fn try_new<L: Language>(src: &str, lang: L) -> Result<Self, PatternError> {
243 let processed = lang.pre_process_pattern(src);
244 let builder = PatternBuilder {
245 selector: None,
246 src: processed,
247 };
248 lang.build_pattern(&builder)
249 }
250
251 pub fn new<L: Language>(src: &str, lang: L) -> Self {
252 Self::try_new(src, lang).unwrap()
253 }
254
255 pub fn with_strictness(mut self, strictness: MatchStrictness) -> Self {
256 self.strictness = strictness;
257 self
258 }
259
260 pub fn contextual<L: Language>(
261 context: &str,
262 selector: &str,
263 lang: L,
264 ) -> Result<Self, PatternError> {
265 let processed = lang.pre_process_pattern(context);
266 let builder = PatternBuilder {
267 selector: Some(selector),
268 src: processed,
269 };
270 lang.build_pattern(&builder)
271 }
272 fn single_matcher<D: Doc>(root: &Root<D>) -> Node<'_, D> {
273 let node = root.root();
275 let mut inner = node.inner;
276 while is_single_node(&inner) {
277 inner = inner.child(0).unwrap();
278 }
279 Node { inner, root }
280 }
281}
282
283impl Matcher for Pattern {
284 fn match_node_with_env<'tree, D: Doc>(
285 &self,
286 node: Node<'tree, D>,
287 env: &mut Cow<MetaVarEnv<'tree, D>>,
288 ) -> Option<Node<'tree, D>> {
289 if let Some(k) = self.root_kind {
290 if node.kind_id() != k {
291 return None;
292 }
293 }
294 let mut may_write = Cow::Borrowed(env.as_ref());
296 let node = match_node_non_recursive(self, node, &mut may_write)?;
297 if let Cow::Owned(map) = may_write {
298 *env = Cow::Owned(map);
300 }
301 Some(node)
302 }
303
304 fn potential_kinds(&self) -> Option<bit_set::BitSet> {
305 if matches!(self.strictness, MatchStrictness::Template) {
307 return None;
308 }
309 let kind = match self.node {
310 PatternNode::Terminal { kind_id, .. } => kind_id,
311 PatternNode::MetaVar { .. } => self.root_kind?,
312 PatternNode::Internal { kind_id, .. } => {
313 if kind_utils::is_error_kind(kind_id) {
314 return None;
316 }
317 kind_id
318 }
319 };
320
321 let mut kinds = BitSet::new();
322 kinds.insert(kind.into());
323 Some(kinds)
324 }
325
326 fn get_match_len<D: Doc>(&self, node: Node<'_, D>) -> Option<usize> {
327 let start = node.range().start;
328 let end = match_end_non_recursive(self, node)?;
329 Some(end - start)
330 }
331}
332impl std::fmt::Debug for PatternNode {
333 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334 match self {
335 Self::MetaVar { meta_var, .. } => write!(f, "{meta_var:?}"),
336 Self::Terminal { text, .. } => write!(f, "{text}"),
337 Self::Internal { children, .. } => write!(f, "{children:?}"),
338 }
339 }
340}
341
342impl std::fmt::Debug for Pattern {
343 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344 write!(f, "{:?}", self.node)
345 }
346}
347
348#[cfg(test)]
349mod test {
350 use super::*;
351 use crate::language::Tsx;
352 use crate::matcher::MatcherExt;
353 use crate::meta_var::MetaVarEnv;
354 use crate::tree_sitter::StrDoc;
355 use std::collections::HashMap;
356
357 fn pattern_node(s: &str) -> Root<StrDoc<Tsx>> {
358 Root::str(s, Tsx)
359 }
360
361 fn test_match(s1: &str, s2: &str) {
362 let pattern = Pattern::new(s1, Tsx);
363 let cand = pattern_node(s2);
364 let cand = cand.root();
365 assert!(
366 pattern.find_node(cand.clone()).is_some(),
367 "goal: {:?}, candidate: {}",
368 pattern,
369 cand.get_inner_node().to_sexp(),
370 );
371 }
372 fn test_non_match(s1: &str, s2: &str) {
373 let pattern = Pattern::new(s1, Tsx);
374 let cand = pattern_node(s2);
375 let cand = cand.root();
376 assert!(
377 pattern.find_node(cand.clone()).is_none(),
378 "goal: {:?}, candidate: {}",
379 pattern,
380 cand.get_inner_node().to_sexp(),
381 );
382 }
383
384 #[test]
385 fn test_meta_variable() {
386 test_match("const a = $VALUE", "const a = 123");
387 test_match("const $VARIABLE = $VALUE", "const a = 123");
388 test_match("const $VARIABLE = $VALUE", "const a = 123");
389 }
390
391 #[test]
392 fn test_whitespace() {
393 test_match("function t() { }", "function t() {}");
394 test_match("function t() {}", "function t() { }");
395 }
396
397 fn match_env(goal_str: &str, cand: &str) -> HashMap<String, String> {
398 let pattern = Pattern::new(goal_str, Tsx);
399 let cand = pattern_node(cand);
400 let cand = cand.root();
401 let nm = pattern.find_node(cand).unwrap();
402 HashMap::from(nm.get_env().clone())
403 }
404
405 #[test]
406 fn test_meta_variable_env() {
407 let env = match_env("const a = $VALUE", "const a = 123");
408 assert_eq!(env["VALUE"], "123");
409 }
410
411 #[test]
412 fn test_pattern_should_not_pollute_env() {
413 let pattern = Pattern::new("const $A = 114", Tsx);
415 let cand = pattern_node("const a = 514");
416 let cand = cand.root().child(0).unwrap();
417 let map = MetaVarEnv::new();
418 let mut env = Cow::Borrowed(&map);
419 let nm = pattern.match_node_with_env(cand, &mut env);
420 assert!(nm.is_none());
421 assert!(env.get_match("A").is_none());
422 assert!(map.get_match("A").is_none());
423 }
424
425 #[test]
426 fn test_match_non_atomic() {
427 let env = match_env("const a = $VALUE", "const a = 5 + 3");
428 assert_eq!(env["VALUE"], "5 + 3");
429 }
430
431 #[test]
432 fn test_class_assignment() {
433 test_match("class $C { $MEMBER = $VAL}", "class A {a = 123}");
434 test_non_match("class $C { $MEMBER = $VAL; b = 123; }", "class A {a = 123}");
435 test_non_match("a = 123", "class B {b = 123}");
437 }
438
439 #[test]
440 fn test_return() {
441 test_match("$A($B)", "return test(123)");
442 }
443
444 #[test]
445 fn test_contextual_pattern() {
446 let pattern =
447 Pattern::contextual("class A { $F = $I }", "public_field_definition", Tsx).expect("test");
448 let cand = pattern_node("class B { b = 123 }");
449 assert!(pattern.find_node(cand.root()).is_some());
450 let cand = pattern_node("let b = 123");
451 assert!(pattern.find_node(cand.root()).is_none());
452 }
453
454 #[test]
455 fn test_contextual_match_with_env() {
456 let pattern =
457 Pattern::contextual("class A { $F = $I }", "public_field_definition", Tsx).expect("test");
458 let cand = pattern_node("class B { b = 123 }");
459 let nm = pattern.find_node(cand.root()).expect("test");
460 let env = nm.get_env();
461 let env = HashMap::from(env.clone());
462 assert_eq!(env["F"], "b");
463 assert_eq!(env["I"], "123");
464 }
465
466 #[test]
467 fn test_contextual_unmatch_with_env() {
468 let pattern =
469 Pattern::contextual("class A { $F = $I }", "public_field_definition", Tsx).expect("test");
470 let cand = pattern_node("let b = 123");
471 let nm = pattern.find_node(cand.root());
472 assert!(nm.is_none());
473 }
474
475 fn get_kind(kind_str: &str) -> usize {
476 Tsx.kind_to_id(kind_str).into()
477 }
478
479 #[test]
480 fn test_pattern_potential_kinds() {
481 let pattern = Pattern::new("const a = 1", Tsx);
482 let kind = get_kind("lexical_declaration");
483 let kinds = pattern.potential_kinds().expect("should have kinds");
484 assert_eq!(kinds.len(), 1);
485 assert!(kinds.contains(kind));
486 }
487
488 #[test]
489 fn test_pattern_with_non_root_meta_var() {
490 let pattern = Pattern::new("const $A = $B", Tsx);
491 let kind = get_kind("lexical_declaration");
492 let kinds = pattern.potential_kinds().expect("should have kinds");
493 assert_eq!(kinds.len(), 1);
494 assert!(kinds.contains(kind));
495 }
496
497 #[test]
498 fn test_bare_wildcard() {
499 let pattern = Pattern::new("$A", Tsx);
500 assert!(pattern.potential_kinds().is_none());
502 }
503
504 #[test]
505 fn test_contextual_potential_kinds() {
506 let pattern =
507 Pattern::contextual("class A { $F = $I }", "public_field_definition", Tsx).expect("test");
508 let kind = get_kind("public_field_definition");
509 let kinds = pattern.potential_kinds().expect("should have kinds");
510 assert_eq!(kinds.len(), 1);
511 assert!(kinds.contains(kind));
512 }
513
514 #[test]
515 fn test_contextual_wildcard() {
516 let pattern = Pattern::contextual("class A { $F }", "property_identifier", Tsx).expect("test");
517 let kind = get_kind("property_identifier");
518 let kinds = pattern.potential_kinds().expect("should have kinds");
519 assert_eq!(kinds.len(), 1);
520 assert!(kinds.contains(kind));
521 }
522
523 #[test]
524 #[ignore]
525 fn test_multi_node_pattern() {
526 let pattern = Pattern::new("a;b;c;", Tsx);
527 let kinds = pattern.potential_kinds().expect("should have kinds");
528 assert_eq!(kinds.len(), 1);
529 test_match("a;b;c", "a;b;c;");
530 }
531
532 #[test]
533 #[ignore]
534 fn test_multi_node_meta_var() {
535 let env = match_env("a;$B;c", "a;b;c");
536 assert_eq!(env["B"], "b");
537 let env = match_env("a;$B;c", "a;1+2+3;c");
538 assert_eq!(env["B"], "1+2+3");
539 }
540
541 #[test]
542 #[ignore]
543 fn test_pattern_size() {
544 assert_eq!(std::mem::size_of::<Pattern>(), 40);
545 }
546
547 #[test]
548 fn test_error_kind() {
549 let ret = Pattern::contextual("a", "property_identifier", Tsx);
550 assert!(ret.is_err());
551 let ret = Pattern::new("123+", Tsx);
552 assert!(ret.has_error());
553 }
554
555 #[test]
556 fn test_bare_wildcard_in_context() {
557 let pattern = Pattern::contextual("class A { $F }", "property_identifier", Tsx).expect("test");
558 let cand = pattern_node("let b = 123");
559 assert!(pattern.find_node(cand.root()).is_none());
561 }
562
563 #[test]
564 fn test_pattern_fixed_string() {
565 let pattern = Pattern::new("class A { $F }", Tsx);
566 assert_eq!(pattern.fixed_string(), "class");
567 let pattern = Pattern::contextual("class A { $F }", "property_identifier", Tsx).expect("test");
568 assert!(pattern.fixed_string().is_empty());
569 }
570
571 #[test]
572 fn test_pattern_error() {
573 let pattern = Pattern::try_new("", Tsx);
574 assert!(matches!(pattern, Err(PatternError::NoContent(_))));
575 let pattern = Pattern::try_new("12 3344", Tsx);
576 assert!(matches!(pattern, Err(PatternError::MultipleNode(_))));
577 }
578
579 #[test]
580 fn test_debug_pattern() {
581 let pattern = Pattern::new("var $A = 1", Tsx);
582 assert_eq!(
583 format!("{pattern:?}"),
584 "[var, [Capture(\"A\", true), =, 1]]"
585 );
586 }
587
588 fn defined_vars(s: &str) -> Vec<String> {
589 let pattern = Pattern::new(s, Tsx);
590 let mut vars: Vec<_> = pattern
591 .defined_vars()
592 .into_iter()
593 .map(String::from)
594 .collect();
595 vars.sort();
596 vars
597 }
598
599 #[test]
600 fn test_extract_meta_var_from_pattern() {
601 let vars = defined_vars("var $A = 1");
602 assert_eq!(vars, ["A"]);
603 }
604
605 #[test]
606 fn test_extract_complex_meta_var() {
607 let vars = defined_vars("function $FUNC($$$ARGS): $RET { $$$BODY }");
608 assert_eq!(vars, ["ARGS", "BODY", "FUNC", "RET"]);
609 }
610
611 #[test]
612 fn test_extract_duplicate_meta_var() {
613 let vars = defined_vars("var $A = $A");
614 assert_eq!(vars, ["A"]);
615 }
616
617 #[test]
618 fn test_contextual_pattern_vars() {
619 let pattern = Pattern::contextual("<div ref={$A}/>", "jsx_attribute", Tsx).expect("correct");
620 assert_eq!(pattern.defined_vars(), ["A"].into_iter().collect());
621 }
622
623 #[test]
624 fn test_gh_1087() {
625 test_match("($P) => $F($P)", "(x) => bar(x)");
626 }
627
628 #[test]
629 fn test_template_pattern_have_no_kinds() {
630 let pattern = Pattern::new("$A = $B", Tsx).with_strictness(MatchStrictness::Template);
631 assert!(pattern.potential_kinds().is_none());
632 let pattern = Pattern::contextual("{a: b}", "pair", Tsx)
633 .expect("should create template pattern")
634 .with_strictness(MatchStrictness::Template);
635 assert!(pattern.potential_kinds().is_none());
636 }
637}