ast_grep_core/tree_sitter/
traversal.rs

1//! # Traverse Node AST
2//!
3//! ast-grep supports common tree traversal algorithms, including
4//! * Pre order traversal
5//! * Post order traversal
6//! * Level order traversal
7//!
8//! Note tree traversal can also be used with Matcher. A traversal with Matcher will
9//! produce a [`NodeMatch`] sequence where all items satisfies the Matcher.
10//!
11//! It is also possible to specify the reentrancy of a traversal.
12//! That is, we can control whether a matching node should be visited when it is nested within another match.
13//! For example, suppose we want to find all usages of calling `foo` in the source `foo(foo())`.
14//! The code has two matching calls and we can configure a traversal
15//! to report only the inner one, only the outer one or both.
16//!
17//! Pre and Post order traversals in this module are implemented using tree-sitter's cursor API without extra heap allocation.
18//! It is recommended to use traversal instead of tree recursion to avoid stack overflow and memory overhead.
19//! Level order is also included for completeness and should be used sparingly.
20
21use super::StrDoc;
22use crate::matcher::{Matcher, MatcherExt};
23use crate::tree_sitter::LanguageExt;
24use crate::{Doc, Node, NodeMatch, Root};
25
26use tree_sitter as ts;
27
28use std::collections::VecDeque;
29use std::marker::PhantomData;
30
31pub struct Visitor<M, A = PreOrder> {
32  /// Whether a node will match if it contains or is contained in another match.
33  reentrant: bool,
34  /// Whether visit named node only
35  named_only: bool,
36  /// optional matcher to filter nodes
37  matcher: M,
38  /// The algorithm to traverse the tree, can be pre/post/level order
39  algorithm: PhantomData<A>,
40}
41
42impl<M> Visitor<M> {
43  pub fn new(matcher: M) -> Visitor<M> {
44    Visitor {
45      reentrant: true,
46      named_only: false,
47      matcher,
48      algorithm: PhantomData,
49    }
50  }
51}
52
53impl<M, A> Visitor<M, A> {
54  pub fn algorithm<Algo>(self) -> Visitor<M, Algo> {
55    Visitor {
56      reentrant: self.reentrant,
57      named_only: self.named_only,
58      matcher: self.matcher,
59      algorithm: PhantomData,
60    }
61  }
62
63  pub fn reentrant(self, reentrant: bool) -> Self {
64    Self { reentrant, ..self }
65  }
66
67  pub fn named_only(self, named_only: bool) -> Self {
68    Self { named_only, ..self }
69  }
70}
71
72impl<M, A> Visitor<M, A>
73where
74  A: Algorithm,
75{
76  pub fn visit<L: LanguageExt>(
77    self,
78    node: Node<'_, StrDoc<L>>,
79  ) -> Visit<'_, StrDoc<L>, A::Traversal<'_, L>, M>
80  where
81    M: Matcher,
82  {
83    let traversal = A::traverse(node);
84    Visit {
85      reentrant: self.reentrant,
86      named: self.named_only,
87      matcher: self.matcher,
88      traversal,
89      lang: PhantomData,
90    }
91  }
92}
93
94pub struct Visit<'t, D, T, M> {
95  reentrant: bool,
96  named: bool,
97  matcher: M,
98  traversal: T,
99  lang: PhantomData<&'t D>,
100}
101impl<'t, D, T, M> Visit<'t, D, T, M>
102where
103  D: Doc + 't,
104  T: Traversal<'t, D>,
105  M: Matcher,
106{
107  #[inline]
108  fn mark_match(&mut self, depth: Option<usize>) {
109    if !self.reentrant {
110      self.traversal.calibrate_for_match(depth);
111    }
112  }
113}
114
115impl<'t, D, T, M> Iterator for Visit<'t, D, T, M>
116where
117  D: Doc + 't,
118  T: Traversal<'t, D>,
119  M: Matcher,
120{
121  type Item = NodeMatch<'t, D>;
122  fn next(&mut self) -> Option<Self::Item> {
123    loop {
124      let match_depth = self.traversal.get_current_depth();
125      let node = self.traversal.next()?;
126      let pass_named = !self.named || node.is_named();
127      if let Some(node_match) = pass_named.then(|| self.matcher.match_node(node)).flatten() {
128        self.mark_match(Some(match_depth));
129        return Some(node_match);
130      } else {
131        self.mark_match(None);
132      }
133    }
134  }
135}
136
137pub trait Algorithm {
138  type Traversal<'t, L: LanguageExt>: Traversal<'t, StrDoc<L>>;
139  fn traverse<L: LanguageExt>(node: Node<'_, StrDoc<L>>) -> Self::Traversal<'_, L>;
140}
141
142pub struct PreOrder;
143impl Algorithm for PreOrder {
144  type Traversal<'t, L: LanguageExt> = Pre<'t, L>;
145  fn traverse<L: LanguageExt>(node: Node<'_, StrDoc<L>>) -> Self::Traversal<'_, L> {
146    Pre::new(&node)
147  }
148}
149pub struct PostOrder;
150impl Algorithm for PostOrder {
151  type Traversal<'t, L: LanguageExt> = Post<'t, L>;
152  fn traverse<L: LanguageExt>(node: Node<'_, StrDoc<L>>) -> Self::Traversal<'_, L> {
153    Post::new(&node)
154  }
155}
156
157/// Traversal can iterate over node by using traversal algorithm.
158/// The `next` method should only handle normal, reentrant iteration.
159/// If reentrancy is not desired, traversal should mutate cursor in `calibrate_for_match`.
160/// Visit will maintain the matched node depth so traversal does not need to use extra field.
161pub trait Traversal<'t, D: Doc + 't>: Iterator<Item = Node<'t, D>> {
162  /// Calibrate cursor position to skip overlapping matches.
163  /// node depth will be passed if matched, otherwise None.
164  fn calibrate_for_match(&mut self, depth: Option<usize>);
165  /// Returns the current depth of cursor depth.
166  /// Cursor depth is incremented by 1 when moving from parent to child.
167  /// Cursor depth at Root node is 0.
168  fn get_current_depth(&self) -> usize;
169}
170
171/// Represents a pre-order traversal
172pub struct TsPre<'tree> {
173  cursor: ts::TreeCursor<'tree>,
174  // record the starting node, if we return back to starting point
175  // we should terminate the dfs.
176  start_id: Option<usize>,
177  current_depth: usize,
178}
179
180impl<'tree> TsPre<'tree> {
181  pub fn new(node: &ts::Node<'tree>) -> Self {
182    Self {
183      cursor: node.walk(),
184      start_id: Some(node.id()),
185      current_depth: 0,
186    }
187  }
188  fn step_down(&mut self) -> bool {
189    if self.cursor.goto_first_child() {
190      self.current_depth += 1;
191      true
192    } else {
193      false
194    }
195  }
196
197  // retrace back to ancestors and find next node to explore
198  fn trace_up(&mut self, start: usize) {
199    let cursor = &mut self.cursor;
200    while cursor.node().id() != start {
201      // try visit sibling nodes
202      if cursor.goto_next_sibling() {
203        return;
204      }
205      self.current_depth -= 1;
206      // go back to parent node
207      if !cursor.goto_parent() {
208        // it should never fail here. However, tree-sitter has bad parsing bugs
209        // stop to avoid panic. https://github.com/ast-grep/ast-grep/issues/713
210        break;
211      }
212    }
213    // terminate traversal here
214    self.start_id = None;
215  }
216}
217
218/// Amortized time complexity is O(NlgN), depending on branching factor.
219impl<'tree> Iterator for TsPre<'tree> {
220  type Item = ts::Node<'tree>;
221  // 1. Yield the node itself
222  // 2. Try visit the child node until no child available
223  // 3. Try visit next sibling after going back to parent
224  // 4. Repeat step 3 until returning to the starting node
225  fn next(&mut self) -> Option<Self::Item> {
226    // start_id will always be Some until the dfs terminates
227    let start = self.start_id?;
228    let cursor = &mut self.cursor;
229    let inner = cursor.node(); // get current node
230    let ret = Some(inner);
231    // try going to children first
232    if self.step_down() {
233      return ret;
234    }
235    // if no child available, go to ancestor nodes
236    // until we get to the starting point
237    self.trace_up(start);
238    ret
239  }
240}
241
242pub struct Pre<'tree, L: LanguageExt> {
243  root: &'tree Root<StrDoc<L>>,
244  inner: TsPre<'tree>,
245}
246impl<'tree, L: LanguageExt> Iterator for Pre<'tree, L> {
247  type Item = Node<'tree, StrDoc<L>>;
248  fn next(&mut self) -> Option<Self::Item> {
249    let inner = self.inner.next()?;
250    Some(self.root.adopt(inner))
251  }
252}
253
254impl<'t, L: LanguageExt> Pre<'t, L> {
255  pub fn new(node: &Node<'t, StrDoc<L>>) -> Self {
256    let inner = TsPre::new(&node.inner);
257    Self {
258      root: node.root,
259      inner,
260    }
261  }
262}
263
264impl<'t, L: LanguageExt> Traversal<'t, StrDoc<L>> for Pre<'t, L> {
265  fn calibrate_for_match(&mut self, depth: Option<usize>) {
266    // not entering the node, ignore
267    let Some(depth) = depth else {
268      return;
269    };
270    // if already entering sibling or traced up, ignore
271    if self.inner.current_depth <= depth {
272      return;
273    }
274    debug_assert!(self.inner.current_depth > depth);
275    if let Some(start) = self.inner.start_id {
276      // revert the step down
277      self.inner.cursor.goto_parent();
278      self.inner.trace_up(start);
279    }
280  }
281
282  #[inline]
283  fn get_current_depth(&self) -> usize {
284    self.inner.current_depth
285  }
286}
287
288/// Represents a post-order traversal
289pub struct Post<'tree, L: LanguageExt> {
290  cursor: ts::TreeCursor<'tree>,
291  root: &'tree Root<StrDoc<L>>,
292  start_id: Option<usize>,
293  current_depth: usize,
294  match_depth: usize,
295}
296
297/// Amortized time complexity is O(NlgN), depending on branching factor.
298impl<'tree, L: LanguageExt> Post<'tree, L> {
299  pub fn new(node: &Node<'tree, StrDoc<L>>) -> Self {
300    let mut ret = Self {
301      cursor: node.inner.walk(),
302      root: node.root,
303      start_id: Some(node.inner.id()),
304      current_depth: 0,
305      match_depth: 0,
306    };
307    ret.trace_down();
308    ret
309  }
310  fn trace_down(&mut self) {
311    while self.cursor.goto_first_child() {
312      self.current_depth += 1;
313    }
314  }
315  fn step_up(&mut self) {
316    self.current_depth -= 1;
317    self.cursor.goto_parent();
318  }
319}
320
321/// Amortized time complexity is O(NlgN), depending on branching factor.
322impl<'tree, L: LanguageExt> Iterator for Post<'tree, L> {
323  type Item = Node<'tree, StrDoc<L>>;
324  fn next(&mut self) -> Option<Self::Item> {
325    // start_id will always be Some until the dfs terminates
326    let start = self.start_id?;
327    let cursor = &mut self.cursor;
328    let node = self.root.adopt(cursor.node());
329    // return to start
330    if node.inner.id() == start {
331      self.start_id = None
332    } else if cursor.goto_next_sibling() {
333      // try visit sibling
334      self.trace_down();
335    } else {
336      self.step_up();
337    }
338    Some(node)
339  }
340}
341
342impl<'t, L: LanguageExt> Traversal<'t, StrDoc<L>> for Post<'t, L> {
343  fn calibrate_for_match(&mut self, depth: Option<usize>) {
344    if let Some(depth) = depth {
345      // Later matches' depth should always be greater than former matches.
346      // because we bump match_depth in `step_up` during traversal.
347      debug_assert!(depth >= self.match_depth);
348      self.match_depth = depth;
349      return;
350    }
351    // found new nodes to explore in trace_down, skip calibration.
352    if self.current_depth >= self.match_depth {
353      return;
354    }
355    let Some(start) = self.start_id else {
356      return;
357    };
358    while self.cursor.node().id() != start {
359      self.match_depth = self.current_depth;
360      if self.cursor.goto_next_sibling() {
361        // try visit sibling
362        self.trace_down();
363        return;
364      }
365      self.step_up();
366    }
367    // terminate because all ancestors are skipped
368    self.start_id = None;
369  }
370
371  #[inline]
372  fn get_current_depth(&self) -> usize {
373    self.current_depth
374  }
375}
376
377/// Represents a level-order traversal.
378/// It is implemented with [`VecDeque`] since quadratic backtracking is too time consuming.
379/// Though level-order is not used as frequently as other DFS traversals,
380/// traversing a big AST with level-order should be done with caution since it might increase the memory usage.
381pub struct Level<'tree, L: LanguageExt> {
382  deque: VecDeque<ts::Node<'tree>>,
383  cursor: ts::TreeCursor<'tree>,
384  root: &'tree Root<StrDoc<L>>,
385}
386
387impl<'tree, L: LanguageExt> Level<'tree, L> {
388  pub fn new(node: &Node<'tree, StrDoc<L>>) -> Self {
389    let mut deque = VecDeque::new();
390    deque.push_back(node.inner);
391    let cursor = node.inner.walk();
392    Self {
393      deque,
394      cursor,
395      root: node.root,
396    }
397  }
398}
399
400/// Time complexity is O(N). Space complexity is O(N)
401impl<'tree, L: LanguageExt> Iterator for Level<'tree, L> {
402  type Item = Node<'tree, StrDoc<L>>;
403  fn next(&mut self) -> Option<Self::Item> {
404    let inner = self.deque.pop_front()?;
405    let children = inner.children(&mut self.cursor);
406    self.deque.extend(children);
407    Some(self.root.adopt(inner))
408  }
409}
410
411#[cfg(test)]
412mod test {
413  use super::*;
414  use crate::language::Tsx;
415  use std::ops::Range;
416
417  // recursive pre order as baseline
418  fn pre_order(node: Node<StrDoc<Tsx>>) -> Vec<Range<usize>> {
419    let mut ret = vec![node.range()];
420    ret.extend(node.children().flat_map(pre_order));
421    ret
422  }
423
424  // recursion baseline
425  fn post_order(node: Node<StrDoc<Tsx>>) -> Vec<Range<usize>> {
426    let mut ret: Vec<_> = node.children().flat_map(post_order).collect();
427    ret.push(node.range());
428    ret
429  }
430
431  fn pre_order_equivalent(source: &str) {
432    let grep = Tsx.ast_grep(source);
433    let node = grep.root();
434    let iterative: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
435    let recursive = pre_order(node);
436    assert_eq!(iterative, recursive);
437  }
438
439  fn post_order_equivalent(source: &str) {
440    let grep = Tsx.ast_grep(source);
441    let node = grep.root();
442    let iterative: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
443    let recursive = post_order(node);
444    assert_eq!(iterative, recursive);
445  }
446
447  const CASES: &[&str] = &[
448    "console.log('hello world')",
449    "let a = (a, b, c)",
450    "function test() { let a = 1; let b = 2; a === b}",
451    "[[[[[[]]]]], 1 , 2 ,3]",
452    "class A { test() { class B {} } }",
453  ];
454
455  #[test]
456  fn tes_pre_order() {
457    for case in CASES {
458      pre_order_equivalent(case);
459    }
460  }
461
462  #[test]
463  fn test_post_order() {
464    for case in CASES {
465      post_order_equivalent(case);
466    }
467  }
468
469  #[test]
470  fn test_different_order() {
471    for case in CASES {
472      let grep = Tsx.ast_grep(case);
473      let node = grep.root();
474      let pre: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
475      let post: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
476      let level: Vec<_> = Level::new(&node).map(|n| n.range()).collect();
477      assert_ne!(pre, post);
478      assert_ne!(pre, level);
479      assert_ne!(post, level);
480    }
481  }
482
483  #[test]
484  fn test_fused_traversal() {
485    for case in CASES {
486      let grep = Tsx.ast_grep(case);
487      let node = grep.root();
488      let mut pre = Pre::new(&node);
489      let mut post = Post::new(&node);
490      while pre.next().is_some() {}
491      while post.next().is_some() {}
492      assert!(pre.next().is_none());
493      assert!(pre.next().is_none());
494      assert!(post.next().is_none());
495      assert!(post.next().is_none());
496    }
497  }
498
499  #[test]
500  fn test_non_root_traverse() {
501    let grep = Tsx.ast_grep("let a = 123; let b = 123;");
502    let node = grep.root();
503    let pre: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
504    let post: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
505    let node2 = node.child(0).unwrap();
506    let pre2: Vec<_> = Pre::new(&node2).map(|n| n.range()).collect();
507    let post2: Vec<_> = Post::new(&node2).map(|n| n.range()).collect();
508    // traversal should stop at node
509    assert_ne!(pre, pre2);
510    assert_ne!(post, post2);
511    // child traversal should be a part of parent traversal
512    assert!(pre[1..].starts_with(&pre2));
513    assert!(post.starts_with(&post2));
514  }
515
516  fn pre_order_with_matcher(node: Node<StrDoc<Tsx>>, matcher: &str) -> Vec<Range<usize>> {
517    if node.matches(matcher) {
518      vec![node.range()]
519    } else {
520      node
521        .children()
522        .flat_map(|n| pre_order_with_matcher(n, matcher))
523        .collect()
524    }
525  }
526
527  fn post_order_with_matcher(node: Node<StrDoc<Tsx>>, matcher: &str) -> Vec<Range<usize>> {
528    let mut ret: Vec<_> = node
529      .children()
530      .flat_map(|n| post_order_with_matcher(n, matcher))
531      .collect();
532    if ret.is_empty() && node.matches(matcher) {
533      ret.push(node.range());
534    }
535    ret
536  }
537
538  const MATCHER_CASES: &[&str] = &[
539    "Some(123)",
540    "Some(1, 2, Some(2))",
541    "NoMatch",
542    "NoMatch(Some(123))",
543    "Some(1, Some(2), Some(3))",
544    "Some(1, Some(2), Some(Some(3)))",
545  ];
546
547  #[test]
548  fn test_pre_order_visitor() {
549    let matcher = "Some($$$)";
550    for case in MATCHER_CASES {
551      let grep = Tsx.ast_grep(case);
552      let node = grep.root();
553      let recur = pre_order_with_matcher(grep.root(), matcher);
554      let visit: Vec<_> = Visitor::new(matcher)
555        .reentrant(false)
556        .visit(node)
557        .map(|n| n.range())
558        .collect();
559      assert_eq!(recur, visit);
560    }
561  }
562  #[test]
563  fn test_post_order_visitor() {
564    let matcher = "Some($$$)";
565    for case in MATCHER_CASES {
566      let grep = Tsx.ast_grep(case);
567      let node = grep.root();
568      let recur = post_order_with_matcher(grep.root(), matcher);
569      let visit: Vec<_> = Visitor::new(matcher)
570        .algorithm::<PostOrder>()
571        .reentrant(false)
572        .visit(node)
573        .map(|n| n.range())
574        .collect();
575      assert_eq!(recur, visit);
576    }
577  }
578
579  // match a leaf node will trace_up the cursor
580  #[test]
581  fn test_traversal_leaf() {
582    let matcher = "true";
583    let case = "((((true))));true";
584    let grep = Tsx.ast_grep(case);
585    let recur = pre_order_with_matcher(grep.root(), matcher);
586    let visit: Vec<_> = Visitor::new(matcher)
587      .reentrant(false)
588      .visit(grep.root())
589      .map(|n| n.range())
590      .collect();
591    assert_eq!(recur, visit);
592    let recur = post_order_with_matcher(grep.root(), matcher);
593    let visit: Vec<_> = Visitor::new(matcher)
594      .algorithm::<PostOrder>()
595      .reentrant(false)
596      .visit(grep.root())
597      .map(|n| n.range())
598      .collect();
599    assert_eq!(recur, visit);
600  }
601}