ast_grep_core/
traversal.rs

1//! # Traverse Node AST
2//!
3//! ast-grep supports common tree traversal algorithms, including
4//! * Pre order traversal
5//! * Post order traversal
6//! * Level order traversal
7//!
8//! Note tree traversal can also be used with Matcher. A traversal with Matcher will
9//! produce a [`NodeMatch`] sequence where all items satisfies the Matcher.
10//!
11//! It is also possible to specify the reentrancy of a traversal.
12//! That is, we can control whether a matching node should be visited when it is nested within another match.
13//! For example, suppose we want to find all usages of calling `foo` in the source `foo(foo())`.
14//! The code has two matching calls and we can configure a traversal
15//! to report only the inner one, only the outer one or both.
16//!
17//! Pre and Post order traversals in this module are implemented using tree-sitter's cursor API without extra heap allocation.
18//! It is recommended to use traversal instead of tree recursion to avoid stack overflow and memory overhead.
19//! Level order is also included for completeness and should be used sparingly.
20
21use crate::matcher::{Matcher, MatcherExt};
22use crate::{Doc, Node, NodeMatch, Root};
23
24use tree_sitter as ts;
25
26use std::collections::VecDeque;
27use std::iter::FusedIterator;
28use std::marker::PhantomData;
29
30pub struct Visitor<M, A = PreOrder> {
31  /// Whether a node will match if it contains or is contained in another match.
32  reentrant: bool,
33  /// Whether visit named node only
34  named_only: bool,
35  /// optional matcher to filter nodes
36  matcher: M,
37  /// The algorithm to traverse the tree, can be pre/post/level order
38  algorithm: PhantomData<A>,
39}
40
41impl<M> Visitor<M> {
42  pub fn new(matcher: M) -> Visitor<M> {
43    Visitor {
44      reentrant: true,
45      named_only: false,
46      matcher,
47      algorithm: PhantomData,
48    }
49  }
50}
51
52impl<M, A> Visitor<M, A> {
53  pub fn algorithm<Algo>(self) -> Visitor<M, Algo> {
54    Visitor {
55      reentrant: self.reentrant,
56      named_only: self.named_only,
57      matcher: self.matcher,
58      algorithm: PhantomData,
59    }
60  }
61
62  pub fn reentrant(self, reentrant: bool) -> Self {
63    Self { reentrant, ..self }
64  }
65
66  pub fn named_only(self, named_only: bool) -> Self {
67    Self { named_only, ..self }
68  }
69}
70
71impl<M, A> Visitor<M, A>
72where
73  A: Algorithm,
74{
75  pub fn visit<D: Doc>(self, node: Node<D>) -> Visit<'_, D, A::Traversal<'_, D>, M>
76  where
77    M: Matcher<D::Lang>,
78  {
79    let traversal = A::traverse(node);
80    Visit {
81      reentrant: self.reentrant,
82      named: self.named_only,
83      matcher: self.matcher,
84      traversal,
85      lang: PhantomData,
86    }
87  }
88}
89
90pub struct Visit<'t, D, T, M> {
91  reentrant: bool,
92  named: bool,
93  matcher: M,
94  traversal: T,
95  lang: PhantomData<&'t D>,
96}
97impl<'t, D, T, M> Visit<'t, D, T, M>
98where
99  D: Doc + 't,
100  T: Traversal<'t, D>,
101  M: Matcher<D::Lang>,
102{
103  #[inline]
104  fn mark_match(&mut self, depth: Option<usize>) {
105    if !self.reentrant {
106      self.traversal.calibrate_for_match(depth);
107    }
108  }
109}
110
111impl<'t, D, T, M> Iterator for Visit<'t, D, T, M>
112where
113  D: Doc + 't,
114  T: Traversal<'t, D>,
115  M: Matcher<D::Lang>,
116{
117  type Item = NodeMatch<'t, D>;
118  fn next(&mut self) -> Option<Self::Item> {
119    loop {
120      let match_depth = self.traversal.get_current_depth();
121      let node = self.traversal.next()?;
122      let pass_named = !self.named || node.is_named();
123      if let Some(node_match) = pass_named.then(|| self.matcher.match_node(node)).flatten() {
124        self.mark_match(Some(match_depth));
125        return Some(node_match);
126      } else {
127        self.mark_match(None);
128      }
129    }
130  }
131}
132
133pub trait Algorithm {
134  type Traversal<'t, D: 't + Doc>: Traversal<'t, D>;
135  fn traverse<D: Doc>(node: Node<D>) -> Self::Traversal<'_, D>;
136}
137
138pub struct PreOrder;
139impl Algorithm for PreOrder {
140  type Traversal<'t, D: 't + Doc> = Pre<'t, D>;
141  fn traverse<D: Doc>(node: Node<D>) -> Self::Traversal<'_, D> {
142    Pre::new(&node)
143  }
144}
145pub struct PostOrder;
146impl Algorithm for PostOrder {
147  type Traversal<'t, D: 't + Doc> = Post<'t, D>;
148  fn traverse<D: Doc>(node: Node<D>) -> Self::Traversal<'_, D> {
149    Post::new(&node)
150  }
151}
152
153/// Traversal can iterate over node by using traversal algorithm.
154/// The `next` method should only handle normal, reentrant iteration.
155/// If reentrancy is not desired, traversal should mutate cursor in `calibrate_for_match`.
156/// Visit will maintain the matched node depth so traversal does not need to use extra field.
157pub trait Traversal<'t, D: Doc + 't>: Iterator<Item = Node<'t, D>> {
158  /// Calibrate cursor position to skip overlapping matches.
159  /// node depth will be passed if matched, otherwise None.
160  fn calibrate_for_match(&mut self, depth: Option<usize>);
161  /// Returns the current depth of cursor depth.
162  /// Cursor depth is incremented by 1 when moving from parent to child.
163  /// Cursor depth at Root node is 0.
164  fn get_current_depth(&self) -> usize;
165}
166
167/// Represents a pre-order traversal
168pub struct Pre<'tree, D: Doc> {
169  cursor: ts::TreeCursor<'tree>,
170  root: &'tree Root<D>,
171  // record the starting node, if we return back to starting point
172  // we should terminate the dfs.
173  start_id: Option<usize>,
174  current_depth: usize,
175}
176
177impl<'tree, D: Doc> Pre<'tree, D> {
178  pub fn new(node: &Node<'tree, D>) -> Self {
179    Self {
180      cursor: node.inner.walk(),
181      root: node.root,
182      start_id: Some(node.inner.id()),
183      current_depth: 0,
184    }
185  }
186  fn step_down(&mut self) -> bool {
187    if self.cursor.goto_first_child() {
188      self.current_depth += 1;
189      true
190    } else {
191      false
192    }
193  }
194
195  // retrace back to ancestors and find next node to explore
196  fn trace_up(&mut self, start: usize) {
197    let cursor = &mut self.cursor;
198    while cursor.node().id() != start {
199      // try visit sibling nodes
200      if cursor.goto_next_sibling() {
201        return;
202      }
203      self.current_depth -= 1;
204      // go back to parent node
205      if !cursor.goto_parent() {
206        // it should never fail here. However, tree-sitter has bad parsing bugs
207        // stop to avoid panic. https://github.com/ast-grep/ast-grep/issues/713
208        break;
209      }
210    }
211    // terminate traversal here
212    self.start_id = None;
213  }
214}
215
216/// Amortized time complexity is O(NlgN), depending on branching factor.
217impl<'tree, D: Doc> Iterator for Pre<'tree, D> {
218  type Item = Node<'tree, D>;
219  // 1. Yield the node itself
220  // 2. Try visit the child node until no child available
221  // 3. Try visit next sibling after going back to parent
222  // 4. Repeat step 3 until returning to the starting node
223  fn next(&mut self) -> Option<Self::Item> {
224    // start_id will always be Some until the dfs terminates
225    let start = self.start_id?;
226    let cursor = &mut self.cursor;
227    let inner = cursor.node(); // get current node
228    let ret = Some(self.root.adopt(inner));
229    // try going to children first
230    if self.step_down() {
231      return ret;
232    }
233    // if no child available, go to ancestor nodes
234    // until we get to the starting point
235    self.trace_up(start);
236    ret
237  }
238}
239impl<D: Doc> FusedIterator for Pre<'_, D> {}
240
241impl<'t, D: Doc> Traversal<'t, D> for Pre<'t, D> {
242  fn calibrate_for_match(&mut self, depth: Option<usize>) {
243    // not entering the node, ignore
244    let Some(depth) = depth else {
245      return;
246    };
247    // if already entering sibling or traced up, ignore
248    if self.current_depth <= depth {
249      return;
250    }
251    debug_assert!(self.current_depth > depth);
252    if let Some(start) = self.start_id {
253      // revert the step down
254      self.cursor.goto_parent();
255      self.trace_up(start);
256    }
257  }
258
259  #[inline]
260  fn get_current_depth(&self) -> usize {
261    self.current_depth
262  }
263}
264
265/// Represents a post-order traversal
266pub struct Post<'tree, D: Doc> {
267  cursor: ts::TreeCursor<'tree>,
268  root: &'tree Root<D>,
269  start_id: Option<usize>,
270  current_depth: usize,
271  match_depth: usize,
272}
273
274/// Amortized time complexity is O(NlgN), depending on branching factor.
275impl<'tree, D: Doc> Post<'tree, D> {
276  pub fn new(node: &Node<'tree, D>) -> Self {
277    let mut ret = Self {
278      cursor: node.inner.walk(),
279      root: node.root,
280      start_id: Some(node.inner.id()),
281      current_depth: 0,
282      match_depth: 0,
283    };
284    ret.trace_down();
285    ret
286  }
287  fn trace_down(&mut self) {
288    while self.cursor.goto_first_child() {
289      self.current_depth += 1;
290    }
291  }
292  fn step_up(&mut self) {
293    self.current_depth -= 1;
294    self.cursor.goto_parent();
295  }
296}
297
298/// Amortized time complexity is O(NlgN), depending on branching factor.
299impl<'tree, D: Doc> Iterator for Post<'tree, D> {
300  type Item = Node<'tree, D>;
301  fn next(&mut self) -> Option<Self::Item> {
302    // start_id will always be Some until the dfs terminates
303    let start = self.start_id?;
304    let cursor = &mut self.cursor;
305    let node = self.root.adopt(cursor.node());
306    // return to start
307    if node.inner.id() == start {
308      self.start_id = None
309    } else if cursor.goto_next_sibling() {
310      // try visit sibling
311      self.trace_down();
312    } else {
313      self.step_up();
314    }
315    Some(node)
316  }
317}
318
319impl<D: Doc> FusedIterator for Post<'_, D> {}
320
321impl<'t, D: Doc> Traversal<'t, D> for Post<'t, D> {
322  fn calibrate_for_match(&mut self, depth: Option<usize>) {
323    if let Some(depth) = depth {
324      // Later matches' depth should always be greater than former matches.
325      // because we bump match_depth in `step_up` during traversal.
326      debug_assert!(depth >= self.match_depth);
327      self.match_depth = depth;
328      return;
329    }
330    // found new nodes to explore in trace_down, skip calibration.
331    if self.current_depth >= self.match_depth {
332      return;
333    }
334    let Some(start) = self.start_id else {
335      return;
336    };
337    while self.cursor.node().id() != start {
338      self.match_depth = self.current_depth;
339      if self.cursor.goto_next_sibling() {
340        // try visit sibling
341        self.trace_down();
342        return;
343      }
344      self.step_up();
345    }
346    // terminate because all ancestors are skipped
347    self.start_id = None;
348  }
349
350  #[inline]
351  fn get_current_depth(&self) -> usize {
352    self.current_depth
353  }
354}
355
356/// Represents a level-order traversal.
357/// It is implemented with [`VecDeque`] since quadratic backtracking is too time consuming.
358/// Though level-order is not used as frequently as other DFS traversals,
359/// traversing a big AST with level-order should be done with caution since it might increase the memory usage.
360pub struct Level<'tree, D: Doc> {
361  deque: VecDeque<ts::Node<'tree>>,
362  cursor: ts::TreeCursor<'tree>,
363  root: &'tree Root<D>,
364}
365
366impl<'tree, D: Doc> Level<'tree, D> {
367  pub fn new(node: &Node<'tree, D>) -> Self {
368    let mut deque = VecDeque::new();
369    deque.push_back(node.inner.clone());
370    let cursor = node.inner.walk();
371    Self {
372      deque,
373      cursor,
374      root: node.root,
375    }
376  }
377}
378
379/// Time complexity is O(N). Space complexity is O(N)
380impl<'tree, D: Doc> Iterator for Level<'tree, D> {
381  type Item = Node<'tree, D>;
382  fn next(&mut self) -> Option<Self::Item> {
383    let inner = self.deque.pop_front()?;
384    let children = inner.children(&mut self.cursor);
385    self.deque.extend(children);
386    Some(self.root.adopt(inner))
387  }
388}
389impl<D: Doc> FusedIterator for Level<'_, D> {}
390
391#[cfg(test)]
392mod test {
393  use super::*;
394  use crate::language::{Language, Tsx};
395  use crate::StrDoc;
396  use std::ops::Range;
397
398  // recursive pre order as baseline
399  fn pre_order(node: Node<StrDoc<Tsx>>) -> Vec<Range<usize>> {
400    let mut ret = vec![node.range()];
401    ret.extend(node.children().flat_map(pre_order));
402    ret
403  }
404
405  // recursion baseline
406  fn post_order(node: Node<StrDoc<Tsx>>) -> Vec<Range<usize>> {
407    let mut ret: Vec<_> = node.children().flat_map(post_order).collect();
408    ret.push(node.range());
409    ret
410  }
411
412  fn pre_order_equivalent(source: &str) {
413    let grep = Tsx.ast_grep(source);
414    let node = grep.root();
415    let iterative: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
416    let recursive = pre_order(node);
417    assert_eq!(iterative, recursive);
418  }
419
420  fn post_order_equivalent(source: &str) {
421    let grep = Tsx.ast_grep(source);
422    let node = grep.root();
423    let iterative: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
424    let recursive = post_order(node);
425    assert_eq!(iterative, recursive);
426  }
427
428  const CASES: &[&str] = &[
429    "console.log('hello world')",
430    "let a = (a, b, c)",
431    "function test() { let a = 1; let b = 2; a === b}",
432    "[[[[[[]]]]], 1 , 2 ,3]",
433    "class A { test() { class B {} } }",
434  ];
435
436  #[test]
437  fn tes_pre_order() {
438    for case in CASES {
439      pre_order_equivalent(case);
440    }
441  }
442
443  #[test]
444  fn test_post_order() {
445    for case in CASES {
446      post_order_equivalent(case);
447    }
448  }
449
450  #[test]
451  fn test_different_order() {
452    for case in CASES {
453      let grep = Tsx.ast_grep(case);
454      let node = grep.root();
455      let pre: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
456      let post: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
457      let level: Vec<_> = Level::new(&node).map(|n| n.range()).collect();
458      assert_ne!(pre, post);
459      assert_ne!(pre, level);
460      assert_ne!(post, level);
461    }
462  }
463
464  #[test]
465  fn test_fused_traversal() {
466    for case in CASES {
467      let grep = Tsx.ast_grep(case);
468      let node = grep.root();
469      let mut pre = Pre::new(&node);
470      let mut post = Post::new(&node);
471      while pre.next().is_some() {}
472      while post.next().is_some() {}
473      assert!(pre.next().is_none());
474      assert!(pre.next().is_none());
475      assert!(post.next().is_none());
476      assert!(post.next().is_none());
477    }
478  }
479
480  #[test]
481  fn test_non_root_traverse() {
482    let grep = Tsx.ast_grep("let a = 123; let b = 123;");
483    let node = grep.root();
484    let pre: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
485    let post: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
486    let node2 = node.child(0).unwrap();
487    let pre2: Vec<_> = Pre::new(&node2).map(|n| n.range()).collect();
488    let post2: Vec<_> = Post::new(&node2).map(|n| n.range()).collect();
489    // traversal should stop at node
490    assert_ne!(pre, pre2);
491    assert_ne!(post, post2);
492    // child traversal should be a part of parent traversal
493    assert!(pre[1..].starts_with(&pre2));
494    assert!(post.starts_with(&post2));
495  }
496
497  fn pre_order_with_matcher(node: Node<StrDoc<Tsx>>, matcher: &str) -> Vec<Range<usize>> {
498    if node.matches(matcher) {
499      vec![node.range()]
500    } else {
501      node
502        .children()
503        .flat_map(|n| pre_order_with_matcher(n, matcher))
504        .collect()
505    }
506  }
507
508  fn post_order_with_matcher(node: Node<StrDoc<Tsx>>, matcher: &str) -> Vec<Range<usize>> {
509    let mut ret: Vec<_> = node
510      .children()
511      .flat_map(|n| post_order_with_matcher(n, matcher))
512      .collect();
513    if ret.is_empty() && node.matches(matcher) {
514      ret.push(node.range());
515    }
516    ret
517  }
518
519  const MATCHER_CASES: &[&str] = &[
520    "Some(123)",
521    "Some(1, 2, Some(2))",
522    "NoMatch",
523    "NoMatch(Some(123))",
524    "Some(1, Some(2), Some(3))",
525    "Some(1, Some(2), Some(Some(3)))",
526  ];
527
528  #[test]
529  fn test_pre_order_visitor() {
530    let matcher = "Some($$$)";
531    for case in MATCHER_CASES {
532      let grep = Tsx.ast_grep(case);
533      let node = grep.root();
534      let recur = pre_order_with_matcher(grep.root(), matcher);
535      let visit: Vec<_> = Visitor::new(matcher)
536        .reentrant(false)
537        .visit(node)
538        .map(|n| n.range())
539        .collect();
540      assert_eq!(recur, visit);
541    }
542  }
543  #[test]
544  fn test_post_order_visitor() {
545    let matcher = "Some($$$)";
546    for case in MATCHER_CASES {
547      let grep = Tsx.ast_grep(case);
548      let node = grep.root();
549      let recur = post_order_with_matcher(grep.root(), matcher);
550      let visit: Vec<_> = Visitor::new(matcher)
551        .algorithm::<PostOrder>()
552        .reentrant(false)
553        .visit(node)
554        .map(|n| n.range())
555        .collect();
556      assert_eq!(recur, visit);
557    }
558  }
559
560  // match a leaf node will trace_up the cursor
561  #[test]
562  fn test_traversal_leaf() {
563    let matcher = "true";
564    let case = "((((true))));true";
565    let grep = Tsx.ast_grep(case);
566    let recur = pre_order_with_matcher(grep.root(), matcher);
567    let visit: Vec<_> = Visitor::new(matcher)
568      .reentrant(false)
569      .visit(grep.root())
570      .map(|n| n.range())
571      .collect();
572    assert_eq!(recur, visit);
573    let recur = post_order_with_matcher(grep.root(), matcher);
574    let visit: Vec<_> = Visitor::new(matcher)
575      .algorithm::<PostOrder>()
576      .reentrant(false)
577      .visit(grep.root())
578      .map(|n| n.range())
579      .collect();
580    assert_eq!(recur, visit);
581  }
582}