1use crate::matcher::{Matcher, MatcherExt};
22use crate::{Doc, Node, NodeMatch, Root};
23
24use tree_sitter as ts;
25
26use std::collections::VecDeque;
27use std::iter::FusedIterator;
28use std::marker::PhantomData;
29
30pub struct Visitor<M, A = PreOrder> {
31 reentrant: bool,
33 named_only: bool,
35 matcher: M,
37 algorithm: PhantomData<A>,
39}
40
41impl<M> Visitor<M> {
42 pub fn new(matcher: M) -> Visitor<M> {
43 Visitor {
44 reentrant: true,
45 named_only: false,
46 matcher,
47 algorithm: PhantomData,
48 }
49 }
50}
51
52impl<M, A> Visitor<M, A> {
53 pub fn algorithm<Algo>(self) -> Visitor<M, Algo> {
54 Visitor {
55 reentrant: self.reentrant,
56 named_only: self.named_only,
57 matcher: self.matcher,
58 algorithm: PhantomData,
59 }
60 }
61
62 pub fn reentrant(self, reentrant: bool) -> Self {
63 Self { reentrant, ..self }
64 }
65
66 pub fn named_only(self, named_only: bool) -> Self {
67 Self { named_only, ..self }
68 }
69}
70
71impl<M, A> Visitor<M, A>
72where
73 A: Algorithm,
74{
75 pub fn visit<D: Doc>(self, node: Node<D>) -> Visit<'_, D, A::Traversal<'_, D>, M>
76 where
77 M: Matcher<D::Lang>,
78 {
79 let traversal = A::traverse(node);
80 Visit {
81 reentrant: self.reentrant,
82 named: self.named_only,
83 matcher: self.matcher,
84 traversal,
85 lang: PhantomData,
86 }
87 }
88}
89
90pub struct Visit<'t, D, T, M> {
91 reentrant: bool,
92 named: bool,
93 matcher: M,
94 traversal: T,
95 lang: PhantomData<&'t D>,
96}
97impl<'t, D, T, M> Visit<'t, D, T, M>
98where
99 D: Doc + 't,
100 T: Traversal<'t, D>,
101 M: Matcher<D::Lang>,
102{
103 #[inline]
104 fn mark_match(&mut self, depth: Option<usize>) {
105 if !self.reentrant {
106 self.traversal.calibrate_for_match(depth);
107 }
108 }
109}
110
111impl<'t, D, T, M> Iterator for Visit<'t, D, T, M>
112where
113 D: Doc + 't,
114 T: Traversal<'t, D>,
115 M: Matcher<D::Lang>,
116{
117 type Item = NodeMatch<'t, D>;
118 fn next(&mut self) -> Option<Self::Item> {
119 loop {
120 let match_depth = self.traversal.get_current_depth();
121 let node = self.traversal.next()?;
122 let pass_named = !self.named || node.is_named();
123 if let Some(node_match) = pass_named.then(|| self.matcher.match_node(node)).flatten() {
124 self.mark_match(Some(match_depth));
125 return Some(node_match);
126 } else {
127 self.mark_match(None);
128 }
129 }
130 }
131}
132
133pub trait Algorithm {
134 type Traversal<'t, D: 't + Doc>: Traversal<'t, D>;
135 fn traverse<D: Doc>(node: Node<D>) -> Self::Traversal<'_, D>;
136}
137
138pub struct PreOrder;
139impl Algorithm for PreOrder {
140 type Traversal<'t, D: 't + Doc> = Pre<'t, D>;
141 fn traverse<D: Doc>(node: Node<D>) -> Self::Traversal<'_, D> {
142 Pre::new(&node)
143 }
144}
145pub struct PostOrder;
146impl Algorithm for PostOrder {
147 type Traversal<'t, D: 't + Doc> = Post<'t, D>;
148 fn traverse<D: Doc>(node: Node<D>) -> Self::Traversal<'_, D> {
149 Post::new(&node)
150 }
151}
152
153pub trait Traversal<'t, D: Doc + 't>: Iterator<Item = Node<'t, D>> {
158 fn calibrate_for_match(&mut self, depth: Option<usize>);
161 fn get_current_depth(&self) -> usize;
165}
166
167pub struct Pre<'tree, D: Doc> {
169 cursor: ts::TreeCursor<'tree>,
170 root: &'tree Root<D>,
171 start_id: Option<usize>,
174 current_depth: usize,
175}
176
177impl<'tree, D: Doc> Pre<'tree, D> {
178 pub fn new(node: &Node<'tree, D>) -> Self {
179 Self {
180 cursor: node.inner.walk(),
181 root: node.root,
182 start_id: Some(node.inner.id()),
183 current_depth: 0,
184 }
185 }
186 fn step_down(&mut self) -> bool {
187 if self.cursor.goto_first_child() {
188 self.current_depth += 1;
189 true
190 } else {
191 false
192 }
193 }
194
195 fn trace_up(&mut self, start: usize) {
197 let cursor = &mut self.cursor;
198 while cursor.node().id() != start {
199 if cursor.goto_next_sibling() {
201 return;
202 }
203 self.current_depth -= 1;
204 if !cursor.goto_parent() {
206 break;
209 }
210 }
211 self.start_id = None;
213 }
214}
215
216impl<'tree, D: Doc> Iterator for Pre<'tree, D> {
218 type Item = Node<'tree, D>;
219 fn next(&mut self) -> Option<Self::Item> {
224 let start = self.start_id?;
226 let cursor = &mut self.cursor;
227 let inner = cursor.node(); let ret = Some(self.root.adopt(inner));
229 if self.step_down() {
231 return ret;
232 }
233 self.trace_up(start);
236 ret
237 }
238}
239impl<D: Doc> FusedIterator for Pre<'_, D> {}
240
241impl<'t, D: Doc> Traversal<'t, D> for Pre<'t, D> {
242 fn calibrate_for_match(&mut self, depth: Option<usize>) {
243 let Some(depth) = depth else {
245 return;
246 };
247 if self.current_depth <= depth {
249 return;
250 }
251 debug_assert!(self.current_depth > depth);
252 if let Some(start) = self.start_id {
253 self.cursor.goto_parent();
255 self.trace_up(start);
256 }
257 }
258
259 #[inline]
260 fn get_current_depth(&self) -> usize {
261 self.current_depth
262 }
263}
264
265pub struct Post<'tree, D: Doc> {
267 cursor: ts::TreeCursor<'tree>,
268 root: &'tree Root<D>,
269 start_id: Option<usize>,
270 current_depth: usize,
271 match_depth: usize,
272}
273
274impl<'tree, D: Doc> Post<'tree, D> {
276 pub fn new(node: &Node<'tree, D>) -> Self {
277 let mut ret = Self {
278 cursor: node.inner.walk(),
279 root: node.root,
280 start_id: Some(node.inner.id()),
281 current_depth: 0,
282 match_depth: 0,
283 };
284 ret.trace_down();
285 ret
286 }
287 fn trace_down(&mut self) {
288 while self.cursor.goto_first_child() {
289 self.current_depth += 1;
290 }
291 }
292 fn step_up(&mut self) {
293 self.current_depth -= 1;
294 self.cursor.goto_parent();
295 }
296}
297
298impl<'tree, D: Doc> Iterator for Post<'tree, D> {
300 type Item = Node<'tree, D>;
301 fn next(&mut self) -> Option<Self::Item> {
302 let start = self.start_id?;
304 let cursor = &mut self.cursor;
305 let node = self.root.adopt(cursor.node());
306 if node.inner.id() == start {
308 self.start_id = None
309 } else if cursor.goto_next_sibling() {
310 self.trace_down();
312 } else {
313 self.step_up();
314 }
315 Some(node)
316 }
317}
318
319impl<D: Doc> FusedIterator for Post<'_, D> {}
320
321impl<'t, D: Doc> Traversal<'t, D> for Post<'t, D> {
322 fn calibrate_for_match(&mut self, depth: Option<usize>) {
323 if let Some(depth) = depth {
324 debug_assert!(depth >= self.match_depth);
327 self.match_depth = depth;
328 return;
329 }
330 if self.current_depth >= self.match_depth {
332 return;
333 }
334 let Some(start) = self.start_id else {
335 return;
336 };
337 while self.cursor.node().id() != start {
338 self.match_depth = self.current_depth;
339 if self.cursor.goto_next_sibling() {
340 self.trace_down();
342 return;
343 }
344 self.step_up();
345 }
346 self.start_id = None;
348 }
349
350 #[inline]
351 fn get_current_depth(&self) -> usize {
352 self.current_depth
353 }
354}
355
356pub struct Level<'tree, D: Doc> {
361 deque: VecDeque<ts::Node<'tree>>,
362 cursor: ts::TreeCursor<'tree>,
363 root: &'tree Root<D>,
364}
365
366impl<'tree, D: Doc> Level<'tree, D> {
367 pub fn new(node: &Node<'tree, D>) -> Self {
368 let mut deque = VecDeque::new();
369 deque.push_back(node.inner.clone());
370 let cursor = node.inner.walk();
371 Self {
372 deque,
373 cursor,
374 root: node.root,
375 }
376 }
377}
378
379impl<'tree, D: Doc> Iterator for Level<'tree, D> {
381 type Item = Node<'tree, D>;
382 fn next(&mut self) -> Option<Self::Item> {
383 let inner = self.deque.pop_front()?;
384 let children = inner.children(&mut self.cursor);
385 self.deque.extend(children);
386 Some(self.root.adopt(inner))
387 }
388}
389impl<D: Doc> FusedIterator for Level<'_, D> {}
390
391#[cfg(test)]
392mod test {
393 use super::*;
394 use crate::language::{Language, Tsx};
395 use crate::StrDoc;
396 use std::ops::Range;
397
398 fn pre_order(node: Node<StrDoc<Tsx>>) -> Vec<Range<usize>> {
400 let mut ret = vec![node.range()];
401 ret.extend(node.children().flat_map(pre_order));
402 ret
403 }
404
405 fn post_order(node: Node<StrDoc<Tsx>>) -> Vec<Range<usize>> {
407 let mut ret: Vec<_> = node.children().flat_map(post_order).collect();
408 ret.push(node.range());
409 ret
410 }
411
412 fn pre_order_equivalent(source: &str) {
413 let grep = Tsx.ast_grep(source);
414 let node = grep.root();
415 let iterative: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
416 let recursive = pre_order(node);
417 assert_eq!(iterative, recursive);
418 }
419
420 fn post_order_equivalent(source: &str) {
421 let grep = Tsx.ast_grep(source);
422 let node = grep.root();
423 let iterative: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
424 let recursive = post_order(node);
425 assert_eq!(iterative, recursive);
426 }
427
428 const CASES: &[&str] = &[
429 "console.log('hello world')",
430 "let a = (a, b, c)",
431 "function test() { let a = 1; let b = 2; a === b}",
432 "[[[[[[]]]]], 1 , 2 ,3]",
433 "class A { test() { class B {} } }",
434 ];
435
436 #[test]
437 fn tes_pre_order() {
438 for case in CASES {
439 pre_order_equivalent(case);
440 }
441 }
442
443 #[test]
444 fn test_post_order() {
445 for case in CASES {
446 post_order_equivalent(case);
447 }
448 }
449
450 #[test]
451 fn test_different_order() {
452 for case in CASES {
453 let grep = Tsx.ast_grep(case);
454 let node = grep.root();
455 let pre: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
456 let post: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
457 let level: Vec<_> = Level::new(&node).map(|n| n.range()).collect();
458 assert_ne!(pre, post);
459 assert_ne!(pre, level);
460 assert_ne!(post, level);
461 }
462 }
463
464 #[test]
465 fn test_fused_traversal() {
466 for case in CASES {
467 let grep = Tsx.ast_grep(case);
468 let node = grep.root();
469 let mut pre = Pre::new(&node);
470 let mut post = Post::new(&node);
471 while pre.next().is_some() {}
472 while post.next().is_some() {}
473 assert!(pre.next().is_none());
474 assert!(pre.next().is_none());
475 assert!(post.next().is_none());
476 assert!(post.next().is_none());
477 }
478 }
479
480 #[test]
481 fn test_non_root_traverse() {
482 let grep = Tsx.ast_grep("let a = 123; let b = 123;");
483 let node = grep.root();
484 let pre: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
485 let post: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
486 let node2 = node.child(0).unwrap();
487 let pre2: Vec<_> = Pre::new(&node2).map(|n| n.range()).collect();
488 let post2: Vec<_> = Post::new(&node2).map(|n| n.range()).collect();
489 assert_ne!(pre, pre2);
491 assert_ne!(post, post2);
492 assert!(pre[1..].starts_with(&pre2));
494 assert!(post.starts_with(&post2));
495 }
496
497 fn pre_order_with_matcher(node: Node<StrDoc<Tsx>>, matcher: &str) -> Vec<Range<usize>> {
498 if node.matches(matcher) {
499 vec![node.range()]
500 } else {
501 node
502 .children()
503 .flat_map(|n| pre_order_with_matcher(n, matcher))
504 .collect()
505 }
506 }
507
508 fn post_order_with_matcher(node: Node<StrDoc<Tsx>>, matcher: &str) -> Vec<Range<usize>> {
509 let mut ret: Vec<_> = node
510 .children()
511 .flat_map(|n| post_order_with_matcher(n, matcher))
512 .collect();
513 if ret.is_empty() && node.matches(matcher) {
514 ret.push(node.range());
515 }
516 ret
517 }
518
519 const MATCHER_CASES: &[&str] = &[
520 "Some(123)",
521 "Some(1, 2, Some(2))",
522 "NoMatch",
523 "NoMatch(Some(123))",
524 "Some(1, Some(2), Some(3))",
525 "Some(1, Some(2), Some(Some(3)))",
526 ];
527
528 #[test]
529 fn test_pre_order_visitor() {
530 let matcher = "Some($$$)";
531 for case in MATCHER_CASES {
532 let grep = Tsx.ast_grep(case);
533 let node = grep.root();
534 let recur = pre_order_with_matcher(grep.root(), matcher);
535 let visit: Vec<_> = Visitor::new(matcher)
536 .reentrant(false)
537 .visit(node)
538 .map(|n| n.range())
539 .collect();
540 assert_eq!(recur, visit);
541 }
542 }
543 #[test]
544 fn test_post_order_visitor() {
545 let matcher = "Some($$$)";
546 for case in MATCHER_CASES {
547 let grep = Tsx.ast_grep(case);
548 let node = grep.root();
549 let recur = post_order_with_matcher(grep.root(), matcher);
550 let visit: Vec<_> = Visitor::new(matcher)
551 .algorithm::<PostOrder>()
552 .reentrant(false)
553 .visit(node)
554 .map(|n| n.range())
555 .collect();
556 assert_eq!(recur, visit);
557 }
558 }
559
560 #[test]
562 fn test_traversal_leaf() {
563 let matcher = "true";
564 let case = "((((true))));true";
565 let grep = Tsx.ast_grep(case);
566 let recur = pre_order_with_matcher(grep.root(), matcher);
567 let visit: Vec<_> = Visitor::new(matcher)
568 .reentrant(false)
569 .visit(grep.root())
570 .map(|n| n.range())
571 .collect();
572 assert_eq!(recur, visit);
573 let recur = post_order_with_matcher(grep.root(), matcher);
574 let visit: Vec<_> = Visitor::new(matcher)
575 .algorithm::<PostOrder>()
576 .reentrant(false)
577 .visit(grep.root())
578 .map(|n| n.range())
579 .collect();
580 assert_eq!(recur, visit);
581 }
582}