tree_sitter_traversal/
lib.rs

1//! Iterators to traverse tree-sitter [`Tree`]s using a [`TreeCursor`],
2//! with a [`Cursor`] trait to allow for traversing arbitrary n-ary trees.
3//!
4//! # Examples
5//!
6//! Basic usage:
7//!
8//! ```
9//! # #[cfg(feature = "tree-sitter")]
10//! # {
11//! use tree_sitter::{Node, Tree};
12//! use std::collections::HashSet;
13//! use std::iter::FromIterator;
14//!
15//! use tree_sitter_traversal::{traverse, traverse_tree, Order};
16//! # fn get_tree() -> Tree {
17//! #     use tree_sitter::Parser;
18//! #     let mut parser = Parser::new();
19//! #     let lang = tree_sitter_rust::language();
20//! #     parser.set_language(&lang).expect("Error loading Rust grammar");
21//! #     return parser.parse("fn double(x: usize) -> usize { x * 2 }", None).expect("Error parsing provided code");
22//! # }
23//!
24//! // Non-existent method, imagine it gets a valid Tree with >1 node
25//! let tree: Tree = get_tree();
26//! let preorder: Vec<Node<'_>> = traverse(tree.walk(), Order::Pre).collect::<Vec<_>>();
27//! let postorder: Vec<Node<'_>> = traverse_tree(&tree, Order::Post).collect::<Vec<_>>();
28//! // For any tree with more than just a root node,
29//! // the order of preorder and postorder will be different
30//! assert_ne!(preorder, postorder);
31//! // However, they will have the same amount of nodes
32//! assert_eq!(preorder.len(), postorder.len());
33//! // Specifically, they will have the exact same nodes, just in a different order
34//! assert_eq!(
35//!     <HashSet<_>>::from_iter(preorder.into_iter()),
36//!     <HashSet<_>>::from_iter(postorder.into_iter())
37//! );
38//! # }
39//! ```
40//!
41//! [`Tree`]: tree_sitter::Tree
42//! [`TreeCursor`]: tree_sitter::TreeCursor
43//! [`Cursor`]: crate::Cursor
44#![no_std]
45#![warn(clippy::pedantic)]
46#![warn(clippy::nursery)]
47#![warn(clippy::cargo)]
48
49use core::iter::FusedIterator;
50
51/// Trait which represents a stateful cursor in a n-ary tree.
52/// The cursor can be moved between nodes in the tree by the given methods,
53/// and the node which the cursor is currently pointing at can be read as well.
54pub trait Cursor {
55    /// The type of the nodes which the cursor points at; the cursor is always pointing
56    /// at exactly one of this type.
57    type Node;
58
59    /// Move this cursor to the first child of its current node.
60    ///
61    /// This returns `true` if the cursor successfully moved, and returns `false`
62    /// if there were no children.
63    fn goto_first_child(&mut self) -> bool;
64
65    /// Move this cursor to the parent of its current node.
66    ///
67    /// This returns `true` if the cursor successfully moved, and returns `false`
68    /// if there was no parent node (the cursor was already on the root node).
69    fn goto_parent(&mut self) -> bool;
70
71    /// Move this cursor to the next sibling of its current node.
72    ///
73    /// This returns `true` if the cursor successfully moved, and returns `false`
74    /// if there was no next sibling node.
75    fn goto_next_sibling(&mut self) -> bool;
76
77    /// Get the node which the cursor is currently pointing at.
78    fn node(&self) -> Self::Node;
79}
80
81impl<'a, T> Cursor for &'a mut T
82where
83    T: Cursor,
84{
85    type Node = T::Node;
86
87    fn goto_first_child(&mut self) -> bool {
88        T::goto_first_child(self)
89    }
90
91    fn goto_parent(&mut self) -> bool {
92        T::goto_parent(self)
93    }
94
95    fn goto_next_sibling(&mut self) -> bool {
96        T::goto_next_sibling(self)
97    }
98
99    fn node(&self) -> Self::Node {
100        T::node(self)
101    }
102}
103
104/// Quintessential implementation of [`Cursor`] for tree-sitter's [`TreeCursor`]
105///
106/// [`TreeCursor`]: tree_sitter::TreeCursor
107/// [`Cursor`]: crate::Cursor
108#[cfg(feature = "tree-sitter")]
109impl<'a> Cursor for tree_sitter::TreeCursor<'a> {
110    type Node = tree_sitter::Node<'a>;
111
112    fn goto_first_child(&mut self) -> bool {
113        self.goto_first_child()
114    }
115
116    fn goto_parent(&mut self) -> bool {
117        self.goto_parent()
118    }
119
120    fn goto_next_sibling(&mut self) -> bool {
121        self.goto_next_sibling()
122    }
123
124    fn node(&self) -> Self::Node {
125        self.node()
126    }
127}
128
129/// Order to iterate through a n-ary tree; for n-ary trees only
130/// Pre-order and Post-order make sense.
131#[derive(Eq, PartialEq, Hash, Debug, Copy, Clone)]
132pub enum Order {
133    Pre,
134    Post,
135}
136
137/// Iterative traversal of the tree; serves as a reference for both
138/// `PreorderTraverse` and `PostorderTraverse`, as they both will call the exact same
139/// cursor methods in the exact same order as this function for a given tree; the order
140/// is also the same as `traverse_recursive`.
141#[allow(dead_code)]
142fn traverse_iterative<C: Cursor, F>(mut c: C, order: Order, mut cb: F)
143where
144    F: FnMut(C::Node),
145{
146    loop {
147        // This is the first time we've encountered the node, so we'll call if preorder
148        if order == Order::Pre {
149            cb(c.node());
150        }
151
152        // Keep travelling down the tree as far as we can
153        if c.goto_first_child() {
154            continue;
155        }
156
157        let node = c.node();
158
159        // If we can't travel any further down, try going to next sibling and repeating
160        if c.goto_next_sibling() {
161            // If we succeed in going to the previous nodes sibling,
162            // we won't be encountering that node again, so we'll call if postorder
163            if order == Order::Post {
164                cb(node);
165            }
166            continue;
167        }
168
169        // Otherwise, we must travel back up; we'll loop until we reach the root or can
170        // go to the next sibling of a node again.
171        loop {
172            // Since we're retracing back up the tree, this is the last time we'll encounter
173            // this node, so we'll call if postorder
174            if order == Order::Post {
175                cb(c.node());
176            }
177            if !c.goto_parent() {
178                // We have arrived back at the root, so we are done.
179                return;
180            }
181
182            let node = c.node();
183
184            if c.goto_next_sibling() {
185                // If we succeed in going to the previous node's sibling,
186                // we will go back to travelling down that sibling's tree, and we also
187                // won't be encountering the previous node again, so we'll call if postorder
188                if order == Order::Post {
189                    cb(node);
190                }
191                break;
192            }
193        }
194    }
195}
196
197/// Idiomatic recursive traversal of the tree; this version is easier to understand
198/// conceptually, but the recursion is actually unnecessary and can cause stack overflow.
199#[allow(dead_code)]
200fn traverse_recursive<C: Cursor, F>(mut c: C, order: Order, mut cb: F)
201where
202    F: FnMut(C::Node),
203{
204    traverse_helper(&mut c, order, &mut cb);
205}
206
207fn traverse_helper<C: Cursor, F>(c: &mut C, order: Order, cb: &mut F)
208where
209    F: FnMut(C::Node),
210{
211    // If preorder, call the callback when we first touch the node
212    if order == Order::Pre {
213        cb(c.node());
214    }
215    if c.goto_first_child() {
216        // If there is a child, recursively call on
217        // that child and all its siblings
218        loop {
219            traverse_helper(c, order, cb);
220            if !c.goto_next_sibling() {
221                break;
222            }
223        }
224        // Make sure to reset back to the original node;
225        // this must always return true, as we only get here if we go to a child
226        // of the original node.
227        assert!(c.goto_parent());
228    }
229    // If preorder, call the callback after the recursive calls on child nodes
230    if order == Order::Post {
231        cb(c.node());
232    }
233}
234
235struct PreorderTraverse<C> {
236    cursor: Option<C>,
237}
238
239impl<C> PreorderTraverse<C> {
240    pub const fn new(c: C) -> Self {
241        Self { cursor: Some(c) }
242    }
243}
244
245impl<C> Iterator for PreorderTraverse<C>
246where
247    C: Cursor,
248{
249    type Item = C::Node;
250
251    fn next(&mut self) -> Option<Self::Item> {
252        let c = match self.cursor.as_mut() {
253            None => {
254                return None;
255            }
256            Some(c) => c,
257        };
258
259        // We will always return the node we were on at the start;
260        // the node we traverse to will either be returned on the next iteration,
261        // or will be back to the root node, at which point we'll clear out
262        // the reference to the cursor
263        let node = c.node();
264
265        // First, try to go to a child or a sibling; if either succeed, this will be the
266        // first time we touch that node, so it'll be the next starting node
267        if c.goto_first_child() || c.goto_next_sibling() {
268            return Some(node);
269        }
270
271        loop {
272            // If we can't go to the parent, then that means we've reached the root, and our
273            // iterator will be done in the next iteration
274            if !c.goto_parent() {
275                self.cursor = None;
276                break;
277            }
278
279            // If we get to a sibling, then this will be the first time we touch that node,
280            // so it'll be the next starting node
281            if c.goto_next_sibling() {
282                break;
283            }
284        }
285
286        Some(node)
287    }
288}
289
290struct PostorderTraverse<C> {
291    cursor: Option<C>,
292    retracing: bool,
293}
294
295impl<C> PostorderTraverse<C> {
296    pub const fn new(c: C) -> Self {
297        Self {
298            cursor: Some(c),
299            retracing: false,
300        }
301    }
302}
303
304impl<C> Iterator for PostorderTraverse<C>
305where
306    C: Cursor,
307{
308    type Item = C::Node;
309
310    fn next(&mut self) -> Option<Self::Item> {
311        let c = match self.cursor.as_mut() {
312            None => {
313                return None;
314            }
315            Some(c) => c,
316        };
317
318        // For the postorder traversal, we will only return a node when we are travelling back up
319        // the tree structure. Therefore, we go all the way to the leaves of the tree immediately,
320        // and only when we are retracing do we return elements
321        if !self.retracing {
322            while c.goto_first_child() {}
323        }
324
325        // Much like in preorder traversal, we want to return the node we were previously at.
326        // We know this will be the last time we touch this node, as we will either be going
327        // to its next sibling or retracing back up the tree
328        let node = c.node();
329        if c.goto_next_sibling() {
330            // If we successfully go to a sibling of this node, we want to go back down
331            // the tree on the next iteration
332            self.retracing = false;
333        } else {
334            // If we weren't already retracing, we are now; travel upwards until we can
335            // go to the next sibling or reach the root again
336            self.retracing = true;
337            if !c.goto_parent() {
338                // We've reached the root again, and our iteration is done
339                self.cursor = None;
340            }
341        }
342
343        Some(node)
344    }
345}
346
347// Used for visibility purposes, in case this struct becomes public
348struct Traverse<C> {
349    inner: TraverseInner<C>,
350}
351
352enum TraverseInner<C> {
353    Post(PostorderTraverse<C>),
354    Pre(PreorderTraverse<C>),
355}
356
357impl<C> Traverse<C> {
358    pub const fn new(c: C, order: Order) -> Self {
359        let inner = match order {
360            Order::Pre => TraverseInner::Pre(PreorderTraverse::new(c)),
361            Order::Post => TraverseInner::Post(PostorderTraverse::new(c)),
362        };
363        Self { inner }
364    }
365}
366
367#[cfg(feature = "tree-sitter")]
368impl<'a> Traverse<tree_sitter::TreeCursor<'a>> {
369    #[allow(dead_code)]
370    pub fn from_tree(tree: &'a tree_sitter::Tree, order: Order) -> Self {
371        Self::new(tree.walk(), order)
372    }
373}
374
375/// Convenience method to traverse a tree-sitter [`Tree`] in an order according to `order`.
376///
377/// [`Tree`]: tree_sitter::Tree
378#[must_use]
379#[cfg(feature = "tree-sitter")]
380pub fn traverse_tree(
381    tree: &tree_sitter::Tree,
382    order: Order,
383) -> impl FusedIterator<Item = tree_sitter::Node> {
384    return traverse(tree.walk(), order);
385}
386
387/// Traverse an n-ary tree using `cursor`, returning the nodes of the tree through an iterator
388/// in an order according to `order`.
389///
390/// # Panics
391///
392/// `cursor` must be at the root of the tree
393/// (i.e. `cursor.goto_parent()` must return false)
394pub fn traverse<C: Cursor>(mut cursor: C, order: Order) -> impl FusedIterator<Item = C::Node> {
395    assert!(!cursor.goto_parent());
396    Traverse::new(cursor, order)
397}
398
399impl<C> Iterator for Traverse<C>
400where
401    C: Cursor,
402{
403    type Item = C::Node;
404
405    fn next(&mut self) -> Option<Self::Item> {
406        match self.inner {
407            TraverseInner::Post(ref mut i) => i.next(),
408            TraverseInner::Pre(ref mut i) => i.next(),
409        }
410    }
411}
412
413// We know that PreorderTraverse and PostorderTraverse are fused due to their implementation,
414// so we can add this bound for free.
415impl<C> FusedIterator for Traverse<C> where C: Cursor {}
416
417#[cfg(test)]
418#[cfg(feature = "tree-sitter")]
419mod tree_sitter_tests {
420    use super::*;
421
422    extern crate std;
423    use std::vec::Vec;
424    use tree_sitter::{Parser, Tree};
425
426    const EX1: &str = r#"
427fn double(x: usize) -> usize {
428    return 2 * x;
429}"#;
430
431    const EX2: &str = r#"
432// Intentionally invalid code below
433
434"123
435
436const DOUBLE = 2;
437
438function double(x: usize) -> usize {
439    return DOUBLE * x;
440}"#;
441
442    const EX3: &str = "";
443
444    const CPPEX: &str = r#"
445    std::vector<unsigned char> readFile(const std::string& filePath) {
446        std::ifstream file(filePath, std::ios::binary);
447        if (!file) {
448            std::cerr << "Failed to open file: " << filePath << std::endl;
449            return {};
450        }
451    
452        file.seekg(0, std::ios::end);
453        std::streampos fileSize = file.tellg();
454        file.seekg(0, std::ios::beg);
455    
456        std::vector<unsigned char> buffer(fileSize);
457        file.read(reinterpret_cast<char*>(buffer.data()), fileSize);
458    
459        return buffer;
460    }
461    "#;
462
463    /// For a given tree and iteration order, verify that the two callback approaches
464    /// and the Iterator approach are all equivalent
465    fn generate_traversals(tree: &Tree, order: Order) {
466        let mut recursive_callback = Vec::new();
467        traverse_recursive(tree.walk(), order, |n| recursive_callback.push(n));
468        let mut iterative_callback = Vec::new();
469        traverse_iterative(tree.walk(), order, |n| iterative_callback.push(n));
470        let iterator = traverse(tree.walk(), order).collect::<Vec<_>>();
471
472        assert_eq!(recursive_callback, iterative_callback);
473        assert_eq!(iterative_callback, iterator);
474    }
475
476    /// Helper function to generate a Tree from Rust code
477    fn get_tree(code: &str) -> Tree {
478        let mut parser = Parser::new();
479        let lang = tree_sitter_rust::language();
480        parser
481            .set_language(&lang)
482            .expect("Error loading Rust grammar");
483        return parser
484            .parse(code, None)
485            .expect("Error parsing provided code");
486    }
487
488    fn get_cpp_tree(code: &str) -> Tree {
489        let mut parser: Parser = Parser::new();
490        let lang = tree_sitter_cpp::language();
491        parser
492            .set_language(&lang)
493            .expect("Error loading Cpp grammar");
494        return parser
495            .parse(code, None)
496            .expect("Error parsing provided code");
497    }
498
499    #[test]
500    fn test_equivalence() {
501        for code in [EX1, EX2, EX3] {
502            let tree = get_tree(code);
503            for order in [Order::Pre, Order::Post] {
504                generate_traversals(&tree, order);
505            }
506        }
507    }
508
509    #[test]
510    fn test_cpp() {
511        let tree = get_cpp_tree(CPPEX);
512        for order in [Order::Pre, Order::Post] {
513            generate_traversals(&tree, order);
514        }
515    }
516
517    #[test]
518    fn test_postconditions() {
519        let parsed = get_tree(EX1);
520        let mut walk = parsed.walk();
521        for order in [Order::Pre, Order::Post] {
522            let mut iter = traverse(&mut walk, order);
523            while iter.next().is_some() {}
524            // Make sure it's fused
525            assert!(iter.next().is_none());
526            // Really make sure it's fused
527            assert!(iter.next().is_none());
528            drop(iter);
529            // Verify that the walk is reset to the root_node and can be reused
530            assert_eq!(walk.node(), parsed.root_node());
531        }
532    }
533
534    #[test]
535    #[should_panic]
536    fn test_panic() {
537        // Tests that the precondition check works
538        let parsed = get_tree(EX1);
539        let mut walk = parsed.walk();
540        walk.goto_first_child();
541        let iter = traverse(&mut walk, Order::Pre);
542        iter.count();
543    }
544
545    #[test]
546    fn example() {
547        use std::collections::HashSet;
548        use std::iter::FromIterator;
549        use tree_sitter::{Node, Tree};
550        let tree: Tree = get_tree(EX1);
551        let preorder: Vec<Node<'_>> = traverse(tree.walk(), Order::Pre).collect::<Vec<_>>();
552        let postorder: Vec<Node<'_>> = traverse_tree(&tree, Order::Post).collect::<Vec<_>>();
553        assert_ne!(preorder, postorder);
554        assert_eq!(preorder.len(), postorder.len());
555        assert_eq!(
556            <HashSet<_>>::from_iter(preorder.into_iter()),
557            <HashSet<_>>::from_iter(postorder.into_iter())
558        );
559    }
560}
561
562#[cfg(test)]
563mod tests {
564    use super::*;
565
566    struct Root;
567
568    // Root represents a tree where there's only one node, the root, and its type is the unit type
569    impl Cursor for Root {
570        type Node = ();
571
572        fn goto_first_child(&mut self) -> bool {
573            return false;
574        }
575
576        fn goto_parent(&mut self) -> bool {
577            return false;
578        }
579
580        fn goto_next_sibling(&mut self) -> bool {
581            return false;
582        }
583
584        fn node(&self) -> Self::Node {
585            ()
586        }
587    }
588
589    #[test]
590    fn test_root() {
591        assert_eq!(1, traverse(Root, Order::Pre).count());
592        assert_eq!(1, traverse(Root, Order::Post).count());
593    }
594}