1use super::StrDoc;
22use crate::matcher::{Matcher, MatcherExt};
23use crate::tree_sitter::LanguageExt;
24use crate::{Doc, Node, NodeMatch, Root};
25
26use tree_sitter as ts;
27
28use std::collections::VecDeque;
29use std::marker::PhantomData;
30
31pub struct Visitor<M, A = PreOrder> {
32 reentrant: bool,
34 named_only: bool,
36 matcher: M,
38 algorithm: PhantomData<A>,
40}
41
42impl<M> Visitor<M> {
43 pub fn new(matcher: M) -> Visitor<M> {
44 Visitor {
45 reentrant: true,
46 named_only: false,
47 matcher,
48 algorithm: PhantomData,
49 }
50 }
51}
52
53impl<M, A> Visitor<M, A> {
54 pub fn algorithm<Algo>(self) -> Visitor<M, Algo> {
55 Visitor {
56 reentrant: self.reentrant,
57 named_only: self.named_only,
58 matcher: self.matcher,
59 algorithm: PhantomData,
60 }
61 }
62
63 pub fn reentrant(self, reentrant: bool) -> Self {
64 Self { reentrant, ..self }
65 }
66
67 pub fn named_only(self, named_only: bool) -> Self {
68 Self { named_only, ..self }
69 }
70}
71
72impl<M, A> Visitor<M, A>
73where
74 A: Algorithm,
75{
76 pub fn visit<L: LanguageExt>(
77 self,
78 node: Node<'_, StrDoc<L>>,
79 ) -> Visit<'_, StrDoc<L>, A::Traversal<'_, L>, M>
80 where
81 M: Matcher,
82 {
83 let traversal = A::traverse(node);
84 Visit {
85 reentrant: self.reentrant,
86 named: self.named_only,
87 matcher: self.matcher,
88 traversal,
89 lang: PhantomData,
90 }
91 }
92}
93
94pub struct Visit<'t, D, T, M> {
95 reentrant: bool,
96 named: bool,
97 matcher: M,
98 traversal: T,
99 lang: PhantomData<&'t D>,
100}
101impl<'t, D, T, M> Visit<'t, D, T, M>
102where
103 D: Doc + 't,
104 T: Traversal<'t, D>,
105 M: Matcher,
106{
107 #[inline]
108 fn mark_match(&mut self, depth: Option<usize>) {
109 if !self.reentrant {
110 self.traversal.calibrate_for_match(depth);
111 }
112 }
113}
114
115impl<'t, D, T, M> Iterator for Visit<'t, D, T, M>
116where
117 D: Doc + 't,
118 T: Traversal<'t, D>,
119 M: Matcher,
120{
121 type Item = NodeMatch<'t, D>;
122 fn next(&mut self) -> Option<Self::Item> {
123 loop {
124 let match_depth = self.traversal.get_current_depth();
125 let node = self.traversal.next()?;
126 let pass_named = !self.named || node.is_named();
127 if let Some(node_match) = pass_named.then(|| self.matcher.match_node(node)).flatten() {
128 self.mark_match(Some(match_depth));
129 return Some(node_match);
130 } else {
131 self.mark_match(None);
132 }
133 }
134 }
135}
136
137pub trait Algorithm {
138 type Traversal<'t, L: LanguageExt>: Traversal<'t, StrDoc<L>>;
139 fn traverse<L: LanguageExt>(node: Node<'_, StrDoc<L>>) -> Self::Traversal<'_, L>;
140}
141
142pub struct PreOrder;
143impl Algorithm for PreOrder {
144 type Traversal<'t, L: LanguageExt> = Pre<'t, L>;
145 fn traverse<L: LanguageExt>(node: Node<'_, StrDoc<L>>) -> Self::Traversal<'_, L> {
146 Pre::new(&node)
147 }
148}
149pub struct PostOrder;
150impl Algorithm for PostOrder {
151 type Traversal<'t, L: LanguageExt> = Post<'t, L>;
152 fn traverse<L: LanguageExt>(node: Node<'_, StrDoc<L>>) -> Self::Traversal<'_, L> {
153 Post::new(&node)
154 }
155}
156
157pub trait Traversal<'t, D: Doc + 't>: Iterator<Item = Node<'t, D>> {
162 fn calibrate_for_match(&mut self, depth: Option<usize>);
165 fn get_current_depth(&self) -> usize;
169}
170
171pub struct TsPre<'tree> {
173 cursor: ts::TreeCursor<'tree>,
174 start_id: Option<usize>,
177 current_depth: usize,
178}
179
180impl<'tree> TsPre<'tree> {
181 pub fn new(node: &ts::Node<'tree>) -> Self {
182 Self {
183 cursor: node.walk(),
184 start_id: Some(node.id()),
185 current_depth: 0,
186 }
187 }
188 fn step_down(&mut self) -> bool {
189 if self.cursor.goto_first_child() {
190 self.current_depth += 1;
191 true
192 } else {
193 false
194 }
195 }
196
197 fn trace_up(&mut self, start: usize) {
199 let cursor = &mut self.cursor;
200 while cursor.node().id() != start {
201 if cursor.goto_next_sibling() {
203 return;
204 }
205 self.current_depth -= 1;
206 if !cursor.goto_parent() {
208 break;
211 }
212 }
213 self.start_id = None;
215 }
216}
217
218impl<'tree> Iterator for TsPre<'tree> {
220 type Item = ts::Node<'tree>;
221 fn next(&mut self) -> Option<Self::Item> {
226 let start = self.start_id?;
228 let cursor = &mut self.cursor;
229 let inner = cursor.node(); let ret = Some(inner);
231 if self.step_down() {
233 return ret;
234 }
235 self.trace_up(start);
238 ret
239 }
240}
241
242pub struct Pre<'tree, L: LanguageExt> {
243 root: &'tree Root<StrDoc<L>>,
244 inner: TsPre<'tree>,
245}
246impl<'tree, L: LanguageExt> Iterator for Pre<'tree, L> {
247 type Item = Node<'tree, StrDoc<L>>;
248 fn next(&mut self) -> Option<Self::Item> {
249 let inner = self.inner.next()?;
250 Some(self.root.adopt(inner))
251 }
252}
253
254impl<'t, L: LanguageExt> Pre<'t, L> {
255 pub fn new(node: &Node<'t, StrDoc<L>>) -> Self {
256 let inner = TsPre::new(&node.inner);
257 Self {
258 root: node.root,
259 inner,
260 }
261 }
262}
263
264impl<'t, L: LanguageExt> Traversal<'t, StrDoc<L>> for Pre<'t, L> {
265 fn calibrate_for_match(&mut self, depth: Option<usize>) {
266 let Some(depth) = depth else {
268 return;
269 };
270 if self.inner.current_depth <= depth {
272 return;
273 }
274 debug_assert!(self.inner.current_depth > depth);
275 if let Some(start) = self.inner.start_id {
276 self.inner.cursor.goto_parent();
278 self.inner.trace_up(start);
279 }
280 }
281
282 #[inline]
283 fn get_current_depth(&self) -> usize {
284 self.inner.current_depth
285 }
286}
287
288pub struct Prune<'tree, L: LanguageExt> {
294 cursor: ts::TreeCursor<'tree>,
295 root: &'tree Root<StrDoc<L>>,
296 start_id: Option<usize>,
297 current_depth: usize,
298}
299
300#[derive(Clone, Copy, Debug, Eq, PartialEq)]
305pub struct PruneSubtree<'tree> {
306 root_id: usize,
307 root_depth: usize,
308 tree: PhantomData<&'tree ()>,
309}
310
311impl<'tree, L: LanguageExt> Prune<'tree, L> {
312 pub fn new(node: &Node<'tree, StrDoc<L>>) -> Self {
313 Self {
314 cursor: node.inner.walk(),
315 root: node.root,
316 start_id: Some(node.inner.id()),
317 current_depth: 0,
318 }
319 }
320
321 pub fn current_node(&self) -> Option<Node<'tree, StrDoc<L>>> {
322 self.start_id.map(|_| self.root.adopt(self.cursor.node()))
323 }
324
325 pub fn current_subtree(&self) -> PruneSubtree<'tree> {
326 debug_assert!(self.start_id.is_some());
327 PruneSubtree {
328 root_id: self.cursor.node().id(),
329 root_depth: self.current_depth,
330 tree: PhantomData,
331 }
332 }
333
334 pub fn has_left_subtree(&self, subtree: PruneSubtree<'tree>) -> bool {
335 if self.start_id.is_none() {
336 return true;
337 }
338 self.current_depth <= subtree.root_depth && self.cursor.node().id() != subtree.root_id
339 }
340
341 pub fn descend(&mut self) {
342 if self.cursor.goto_first_child() {
343 self.current_depth += 1;
344 return;
345 }
346 self.skip_subtree();
347 }
348
349 pub fn skip_subtree(&mut self) {
350 let Some(start) = self.start_id else {
351 return;
352 };
353 while self.cursor.node().id() != start {
354 if self.cursor.goto_next_sibling() {
355 return;
356 }
357 self.current_depth = self.current_depth.saturating_sub(1);
358 if !self.cursor.goto_parent() {
359 break;
360 }
361 }
362 self.start_id = None;
363 }
364}
365
366pub struct Post<'tree, L: LanguageExt> {
368 cursor: ts::TreeCursor<'tree>,
369 root: &'tree Root<StrDoc<L>>,
370 start_id: Option<usize>,
371 current_depth: usize,
372 match_depth: usize,
373}
374
375impl<'tree, L: LanguageExt> Post<'tree, L> {
377 pub fn new(node: &Node<'tree, StrDoc<L>>) -> Self {
378 let mut ret = Self {
379 cursor: node.inner.walk(),
380 root: node.root,
381 start_id: Some(node.inner.id()),
382 current_depth: 0,
383 match_depth: 0,
384 };
385 ret.trace_down();
386 ret
387 }
388 fn trace_down(&mut self) {
389 while self.cursor.goto_first_child() {
390 self.current_depth += 1;
391 }
392 }
393 fn step_up(&mut self) {
394 self.current_depth -= 1;
395 self.cursor.goto_parent();
396 }
397}
398
399impl<'tree, L: LanguageExt> Iterator for Post<'tree, L> {
401 type Item = Node<'tree, StrDoc<L>>;
402 fn next(&mut self) -> Option<Self::Item> {
403 let start = self.start_id?;
405 let cursor = &mut self.cursor;
406 let node = self.root.adopt(cursor.node());
407 if node.inner.id() == start {
409 self.start_id = None
410 } else if cursor.goto_next_sibling() {
411 self.trace_down();
413 } else {
414 self.step_up();
415 }
416 Some(node)
417 }
418}
419
420impl<'t, L: LanguageExt> Traversal<'t, StrDoc<L>> for Post<'t, L> {
421 fn calibrate_for_match(&mut self, depth: Option<usize>) {
422 if let Some(depth) = depth {
423 debug_assert!(depth >= self.match_depth);
426 self.match_depth = depth;
427 return;
428 }
429 if self.current_depth >= self.match_depth {
431 return;
432 }
433 let Some(start) = self.start_id else {
434 return;
435 };
436 while self.cursor.node().id() != start {
437 self.match_depth = self.current_depth;
438 if self.cursor.goto_next_sibling() {
439 self.trace_down();
441 return;
442 }
443 self.step_up();
444 }
445 self.start_id = None;
447 }
448
449 #[inline]
450 fn get_current_depth(&self) -> usize {
451 self.current_depth
452 }
453}
454
455pub struct Level<'tree, L: LanguageExt> {
460 deque: VecDeque<ts::Node<'tree>>,
461 cursor: ts::TreeCursor<'tree>,
462 root: &'tree Root<StrDoc<L>>,
463}
464
465impl<'tree, L: LanguageExt> Level<'tree, L> {
466 pub fn new(node: &Node<'tree, StrDoc<L>>) -> Self {
467 let mut deque = VecDeque::new();
468 deque.push_back(node.inner);
469 let cursor = node.inner.walk();
470 Self {
471 deque,
472 cursor,
473 root: node.root,
474 }
475 }
476}
477
478impl<'tree, L: LanguageExt> Iterator for Level<'tree, L> {
480 type Item = Node<'tree, StrDoc<L>>;
481 fn next(&mut self) -> Option<Self::Item> {
482 let inner = self.deque.pop_front()?;
483 let children = inner.children(&mut self.cursor);
484 self.deque.extend(children);
485 Some(self.root.adopt(inner))
486 }
487}
488
489#[cfg(test)]
490mod test {
491 use super::*;
492 use crate::language::Tsx;
493 use std::ops::Range;
494
495 fn pre_order(node: Node<StrDoc<Tsx>>) -> Vec<Range<usize>> {
497 let mut ret = vec![node.range()];
498 ret.extend(node.children().flat_map(pre_order));
499 ret
500 }
501
502 fn post_order(node: Node<StrDoc<Tsx>>) -> Vec<Range<usize>> {
504 let mut ret: Vec<_> = node.children().flat_map(post_order).collect();
505 ret.push(node.range());
506 ret
507 }
508
509 fn pre_order_equivalent(source: &str) {
510 let grep = Tsx.ast_grep(source);
511 let node = grep.root();
512 let iterative: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
513 let recursive = pre_order(node);
514 assert_eq!(iterative, recursive);
515 }
516
517 fn post_order_equivalent(source: &str) {
518 let grep = Tsx.ast_grep(source);
519 let node = grep.root();
520 let iterative: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
521 let recursive = post_order(node);
522 assert_eq!(iterative, recursive);
523 }
524
525 const CASES: &[&str] = &[
526 "console.log('hello world')",
527 "let a = (a, b, c)",
528 "function test() { let a = 1; let b = 2; a === b}",
529 "[[[[[[]]]]], 1 , 2 ,3]",
530 "class A { test() { class B {} } }",
531 ];
532
533 #[test]
534 fn tes_pre_order() {
535 for case in CASES {
536 pre_order_equivalent(case);
537 }
538 }
539
540 #[test]
541 fn test_prune_pre_order_skips_subtree() {
542 let grep = Tsx.ast_grep(
543 r#"
544function a() { foo(); }
545function b() { bar(); }
546"#,
547 );
548 let node = grep.root();
549 let mut traversal = Prune::new(&node);
550 let mut visited = vec![];
551 while let Some(node) = traversal.current_node() {
552 let kind = node.kind().into_owned();
553 let skip = kind == "function_declaration";
554 visited.push(kind);
555 if skip {
556 traversal.skip_subtree();
557 } else {
558 traversal.descend();
559 }
560 }
561
562 assert_eq!(
563 visited,
564 vec![
565 "program".to_string(),
566 "function_declaration".to_string(),
567 "function_declaration".to_string()
568 ]
569 );
570 }
571
572 #[test]
573 fn test_prune_subtree_scope_tracks_exit() {
574 let grep = Tsx.ast_grep(
575 r#"
576function a() { foo(); }
577function b() { bar(); }
578"#,
579 );
580 let node = grep.root();
581 let mut traversal = Prune::new(&node);
582 traversal.descend();
583 let subtree = traversal.current_subtree();
584 assert!(!traversal.has_left_subtree(subtree));
585 traversal.descend();
586 assert!(!traversal.has_left_subtree(subtree));
587 while traversal.current_node().is_some() && !traversal.has_left_subtree(subtree) {
588 traversal.skip_subtree();
589 }
590
591 let node = traversal
592 .current_node()
593 .expect("traversal should move to the next sibling");
594 assert_eq!(node.kind().as_ref(), "function_declaration");
595 assert!(traversal.has_left_subtree(subtree));
596 }
597
598 #[test]
599 fn test_post_order() {
600 for case in CASES {
601 post_order_equivalent(case);
602 }
603 }
604
605 #[test]
606 fn test_different_order() {
607 for case in CASES {
608 let grep = Tsx.ast_grep(case);
609 let node = grep.root();
610 let pre: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
611 let post: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
612 let level: Vec<_> = Level::new(&node).map(|n| n.range()).collect();
613 assert_ne!(pre, post);
614 assert_ne!(pre, level);
615 assert_ne!(post, level);
616 }
617 }
618
619 #[test]
620 fn test_fused_traversal() {
621 for case in CASES {
622 let grep = Tsx.ast_grep(case);
623 let node = grep.root();
624 let mut pre = Pre::new(&node);
625 let mut post = Post::new(&node);
626 while pre.next().is_some() {}
627 while post.next().is_some() {}
628 assert!(pre.next().is_none());
629 assert!(pre.next().is_none());
630 assert!(post.next().is_none());
631 assert!(post.next().is_none());
632 }
633 }
634
635 #[test]
636 fn test_non_root_traverse() {
637 let grep = Tsx.ast_grep("let a = 123; let b = 123;");
638 let node = grep.root();
639 let pre: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
640 let post: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
641 let node2 = node.child(0).unwrap();
642 let pre2: Vec<_> = Pre::new(&node2).map(|n| n.range()).collect();
643 let post2: Vec<_> = Post::new(&node2).map(|n| n.range()).collect();
644 assert_ne!(pre, pre2);
646 assert_ne!(post, post2);
647 assert!(pre[1..].starts_with(&pre2));
649 assert!(post.starts_with(&post2));
650 }
651
652 fn pre_order_with_matcher(node: Node<StrDoc<Tsx>>, matcher: &str) -> Vec<Range<usize>> {
653 if node.matches(matcher) {
654 vec![node.range()]
655 } else {
656 node
657 .children()
658 .flat_map(|n| pre_order_with_matcher(n, matcher))
659 .collect()
660 }
661 }
662
663 fn post_order_with_matcher(node: Node<StrDoc<Tsx>>, matcher: &str) -> Vec<Range<usize>> {
664 let mut ret: Vec<_> = node
665 .children()
666 .flat_map(|n| post_order_with_matcher(n, matcher))
667 .collect();
668 if ret.is_empty() && node.matches(matcher) {
669 ret.push(node.range());
670 }
671 ret
672 }
673
674 const MATCHER_CASES: &[&str] = &[
675 "Some(123)",
676 "Some(1, 2, Some(2))",
677 "NoMatch",
678 "NoMatch(Some(123))",
679 "Some(1, Some(2), Some(3))",
680 "Some(1, Some(2), Some(Some(3)))",
681 ];
682
683 #[test]
684 fn test_pre_order_visitor() {
685 let matcher = "Some($$$)";
686 for case in MATCHER_CASES {
687 let grep = Tsx.ast_grep(case);
688 let node = grep.root();
689 let recur = pre_order_with_matcher(grep.root(), matcher);
690 let visit: Vec<_> = Visitor::new(matcher)
691 .reentrant(false)
692 .visit(node)
693 .map(|n| n.range())
694 .collect();
695 assert_eq!(recur, visit);
696 }
697 }
698 #[test]
699 fn test_post_order_visitor() {
700 let matcher = "Some($$$)";
701 for case in MATCHER_CASES {
702 let grep = Tsx.ast_grep(case);
703 let node = grep.root();
704 let recur = post_order_with_matcher(grep.root(), matcher);
705 let visit: Vec<_> = Visitor::new(matcher)
706 .algorithm::<PostOrder>()
707 .reentrant(false)
708 .visit(node)
709 .map(|n| n.range())
710 .collect();
711 assert_eq!(recur, visit);
712 }
713 }
714
715 #[test]
717 fn test_traversal_leaf() {
718 let matcher = "true";
719 let case = "((((true))));true";
720 let grep = Tsx.ast_grep(case);
721 let recur = pre_order_with_matcher(grep.root(), matcher);
722 let visit: Vec<_> = Visitor::new(matcher)
723 .reentrant(false)
724 .visit(grep.root())
725 .map(|n| n.range())
726 .collect();
727 assert_eq!(recur, visit);
728 let recur = post_order_with_matcher(grep.root(), matcher);
729 let visit: Vec<_> = Visitor::new(matcher)
730 .algorithm::<PostOrder>()
731 .reentrant(false)
732 .visit(grep.root())
733 .map(|n| n.range())
734 .collect();
735 assert_eq!(recur, visit);
736 }
737}