Skip to main content

ast_grep_core/tree_sitter/
traversal.rs

1//! # Traverse Node AST
2//!
3//! ast-grep supports common tree traversal algorithms, including
4//! * Pre order traversal
5//! * Post order traversal
6//! * Level order traversal
7//!
8//! Note tree traversal can also be used with Matcher. A traversal with Matcher will
9//! produce a [`NodeMatch`] sequence where all items satisfies the Matcher.
10//!
11//! It is also possible to specify the reentrancy of a traversal.
12//! That is, we can control whether a matching node should be visited when it is nested within another match.
13//! For example, suppose we want to find all usages of calling `foo` in the source `foo(foo())`.
14//! The code has two matching calls and we can configure a traversal
15//! to report only the inner one, only the outer one or both.
16//!
17//! Pre and Post order traversals in this module are implemented using tree-sitter's cursor API without extra heap allocation.
18//! It is recommended to use traversal instead of tree recursion to avoid stack overflow and memory overhead.
19//! Level order is also included for completeness and should be used sparingly.
20
21use super::StrDoc;
22use crate::matcher::{Matcher, MatcherExt};
23use crate::tree_sitter::LanguageExt;
24use crate::{Doc, Node, NodeMatch, Root};
25
26use tree_sitter as ts;
27
28use std::collections::VecDeque;
29use std::marker::PhantomData;
30
31pub struct Visitor<M, A = PreOrder> {
32  /// Whether a node will match if it contains or is contained in another match.
33  reentrant: bool,
34  /// Whether visit named node only
35  named_only: bool,
36  /// optional matcher to filter nodes
37  matcher: M,
38  /// The algorithm to traverse the tree, can be pre/post/level order
39  algorithm: PhantomData<A>,
40}
41
42impl<M> Visitor<M> {
43  pub fn new(matcher: M) -> Visitor<M> {
44    Visitor {
45      reentrant: true,
46      named_only: false,
47      matcher,
48      algorithm: PhantomData,
49    }
50  }
51}
52
53impl<M, A> Visitor<M, A> {
54  pub fn algorithm<Algo>(self) -> Visitor<M, Algo> {
55    Visitor {
56      reentrant: self.reentrant,
57      named_only: self.named_only,
58      matcher: self.matcher,
59      algorithm: PhantomData,
60    }
61  }
62
63  pub fn reentrant(self, reentrant: bool) -> Self {
64    Self { reentrant, ..self }
65  }
66
67  pub fn named_only(self, named_only: bool) -> Self {
68    Self { named_only, ..self }
69  }
70}
71
72impl<M, A> Visitor<M, A>
73where
74  A: Algorithm,
75{
76  pub fn visit<L: LanguageExt>(
77    self,
78    node: Node<'_, StrDoc<L>>,
79  ) -> Visit<'_, StrDoc<L>, A::Traversal<'_, L>, M>
80  where
81    M: Matcher,
82  {
83    let traversal = A::traverse(node);
84    Visit {
85      reentrant: self.reentrant,
86      named: self.named_only,
87      matcher: self.matcher,
88      traversal,
89      lang: PhantomData,
90    }
91  }
92}
93
94pub struct Visit<'t, D, T, M> {
95  reentrant: bool,
96  named: bool,
97  matcher: M,
98  traversal: T,
99  lang: PhantomData<&'t D>,
100}
101impl<'t, D, T, M> Visit<'t, D, T, M>
102where
103  D: Doc + 't,
104  T: Traversal<'t, D>,
105  M: Matcher,
106{
107  #[inline]
108  fn mark_match(&mut self, depth: Option<usize>) {
109    if !self.reentrant {
110      self.traversal.calibrate_for_match(depth);
111    }
112  }
113}
114
115impl<'t, D, T, M> Iterator for Visit<'t, D, T, M>
116where
117  D: Doc + 't,
118  T: Traversal<'t, D>,
119  M: Matcher,
120{
121  type Item = NodeMatch<'t, D>;
122  fn next(&mut self) -> Option<Self::Item> {
123    loop {
124      let match_depth = self.traversal.get_current_depth();
125      let node = self.traversal.next()?;
126      let pass_named = !self.named || node.is_named();
127      if let Some(node_match) = pass_named.then(|| self.matcher.match_node(node)).flatten() {
128        self.mark_match(Some(match_depth));
129        return Some(node_match);
130      } else {
131        self.mark_match(None);
132      }
133    }
134  }
135}
136
137pub trait Algorithm {
138  type Traversal<'t, L: LanguageExt>: Traversal<'t, StrDoc<L>>;
139  fn traverse<L: LanguageExt>(node: Node<'_, StrDoc<L>>) -> Self::Traversal<'_, L>;
140}
141
142pub struct PreOrder;
143impl Algorithm for PreOrder {
144  type Traversal<'t, L: LanguageExt> = Pre<'t, L>;
145  fn traverse<L: LanguageExt>(node: Node<'_, StrDoc<L>>) -> Self::Traversal<'_, L> {
146    Pre::new(&node)
147  }
148}
149pub struct PostOrder;
150impl Algorithm for PostOrder {
151  type Traversal<'t, L: LanguageExt> = Post<'t, L>;
152  fn traverse<L: LanguageExt>(node: Node<'_, StrDoc<L>>) -> Self::Traversal<'_, L> {
153    Post::new(&node)
154  }
155}
156
157/// Traversal can iterate over node by using traversal algorithm.
158/// The `next` method should only handle normal, reentrant iteration.
159/// If reentrancy is not desired, traversal should mutate cursor in `calibrate_for_match`.
160/// Visit will maintain the matched node depth so traversal does not need to use extra field.
161pub trait Traversal<'t, D: Doc + 't>: Iterator<Item = Node<'t, D>> {
162  /// Calibrate cursor position to skip overlapping matches.
163  /// node depth will be passed if matched, otherwise None.
164  fn calibrate_for_match(&mut self, depth: Option<usize>);
165  /// Returns the current depth of cursor depth.
166  /// Cursor depth is incremented by 1 when moving from parent to child.
167  /// Cursor depth at Root node is 0.
168  fn get_current_depth(&self) -> usize;
169}
170
171/// Represents a pre-order traversal
172pub struct TsPre<'tree> {
173  cursor: ts::TreeCursor<'tree>,
174  // record the starting node, if we return back to starting point
175  // we should terminate the dfs.
176  start_id: Option<usize>,
177  current_depth: usize,
178}
179
180impl<'tree> TsPre<'tree> {
181  pub fn new(node: &ts::Node<'tree>) -> Self {
182    Self {
183      cursor: node.walk(),
184      start_id: Some(node.id()),
185      current_depth: 0,
186    }
187  }
188  fn step_down(&mut self) -> bool {
189    if self.cursor.goto_first_child() {
190      self.current_depth += 1;
191      true
192    } else {
193      false
194    }
195  }
196
197  // retrace back to ancestors and find next node to explore
198  fn trace_up(&mut self, start: usize) {
199    let cursor = &mut self.cursor;
200    while cursor.node().id() != start {
201      // try visit sibling nodes
202      if cursor.goto_next_sibling() {
203        return;
204      }
205      self.current_depth -= 1;
206      // go back to parent node
207      if !cursor.goto_parent() {
208        // it should never fail here. However, tree-sitter has bad parsing bugs
209        // stop to avoid panic. https://github.com/ast-grep/ast-grep/issues/713
210        break;
211      }
212    }
213    // terminate traversal here
214    self.start_id = None;
215  }
216}
217
218/// Amortized time complexity is O(NlgN), depending on branching factor.
219impl<'tree> Iterator for TsPre<'tree> {
220  type Item = ts::Node<'tree>;
221  // 1. Yield the node itself
222  // 2. Try visit the child node until no child available
223  // 3. Try visit next sibling after going back to parent
224  // 4. Repeat step 3 until returning to the starting node
225  fn next(&mut self) -> Option<Self::Item> {
226    // start_id will always be Some until the dfs terminates
227    let start = self.start_id?;
228    let cursor = &mut self.cursor;
229    let inner = cursor.node(); // get current node
230    let ret = Some(inner);
231    // try going to children first
232    if self.step_down() {
233      return ret;
234    }
235    // if no child available, go to ancestor nodes
236    // until we get to the starting point
237    self.trace_up(start);
238    ret
239  }
240}
241
242pub struct Pre<'tree, L: LanguageExt> {
243  root: &'tree Root<StrDoc<L>>,
244  inner: TsPre<'tree>,
245}
246impl<'tree, L: LanguageExt> Iterator for Pre<'tree, L> {
247  type Item = Node<'tree, StrDoc<L>>;
248  fn next(&mut self) -> Option<Self::Item> {
249    let inner = self.inner.next()?;
250    Some(self.root.adopt(inner))
251  }
252}
253
254impl<'t, L: LanguageExt> Pre<'t, L> {
255  pub fn new(node: &Node<'t, StrDoc<L>>) -> Self {
256    let inner = TsPre::new(&node.inner);
257    Self {
258      root: node.root,
259      inner,
260    }
261  }
262}
263
264impl<'t, L: LanguageExt> Traversal<'t, StrDoc<L>> for Pre<'t, L> {
265  fn calibrate_for_match(&mut self, depth: Option<usize>) {
266    // not entering the node, ignore
267    let Some(depth) = depth else {
268      return;
269    };
270    // if already entering sibling or traced up, ignore
271    if self.inner.current_depth <= depth {
272      return;
273    }
274    debug_assert!(self.inner.current_depth > depth);
275    if let Some(start) = self.inner.start_id {
276      // revert the step down
277      self.inner.cursor.goto_parent();
278      self.inner.trace_up(start);
279    }
280  }
281
282  #[inline]
283  fn get_current_depth(&self) -> usize {
284    self.inner.current_depth
285  }
286}
287
288/// Pre-order cursor traversal where the caller decides whether to enter children.
289///
290/// This is useful when matching the current node determines whether its whole
291/// subtree can be skipped. Unlike [`Pre`] plus reentrancy calibration, this
292/// traversal does not step into a child before the caller has made that choice.
293pub struct Prune<'tree, L: LanguageExt> {
294  cursor: ts::TreeCursor<'tree>,
295  root: &'tree Root<StrDoc<L>>,
296  start_id: Option<usize>,
297  current_depth: usize,
298}
299
300/// Opaque marker for a subtree in a [`Prune`] traversal.
301///
302/// Callers can store this when visiting a node and later ask whether traversal
303/// has moved past that node's subtree without depending on cursor depth.
304#[derive(Clone, Copy, Debug, Eq, PartialEq)]
305pub struct PruneSubtree<'tree> {
306  root_id: usize,
307  root_depth: usize,
308  tree: PhantomData<&'tree ()>,
309}
310
311impl<'tree, L: LanguageExt> Prune<'tree, L> {
312  pub fn new(node: &Node<'tree, StrDoc<L>>) -> Self {
313    Self {
314      cursor: node.inner.walk(),
315      root: node.root,
316      start_id: Some(node.inner.id()),
317      current_depth: 0,
318    }
319  }
320
321  pub fn current_node(&self) -> Option<Node<'tree, StrDoc<L>>> {
322    self.start_id.map(|_| self.root.adopt(self.cursor.node()))
323  }
324
325  pub fn current_subtree(&self) -> PruneSubtree<'tree> {
326    debug_assert!(self.start_id.is_some());
327    PruneSubtree {
328      root_id: self.cursor.node().id(),
329      root_depth: self.current_depth,
330      tree: PhantomData,
331    }
332  }
333
334  pub fn has_left_subtree(&self, subtree: PruneSubtree<'tree>) -> bool {
335    if self.start_id.is_none() {
336      return true;
337    }
338    self.current_depth <= subtree.root_depth && self.cursor.node().id() != subtree.root_id
339  }
340
341  pub fn descend(&mut self) {
342    if self.cursor.goto_first_child() {
343      self.current_depth += 1;
344      return;
345    }
346    self.skip_subtree();
347  }
348
349  pub fn skip_subtree(&mut self) {
350    let Some(start) = self.start_id else {
351      return;
352    };
353    while self.cursor.node().id() != start {
354      if self.cursor.goto_next_sibling() {
355        return;
356      }
357      self.current_depth = self.current_depth.saturating_sub(1);
358      if !self.cursor.goto_parent() {
359        break;
360      }
361    }
362    self.start_id = None;
363  }
364}
365
366/// Represents a post-order traversal
367pub struct Post<'tree, L: LanguageExt> {
368  cursor: ts::TreeCursor<'tree>,
369  root: &'tree Root<StrDoc<L>>,
370  start_id: Option<usize>,
371  current_depth: usize,
372  match_depth: usize,
373}
374
375/// Amortized time complexity is O(NlgN), depending on branching factor.
376impl<'tree, L: LanguageExt> Post<'tree, L> {
377  pub fn new(node: &Node<'tree, StrDoc<L>>) -> Self {
378    let mut ret = Self {
379      cursor: node.inner.walk(),
380      root: node.root,
381      start_id: Some(node.inner.id()),
382      current_depth: 0,
383      match_depth: 0,
384    };
385    ret.trace_down();
386    ret
387  }
388  fn trace_down(&mut self) {
389    while self.cursor.goto_first_child() {
390      self.current_depth += 1;
391    }
392  }
393  fn step_up(&mut self) {
394    self.current_depth -= 1;
395    self.cursor.goto_parent();
396  }
397}
398
399/// Amortized time complexity is O(NlgN), depending on branching factor.
400impl<'tree, L: LanguageExt> Iterator for Post<'tree, L> {
401  type Item = Node<'tree, StrDoc<L>>;
402  fn next(&mut self) -> Option<Self::Item> {
403    // start_id will always be Some until the dfs terminates
404    let start = self.start_id?;
405    let cursor = &mut self.cursor;
406    let node = self.root.adopt(cursor.node());
407    // return to start
408    if node.inner.id() == start {
409      self.start_id = None
410    } else if cursor.goto_next_sibling() {
411      // try visit sibling
412      self.trace_down();
413    } else {
414      self.step_up();
415    }
416    Some(node)
417  }
418}
419
420impl<'t, L: LanguageExt> Traversal<'t, StrDoc<L>> for Post<'t, L> {
421  fn calibrate_for_match(&mut self, depth: Option<usize>) {
422    if let Some(depth) = depth {
423      // Later matches' depth should always be greater than former matches.
424      // because we bump match_depth in `step_up` during traversal.
425      debug_assert!(depth >= self.match_depth);
426      self.match_depth = depth;
427      return;
428    }
429    // found new nodes to explore in trace_down, skip calibration.
430    if self.current_depth >= self.match_depth {
431      return;
432    }
433    let Some(start) = self.start_id else {
434      return;
435    };
436    while self.cursor.node().id() != start {
437      self.match_depth = self.current_depth;
438      if self.cursor.goto_next_sibling() {
439        // try visit sibling
440        self.trace_down();
441        return;
442      }
443      self.step_up();
444    }
445    // terminate because all ancestors are skipped
446    self.start_id = None;
447  }
448
449  #[inline]
450  fn get_current_depth(&self) -> usize {
451    self.current_depth
452  }
453}
454
455/// Represents a level-order traversal.
456/// It is implemented with [`VecDeque`] since quadratic backtracking is too time consuming.
457/// Though level-order is not used as frequently as other DFS traversals,
458/// traversing a big AST with level-order should be done with caution since it might increase the memory usage.
459pub struct Level<'tree, L: LanguageExt> {
460  deque: VecDeque<ts::Node<'tree>>,
461  cursor: ts::TreeCursor<'tree>,
462  root: &'tree Root<StrDoc<L>>,
463}
464
465impl<'tree, L: LanguageExt> Level<'tree, L> {
466  pub fn new(node: &Node<'tree, StrDoc<L>>) -> Self {
467    let mut deque = VecDeque::new();
468    deque.push_back(node.inner);
469    let cursor = node.inner.walk();
470    Self {
471      deque,
472      cursor,
473      root: node.root,
474    }
475  }
476}
477
478/// Time complexity is O(N). Space complexity is O(N)
479impl<'tree, L: LanguageExt> Iterator for Level<'tree, L> {
480  type Item = Node<'tree, StrDoc<L>>;
481  fn next(&mut self) -> Option<Self::Item> {
482    let inner = self.deque.pop_front()?;
483    let children = inner.children(&mut self.cursor);
484    self.deque.extend(children);
485    Some(self.root.adopt(inner))
486  }
487}
488
489#[cfg(test)]
490mod test {
491  use super::*;
492  use crate::language::Tsx;
493  use std::ops::Range;
494
495  // recursive pre order as baseline
496  fn pre_order(node: Node<StrDoc<Tsx>>) -> Vec<Range<usize>> {
497    let mut ret = vec![node.range()];
498    ret.extend(node.children().flat_map(pre_order));
499    ret
500  }
501
502  // recursion baseline
503  fn post_order(node: Node<StrDoc<Tsx>>) -> Vec<Range<usize>> {
504    let mut ret: Vec<_> = node.children().flat_map(post_order).collect();
505    ret.push(node.range());
506    ret
507  }
508
509  fn pre_order_equivalent(source: &str) {
510    let grep = Tsx.ast_grep(source);
511    let node = grep.root();
512    let iterative: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
513    let recursive = pre_order(node);
514    assert_eq!(iterative, recursive);
515  }
516
517  fn post_order_equivalent(source: &str) {
518    let grep = Tsx.ast_grep(source);
519    let node = grep.root();
520    let iterative: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
521    let recursive = post_order(node);
522    assert_eq!(iterative, recursive);
523  }
524
525  const CASES: &[&str] = &[
526    "console.log('hello world')",
527    "let a = (a, b, c)",
528    "function test() { let a = 1; let b = 2; a === b}",
529    "[[[[[[]]]]], 1 , 2 ,3]",
530    "class A { test() { class B {} } }",
531  ];
532
533  #[test]
534  fn tes_pre_order() {
535    for case in CASES {
536      pre_order_equivalent(case);
537    }
538  }
539
540  #[test]
541  fn test_prune_pre_order_skips_subtree() {
542    let grep = Tsx.ast_grep(
543      r#"
544function a() { foo(); }
545function b() { bar(); }
546"#,
547    );
548    let node = grep.root();
549    let mut traversal = Prune::new(&node);
550    let mut visited = vec![];
551    while let Some(node) = traversal.current_node() {
552      let kind = node.kind().into_owned();
553      let skip = kind == "function_declaration";
554      visited.push(kind);
555      if skip {
556        traversal.skip_subtree();
557      } else {
558        traversal.descend();
559      }
560    }
561
562    assert_eq!(
563      visited,
564      vec![
565        "program".to_string(),
566        "function_declaration".to_string(),
567        "function_declaration".to_string()
568      ]
569    );
570  }
571
572  #[test]
573  fn test_prune_subtree_scope_tracks_exit() {
574    let grep = Tsx.ast_grep(
575      r#"
576function a() { foo(); }
577function b() { bar(); }
578"#,
579    );
580    let node = grep.root();
581    let mut traversal = Prune::new(&node);
582    traversal.descend();
583    let subtree = traversal.current_subtree();
584    assert!(!traversal.has_left_subtree(subtree));
585    traversal.descend();
586    assert!(!traversal.has_left_subtree(subtree));
587    while traversal.current_node().is_some() && !traversal.has_left_subtree(subtree) {
588      traversal.skip_subtree();
589    }
590
591    let node = traversal
592      .current_node()
593      .expect("traversal should move to the next sibling");
594    assert_eq!(node.kind().as_ref(), "function_declaration");
595    assert!(traversal.has_left_subtree(subtree));
596  }
597
598  #[test]
599  fn test_post_order() {
600    for case in CASES {
601      post_order_equivalent(case);
602    }
603  }
604
605  #[test]
606  fn test_different_order() {
607    for case in CASES {
608      let grep = Tsx.ast_grep(case);
609      let node = grep.root();
610      let pre: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
611      let post: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
612      let level: Vec<_> = Level::new(&node).map(|n| n.range()).collect();
613      assert_ne!(pre, post);
614      assert_ne!(pre, level);
615      assert_ne!(post, level);
616    }
617  }
618
619  #[test]
620  fn test_fused_traversal() {
621    for case in CASES {
622      let grep = Tsx.ast_grep(case);
623      let node = grep.root();
624      let mut pre = Pre::new(&node);
625      let mut post = Post::new(&node);
626      while pre.next().is_some() {}
627      while post.next().is_some() {}
628      assert!(pre.next().is_none());
629      assert!(pre.next().is_none());
630      assert!(post.next().is_none());
631      assert!(post.next().is_none());
632    }
633  }
634
635  #[test]
636  fn test_non_root_traverse() {
637    let grep = Tsx.ast_grep("let a = 123; let b = 123;");
638    let node = grep.root();
639    let pre: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
640    let post: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
641    let node2 = node.child(0).unwrap();
642    let pre2: Vec<_> = Pre::new(&node2).map(|n| n.range()).collect();
643    let post2: Vec<_> = Post::new(&node2).map(|n| n.range()).collect();
644    // traversal should stop at node
645    assert_ne!(pre, pre2);
646    assert_ne!(post, post2);
647    // child traversal should be a part of parent traversal
648    assert!(pre[1..].starts_with(&pre2));
649    assert!(post.starts_with(&post2));
650  }
651
652  fn pre_order_with_matcher(node: Node<StrDoc<Tsx>>, matcher: &str) -> Vec<Range<usize>> {
653    if node.matches(matcher) {
654      vec![node.range()]
655    } else {
656      node
657        .children()
658        .flat_map(|n| pre_order_with_matcher(n, matcher))
659        .collect()
660    }
661  }
662
663  fn post_order_with_matcher(node: Node<StrDoc<Tsx>>, matcher: &str) -> Vec<Range<usize>> {
664    let mut ret: Vec<_> = node
665      .children()
666      .flat_map(|n| post_order_with_matcher(n, matcher))
667      .collect();
668    if ret.is_empty() && node.matches(matcher) {
669      ret.push(node.range());
670    }
671    ret
672  }
673
674  const MATCHER_CASES: &[&str] = &[
675    "Some(123)",
676    "Some(1, 2, Some(2))",
677    "NoMatch",
678    "NoMatch(Some(123))",
679    "Some(1, Some(2), Some(3))",
680    "Some(1, Some(2), Some(Some(3)))",
681  ];
682
683  #[test]
684  fn test_pre_order_visitor() {
685    let matcher = "Some($$$)";
686    for case in MATCHER_CASES {
687      let grep = Tsx.ast_grep(case);
688      let node = grep.root();
689      let recur = pre_order_with_matcher(grep.root(), matcher);
690      let visit: Vec<_> = Visitor::new(matcher)
691        .reentrant(false)
692        .visit(node)
693        .map(|n| n.range())
694        .collect();
695      assert_eq!(recur, visit);
696    }
697  }
698  #[test]
699  fn test_post_order_visitor() {
700    let matcher = "Some($$$)";
701    for case in MATCHER_CASES {
702      let grep = Tsx.ast_grep(case);
703      let node = grep.root();
704      let recur = post_order_with_matcher(grep.root(), matcher);
705      let visit: Vec<_> = Visitor::new(matcher)
706        .algorithm::<PostOrder>()
707        .reentrant(false)
708        .visit(node)
709        .map(|n| n.range())
710        .collect();
711      assert_eq!(recur, visit);
712    }
713  }
714
715  // match a leaf node will trace_up the cursor
716  #[test]
717  fn test_traversal_leaf() {
718    let matcher = "true";
719    let case = "((((true))));true";
720    let grep = Tsx.ast_grep(case);
721    let recur = pre_order_with_matcher(grep.root(), matcher);
722    let visit: Vec<_> = Visitor::new(matcher)
723      .reentrant(false)
724      .visit(grep.root())
725      .map(|n| n.range())
726      .collect();
727    assert_eq!(recur, visit);
728    let recur = post_order_with_matcher(grep.root(), matcher);
729    let visit: Vec<_> = Visitor::new(matcher)
730      .algorithm::<PostOrder>()
731      .reentrant(false)
732      .visit(grep.root())
733      .map(|n| n.range())
734      .collect();
735    assert_eq!(recur, visit);
736  }
737}