ast_grep_core/
ops.rs

1use crate::matcher::{MatchAll, MatchNone, Matcher};
2use crate::meta_var::MetaVarEnv;
3use crate::{Doc, Language, Node};
4use bit_set::BitSet;
5use std::borrow::Cow;
6use std::marker::PhantomData;
7
8pub struct And<L: Language, P1: Matcher<L>, P2: Matcher<L>> {
9  pattern1: P1,
10  pattern2: P2,
11  lang: PhantomData<L>,
12}
13
14impl<L: Language, P1, P2> Matcher<L> for And<L, P1, P2>
15where
16  P1: Matcher<L>,
17  P2: Matcher<L>,
18{
19  fn match_node_with_env<'tree, D: Doc<Lang = L>>(
20    &self,
21    node: Node<'tree, D>,
22    env: &mut Cow<MetaVarEnv<'tree, D>>,
23  ) -> Option<Node<'tree, D>> {
24    let node = self.pattern1.match_node_with_env(node, env)?;
25    self.pattern2.match_node_with_env(node, env)
26  }
27
28  fn potential_kinds(&self) -> Option<BitSet> {
29    let set1 = self.pattern1.potential_kinds();
30    let set2 = self.pattern2.potential_kinds();
31    // if both constituent have Some(bitset), intersect them
32    // otherwise returns either of the non-null set
33    match (&set1, &set2) {
34      (Some(s1), Some(s2)) => Some(s1.intersection(s2).collect()),
35      _ => set1.xor(set2),
36    }
37  }
38}
39
40// we pre-compute and cache potential_kinds. So patterns should not be mutated.
41// Box<[P]> is used here for immutability so that kinds will never be invalidated.
42pub struct All<L: Language, P: Matcher<L>> {
43  patterns: Box<[P]>,
44  kinds: Option<BitSet>,
45  lang: PhantomData<L>,
46}
47
48impl<L: Language, P: Matcher<L>> All<L, P> {
49  pub fn new<PS: IntoIterator<Item = P>>(patterns: PS) -> Self {
50    let patterns: Box<[P]> = patterns.into_iter().collect();
51    let kinds = Self::compute_kinds(&patterns);
52    Self {
53      patterns,
54      kinds,
55      lang: PhantomData,
56    }
57  }
58
59  fn compute_kinds(patterns: &[P]) -> Option<BitSet> {
60    let mut set: Option<BitSet> = None;
61    for pattern in patterns {
62      let Some(n) = pattern.potential_kinds() else {
63        continue;
64      };
65      if let Some(set) = set.as_mut() {
66        set.intersect_with(&n);
67      } else {
68        set = Some(n);
69      }
70    }
71    set
72  }
73
74  pub fn inner(&self) -> &[P] {
75    &self.patterns
76  }
77}
78
79impl<L: Language, P: Matcher<L>> Matcher<L> for All<L, P> {
80  fn match_node_with_env<'tree, D: Doc<Lang = L>>(
81    &self,
82    node: Node<'tree, D>,
83    env: &mut Cow<MetaVarEnv<'tree, D>>,
84  ) -> Option<Node<'tree, D>> {
85    if let Some(kinds) = &self.kinds {
86      if !kinds.contains(node.kind_id().into()) {
87        return None;
88      }
89    }
90    let mut new_env = Cow::Borrowed(env.as_ref());
91    let all_satisfied = self
92      .patterns
93      .iter()
94      .all(|p| p.match_node_with_env(node.clone(), &mut new_env).is_some());
95    if all_satisfied {
96      *env = Cow::Owned(new_env.into_owned());
97      Some(node)
98    } else {
99      None
100    }
101  }
102
103  fn potential_kinds(&self) -> Option<BitSet> {
104    self.kinds.clone()
105  }
106}
107
108// Box<[P]> for immutability and potential_kinds cache correctness
109pub struct Any<L, P> {
110  patterns: Box<[P]>,
111  kinds: Option<BitSet>,
112  lang: PhantomData<L>,
113}
114
115impl<L: Language, P: Matcher<L>> Any<L, P> {
116  pub fn new<PS: IntoIterator<Item = P>>(patterns: PS) -> Self {
117    let patterns: Box<[P]> = patterns.into_iter().collect();
118    let kinds = Self::compute_kinds(&patterns);
119    Self {
120      patterns,
121      kinds,
122      lang: PhantomData,
123    }
124  }
125
126  fn compute_kinds(patterns: &[P]) -> Option<BitSet> {
127    let mut set = BitSet::new();
128    for pattern in patterns {
129      let n = pattern.potential_kinds()?;
130      set.union_with(&n);
131    }
132    Some(set)
133  }
134
135  pub fn inner(&self) -> &[P] {
136    &self.patterns
137  }
138}
139
140impl<L: Language, M: Matcher<L>> Matcher<L> for Any<L, M> {
141  fn match_node_with_env<'tree, D: Doc<Lang = L>>(
142    &self,
143    node: Node<'tree, D>,
144    env: &mut Cow<MetaVarEnv<'tree, D>>,
145  ) -> Option<Node<'tree, D>> {
146    if let Some(kinds) = &self.kinds {
147      if !kinds.contains(node.kind_id().into()) {
148        return None;
149      }
150    }
151    let mut new_env = Cow::Borrowed(env.as_ref());
152    let found = self.patterns.iter().find_map(|p| {
153      new_env = Cow::Borrowed(env.as_ref());
154      p.match_node_with_env(node.clone(), &mut new_env)
155    });
156    if found.is_some() {
157      *env = Cow::Owned(new_env.into_owned());
158      Some(node)
159    } else {
160      None
161    }
162  }
163
164  fn potential_kinds(&self) -> Option<BitSet> {
165    self.kinds.clone()
166  }
167}
168
169pub struct Or<L: Language, P1: Matcher<L>, P2: Matcher<L>> {
170  pattern1: P1,
171  pattern2: P2,
172  lang: PhantomData<L>,
173}
174
175impl<L, P1, P2> Matcher<L> for Or<L, P1, P2>
176where
177  L: Language,
178  P1: Matcher<L>,
179  P2: Matcher<L>,
180{
181  fn match_node_with_env<'tree, D: Doc<Lang = L>>(
182    &self,
183    node: Node<'tree, D>,
184    env: &mut Cow<MetaVarEnv<'tree, D>>,
185  ) -> Option<Node<'tree, D>> {
186    let mut new_env = Cow::Borrowed(env.as_ref());
187    if let Some(ret) = self
188      .pattern1
189      .match_node_with_env(node.clone(), &mut new_env)
190    {
191      *env = Cow::Owned(new_env.into_owned());
192      Some(ret)
193    } else {
194      self.pattern2.match_node_with_env(node, env)
195    }
196  }
197
198  fn potential_kinds(&self) -> Option<BitSet> {
199    let mut set1 = self.pattern1.potential_kinds()?;
200    let set2 = self.pattern2.potential_kinds()?;
201    set1.union_with(&set2);
202    Some(set1)
203  }
204}
205
206pub struct Not<L: Language, M: Matcher<L>> {
207  not: M,
208  lang: PhantomData<L>,
209}
210
211impl<L: Language, M: Matcher<L>> Not<L, M> {
212  pub fn new(not: M) -> Self {
213    Self {
214      not,
215      lang: PhantomData,
216    }
217  }
218
219  pub fn inner(&self) -> &M {
220    &self.not
221  }
222}
223impl<L, P> Matcher<L> for Not<L, P>
224where
225  L: Language,
226  P: Matcher<L>,
227{
228  fn match_node_with_env<'tree, D: Doc<Lang = L>>(
229    &self,
230    node: Node<'tree, D>,
231    env: &mut Cow<MetaVarEnv<'tree, D>>,
232  ) -> Option<Node<'tree, D>> {
233    self
234      .not
235      .match_node_with_env(node.clone(), env)
236      .xor(Some(node))
237  }
238}
239
240#[derive(Clone)]
241pub struct Op<L: Language, M: Matcher<L>> {
242  inner: M,
243  lang: PhantomData<L>,
244}
245
246impl<L, M> Matcher<L> for Op<L, M>
247where
248  L: Language,
249  M: Matcher<L>,
250{
251  fn match_node_with_env<'tree, D: Doc<Lang = L>>(
252    &self,
253    node: Node<'tree, D>,
254    env: &mut Cow<MetaVarEnv<'tree, D>>,
255  ) -> Option<Node<'tree, D>> {
256    let ret = self.inner.match_node_with_env(node, env);
257    ret
258  }
259
260  fn potential_kinds(&self) -> Option<BitSet> {
261    self.inner.potential_kinds()
262  }
263}
264
265/*
266pub struct Predicate<F> {
267  func: F,
268}
269
270impl<L, F> Matcher<L> for Predicate<F>
271where
272  L: Language,
273  F: for<'tree> Fn(&Node<'tree, StrDoc<L>>) -> bool,
274{
275  fn match_node_with_env<'tree, D: Doc<Lang=L>>(
276    &self,
277    node: Node<'tree, D>,
278    env: &mut MetaVarEnv<'tree, D>,
279  ) -> Option<Node<'tree, D>> {
280    (self.func)(&node).then_some(node)
281  }
282}
283*/
284
285/*
286// we don't need specify M for static method
287impl<L: Language> Op<L, MatchNone> {
288  pub fn func<F>(func: F) -> Predicate<F>
289  where
290    F: for<'tree> Fn(&Node<'tree, StrDoc<L>>) -> bool,
291  {
292    Predicate { func }
293  }
294}
295*/
296
297impl<L: Language, M: Matcher<L>> Op<L, M> {
298  pub fn not(pattern: M) -> Not<L, M> {
299    Not {
300      not: pattern,
301      lang: PhantomData,
302    }
303  }
304}
305
306impl<L: Language, M: Matcher<L>> Op<L, M> {
307  pub fn every(pattern: M) -> Op<L, And<L, M, MatchAll>> {
308    Op {
309      inner: And {
310        pattern1: pattern,
311        pattern2: MatchAll,
312        lang: PhantomData,
313      },
314      lang: PhantomData,
315    }
316  }
317  pub fn either(pattern: M) -> Op<L, Or<L, M, MatchNone>> {
318    Op {
319      inner: Or {
320        pattern1: pattern,
321        pattern2: MatchNone,
322        lang: PhantomData,
323      },
324      lang: PhantomData,
325    }
326  }
327
328  pub fn all<MS: IntoIterator<Item = M>>(patterns: MS) -> All<L, M> {
329    All::new(patterns)
330  }
331
332  pub fn any<MS: IntoIterator<Item = M>>(patterns: MS) -> Any<L, M> {
333    Any::new(patterns)
334  }
335
336  pub fn new(matcher: M) -> Op<L, M> {
337    Self {
338      inner: matcher,
339      lang: PhantomData,
340    }
341  }
342}
343
344type NestedAnd<L, M, N, O> = And<L, And<L, M, N>, O>;
345impl<L: Language, M: Matcher<L>, N: Matcher<L>> Op<L, And<L, M, N>> {
346  pub fn and<O: Matcher<L>>(self, other: O) -> Op<L, NestedAnd<L, M, N, O>> {
347    Op {
348      inner: And {
349        pattern1: self.inner,
350        pattern2: other,
351        lang: PhantomData,
352      },
353      lang: PhantomData,
354    }
355  }
356}
357
358type NestedOr<L, M, N, O> = Or<L, Or<L, M, N>, O>;
359impl<L: Language, M: Matcher<L>, N: Matcher<L>> Op<L, Or<L, M, N>> {
360  pub fn or<O: Matcher<L>>(self, other: O) -> Op<L, NestedOr<L, M, N, O>> {
361    Op {
362      inner: Or {
363        pattern1: self.inner,
364        pattern2: other,
365        lang: PhantomData,
366      },
367      lang: PhantomData,
368    }
369  }
370}
371
372#[cfg(test)]
373mod test {
374  use super::*;
375  use crate::language::Tsx;
376  use crate::matcher::MatcherExt;
377  use crate::Root;
378
379  fn test_find(matcher: &impl Matcher<Tsx>, code: &str) {
380    let node = Root::str(code, Tsx);
381    assert!(matcher.find_node(node.root()).is_some());
382  }
383  fn test_not_find(matcher: &impl Matcher<Tsx>, code: &str) {
384    let node = Root::str(code, Tsx);
385    assert!(matcher.find_node(node.root()).is_none());
386  }
387  fn find_all(matcher: impl Matcher<Tsx>, code: &str) -> Vec<String> {
388    let node = Root::str(code, Tsx);
389    node
390      .root()
391      .find_all(matcher)
392      .map(|n| n.text().to_string())
393      .collect()
394  }
395
396  #[test]
397  fn test_or() {
398    let matcher = Or {
399      pattern1: "let a = 1",
400      pattern2: "const b = 2",
401      lang: PhantomData,
402    };
403    test_find(&matcher, "let a = 1");
404    test_find(&matcher, "const b = 2");
405    test_not_find(&matcher, "let a = 2");
406    test_not_find(&matcher, "const a = 1");
407    test_not_find(&matcher, "let b = 2");
408    test_not_find(&matcher, "const b = 1");
409  }
410
411  #[test]
412  fn test_not() {
413    let matcher = Not {
414      not: "let a = 1",
415      lang: PhantomData,
416    };
417    test_find(&matcher, "const b = 2");
418  }
419
420  #[test]
421  fn test_and() {
422    let matcher = And {
423      pattern1: "let a = $_",
424      pattern2: Not {
425        not: "let a = 123",
426        lang: PhantomData,
427      },
428      lang: PhantomData,
429    };
430    test_find(&matcher, "let a = 233");
431    test_find(&matcher, "let a = 456");
432    test_not_find(&matcher, "let a = 123");
433  }
434
435  #[test]
436  fn test_api_and() {
437    let matcher = Op::every("let a = $_").and(Op::not("let a = 123"));
438    test_find(&matcher, "let a = 233");
439    test_find(&matcher, "let a = 456");
440    test_not_find(&matcher, "let a = 123");
441  }
442
443  #[test]
444  fn test_api_or() {
445    let matcher = Op::either("let a = 1").or("const b = 2");
446    test_find(&matcher, "let a = 1");
447    test_find(&matcher, "const b = 2");
448    test_not_find(&matcher, "let a = 2");
449    test_not_find(&matcher, "const a = 1");
450    test_not_find(&matcher, "let b = 2");
451    test_not_find(&matcher, "const b = 1");
452  }
453  #[test]
454  fn test_multiple_match() {
455    let sequential = find_all("$A + b", "let f = () => a + b; let ff = () => c + b");
456    assert_eq!(sequential.len(), 2);
457    let nested = find_all(
458      "function $A() { $$$ }",
459      "function a() { function b() { b } }",
460    );
461    assert_eq!(nested.len(), 2);
462  }
463
464  #[test]
465  fn test_multiple_match_order() {
466    let ret = find_all(
467      "$A + b",
468      "let f = () => () => () => a + b; let ff = () => c + b",
469    );
470    assert_eq!(ret, ["a + b", "c + b"], "should match source code order");
471  }
472
473  /*
474  #[test]
475  fn test_api_func() {
476    let matcher = Op::func(|n| n.text().contains("114514"));
477    test_find(&matcher, "let a = 114514");
478    test_not_find(&matcher, "let a = 1919810");
479  }
480  */
481  use crate::Pattern;
482  trait TsxMatcher {
483    fn t(self) -> Pattern<Tsx>;
484  }
485  impl TsxMatcher for &str {
486    fn t(self) -> Pattern<Tsx> {
487      Pattern::new(self, Tsx)
488    }
489  }
490
491  #[test]
492  fn test_and_kinds() {
493    // intersect None kinds
494    let matcher = Op::every("let a = $_".t()).and(Op::not("let a = 123".t()));
495    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
496    let matcher = Op::every(Op::not("let a = $_".t())).and("let a = 123".t());
497    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
498    // intersect Same kinds
499    let matcher = Op::every("let a = $_".t()).and("let b = 123".t());
500    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
501    // intersect different kinds
502    let matcher = Op::every("let a = 1".t()).and("console.log(1)".t());
503    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(0));
504    // two None kinds
505    let matcher = Op::every(Op::not("let a = $_".t())).and(Op::not("let a = 123".t()));
506    assert_eq!(matcher.potential_kinds(), None);
507  }
508
509  #[test]
510  fn test_or_kinds() {
511    // union None kinds
512    let matcher = Op::either("let a = $_".t()).or(Op::not("let a = 123".t()));
513    assert_eq!(matcher.potential_kinds(), None);
514    let matcher = Op::either(Op::not("let a = $_".t())).or("let a = 123".t());
515    assert_eq!(matcher.potential_kinds(), None);
516    // union Same kinds
517    let matcher = Op::either("let a = $_".t()).or("let b = 123".t());
518    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
519    // union different kinds
520    let matcher = Op::either("let a = 1".t()).or("console.log(1)".t());
521    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(2));
522    // two None kinds
523    let matcher = Op::either(Op::not("let a = $_".t())).or(Op::not("let a = 123".t()));
524    assert_eq!(matcher.potential_kinds(), None);
525  }
526
527  #[test]
528  fn test_all_kinds() {
529    // intersect None kinds
530    let matcher = Op::all(["let a = $_".t(), "$A".t()]);
531    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
532    let matcher = Op::all(["$A".t(), "let a = $_".t()]);
533    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
534    // intersect Same kinds
535    let matcher = Op::all(["let a = $_".t(), "let b = 123".t()]);
536    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
537    // intersect different kinds
538    let matcher = Op::all(["let a = 1".t(), "console.log(1)".t()]);
539    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(0));
540    // two None kinds
541    let matcher = Op::all(["$A".t(), "$B".t()]);
542    assert_eq!(matcher.potential_kinds(), None);
543  }
544
545  #[test]
546  fn test_any_kinds() {
547    // union None kinds
548    let matcher = Op::any(["let a = $_".t(), "$A".t()]);
549    assert_eq!(matcher.potential_kinds(), None);
550    let matcher = Op::any(["$A".t(), "let a = $_".t()]);
551    assert_eq!(matcher.potential_kinds(), None);
552    // union Same kinds
553    let matcher = Op::any(["let a = $_".t(), "let b = 123".t()]);
554    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
555    // union different kinds
556    let matcher = Op::any(["let a = 1".t(), "console.log(1)".t()]);
557    assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(2));
558    // two None kinds
559    let matcher = Op::any(["$A".t(), "$B".t()]);
560    assert_eq!(matcher.potential_kinds(), None);
561  }
562
563  #[test]
564  fn test_or_revert_env() {
565    let matcher = Op::either(Op::every("foo($A)".t()).and("impossible".t())).or("foo($B)".t());
566    let code = Root::str("foo(123)", Tsx);
567    let matches = code.root().find(matcher).expect("should found");
568    assert!(matches.get_env().get_match("A").is_none());
569    assert_eq!(matches.get_env().get_match("B").unwrap().text(), "123");
570  }
571
572  #[test]
573  fn test_any_revert_env() {
574    let matcher = Op::any([
575      Op::all(["foo($A)".t(), "impossible".t()]),
576      Op::all(["foo($B)".t()]),
577    ]);
578    let code = Root::str("foo(123)", Tsx);
579    let matches = code.root().find(matcher).expect("should found");
580    assert!(matches.get_env().get_match("A").is_none());
581    assert_eq!(matches.get_env().get_match("B").unwrap().text(), "123");
582  }
583
584  // gh #1225
585  #[test]
586  fn test_all_revert_env() {
587    let matcher = Op::all(["$A(123)".t(), "$B(456)".t()]);
588    let code = Root::str("foo(123)", Tsx);
589    let node = code.root().find("foo($C)").expect("should exist");
590    let node = node.get_node().clone();
591    let mut env = Cow::Owned(MetaVarEnv::new());
592    assert!(matcher.match_node_with_env(node, &mut env).is_none());
593    assert!(env.get_match("A").is_none());
594  }
595}