ast_grep_core/
ops.rs

1use crate::matcher::{MatchAll, MatchNone, Matcher};
2use crate::meta_var::MetaVarEnv;
3use crate::{Doc, Language, Node};
4use bit_set::BitSet;
5use std::borrow::Cow;
6use std::marker::PhantomData;
7
8pub struct And<L: Language, P1: Matcher<L>, P2: Matcher<L>> {
9  pattern1: P1,
10  pattern2: P2,
11  lang: PhantomData<L>,
12}
13
14impl<L: Language, P1, P2> Matcher<L> for And<L, P1, P2>
15where
16  P1: Matcher<L>,
17  P2: Matcher<L>,
18{
19  fn match_node_with_env<'tree, D: Doc<Lang = L>>(
20    &self,
21    node: Node<'tree, D>,
22    env: &mut Cow<MetaVarEnv<'tree, D>>,
23  ) -> Option<Node<'tree, D>> {
24    let node = self.pattern1.match_node_with_env(node, env)?;
25    self.pattern2.match_node_with_env(node, env)
26  }
27
28  fn potential_kinds(&self) -> Option<BitSet> {
29    let set1 = self.pattern1.potential_kinds();
30    let set2 = self.pattern2.potential_kinds();
31    // if both constituent have Some(bitset), intersect them
32    // otherwise returns either of the non-null set
33    match (&set1, &set2) {
34      (Some(s1), Some(s2)) => Some(s1.intersection(s2).collect()),
35      _ => set1.xor(set2),
36    }
37  }
38}
39
40// we pre-compute and cache potential_kinds. So patterns should not be mutated.
41// Box<[P]> is used here for immutability so that kinds will never be invalidated.
42pub struct All<L: Language, P: Matcher<L>> {
43  patterns: Box<[P]>,
44  kinds: Option<BitSet>,
45  lang: PhantomData<L>,
46}
47
48impl<L: Language, P: Matcher<L>> All<L, P> {
49  pub fn new<PS: IntoIterator<Item = P>>(patterns: PS) -> Self {
50    let patterns: Box<[P]> = patterns.into_iter().collect();
51    let kinds = Self::compute_kinds(&patterns);
52    Self {
53      patterns,
54      kinds,
55      lang: PhantomData,
56    }
57  }
58
59  fn compute_kinds(patterns: &[P]) -> Option<BitSet> {
60    let mut set: Option<BitSet> = None;
61    for pattern in patterns {
62      let Some(n) = pattern.potential_kinds() else {
63        continue;
64      };
65      if let Some(set) = set.as_mut() {
66        set.intersect_with(&n);
67      } else {
68        set = Some(n);
69      }
70    }
71    set
72  }
73
74  pub fn inner(&self) -> &[P] {
75    &self.patterns
76  }
77}
78
79impl<L: Language, P: Matcher<L>> Matcher<L> for All<L, P> {
80  fn match_node_with_env<'tree, D: Doc<Lang = L>>(
81    &self,
82    node: Node<'tree, D>,
83    env: &mut Cow<MetaVarEnv<'tree, D>>,
84  ) -> Option<Node<'tree, D>> {
85    if let Some(kinds) = &self.kinds {
86      if !kinds.contains(node.kind_id().into()) {
87        return None;
88      }
89    }
90    let mut new_env = Cow::Borrowed(env.as_ref());
91    let all_satisfied = self
92      .patterns
93      .iter()
94      .all(|p| p.match_node_with_env(node.clone(), &mut new_env).is_some());
95    if all_satisfied {
96      *env = Cow::Owned(new_env.into_owned());
97      Some(node)
98    } else {
99      None
100    }
101  }
102
103  fn potential_kinds(&self) -> Option<BitSet> {
104    self.kinds.clone()
105  }
106}
107
108// Box<[P]> for immutability and potential_kinds cache correctness
109pub struct Any<L, P> {
110  patterns: Box<[P]>,
111  kinds: Option<BitSet>,
112  lang: PhantomData<L>,
113}
114
115impl<L: Language, P: Matcher<L>> Any<L, P> {
116  pub fn new<PS: IntoIterator<Item = P>>(patterns: PS) -> Self {
117    let patterns: Box<[P]> = patterns.into_iter().collect();
118    let kinds = Self::compute_kinds(&patterns);
119    Self {
120      patterns,
121      kinds,
122      lang: PhantomData,
123    }
124  }
125
126  fn compute_kinds(patterns: &[P]) -> Option<BitSet> {
127    let mut set = BitSet::new();
128    for pattern in patterns {
129      let n = pattern.potential_kinds()?;
130      set.union_with(&n);
131    }
132    Some(set)
133  }
134
135  pub fn inner(&self) -> &[P] {
136    &self.patterns
137  }
138}
139
140impl<L: Language, M: Matcher<L>> Matcher<L> for Any<L, M> {
141  fn match_node_with_env<'tree, D: Doc<Lang = L>>(
142    &self,
143    node: Node<'tree, D>,
144    env: &mut Cow<MetaVarEnv<'tree, D>>,
145  ) -> Option<Node<'tree, D>> {
146    if let Some(kinds) = &self.kinds {
147      if !kinds.contains(node.kind_id().into()) {
148        return None;
149      }
150    }
151    let mut new_env = Cow::Borrowed(env.as_ref());
152    let found = self.patterns.iter().find_map(|p| {
153      new_env = Cow::Borrowed(env.as_ref());
154      p.match_node_with_env(node.clone(), &mut new_env)
155    });
156    if found.is_some() {
157      *env = Cow::Owned(new_env.into_owned());
158      Some(node)
159    } else {
160      None
161    }
162  }
163
164  fn potential_kinds(&self) -> Option<BitSet> {
165    self.kinds.clone()
166  }
167}
168
169pub struct Or<L: Language, P1: Matcher<L>, P2: Matcher<L>> {
170  pattern1: P1,
171  pattern2: P2,
172  lang: PhantomData<L>,
173}
174
175impl<L, P1, P2> Matcher<L> for Or<L, P1, P2>
176where
177  L: Language,
178  P1: Matcher<L>,
179  P2: Matcher<L>,
180{
181  fn match_node_with_env<'tree, D: Doc<Lang = L>>(
182    &self,
183    node: Node<'tree, D>,
184    env: &mut Cow<MetaVarEnv<'tree, D>>,
185  ) -> Option<Node<'tree, D>> {
186    let mut new_env = Cow::Borrowed(env.as_ref());
187    if let Some(ret) = self
188      .pattern1
189      .match_node_with_env(node.clone(), &mut new_env)
190    {
191      *env = Cow::Owned(new_env.into_owned());
192      Some(ret)
193    } else {
194      self.pattern2.match_node_with_env(node, env)
195    }
196  }
197
198  fn potential_kinds(&self) -> Option<BitSet> {
199    let mut set1 = self.pattern1.potential_kinds()?;
200    let set2 = self.pattern2.potential_kinds()?;
201    set1.union_with(&set2);
202    Some(set1)
203  }
204}
205
206pub struct Not<L: Language, M: Matcher<L>> {
207  not: M,
208  lang: PhantomData<L>,
209}
210
211impl<L: Language, M: Matcher<L>> Not<L, M> {
212  pub fn new(not: M) -> Self {
213    Self {
214      not,
215      lang: PhantomData,
216    }
217  }
218
219  pub fn inner(&self) -> &M {
220    &self.not
221  }
222}
223impl<L, P> Matcher<L> for Not<L, P>
224where
225  L: Language,
226  P: Matcher<L>,
227{
228  fn match_node_with_env<'tree, D: Doc<Lang = L>>(
229    &self,
230    node: Node<'tree, D>,
231    env: &mut Cow<MetaVarEnv<'tree, D>>,
232  ) -> Option<Node<'tree, D>> {
233    self
234      .not
235      .match_node_with_env(node.clone(), env)
236      .xor(Some(node))
237  }
238}
239
240#[derive(Clone)]
241pub struct Op<L: Language, M: Matcher<L>> {
242  inner: M,
243  lang: PhantomData<L>,
244}
245
246impl<L, M> Matcher<L> for Op<L, M>
247where
248  L: Language,
249  M: Matcher<L>,
250{
251  fn match_node_with_env<'tree, D: Doc<Lang = L>>(
252    &self,
253    node: Node<'tree, D>,
254    env: &mut Cow<MetaVarEnv<'tree, D>>,
255  ) -> Option<Node<'tree, D>> {
256    let ret = self.inner.match_node_with_env(node, env);
257    ret
258  }
259
260  fn potential_kinds(&self) -> Option<BitSet> {
261    self.inner.potential_kinds()
262  }
263}
264
265/*
266pub struct Predicate<F> {
267  func: F,
268}
269
270impl<L, F> Matcher<L> for Predicate<F>
271where
272  L: Language,
273  F: for<'tree> Fn(&Node<'tree, StrDoc<L>>) -> bool,
274{
275  fn match_node_with_env<'tree, D: Doc<Lang=L>>(
276    &self,
277    node: Node<'tree, D>,
278    env: &mut MetaVarEnv<'tree, D>,
279  ) -> Option<Node<'tree, D>> {
280    (self.func)(&node).then_some(node)
281  }
282}
283*/
284
285/*
286// we don't need specify M for static method
287impl<L: Language> Op<L, MatchNone> {
288  pub fn func<F>(func: F) -> Predicate<F>
289  where
290    F: for<'tree> Fn(&Node<'tree, StrDoc<L>>) -> bool,
291  {
292    Predicate { func }
293  }
294}
295*/
296
297impl<L: Language, M: Matcher<L>> Op<L, M> {
298  pub fn not(pattern: M) -> Not<L, M> {
299    Not {
300      not: pattern,
301      lang: PhantomData,
302    }
303  }
304}
305
306impl<L: Language, M: Matcher<L>> Op<L, M> {
307  pub fn every(pattern: M) -> Op<L, And<L, M, MatchAll>> {
308    Op {
309      inner: And {
310        pattern1: pattern,
311        pattern2: MatchAll,
312        lang: PhantomData,
313      },
314      lang: PhantomData,
315    }
316  }
317  pub fn either(pattern: M) -> Op<L, Or<L, M, MatchNone>> {
318    Op {
319      inner: Or {
320        pattern1: pattern,
321        pattern2: MatchNone,
322        lang: PhantomData,
323      },
324      lang: PhantomData,
325    }
326  }
327
328  pub fn all<MS: IntoIterator<Item = M>>(patterns: MS) -> All<L, M> {
329    All::new(patterns)
330  }
331
332  pub fn any<MS: IntoIterator<Item = M>>(patterns: MS) -> Any<L, M> {
333    Any::new(patterns)
334  }
335
336  pub fn new(matcher: M) -> Op<L, M> {
337    Self {
338      inner: matcher,
339      lang: PhantomData,
340    }
341  }
342}
343
344type NestedAnd<L, M, N, O> = And<L, And<L, M, N>, O>;
345impl<L: Language, M: Matcher<L>, N: Matcher<L>> Op<L, And<L, M, N>> {
346  pub fn and<O: Matcher<L>>(self, other: O) -> Op<L, NestedAnd<L, M, N, O>> {
347    Op {
348      inner: And {
349        pattern1: self.inner,
350        pattern2: other,
351        lang: PhantomData,
352      },
353      lang: PhantomData,
354    }
355  }
356}
357
358type NestedOr<L, M, N, O> = Or<L, Or<L, M, N>, O>;
359impl<L: Language, M: Matcher<L>, N: Matcher<L>> Op<L, Or<L, M, N>> {
360  pub fn or<O: Matcher<L>>(self, other: O) -> Op<L, NestedOr<L, M, N, O>> {
361    Op {
362      inner: Or {
363        pattern1: self.inner,
364        pattern2: other,
365        lang: PhantomData,
366      },
367      lang: PhantomData,
368    }
369  }
370}
371
372#[cfg(test)]
373mod test {
374  use super::*;
375  use crate::language::Tsx;
376  use crate::Root;
377
378  fn test_find(matcher: &impl Matcher<Tsx>, code: &str) {
379    let node = Root::str(code, Tsx);
380    assert!(matcher.find_node(node.root()).is_some());
381  }
382  fn test_not_find(matcher: &impl Matcher<Tsx>, code: &str) {
383    let node = Root::str(code, Tsx);
384    assert!(matcher.find_node(node.root()).is_none());
385  }
386  fn find_all(matcher: impl Matcher<Tsx>, code: &str) -> Vec<String> {
387    let node = Root::str(code, Tsx);
388    node
389      .root()
390      .find_all(matcher)
391      .map(|n| n.text().to_string())
392      .collect()
393  }
394
395  #[test]
396  fn test_or() {
397    let matcher = Or {
398      pattern1: "let a = 1",
399      pattern2: "const b = 2",
400      lang: PhantomData,
401    };
402    test_find(&matcher, "let a = 1");
403    test_find(&matcher, "const b = 2");
404    test_not_find(&matcher, "let a = 2");
405    test_not_find(&matcher, "const a = 1");
406    test_not_find(&matcher, "let b = 2");
407    test_not_find(&matcher, "const b = 1");
408  }
409
410  #[test]
411  fn test_not() {
412    let matcher = Not {
413      not: "let a = 1",
414      lang: PhantomData,
415    };
416    test_find(&matcher, "const b = 2");
417  }
418
419  #[test]
420  fn test_and() {
421    let matcher = And {
422      pattern1: "let a = $_",
423      pattern2: Not {
424        not: "let a = 123",
425        lang: PhantomData,
426      },
427      lang: PhantomData,
428    };
429    test_find(&matcher, "let a = 233");
430    test_find(&matcher, "let a = 456");
431    test_not_find(&matcher, "let a = 123");
432  }
433
434  #[test]
435  fn test_api_and() {
436    let matcher = Op::every("let a = $_").and(Op::not("let a = 123"));
437    test_find(&matcher, "let a = 233");
438    test_find(&matcher, "let a = 456");
439    test_not_find(&matcher, "let a = 123");
440  }
441
442  #[test]
443  fn test_api_or() {
444    let matcher = Op::either("let a = 1").or("const b = 2");
445    test_find(&matcher, "let a = 1");
446    test_find(&matcher, "const b = 2");
447    test_not_find(&matcher, "let a = 2");
448    test_not_find(&matcher, "const a = 1");
449    test_not_find(&matcher, "let b = 2");
450    test_not_find(&matcher, "const b = 1");
451  }
452  #[test]
453  fn test_multiple_match() {
454    let sequential = find_all("$A + b", "let f = () => a + b; let ff = () => c + b");
455    assert_eq!(sequential.len(), 2);
456    let nested = find_all(
457      "function $A() { $$$ }",
458      "function a() { function b() { b } }",
459    );
460    assert_eq!(nested.len(), 2);
461  }
462
463  #[test]
464  fn test_multiple_match_order() {
465    let ret = find_all(
466      "$A + b",
467      "let f = () => () => () => a + b; let ff = () => c + b",
468    );
469    assert_eq!(ret, ["a + b", "c + b"], "should match source code order");
470  }
471
472  /*
473  #[test]
474  fn test_api_func() {
475    let matcher = Op::func(|n| n.text().contains("114514"));
476    test_find(&matcher, "let a = 114514");
477    test_not_find(&matcher, "let a = 1919810");
478  }
479  */
480  use crate::Pattern;
481  trait TsxMatcher {
482    fn t(self) -> Pattern<Tsx>;
483  }
484  impl TsxMatcher for &str {
485    fn t(self) -> Pattern<Tsx> {
486      Pattern::new(self, Tsx)
487    }
488  }
489
490  #[test]
491  fn test_and_kinds() {
492    // intersect None kinds
493    let matcher = Op::every("let a = $_".t()).and(Op::not("let a = 123".t()));
494    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
495    let matcher = Op::every(Op::not("let a = $_".t())).and("let a = 123".t());
496    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
497    // intersect Same kinds
498    let matcher = Op::every("let a = $_".t()).and("let b = 123".t());
499    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
500    // intersect different kinds
501    let matcher = Op::every("let a = 1".t()).and("console.log(1)".t());
502    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(0));
503    // two None kinds
504    let matcher = Op::every(Op::not("let a = $_".t())).and(Op::not("let a = 123".t()));
505    assert_eq!(matcher.potential_kinds(), None);
506  }
507
508  #[test]
509  fn test_or_kinds() {
510    // union None kinds
511    let matcher = Op::either("let a = $_".t()).or(Op::not("let a = 123".t()));
512    assert_eq!(matcher.potential_kinds(), None);
513    let matcher = Op::either(Op::not("let a = $_".t())).or("let a = 123".t());
514    assert_eq!(matcher.potential_kinds(), None);
515    // union Same kinds
516    let matcher = Op::either("let a = $_".t()).or("let b = 123".t());
517    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
518    // union different kinds
519    let matcher = Op::either("let a = 1".t()).or("console.log(1)".t());
520    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(2));
521    // two None kinds
522    let matcher = Op::either(Op::not("let a = $_".t())).or(Op::not("let a = 123".t()));
523    assert_eq!(matcher.potential_kinds(), None);
524  }
525
526  #[test]
527  fn test_all_kinds() {
528    // intersect None kinds
529    let matcher = Op::all(["let a = $_".t(), "$A".t()]);
530    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
531    let matcher = Op::all(["$A".t(), "let a = $_".t()]);
532    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
533    // intersect Same kinds
534    let matcher = Op::all(["let a = $_".t(), "let b = 123".t()]);
535    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
536    // intersect different kinds
537    let matcher = Op::all(["let a = 1".t(), "console.log(1)".t()]);
538    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(0));
539    // two None kinds
540    let matcher = Op::all(["$A".t(), "$B".t()]);
541    assert_eq!(matcher.potential_kinds(), None);
542  }
543
544  #[test]
545  fn test_any_kinds() {
546    // union None kinds
547    let matcher = Op::any(["let a = $_".t(), "$A".t()]);
548    assert_eq!(matcher.potential_kinds(), None);
549    let matcher = Op::any(["$A".t(), "let a = $_".t()]);
550    assert_eq!(matcher.potential_kinds(), None);
551    // union Same kinds
552    let matcher = Op::any(["let a = $_".t(), "let b = 123".t()]);
553    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
554    // union different kinds
555    let matcher = Op::any(["let a = 1".t(), "console.log(1)".t()]);
556    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(2));
557    // two None kinds
558    let matcher = Op::any(["$A".t(), "$B".t()]);
559    assert_eq!(matcher.potential_kinds(), None);
560  }
561
562  #[test]
563  fn test_or_revert_env() {
564    let matcher = Op::either(Op::every("foo($A)".t()).and("impossible".t())).or("foo($B)".t());
565    let code = Root::str("foo(123)", Tsx);
566    let matches = code.root().find(matcher).expect("should found");
567    assert!(matches.get_env().get_match("A").is_none());
568    assert_eq!(matches.get_env().get_match("B").unwrap().text(), "123");
569  }
570
571  #[test]
572  fn test_any_revert_env() {
573    let matcher = Op::any([
574      Op::all(["foo($A)".t(), "impossible".t()]),
575      Op::all(["foo($B)".t()]),
576    ]);
577    let code = Root::str("foo(123)", Tsx);
578    let matches = code.root().find(matcher).expect("should found");
579    assert!(matches.get_env().get_match("A").is_none());
580    assert_eq!(matches.get_env().get_match("B").unwrap().text(), "123");
581  }
582
583  // gh #1225
584  #[test]
585  fn test_all_revert_env() {
586    let matcher = Op::all(["$A(123)".t(), "$B(456)".t()]);
587    let code = Root::str("foo(123)", Tsx);
588    let node = code.root().find("foo($C)").expect("should exist");
589    let node = node.get_node().clone();
590    let mut env = Cow::Owned(MetaVarEnv::new());
591    assert!(matcher.match_node_with_env(node, &mut env).is_none());
592    assert!(env.get_match("A").is_none());
593  }
594}