Skip to main content

codex_patcher/ts/
parser.rs

1use crate::ts::errors::TreeSitterError;
2use ast_grep_language::{LanguageExt, SupportLang};
3use tree_sitter::{Parser, Tree};
4
5/// Rust edition for grammar compatibility checking.
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
7pub enum RustEdition {
8    E2015,
9    E2018,
10    #[default]
11    E2021,
12    E2024,
13}
14
15impl RustEdition {
16    /// Parse edition from Cargo.toml edition string.
17    pub fn parse(s: &str) -> Option<Self> {
18        match s {
19            "2015" => Some(RustEdition::E2015),
20            "2018" => Some(RustEdition::E2018),
21            "2021" => Some(RustEdition::E2021),
22            "2024" => Some(RustEdition::E2024),
23            _ => None,
24        }
25    }
26}
27
28/// Tree-sitter parser wrapper for Rust source code.
29pub struct RustParser {
30    parser: Parser,
31    edition: RustEdition,
32}
33
34impl RustParser {
35    /// Create a new Rust parser with the default edition (2021).
36    pub fn new() -> Result<Self, TreeSitterError> {
37        Self::with_edition(RustEdition::default())
38    }
39
40    /// Create a new Rust parser targeting a specific edition.
41    pub fn with_edition(edition: RustEdition) -> Result<Self, TreeSitterError> {
42        let mut parser = Parser::new();
43        // Get the tree-sitter Language from ast-grep-language
44        let ts_lang = SupportLang::Rust.get_ts_language();
45        parser
46            .set_language(&ts_lang)
47            .map_err(|_| TreeSitterError::LanguageSet)?;
48
49        Ok(Self { parser, edition })
50    }
51
52    /// Get the configured edition.
53    pub fn edition(&self) -> RustEdition {
54        self.edition
55    }
56
57    /// Parse source code into a tree-sitter Tree.
58    pub fn parse(&mut self, source: &str) -> Result<Tree, TreeSitterError> {
59        self.parser
60            .parse(source, None)
61            .ok_or(TreeSitterError::ParseFailed)
62    }
63
64    /// Parse source code and return the tree along with the source.
65    pub fn parse_with_source<'a>(
66        &mut self,
67        source: &'a str,
68    ) -> Result<ParsedSource<'a>, TreeSitterError> {
69        let tree = self.parse(source)?;
70        Ok(ParsedSource { source, tree })
71    }
72}
73
74impl Default for RustParser {
75    fn default() -> Self {
76        Self::new().expect("failed to create default RustParser")
77    }
78}
79
80/// A parsed source file with its tree-sitter tree.
81pub struct ParsedSource<'a> {
82    pub source: &'a str,
83    pub tree: Tree,
84}
85
86impl<'a> ParsedSource<'a> {
87    /// Get the root node of the tree.
88    pub fn root_node(&self) -> tree_sitter::Node<'_> {
89        self.tree.root_node()
90    }
91
92    /// Check if the tree contains any ERROR nodes.
93    pub fn has_errors(&self) -> bool {
94        has_error_nodes(self.tree.root_node())
95    }
96
97    /// Get all ERROR nodes in the tree.
98    pub fn error_nodes(&self) -> Vec<ErrorNode> {
99        let mut errors = Vec::new();
100        collect_error_nodes(self.tree.root_node(), &mut errors);
101        errors
102    }
103
104    /// Extract text for a node's byte range.
105    pub fn node_text(&self, node: tree_sitter::Node<'_>) -> &'a str {
106        &self.source[node.byte_range()]
107    }
108}
109
110/// Information about an ERROR node in the parse tree.
111#[derive(Debug, Clone)]
112pub struct ErrorNode {
113    pub byte_start: usize,
114    pub byte_end: usize,
115    pub start_point: tree_sitter::Point,
116    pub end_point: tree_sitter::Point,
117}
118
119fn has_error_nodes(node: tree_sitter::Node<'_>) -> bool {
120    if node.is_error() || node.is_missing() {
121        return true;
122    }
123
124    let mut cursor = node.walk();
125    for child in node.children(&mut cursor) {
126        if has_error_nodes(child) {
127            return true;
128        }
129    }
130
131    false
132}
133
134fn collect_error_nodes(node: tree_sitter::Node<'_>, errors: &mut Vec<ErrorNode>) {
135    if node.is_error() || node.is_missing() {
136        errors.push(ErrorNode {
137            byte_start: node.start_byte(),
138            byte_end: node.end_byte(),
139            start_point: node.start_position(),
140            end_point: node.end_position(),
141        });
142    }
143
144    let mut cursor = node.walk();
145    for child in node.children(&mut cursor) {
146        collect_error_nodes(child, errors);
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn parse_valid_rust() {
156        let mut parser = RustParser::new().unwrap();
157        let source = "fn main() { println!(\"hello\"); }";
158        let parsed = parser.parse_with_source(source).unwrap();
159
160        assert!(!parsed.has_errors());
161        assert_eq!(parsed.root_node().kind(), "source_file");
162    }
163
164    #[test]
165    fn parse_invalid_rust() {
166        let mut parser = RustParser::new().unwrap();
167        let source = "fn main( { }";
168        let parsed = parser.parse_with_source(source).unwrap();
169
170        assert!(parsed.has_errors());
171        assert!(!parsed.error_nodes().is_empty());
172    }
173
174    #[test]
175    fn edition_parsing() {
176        assert_eq!(RustEdition::parse("2021"), Some(RustEdition::E2021));
177        assert_eq!(RustEdition::parse("2024"), Some(RustEdition::E2024));
178        assert_eq!(RustEdition::parse("invalid"), None);
179    }
180}