Skip to main content

thread_ast_engine/
ops.rs

1// SPDX-FileCopyrightText: 2022 Herrington Darkholme <2883231+HerringtonDarkholme@users.noreply.github.com>
2// SPDX-FileCopyrightText: 2025 Knitli Inc. <knitli@knit.li>
3// SPDX-FileContributor: Adam Poulemanos <adam@knit.li>
4//
5// SPDX-License-Identifier: AGPL-3.0-or-later AND MIT
6
7use crate::matcher::{MatchAll, MatchNone, Matcher};
8use crate::meta_var::MetaVarEnv;
9use crate::{Doc, Node};
10use bit_set::BitSet;
11use std::borrow::Cow;
12
13pub struct And<P1: Matcher, P2: Matcher> {
14    pattern1: P1,
15    pattern2: P2,
16}
17
18impl<P1, P2> Matcher for And<P1, P2>
19where
20    P1: Matcher,
21    P2: Matcher,
22{
23    fn match_node_with_env<'tree, D: Doc>(
24        &self,
25        node: Node<'tree, D>,
26        env: &mut Cow<MetaVarEnv<'tree, D>>,
27    ) -> Option<Node<'tree, D>> {
28        // keep the original env intact until both arms match
29        let mut new_env = Cow::Borrowed(env.as_ref());
30        let node = self.pattern1.match_node_with_env(node, &mut new_env)?;
31        let ret = self.pattern2.match_node_with_env(node, &mut new_env)?;
32        // both succeed – commit the combined env
33        *env = Cow::Owned(new_env.into_owned());
34        Some(ret)
35    }
36
37    fn potential_kinds(&self) -> Option<BitSet> {
38        let set1 = self.pattern1.potential_kinds();
39        let set2 = self.pattern2.potential_kinds();
40        // if both constituent have Some(bitset), intersect them
41        // otherwise returns either of the non-null set
42        match (&set1, &set2) {
43            (Some(s1), Some(s2)) => Some(s1.intersection(s2).collect()),
44            _ => set1.xor(set2),
45        }
46    }
47}
48
49// we pre-compute and cache potential_kinds. So patterns should not be mutated.
50// Box<[P]> is used here for immutability so that kinds will never be invalidated.
51#[derive(Clone, Debug)]
52pub struct All<P: Matcher> {
53    patterns: Box<[P]>,
54    kinds: Option<BitSet>,
55}
56
57impl<P: Matcher> All<P> {
58    pub fn new<PS: IntoIterator<Item = P>>(patterns: PS) -> Self {
59        let patterns: Box<[P]> = patterns.into_iter().collect();
60        let kinds = Self::compute_kinds(&patterns);
61        Self { patterns, kinds }
62    }
63
64    fn compute_kinds(patterns: &[P]) -> Option<BitSet> {
65        let mut set: Option<BitSet> = None;
66        for pattern in patterns {
67            let Some(n) = pattern.potential_kinds() else {
68                continue;
69            };
70            if let Some(set) = set.as_mut() {
71                set.intersect_with(&n);
72            } else {
73                set = Some(n);
74            }
75        }
76        set
77    }
78
79    #[must_use]
80    pub const fn inner(&self) -> &[P] {
81        &self.patterns
82    }
83}
84
85impl<P: Matcher> Matcher for All<P> {
86    fn match_node_with_env<'tree, D: Doc>(
87        &self,
88        node: Node<'tree, D>,
89        env: &mut Cow<MetaVarEnv<'tree, D>>,
90    ) -> Option<Node<'tree, D>> {
91        if let Some(kinds) = &self.kinds
92            && !kinds.contains(node.kind_id().into())
93        {
94            return None;
95        }
96        let mut new_env = Cow::Borrowed(env.as_ref());
97        let all_satisfied = self
98            .patterns
99            .iter()
100            .all(|p| p.match_node_with_env(node.clone(), &mut new_env).is_some());
101        if all_satisfied {
102            *env = Cow::Owned(new_env.into_owned());
103            Some(node)
104        } else {
105            None
106        }
107    }
108
109    fn potential_kinds(&self) -> Option<BitSet> {
110        self.kinds.clone()
111    }
112}
113
114// Box<[P]> for immutability and potential_kinds cache correctness
115#[derive(Clone, Debug)]
116pub struct Any<P> {
117    patterns: Box<[P]>,
118    kinds: Option<BitSet>,
119}
120
121impl<P: Matcher> Any<P> {
122    pub fn new<PS: IntoIterator<Item = P>>(patterns: PS) -> Self {
123        let patterns: Box<[P]> = patterns.into_iter().collect();
124        let kinds = Self::compute_kinds(&patterns);
125        Self { patterns, kinds }
126    }
127
128    fn compute_kinds(patterns: &[P]) -> Option<BitSet> {
129        let mut set = BitSet::new();
130        for pattern in patterns {
131            let n = pattern.potential_kinds()?;
132            set.union_with(&n);
133        }
134        Some(set)
135    }
136
137    #[must_use]
138    pub const fn inner(&self) -> &[P] {
139        &self.patterns
140    }
141}
142
143impl<M: Matcher> Matcher for Any<M> {
144    fn match_node_with_env<'tree, D: Doc>(
145        &self,
146        node: Node<'tree, D>,
147        env: &mut Cow<MetaVarEnv<'tree, D>>,
148    ) -> Option<Node<'tree, D>> {
149        if let Some(kinds) = &self.kinds
150            && !kinds.contains(node.kind_id().into())
151        {
152            return None;
153        }
154        let mut new_env = Cow::Borrowed(env.as_ref());
155        let found = self.patterns.iter().find_map(|p| {
156            new_env = Cow::Borrowed(env.as_ref());
157            p.match_node_with_env(node.clone(), &mut new_env)
158        });
159        if found.is_some() {
160            *env = Cow::Owned(new_env.into_owned());
161            Some(node)
162        } else {
163            None
164        }
165    }
166
167    fn potential_kinds(&self) -> Option<BitSet> {
168        self.kinds.clone()
169    }
170}
171
172#[derive(Clone, Debug)]
173pub struct Or<P1: Matcher, P2: Matcher> {
174    pattern1: P1,
175    pattern2: P2,
176}
177
178impl<P1, P2> Matcher for Or<P1, P2>
179where
180    P1: Matcher,
181    P2: Matcher,
182{
183    fn match_node_with_env<'tree, D: Doc>(
184        &self,
185        node: Node<'tree, D>,
186        env: &mut Cow<MetaVarEnv<'tree, D>>,
187    ) -> Option<Node<'tree, D>> {
188        let mut new_env = Cow::Borrowed(env.as_ref());
189        if let Some(ret) = self
190            .pattern1
191            .match_node_with_env(node.clone(), &mut new_env)
192        {
193            *env = Cow::Owned(new_env.into_owned());
194            Some(ret)
195        } else {
196            self.pattern2.match_node_with_env(node, env)
197        }
198    }
199
200    fn potential_kinds(&self) -> Option<BitSet> {
201        let mut set1 = self.pattern1.potential_kinds()?;
202        let set2 = self.pattern2.potential_kinds()?;
203        set1.union_with(&set2);
204        Some(set1)
205    }
206}
207
208#[derive(Clone, Debug)]
209pub struct Not<M: Matcher> {
210    not: M,
211}
212
213impl<M: Matcher> Not<M> {
214    pub const fn new(not: M) -> Self {
215        Self { not }
216    }
217
218    pub const fn inner(&self) -> &M {
219        &self.not
220    }
221}
222impl<P> Matcher for Not<P>
223where
224    P: Matcher,
225{
226    fn match_node_with_env<'tree, D: Doc>(
227        &self,
228        node: Node<'tree, D>,
229        env: &mut Cow<MetaVarEnv<'tree, D>>,
230    ) -> Option<Node<'tree, D>> {
231        self.not
232            .match_node_with_env(node.clone(), env)
233            .xor(Some(node))
234    }
235}
236
237#[derive(Clone, Debug)]
238pub struct Op<M: Matcher> {
239    inner: M,
240}
241
242impl<M> Matcher for Op<M>
243where
244    M: Matcher,
245{
246    fn match_node_with_env<'tree, D: Doc>(
247        &self,
248        node: Node<'tree, D>,
249        env: &mut Cow<MetaVarEnv<'tree, D>>,
250    ) -> Option<Node<'tree, D>> {
251        self.inner.match_node_with_env(node, env)
252    }
253
254    fn potential_kinds(&self) -> Option<BitSet> {
255        self.inner.potential_kinds()
256    }
257}
258
259/*
260pub struct Predicate<F> {
261  func: F,
262}
263
264impl<L, F> Matcher for Predicate<F>
265where
266  L: Language,
267  F: for<'tree> Fn(&Node<'tree, StrDoc<L>>) -> bool,
268{
269  fn match_node_with_env<'tree, D: Doc<Lang=L>>(
270    &self,
271    node: Node<'tree, D>,
272    env: &mut MetaVarEnv<'tree, D>,
273  ) -> Option<Node<'tree, D>> {
274    (self.func)(&node).then_some(node)
275  }
276}
277*/
278
279/*
280// we don't need specify M for static method
281impl<L: Language> Op<L, MatchNone> {
282  pub fn func<F>(func: F) -> Predicate<F>
283  where
284    F: for<'tree> Fn(&Node<'tree, StrDoc<L>>) -> bool,
285  {
286    Predicate { func }
287  }
288}
289*/
290
291impl<M: Matcher> Op<M> {
292    pub const fn not(pattern: M) -> Not<M> {
293        Not { not: pattern }
294    }
295}
296
297impl<M: Matcher> Op<M> {
298    pub const fn every(pattern: M) -> Op<And<M, MatchAll>> {
299        Op {
300            inner: And {
301                pattern1: pattern,
302                pattern2: MatchAll,
303            },
304        }
305    }
306    pub const fn either(pattern: M) -> Op<Or<M, MatchNone>> {
307        Op {
308            inner: Or {
309                pattern1: pattern,
310                pattern2: MatchNone,
311            },
312        }
313    }
314
315    pub fn all<MS: IntoIterator<Item = M>>(patterns: MS) -> All<M> {
316        All::new(patterns)
317    }
318
319    pub fn any<MS: IntoIterator<Item = M>>(patterns: MS) -> Any<M> {
320        Any::new(patterns)
321    }
322
323    pub const fn new(matcher: M) -> Self {
324        Self { inner: matcher }
325    }
326}
327
328type NestedAnd<M, N, O> = And<And<M, N>, O>;
329impl<M: Matcher, N: Matcher> Op<And<M, N>> {
330    pub fn and<O: Matcher>(self, other: O) -> Op<NestedAnd<M, N, O>> {
331        Op {
332            inner: And {
333                pattern1: self.inner,
334                pattern2: other,
335            },
336        }
337    }
338}
339
340type NestedOr<M, N, O> = Or<Or<M, N>, O>;
341impl<M: Matcher, N: Matcher> Op<Or<M, N>> {
342    pub fn or<O: Matcher>(self, other: O) -> Op<NestedOr<M, N, O>> {
343        Op {
344            inner: Or {
345                pattern1: self.inner,
346                pattern2: other,
347            },
348        }
349    }
350}
351
352#[cfg(test)]
353mod test {
354    use super::*;
355    use crate::Root;
356    use crate::language::Tsx;
357    use crate::matcher::MatcherExt;
358    use crate::meta_var::MetaVarEnv;
359
360    fn test_find(matcher: &impl Matcher, code: &str) {
361        let node = Root::str(code, Tsx);
362        assert!(matcher.find_node(node.root()).is_some());
363    }
364    fn test_not_find(matcher: &impl Matcher, code: &str) {
365        let node = Root::str(code, Tsx);
366        assert!(matcher.find_node(node.root()).is_none());
367    }
368    fn find_all(matcher: impl Matcher, code: &str) -> Vec<String> {
369        let node = Root::str(code, Tsx);
370        node.root()
371            .find_all(matcher)
372            .map(|n| n.text().to_string())
373            .collect()
374    }
375
376    #[test]
377    fn test_or() {
378        let matcher = Or {
379            pattern1: "let a = 1",
380            pattern2: "const b = 2",
381        };
382        test_find(&matcher, "let a = 1");
383        test_find(&matcher, "const b = 2");
384        test_not_find(&matcher, "let a = 2");
385        test_not_find(&matcher, "const a = 1");
386        test_not_find(&matcher, "let b = 2");
387        test_not_find(&matcher, "const b = 1");
388    }
389
390    #[test]
391    fn test_not() {
392        let matcher = Not { not: "let a = 1" };
393        test_find(&matcher, "const b = 2");
394    }
395
396    #[test]
397    fn test_and() {
398        let matcher = And {
399            pattern1: "let a = $_",
400            pattern2: Not { not: "let a = 123" },
401        };
402        test_find(&matcher, "let a = 233");
403        test_find(&matcher, "let a = 456");
404        test_not_find(&matcher, "let a = 123");
405    }
406
407    #[test]
408    fn test_api_and() {
409        let matcher = Op::every("let a = $_").and(Op::not("let a = 123"));
410        test_find(&matcher, "let a = 233");
411        test_find(&matcher, "let a = 456");
412        test_not_find(&matcher, "let a = 123");
413    }
414
415    #[test]
416    fn test_api_or() {
417        let matcher = Op::either("let a = 1").or("const b = 2");
418        test_find(&matcher, "let a = 1");
419        test_find(&matcher, "const b = 2");
420        test_not_find(&matcher, "let a = 2");
421        test_not_find(&matcher, "const a = 1");
422        test_not_find(&matcher, "let b = 2");
423        test_not_find(&matcher, "const b = 1");
424    }
425    #[test]
426    fn test_multiple_match() {
427        let sequential = find_all("$A + b", "let f = () => a + b; let ff = () => c + b");
428        assert_eq!(sequential.len(), 2);
429        let nested = find_all(
430            "function $A() { $$$ }",
431            "function a() { function b() { b } }",
432        );
433        assert_eq!(nested.len(), 2);
434    }
435
436    #[test]
437    fn test_multiple_match_order() {
438        let ret = find_all(
439            "$A + b",
440            "let f = () => () => () => a + b; let ff = () => c + b",
441        );
442        assert_eq!(ret, ["a + b", "c + b"], "should match source code order");
443    }
444
445    /*
446    #[test]
447    fn test_api_func() {
448      let matcher = Op::func(|n| n.text().contains("114514"));
449      test_find(&matcher, "let a = 114514");
450      test_not_find(&matcher, "let a = 1919810");
451    }
452    */
453    use crate::Pattern;
454    trait TsxMatcher {
455        fn t(self) -> Pattern;
456    }
457    impl TsxMatcher for &str {
458        fn t(self) -> Pattern {
459            Pattern::new(self, &Tsx)
460        }
461    }
462
463    #[test]
464    fn test_and_kinds() {
465        // intersect None kinds
466        let matcher = Op::every("let a = $_".t()).and(Op::not("let a = 123".t()));
467        assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
468        let matcher = Op::every(Op::not("let a = $_".t())).and("let a = 123".t());
469        assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
470        // intersect Same kinds
471        let matcher = Op::every("let a = $_".t()).and("let b = 123".t());
472        assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
473        // intersect different kinds
474        let matcher = Op::every("let a = 1".t()).and("console.log(1)".t());
475        assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(0));
476        // two None kinds
477        let matcher = Op::every(Op::not("let a = $_".t())).and(Op::not("let a = 123".t()));
478        assert_eq!(matcher.potential_kinds(), None);
479    }
480
481    #[test]
482    fn test_or_kinds() {
483        // union None kinds
484        let matcher = Op::either("let a = $_".t()).or(Op::not("let a = 123".t()));
485        assert_eq!(matcher.potential_kinds(), None);
486        let matcher = Op::either(Op::not("let a = $_".t())).or("let a = 123".t());
487        assert_eq!(matcher.potential_kinds(), None);
488        // union Same kinds
489        let matcher = Op::either("let a = $_".t()).or("let b = 123".t());
490        assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
491        // union different kinds
492        let matcher = Op::either("let a = 1".t()).or("console.log(1)".t());
493        assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(2));
494        // two None kinds
495        let matcher = Op::either(Op::not("let a = $_".t())).or(Op::not("let a = 123".t()));
496        assert_eq!(matcher.potential_kinds(), None);
497    }
498
499    #[test]
500    fn test_all_kinds() {
501        // intersect None kinds
502        let matcher = Op::all(["let a = $_".t(), "$A".t()]);
503        assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
504        let matcher = Op::all(["$A".t(), "let a = $_".t()]);
505        assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
506        // intersect Same kinds
507        let matcher = Op::all(["let a = $_".t(), "let b = 123".t()]);
508        assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
509        // intersect different kinds
510        let matcher = Op::all(["let a = 1".t(), "console.log(1)".t()]);
511        assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(0));
512        // two None kinds
513        let matcher = Op::all(["$A".t(), "$B".t()]);
514        assert_eq!(matcher.potential_kinds(), None);
515    }
516
517    #[test]
518    fn test_any_kinds() {
519        // union None kinds
520        let matcher = Op::any(["let a = $_".t(), "$A".t()]);
521        assert_eq!(matcher.potential_kinds(), None);
522        let matcher = Op::any(["$A".t(), "let a = $_".t()]);
523        assert_eq!(matcher.potential_kinds(), None);
524        // union Same kinds
525        let matcher = Op::any(["let a = $_".t(), "let b = 123".t()]);
526        assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(1));
527        // union different kinds
528        let matcher = Op::any(["let a = 1".t(), "console.log(1)".t()]);
529        assert_eq!(matcher.potential_kinds().map(|v| v.len()), Some(2));
530        // two None kinds
531        let matcher = Op::any(["$A".t(), "$B".t()]);
532        assert_eq!(matcher.potential_kinds(), None);
533    }
534
535    #[test]
536    fn test_or_revert_env() {
537        let matcher = Op::either(Op::every("foo($A)".t()).and("impossible".t())).or("foo($B)".t());
538        let code = Root::str("foo(123)", Tsx);
539        let matches = code.root().find(matcher).expect("should found");
540        assert!(matches.get_env().get_match("A").is_none());
541        assert_eq!(matches.get_env().get_match("B").unwrap().text(), "123");
542    }
543
544    #[test]
545    fn test_any_revert_env() {
546        let matcher = Op::any([
547            Op::all(["foo($A)".t(), "impossible".t()]),
548            Op::all(["foo($B)".t()]),
549        ]);
550        let code = Root::str("foo(123)", Tsx);
551        let matches = code.root().find(matcher).expect("should found");
552        assert!(matches.get_env().get_match("A").is_none());
553        assert_eq!(matches.get_env().get_match("B").unwrap().text(), "123");
554    }
555
556    // gh #1225
557    #[test]
558    fn test_all_revert_env() {
559        let matcher = Op::all(["$A(123)".t(), "$B(456)".t()]);
560        let code = Root::str("foo(123)", Tsx);
561        let node = code.root().find("foo($C)").expect("should exist");
562        let node = node.get_node().clone();
563        let mut env = Cow::Owned(MetaVarEnv::new());
564        assert!(matcher.match_node_with_env(node, &mut env).is_none());
565        assert!(env.get_match("A").is_none());
566    }
567}