Skip to main content

thread_ast_engine/
node.rs

1// SPDX-FileCopyrightText: 2022 Herrington Darkholme <2883231+HerringtonDarkholme@users.noreply.github.com>
2// SPDX-FileCopyrightText: 2025 Knitli Inc. <knitli@knit.li>
3// SPDX-FileContributor: Adam Poulemanos <adam@knit.li>
4//
5// SPDX-License-Identifier: AGPL-3.0-or-later AND MIT
6
7//! # AST Node Representation and Navigation
8//!
9//! Core types for representing and navigating Abstract Syntax Tree nodes.
10//!
11//! ## Key Types
12//!
13//! - [`Node`] - A single AST node with navigation and matching capabilities
14//! - [`Root`] - The root of an AST tree, owns the source code and tree structure
15//! - [`Position`] - Represents a position in source code (line/column)
16//!
17//! ## Usage
18//!
19//! ```rust,no_run
20//! # use thread_ast_engine::Language;
21//! # use thread_ast_engine::tree_sitter::LanguageExt;
22//! # use thread_ast_engine::MatcherExt;
23//! let ast = Language::Tsx.ast_grep("function foo() { return 42; }");
24//! let root_node = ast.root();
25//!
26//! // Navigate the tree
27//! for child in root_node.children() {
28//!     println!("Child kind: {}", child.kind());
29//! }
30//!
31//! // Find specific patterns
32//! if let Some(func) = root_node.find("function $NAME() { $$$BODY }") {
33//!     println!("Found function: {}", func.get_env().get_match("NAME").unwrap().text());
34//! }
35//! ```
36
37use crate::Doc;
38use crate::Language;
39#[cfg(feature = "matching")]
40use crate::matcher::{Matcher, MatcherExt, NodeMatch};
41#[cfg(feature = "matching")]
42use crate::replacer::Replacer;
43use crate::source::{Content, Edit as E, SgNode};
44
45type Edit<D> = E<<D as Doc>::Source>;
46
47use std::borrow::Cow;
48
49/// Represents a position in source code.
50///
51/// Positions use zero-based line and column numbers, where line 0 is the first line
52/// and column 0 is the first character. Unlike tree-sitter's internal positions,
53/// these are character-based rather than byte-based for easier human consumption.
54///
55/// # Note
56///
57/// Computing the character column from byte positions is an O(n) operation,
58/// so avoid calling [`Position::column`] in performance-critical loops.
59///
60/// # Example
61///
62/// ```rust,no_run
63/// # use thread_ast_engine::Language;
64/// # use thread_ast_engine::tree_sitter::LanguageExt;
65/// let ast = Language::Tsx.ast_grep("let x = 42;\nlet y = 24;");
66/// let root = ast.root();
67///
68/// let start_pos = root.start_pos();
69/// assert_eq!(start_pos.line(), 0);
70/// ```
71#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
72pub struct Position {
73    /// Zero-based line number (line 0 = first line)
74    line: usize,
75    /// Zero-based byte offset within the line
76    byte_column: usize,
77    /// Absolute byte offset from start of file
78    byte_offset: usize,
79}
80
81impl Position {
82    #[must_use]
83    pub const fn new(line: usize, byte_column: usize, byte_offset: usize) -> Self {
84        Self {
85            line,
86            byte_column,
87            byte_offset,
88        }
89    }
90    #[must_use]
91    pub const fn line(&self) -> usize {
92        self.line
93    }
94    /// Returns the column in terms of characters.
95    /// Note: node does not have to be a node of matching position.
96    pub fn column<D: Doc>(&self, node: &Node<'_, D>) -> usize {
97        let source = node.get_doc().get_source();
98        source.get_char_column(self.byte_column, self.byte_offset)
99    }
100    #[must_use]
101    pub const fn byte_point(&self) -> (usize, usize) {
102        (self.line, self.byte_column)
103    }
104}
105
106/// Root of an AST tree that owns the source code and parsed tree structure.
107///
108/// Root acts as the entry point for all AST operations. It manages the document
109/// (source code + parsed tree) and provides methods to get the root node and
110/// perform tree-wide operations like replacements.
111///
112/// # Generic Parameters
113///
114/// - `D: Doc` - The document type that holds source code and language information
115///
116/// # Example
117///
118/// ```rust,no_run
119/// # use thread_ast_engine::Language;
120/// # use thread_ast_engine::tree_sitter::LanguageExt;
121/// # use thread_ast_engine::MatcherExt;
122/// let mut ast = Language::Tsx.ast_grep("let x = 42;");
123/// let root_node = ast.root();
124///
125/// // Perform tree-wide replacements
126/// ast.replace("let $VAR = $VALUE", "const $VAR = $VALUE");
127/// println!("{}", ast.generate());
128/// ```
129#[derive(Clone, Debug)]
130pub struct Root<D: Doc> {
131    pub(crate) doc: D,
132}
133
134impl<D: Doc> Root<D> {
135    pub const fn doc(doc: D) -> Self {
136        Self { doc }
137    }
138
139    pub fn lang(&self) -> &D::Lang {
140        self.doc.get_lang()
141    }
142    /// The root node represents the entire source
143    pub fn root(&self) -> Node<'_, D> {
144        Node {
145            inner: self.doc.root_node(),
146            root: self,
147        }
148    }
149
150    // extract non generic implementation to reduce code size
151    pub fn edit(&mut self, edit: &Edit<D>) -> Result<&mut Self, String> {
152        self.doc.do_edit(edit)?;
153        Ok(self)
154    }
155
156    #[cfg(feature = "matching")]
157    pub fn replace<M: Matcher, R: Replacer<D>>(
158        &mut self,
159        pattern: M,
160        replacer: R,
161    ) -> Result<bool, String> {
162        let root = self.root();
163        if let Some(edit) = root.replace(pattern, replacer) {
164            drop(root); // rust cannot auto drop root if D is not specified
165            self.edit(&edit)?;
166            Ok(true)
167        } else {
168            Ok(false)
169        }
170    }
171
172    /// Adopt the `tree_sitter` as the descendant of the root and return the wrapped sg Node.
173    /// It assumes `inner` is under the root and will panic at dev build if wrong node is used.
174    pub fn adopt<'r>(&'r self, inner: D::Node<'r>) -> Node<'r, D> {
175        debug_assert!(self.check_lineage(&inner));
176        Node { inner, root: self }
177    }
178
179    fn check_lineage(&self, inner: &D::Node<'_>) -> bool {
180        let mut node = inner.clone();
181        while let Some(n) = node.parent() {
182            node = n;
183        }
184        node.node_id() == self.doc.root_node().node_id()
185    }
186
187    /// P.S. I am your father.
188    #[doc(hidden)]
189    pub unsafe fn readopt<'a: 'b, 'b>(&'a self, node: &mut Node<'b, D>) {
190        debug_assert!(self.check_lineage(&node.inner));
191        node.root = self;
192    }
193}
194
195/// A single node in an Abstract Syntax Tree.
196///
197/// Node represents a specific element in the parsed AST, such as a function declaration,
198/// variable assignment, or expression. Each node knows its position in the source code,
199/// its type (kind), and provides methods for navigation and pattern matching.
200///
201/// # Lifetime
202///
203/// The lifetime `'r` ties the node to its root AST, ensuring memory safety.
204/// Nodes cannot outlive the Root that owns the underlying tree structure.
205///
206/// # Example
207///
208/// ```rust,no_run
209/// # use thread_ast_engine::Language;
210/// # use thread_ast_engine::tree_sitter::LanguageExt;
211/// # use thread_ast_engine::matcher::MatcherExt;
212/// let ast = Language::Tsx.ast_grep("function hello() { return 'world'; }");
213/// let root_node = ast.root();
214///
215/// // Check the node type
216/// println!("Root kind: {}", root_node.kind());
217///
218/// // Navigate to children
219/// for child in root_node.children() {
220///     println!("Child: {} at {}:{}", child.kind(),
221///         child.start_pos().line(), child.start_pos().column(&child));
222/// }
223///
224/// // Find specific patterns
225/// if let Some(return_stmt) = root_node.find("return $VALUE") {
226///     let value = return_stmt.get_env().get_match("VALUE").unwrap();
227///     println!("Returns: {}", value.text());
228/// }
229/// ```
230#[derive(Clone, Debug)]
231pub struct Node<'r, D: Doc> {
232    pub(crate) inner: D::Node<'r>,
233    pub(crate) root: &'r Root<D>,
234}
235
236/// Identifier for different AST node types (e.g., "`function_declaration`", "identifier")
237pub type KindId = u16;
238
239/// APIs for Node inspection
240impl<'r, D: Doc> Node<'r, D> {
241    pub const fn get_doc(&self) -> &'r D {
242        &self.root.doc
243    }
244    pub fn node_id(&self) -> usize {
245        self.inner.node_id()
246    }
247    pub fn is_leaf(&self) -> bool {
248        self.inner.is_leaf()
249    }
250    /// if has no named children.
251    /// N.B. it is different from `is_named` && `is_leaf`
252    // see https://github.com/ast-grep/ast-grep/issues/276
253    pub fn is_named_leaf(&self) -> bool {
254        self.inner.is_named_leaf()
255    }
256    pub fn is_error(&self) -> bool {
257        self.inner.is_error()
258    }
259    pub fn kind(&self) -> Cow<'_, str> {
260        self.inner.kind()
261    }
262    pub fn kind_id(&self) -> KindId {
263        self.inner.kind_id()
264    }
265
266    pub fn is_named(&self) -> bool {
267        self.inner.is_named()
268    }
269    pub fn is_missing(&self) -> bool {
270        self.inner.is_missing()
271    }
272
273    /// byte offsets of start and end.
274    pub fn range(&self) -> std::ops::Range<usize> {
275        self.inner.range()
276    }
277
278    /// Nodes' start position in terms of zero-based rows and columns.
279    pub fn start_pos(&self) -> Position {
280        self.inner.start_pos()
281    }
282
283    /// Nodes' end position in terms of rows and columns.
284    pub fn end_pos(&self) -> Position {
285        self.inner.end_pos()
286    }
287
288    pub fn text(&self) -> Cow<'r, str> {
289        self.root.doc.get_node_text(&self.inner)
290    }
291
292    pub fn lang(&self) -> &'r D::Lang {
293        self.root.lang()
294    }
295
296    /// the underlying tree-sitter Node
297    pub fn get_inner_node(&self) -> D::Node<'r> {
298        self.inner.clone()
299    }
300
301    pub const fn root(&self) -> &'r Root<D> {
302        self.root
303    }
304}
305
306/**
307 * Corresponds to inside/has/precedes/follows
308 */
309#[cfg(feature = "matching")]
310impl<D: Doc> Node<'_, D> {
311    pub fn matches<M: Matcher>(&self, m: M) -> bool {
312        m.match_node(self.clone()).is_some()
313    }
314
315    pub fn inside<M: Matcher>(&self, m: M) -> bool {
316        self.ancestors().find_map(|n| m.match_node(n)).is_some()
317    }
318
319    pub fn has<M: Matcher>(&self, m: M) -> bool {
320        self.dfs().skip(1).find_map(|n| m.match_node(n)).is_some()
321    }
322
323    pub fn precedes<M: Matcher>(&self, m: M) -> bool {
324        self.next_all().find_map(|n| m.match_node(n)).is_some()
325    }
326
327    pub fn follows<M: Matcher>(&self, m: M) -> bool {
328        self.prev_all().find_map(|n| m.match_node(n)).is_some()
329    }
330}
331
332/// tree traversal API
333impl<'r, D: Doc> Node<'r, D> {
334    #[must_use]
335    pub fn parent(&self) -> Option<Self> {
336        let inner = self.inner.parent()?;
337        Some(Node {
338            inner,
339            root: self.root,
340        })
341    }
342
343    pub fn children(&self) -> impl ExactSizeIterator<Item = Node<'r, D>> + '_ {
344        self.inner.children().map(|inner| Node {
345            inner,
346            root: self.root,
347        })
348    }
349
350    #[must_use]
351    pub fn child(&self, nth: usize) -> Option<Self> {
352        let inner = self.inner.child(nth)?;
353        Some(Node {
354            inner,
355            root: self.root,
356        })
357    }
358
359    pub fn field(&self, name: &str) -> Option<Self> {
360        let inner = self.inner.field(name)?;
361        Some(Node {
362            inner,
363            root: self.root,
364        })
365    }
366
367    pub fn child_by_field_id(&self, field_id: u16) -> Option<Self> {
368        let inner = self.inner.child_by_field_id(field_id)?;
369        Some(Node {
370            inner,
371            root: self.root,
372        })
373    }
374
375    pub fn field_children(&self, name: &str) -> impl Iterator<Item = Node<'r, D>> + '_ {
376        let field_id = self.lang().field_to_id(name);
377        self.inner.field_children(field_id).map(|inner| Node {
378            inner,
379            root: self.root,
380        })
381    }
382
383    /// Returns all ancestors nodes of `self`.
384    /// Using cursor is overkill here because adjust cursor is too expensive.
385    pub fn ancestors(&self) -> impl Iterator<Item = Node<'r, D>> + '_ {
386        let root = self.root.doc.root_node();
387        self.inner.ancestors(root).map(|inner| Node {
388            inner,
389            root: self.root,
390        })
391    }
392    #[must_use]
393    pub fn next(&self) -> Option<Self> {
394        let inner = self.inner.next()?;
395        Some(Node {
396            inner,
397            root: self.root,
398        })
399    }
400
401    /// Returns all sibling nodes next to `self`.
402    // NOTE: Need go to parent first, then move to current node by byte offset.
403    // This is because tree_sitter cursor is scoped to the starting node.
404    // See https://github.com/tree-sitter/tree-sitter/issues/567
405    pub fn next_all(&self) -> impl Iterator<Item = Node<'r, D>> + '_ {
406        self.inner.next_all().map(|inner| Node {
407            inner,
408            root: self.root,
409        })
410    }
411
412    #[must_use]
413    pub fn prev(&self) -> Option<Self> {
414        let inner = self.inner.prev()?;
415        Some(Node {
416            inner,
417            root: self.root,
418        })
419    }
420
421    pub fn prev_all(&self) -> impl Iterator<Item = Node<'r, D>> + '_ {
422        self.inner.prev_all().map(|inner| Node {
423            inner,
424            root: self.root,
425        })
426    }
427
428    pub fn dfs<'s>(&'s self) -> impl Iterator<Item = Node<'r, D>> + 's {
429        self.inner.dfs().map(|inner| Node {
430            inner,
431            root: self.root,
432        })
433    }
434
435    #[cfg(feature = "matching")]
436    pub fn find<M: Matcher>(&self, pat: M) -> Option<NodeMatch<'r, D>> {
437        pat.find_node(self.clone())
438    }
439    #[cfg(feature = "matching")]
440    pub fn find_all<'s, M: Matcher + 's>(
441        &'s self,
442        pat: M,
443    ) -> impl Iterator<Item = NodeMatch<'r, D>> + 's {
444        let kinds = pat.potential_kinds();
445        self.dfs().filter_map(move |cand| {
446            if let Some(k) = &kinds
447                && !k.contains(cand.kind_id().into())
448            {
449                return None;
450            }
451            pat.match_node(cand)
452        })
453    }
454}
455
456/// Tree manipulation API
457impl<D: Doc> Node<'_, D> {
458    #[cfg(feature = "matching")]
459    pub fn replace<M: Matcher, R: Replacer<D>>(&self, matcher: M, replacer: R) -> Option<Edit<D>> {
460        let matched = matcher.find_node(self.clone())?;
461        let edit = matched.make_edit(&matcher, &replacer);
462        Some(edit)
463    }
464
465    pub fn after(&self) -> Edit<D> {
466        todo!()
467    }
468    pub fn before(&self) -> Edit<D> {
469        todo!()
470    }
471    pub fn append(&self) -> Edit<D> {
472        todo!()
473    }
474    pub fn prepend(&self) -> Edit<D> {
475        todo!()
476    }
477
478    /// Empty children. Remove all child node
479    pub fn empty(&self) -> Option<Edit<D>> {
480        let mut children = self.children().peekable();
481        let start = children.peek()?.range().start;
482        let end = children.last()?.range().end;
483        Some(Edit::<D> {
484            position: start,
485            deleted_length: end - start,
486            inserted_text: Vec::new(),
487        })
488    }
489
490    /// Remove the node itself
491    pub fn remove(&self) -> Edit<D> {
492        let range = self.range();
493        Edit::<D> {
494            position: range.start,
495            deleted_length: range.end - range.start,
496            inserted_text: Vec::new(),
497        }
498    }
499}
500
501#[cfg(test)]
502mod test {
503    use crate::language::{Language, Tsx};
504    use crate::tree_sitter::LanguageExt;
505    #[test]
506    fn test_is_leaf() {
507        let root = Tsx.ast_grep("let a = 123");
508        let node = root.root();
509        assert!(!node.is_leaf());
510    }
511
512    #[test]
513    fn test_children() {
514        let root = Tsx.ast_grep("let a = 123");
515        let node = root.root();
516        let children: Vec<_> = node.children().collect();
517        assert_eq!(children.len(), 1);
518        let texts: Vec<_> = children[0]
519            .children()
520            .map(|c| c.text().to_string())
521            .collect();
522        assert_eq!(texts, vec!["let", "a = 123"]);
523    }
524    #[test]
525    fn test_empty() {
526        let root = Tsx.ast_grep("let a = 123");
527        let node = root.root();
528        let edit = node.empty().unwrap();
529        assert_eq!(edit.inserted_text.len(), 0);
530        assert_eq!(edit.deleted_length, 11);
531        assert_eq!(edit.position, 0);
532    }
533
534    #[test]
535    fn test_field_children() {
536        let root = Tsx.ast_grep("let a = 123");
537        let node = root.root().find("let a = $A").unwrap();
538        let children: Vec<_> = node.field_children("kind").collect();
539        assert_eq!(children.len(), 1);
540        assert_eq!(children[0].text(), "let");
541    }
542
543    const MULTI_LINE: &str = "
544if (a) {
545  test(1)
546} else {
547  x
548}
549";
550
551    #[test]
552    fn test_display_context() {
553        // src, matcher, lead, trail
554        let cases = [
555            ["i()", "i()", "", ""],
556            ["i()", "i", "", "()"],
557            [MULTI_LINE, "test", "  ", "(1)"],
558        ];
559        // display context should not panic
560        for [src, matcher, lead, trail] in cases {
561            let root = Tsx.ast_grep(src);
562            let node = root.root().find(matcher).expect("should match");
563            let display = node.display_context(0, 0);
564            assert_eq!(display.leading, lead);
565            assert_eq!(display.trailing, trail);
566        }
567    }
568
569    #[test]
570    fn test_multi_line_context() {
571        let cases = [
572            ["i()", "i()", "", ""],
573            [MULTI_LINE, "test", "if (a) {\n  ", "(1)\n} else {"],
574        ];
575        // display context should not panic
576        for [src, matcher, lead, trail] in cases {
577            let root = Tsx.ast_grep(src);
578            let node = root.root().find(matcher).expect("should match");
579            let display = node.display_context(1, 1);
580            assert_eq!(display.leading, lead);
581            assert_eq!(display.trailing, trail);
582        }
583    }
584
585    #[test]
586    fn test_replace_all_nested() {
587        let root = Tsx.ast_grep("Some(Some(1))");
588        let node = root.root();
589        let edits = node.replace_all("Some($A)", "$A");
590        assert_eq!(edits.len(), 1);
591        assert_eq!(edits[0].inserted_text, "Some(1)".as_bytes());
592    }
593
594    #[test]
595    fn test_replace_all_multiple_sorted() {
596        let root = Tsx.ast_grep("Some(Some(1)); Some(2)");
597        let node = root.root();
598        let edits = node.replace_all("Some($A)", "$A");
599        // edits must be sorted by position
600        assert_eq!(edits.len(), 2);
601        assert_eq!(edits[0].inserted_text, "Some(1)".as_bytes());
602        assert_eq!(edits[1].inserted_text, "2".as_bytes());
603    }
604
605    #[test]
606    fn test_inside() {
607        let root = Tsx.ast_grep("Some(Some(1)); Some(2)");
608        let root = root.root();
609        let node = root.find("Some(1)").expect("should exist");
610        assert!(node.inside("Some($A)"));
611    }
612    #[test]
613    fn test_has() {
614        let root = Tsx.ast_grep("Some(Some(1)); Some(2)");
615        let root = root.root();
616        let node = root.find("Some($A)").expect("should exist");
617        assert!(node.has("Some(1)"));
618    }
619    #[test]
620    fn precedes() {
621        let root = Tsx.ast_grep("Some(Some(1)); Some(2);");
622        let root = root.root();
623        let node = root.find("Some($A);").expect("should exist");
624        assert!(node.precedes("Some(2);"));
625    }
626    #[test]
627    fn follows() {
628        let root = Tsx.ast_grep("Some(Some(1)); Some(2);");
629        let root = root.root();
630        let node = root.find("Some(2);").expect("should exist");
631        assert!(node.follows("Some(Some(1));"));
632    }
633
634    #[test]
635    fn test_field() {
636        let root = Tsx.ast_grep("class A{}");
637        let root = root.root();
638        let node = root.find("class $C {}").expect("should exist");
639        assert!(node.field("name").is_some());
640        assert!(node.field("none").is_none());
641    }
642    #[test]
643    fn test_child_by_field_id() {
644        let root = Tsx.ast_grep("class A{}");
645        let root = root.root();
646        let node = root.find("class $C {}").expect("should exist");
647        let id = Tsx.field_to_id("name").unwrap();
648        assert!(node.child_by_field_id(id).is_some());
649        assert!(node.child_by_field_id(id + 1).is_none());
650    }
651
652    #[test]
653    fn test_remove() {
654        let root = Tsx.ast_grep("Some(Some(1)); Some(2);");
655        let root = root.root();
656        let node = root.find("Some(2);").expect("should exist");
657        let edit = node.remove();
658        assert_eq!(edit.position, 15);
659        assert_eq!(edit.deleted_length, 8);
660    }
661
662    #[test]
663    fn test_ascii_pos() {
664        let root = Tsx.ast_grep("a");
665        let root = root.root();
666        let node = root.find("$A").expect("should exist");
667        assert_eq!(node.start_pos().line(), 0);
668        assert_eq!(node.start_pos().column(&*node), 0);
669        assert_eq!(node.end_pos().line(), 0);
670        assert_eq!(node.end_pos().column(&*node), 1);
671    }
672
673    #[test]
674    fn test_unicode_pos() {
675        let root = Tsx.ast_grep("🦀");
676        let root = root.root();
677        let node = root.find("$A").expect("should exist");
678        assert_eq!(node.start_pos().line(), 0);
679        assert_eq!(node.start_pos().column(&*node), 0);
680        assert_eq!(node.end_pos().line(), 0);
681        assert_eq!(node.end_pos().column(&*node), 1);
682        let root = Tsx.ast_grep("\n  🦀🦀");
683        let root = root.root();
684        let node = root.find("$A").expect("should exist");
685        assert_eq!(node.start_pos().line(), 1);
686        assert_eq!(node.start_pos().column(&*node), 2);
687        assert_eq!(node.end_pos().line(), 1);
688        assert_eq!(node.end_pos().column(&*node), 4);
689    }
690}