1use super::StrDoc;
22use crate::matcher::{Matcher, MatcherExt};
23use crate::tree_sitter::LanguageExt;
24use crate::{Doc, Node, NodeMatch, Root};
25
26use tree_sitter as ts;
27
28use std::collections::VecDeque;
29use std::marker::PhantomData;
30
31pub struct Visitor<M, A = PreOrder> {
32 reentrant: bool,
34 named_only: bool,
36 matcher: M,
38 algorithm: PhantomData<A>,
40}
41
42impl<M> Visitor<M> {
43 pub fn new(matcher: M) -> Visitor<M> {
44 Visitor {
45 reentrant: true,
46 named_only: false,
47 matcher,
48 algorithm: PhantomData,
49 }
50 }
51}
52
53impl<M, A> Visitor<M, A> {
54 pub fn algorithm<Algo>(self) -> Visitor<M, Algo> {
55 Visitor {
56 reentrant: self.reentrant,
57 named_only: self.named_only,
58 matcher: self.matcher,
59 algorithm: PhantomData,
60 }
61 }
62
63 pub fn reentrant(self, reentrant: bool) -> Self {
64 Self { reentrant, ..self }
65 }
66
67 pub fn named_only(self, named_only: bool) -> Self {
68 Self { named_only, ..self }
69 }
70}
71
72impl<M, A> Visitor<M, A>
73where
74 A: Algorithm,
75{
76 pub fn visit<L: LanguageExt>(
77 self,
78 node: Node<'_, StrDoc<L>>,
79 ) -> Visit<'_, StrDoc<L>, A::Traversal<'_, L>, M>
80 where
81 M: Matcher,
82 {
83 let traversal = A::traverse(node);
84 Visit {
85 reentrant: self.reentrant,
86 named: self.named_only,
87 matcher: self.matcher,
88 traversal,
89 lang: PhantomData,
90 }
91 }
92}
93
94pub struct Visit<'t, D, T, M> {
95 reentrant: bool,
96 named: bool,
97 matcher: M,
98 traversal: T,
99 lang: PhantomData<&'t D>,
100}
101impl<'t, D, T, M> Visit<'t, D, T, M>
102where
103 D: Doc + 't,
104 T: Traversal<'t, D>,
105 M: Matcher,
106{
107 #[inline]
108 fn mark_match(&mut self, depth: Option<usize>) {
109 if !self.reentrant {
110 self.traversal.calibrate_for_match(depth);
111 }
112 }
113}
114
115impl<'t, D, T, M> Iterator for Visit<'t, D, T, M>
116where
117 D: Doc + 't,
118 T: Traversal<'t, D>,
119 M: Matcher,
120{
121 type Item = NodeMatch<'t, D>;
122 fn next(&mut self) -> Option<Self::Item> {
123 loop {
124 let match_depth = self.traversal.get_current_depth();
125 let node = self.traversal.next()?;
126 let pass_named = !self.named || node.is_named();
127 if let Some(node_match) = pass_named.then(|| self.matcher.match_node(node)).flatten() {
128 self.mark_match(Some(match_depth));
129 return Some(node_match);
130 } else {
131 self.mark_match(None);
132 }
133 }
134 }
135}
136
137pub trait Algorithm {
138 type Traversal<'t, L: LanguageExt>: Traversal<'t, StrDoc<L>>;
139 fn traverse<L: LanguageExt>(node: Node<'_, StrDoc<L>>) -> Self::Traversal<'_, L>;
140}
141
142pub struct PreOrder;
143impl Algorithm for PreOrder {
144 type Traversal<'t, L: LanguageExt> = Pre<'t, L>;
145 fn traverse<L: LanguageExt>(node: Node<'_, StrDoc<L>>) -> Self::Traversal<'_, L> {
146 Pre::new(&node)
147 }
148}
149pub struct PostOrder;
150impl Algorithm for PostOrder {
151 type Traversal<'t, L: LanguageExt> = Post<'t, L>;
152 fn traverse<L: LanguageExt>(node: Node<'_, StrDoc<L>>) -> Self::Traversal<'_, L> {
153 Post::new(&node)
154 }
155}
156
157pub trait Traversal<'t, D: Doc + 't>: Iterator<Item = Node<'t, D>> {
162 fn calibrate_for_match(&mut self, depth: Option<usize>);
165 fn get_current_depth(&self) -> usize;
169}
170
171pub struct TsPre<'tree> {
173 cursor: ts::TreeCursor<'tree>,
174 start_id: Option<usize>,
177 current_depth: usize,
178}
179
180impl<'tree> TsPre<'tree> {
181 pub fn new(node: &ts::Node<'tree>) -> Self {
182 Self {
183 cursor: node.walk(),
184 start_id: Some(node.id()),
185 current_depth: 0,
186 }
187 }
188 fn step_down(&mut self) -> bool {
189 if self.cursor.goto_first_child() {
190 self.current_depth += 1;
191 true
192 } else {
193 false
194 }
195 }
196
197 fn trace_up(&mut self, start: usize) {
199 let cursor = &mut self.cursor;
200 while cursor.node().id() != start {
201 if cursor.goto_next_sibling() {
203 return;
204 }
205 self.current_depth -= 1;
206 if !cursor.goto_parent() {
208 break;
211 }
212 }
213 self.start_id = None;
215 }
216}
217
218impl<'tree> Iterator for TsPre<'tree> {
220 type Item = ts::Node<'tree>;
221 fn next(&mut self) -> Option<Self::Item> {
226 let start = self.start_id?;
228 let cursor = &mut self.cursor;
229 let inner = cursor.node(); let ret = Some(inner);
231 if self.step_down() {
233 return ret;
234 }
235 self.trace_up(start);
238 ret
239 }
240}
241
242pub struct Pre<'tree, L: LanguageExt> {
243 root: &'tree Root<StrDoc<L>>,
244 inner: TsPre<'tree>,
245}
246impl<'tree, L: LanguageExt> Iterator for Pre<'tree, L> {
247 type Item = Node<'tree, StrDoc<L>>;
248 fn next(&mut self) -> Option<Self::Item> {
249 let inner = self.inner.next()?;
250 Some(self.root.adopt(inner))
251 }
252}
253
254impl<'t, L: LanguageExt> Pre<'t, L> {
255 pub fn new(node: &Node<'t, StrDoc<L>>) -> Self {
256 let inner = TsPre::new(&node.inner);
257 Self {
258 root: node.root,
259 inner,
260 }
261 }
262}
263
264impl<'t, L: LanguageExt> Traversal<'t, StrDoc<L>> for Pre<'t, L> {
265 fn calibrate_for_match(&mut self, depth: Option<usize>) {
266 let Some(depth) = depth else {
268 return;
269 };
270 if self.inner.current_depth <= depth {
272 return;
273 }
274 debug_assert!(self.inner.current_depth > depth);
275 if let Some(start) = self.inner.start_id {
276 self.inner.cursor.goto_parent();
278 self.inner.trace_up(start);
279 }
280 }
281
282 #[inline]
283 fn get_current_depth(&self) -> usize {
284 self.inner.current_depth
285 }
286}
287
288pub struct Post<'tree, L: LanguageExt> {
290 cursor: ts::TreeCursor<'tree>,
291 root: &'tree Root<StrDoc<L>>,
292 start_id: Option<usize>,
293 current_depth: usize,
294 match_depth: usize,
295}
296
297impl<'tree, L: LanguageExt> Post<'tree, L> {
299 pub fn new(node: &Node<'tree, StrDoc<L>>) -> Self {
300 let mut ret = Self {
301 cursor: node.inner.walk(),
302 root: node.root,
303 start_id: Some(node.inner.id()),
304 current_depth: 0,
305 match_depth: 0,
306 };
307 ret.trace_down();
308 ret
309 }
310 fn trace_down(&mut self) {
311 while self.cursor.goto_first_child() {
312 self.current_depth += 1;
313 }
314 }
315 fn step_up(&mut self) {
316 self.current_depth -= 1;
317 self.cursor.goto_parent();
318 }
319}
320
321impl<'tree, L: LanguageExt> Iterator for Post<'tree, L> {
323 type Item = Node<'tree, StrDoc<L>>;
324 fn next(&mut self) -> Option<Self::Item> {
325 let start = self.start_id?;
327 let cursor = &mut self.cursor;
328 let node = self.root.adopt(cursor.node());
329 if node.inner.id() == start {
331 self.start_id = None
332 } else if cursor.goto_next_sibling() {
333 self.trace_down();
335 } else {
336 self.step_up();
337 }
338 Some(node)
339 }
340}
341
342impl<'t, L: LanguageExt> Traversal<'t, StrDoc<L>> for Post<'t, L> {
343 fn calibrate_for_match(&mut self, depth: Option<usize>) {
344 if let Some(depth) = depth {
345 debug_assert!(depth >= self.match_depth);
348 self.match_depth = depth;
349 return;
350 }
351 if self.current_depth >= self.match_depth {
353 return;
354 }
355 let Some(start) = self.start_id else {
356 return;
357 };
358 while self.cursor.node().id() != start {
359 self.match_depth = self.current_depth;
360 if self.cursor.goto_next_sibling() {
361 self.trace_down();
363 return;
364 }
365 self.step_up();
366 }
367 self.start_id = None;
369 }
370
371 #[inline]
372 fn get_current_depth(&self) -> usize {
373 self.current_depth
374 }
375}
376
377pub struct Level<'tree, L: LanguageExt> {
382 deque: VecDeque<ts::Node<'tree>>,
383 cursor: ts::TreeCursor<'tree>,
384 root: &'tree Root<StrDoc<L>>,
385}
386
387impl<'tree, L: LanguageExt> Level<'tree, L> {
388 pub fn new(node: &Node<'tree, StrDoc<L>>) -> Self {
389 let mut deque = VecDeque::new();
390 deque.push_back(node.inner);
391 let cursor = node.inner.walk();
392 Self {
393 deque,
394 cursor,
395 root: node.root,
396 }
397 }
398}
399
400impl<'tree, L: LanguageExt> Iterator for Level<'tree, L> {
402 type Item = Node<'tree, StrDoc<L>>;
403 fn next(&mut self) -> Option<Self::Item> {
404 let inner = self.deque.pop_front()?;
405 let children = inner.children(&mut self.cursor);
406 self.deque.extend(children);
407 Some(self.root.adopt(inner))
408 }
409}
410
411#[cfg(test)]
412mod test {
413 use super::*;
414 use crate::language::Tsx;
415 use std::ops::Range;
416
417 fn pre_order(node: Node<StrDoc<Tsx>>) -> Vec<Range<usize>> {
419 let mut ret = vec![node.range()];
420 ret.extend(node.children().flat_map(pre_order));
421 ret
422 }
423
424 fn post_order(node: Node<StrDoc<Tsx>>) -> Vec<Range<usize>> {
426 let mut ret: Vec<_> = node.children().flat_map(post_order).collect();
427 ret.push(node.range());
428 ret
429 }
430
431 fn pre_order_equivalent(source: &str) {
432 let grep = Tsx.ast_grep(source);
433 let node = grep.root();
434 let iterative: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
435 let recursive = pre_order(node);
436 assert_eq!(iterative, recursive);
437 }
438
439 fn post_order_equivalent(source: &str) {
440 let grep = Tsx.ast_grep(source);
441 let node = grep.root();
442 let iterative: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
443 let recursive = post_order(node);
444 assert_eq!(iterative, recursive);
445 }
446
447 const CASES: &[&str] = &[
448 "console.log('hello world')",
449 "let a = (a, b, c)",
450 "function test() { let a = 1; let b = 2; a === b}",
451 "[[[[[[]]]]], 1 , 2 ,3]",
452 "class A { test() { class B {} } }",
453 ];
454
455 #[test]
456 fn tes_pre_order() {
457 for case in CASES {
458 pre_order_equivalent(case);
459 }
460 }
461
462 #[test]
463 fn test_post_order() {
464 for case in CASES {
465 post_order_equivalent(case);
466 }
467 }
468
469 #[test]
470 fn test_different_order() {
471 for case in CASES {
472 let grep = Tsx.ast_grep(case);
473 let node = grep.root();
474 let pre: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
475 let post: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
476 let level: Vec<_> = Level::new(&node).map(|n| n.range()).collect();
477 assert_ne!(pre, post);
478 assert_ne!(pre, level);
479 assert_ne!(post, level);
480 }
481 }
482
483 #[test]
484 fn test_fused_traversal() {
485 for case in CASES {
486 let grep = Tsx.ast_grep(case);
487 let node = grep.root();
488 let mut pre = Pre::new(&node);
489 let mut post = Post::new(&node);
490 while pre.next().is_some() {}
491 while post.next().is_some() {}
492 assert!(pre.next().is_none());
493 assert!(pre.next().is_none());
494 assert!(post.next().is_none());
495 assert!(post.next().is_none());
496 }
497 }
498
499 #[test]
500 fn test_non_root_traverse() {
501 let grep = Tsx.ast_grep("let a = 123; let b = 123;");
502 let node = grep.root();
503 let pre: Vec<_> = Pre::new(&node).map(|n| n.range()).collect();
504 let post: Vec<_> = Post::new(&node).map(|n| n.range()).collect();
505 let node2 = node.child(0).unwrap();
506 let pre2: Vec<_> = Pre::new(&node2).map(|n| n.range()).collect();
507 let post2: Vec<_> = Post::new(&node2).map(|n| n.range()).collect();
508 assert_ne!(pre, pre2);
510 assert_ne!(post, post2);
511 assert!(pre[1..].starts_with(&pre2));
513 assert!(post.starts_with(&post2));
514 }
515
516 fn pre_order_with_matcher(node: Node<StrDoc<Tsx>>, matcher: &str) -> Vec<Range<usize>> {
517 if node.matches(matcher) {
518 vec![node.range()]
519 } else {
520 node
521 .children()
522 .flat_map(|n| pre_order_with_matcher(n, matcher))
523 .collect()
524 }
525 }
526
527 fn post_order_with_matcher(node: Node<StrDoc<Tsx>>, matcher: &str) -> Vec<Range<usize>> {
528 let mut ret: Vec<_> = node
529 .children()
530 .flat_map(|n| post_order_with_matcher(n, matcher))
531 .collect();
532 if ret.is_empty() && node.matches(matcher) {
533 ret.push(node.range());
534 }
535 ret
536 }
537
538 const MATCHER_CASES: &[&str] = &[
539 "Some(123)",
540 "Some(1, 2, Some(2))",
541 "NoMatch",
542 "NoMatch(Some(123))",
543 "Some(1, Some(2), Some(3))",
544 "Some(1, Some(2), Some(Some(3)))",
545 ];
546
547 #[test]
548 fn test_pre_order_visitor() {
549 let matcher = "Some($$$)";
550 for case in MATCHER_CASES {
551 let grep = Tsx.ast_grep(case);
552 let node = grep.root();
553 let recur = pre_order_with_matcher(grep.root(), matcher);
554 let visit: Vec<_> = Visitor::new(matcher)
555 .reentrant(false)
556 .visit(node)
557 .map(|n| n.range())
558 .collect();
559 assert_eq!(recur, visit);
560 }
561 }
562 #[test]
563 fn test_post_order_visitor() {
564 let matcher = "Some($$$)";
565 for case in MATCHER_CASES {
566 let grep = Tsx.ast_grep(case);
567 let node = grep.root();
568 let recur = post_order_with_matcher(grep.root(), matcher);
569 let visit: Vec<_> = Visitor::new(matcher)
570 .algorithm::<PostOrder>()
571 .reentrant(false)
572 .visit(node)
573 .map(|n| n.range())
574 .collect();
575 assert_eq!(recur, visit);
576 }
577 }
578
579 #[test]
581 fn test_traversal_leaf() {
582 let matcher = "true";
583 let case = "((((true))));true";
584 let grep = Tsx.ast_grep(case);
585 let recur = pre_order_with_matcher(grep.root(), matcher);
586 let visit: Vec<_> = Visitor::new(matcher)
587 .reentrant(false)
588 .visit(grep.root())
589 .map(|n| n.range())
590 .collect();
591 assert_eq!(recur, visit);
592 let recur = post_order_with_matcher(grep.root(), matcher);
593 let visit: Vec<_> = Visitor::new(matcher)
594 .algorithm::<PostOrder>()
595 .reentrant(false)
596 .visit(grep.root())
597 .map(|n| n.range())
598 .collect();
599 assert_eq!(recur, visit);
600 }
601}