codex_patcher/ts/
parser.rs1use crate::ts::errors::TreeSitterError;
2use ast_grep_language::{LanguageExt, SupportLang};
3use tree_sitter::{Parser, Tree};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
7pub enum RustEdition {
8 E2015,
9 E2018,
10 #[default]
11 E2021,
12 E2024,
13}
14
15impl RustEdition {
16 pub fn parse(s: &str) -> Option<Self> {
18 match s {
19 "2015" => Some(RustEdition::E2015),
20 "2018" => Some(RustEdition::E2018),
21 "2021" => Some(RustEdition::E2021),
22 "2024" => Some(RustEdition::E2024),
23 _ => None,
24 }
25 }
26}
27
28pub struct RustParser {
30 parser: Parser,
31 edition: RustEdition,
32}
33
34impl RustParser {
35 pub fn new() -> Result<Self, TreeSitterError> {
37 Self::with_edition(RustEdition::default())
38 }
39
40 pub fn with_edition(edition: RustEdition) -> Result<Self, TreeSitterError> {
42 let mut parser = Parser::new();
43 let ts_lang = SupportLang::Rust.get_ts_language();
45 parser
46 .set_language(&ts_lang)
47 .map_err(|_| TreeSitterError::LanguageSet)?;
48
49 Ok(Self { parser, edition })
50 }
51
52 pub fn edition(&self) -> RustEdition {
54 self.edition
55 }
56
57 pub fn parse(&mut self, source: &str) -> Result<Tree, TreeSitterError> {
59 self.parser
60 .parse(source, None)
61 .ok_or(TreeSitterError::ParseFailed)
62 }
63
64 pub fn parse_with_source<'a>(
66 &mut self,
67 source: &'a str,
68 ) -> Result<ParsedSource<'a>, TreeSitterError> {
69 let tree = self.parse(source)?;
70 Ok(ParsedSource { source, tree })
71 }
72}
73
74impl Default for RustParser {
75 fn default() -> Self {
76 Self::new().expect("failed to create default RustParser")
77 }
78}
79
80pub struct ParsedSource<'a> {
82 pub source: &'a str,
83 pub tree: Tree,
84}
85
86impl<'a> ParsedSource<'a> {
87 pub fn root_node(&self) -> tree_sitter::Node<'_> {
89 self.tree.root_node()
90 }
91
92 pub fn has_errors(&self) -> bool {
94 has_error_nodes(self.tree.root_node())
95 }
96
97 pub fn error_nodes(&self) -> Vec<ErrorNode> {
99 let mut errors = Vec::new();
100 collect_error_nodes(self.tree.root_node(), &mut errors);
101 errors
102 }
103
104 pub fn node_text(&self, node: tree_sitter::Node<'_>) -> &'a str {
106 &self.source[node.byte_range()]
107 }
108}
109
110#[derive(Debug, Clone)]
112pub struct ErrorNode {
113 pub byte_start: usize,
114 pub byte_end: usize,
115 pub start_point: tree_sitter::Point,
116 pub end_point: tree_sitter::Point,
117}
118
119fn has_error_nodes(node: tree_sitter::Node<'_>) -> bool {
120 if node.is_error() || node.is_missing() {
121 return true;
122 }
123
124 let mut cursor = node.walk();
125 for child in node.children(&mut cursor) {
126 if has_error_nodes(child) {
127 return true;
128 }
129 }
130
131 false
132}
133
134fn collect_error_nodes(node: tree_sitter::Node<'_>, errors: &mut Vec<ErrorNode>) {
135 if node.is_error() || node.is_missing() {
136 errors.push(ErrorNode {
137 byte_start: node.start_byte(),
138 byte_end: node.end_byte(),
139 start_point: node.start_position(),
140 end_point: node.end_position(),
141 });
142 }
143
144 let mut cursor = node.walk();
145 for child in node.children(&mut cursor) {
146 collect_error_nodes(child, errors);
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn parse_valid_rust() {
156 let mut parser = RustParser::new().unwrap();
157 let source = "fn main() { println!(\"hello\"); }";
158 let parsed = parser.parse_with_source(source).unwrap();
159
160 assert!(!parsed.has_errors());
161 assert_eq!(parsed.root_node().kind(), "source_file");
162 }
163
164 #[test]
165 fn parse_invalid_rust() {
166 let mut parser = RustParser::new().unwrap();
167 let source = "fn main( { }";
168 let parsed = parser.parse_with_source(source).unwrap();
169
170 assert!(parsed.has_errors());
171 assert!(!parsed.error_nodes().is_empty());
172 }
173
174 #[test]
175 fn edition_parsing() {
176 assert_eq!(RustEdition::parse("2021"), Some(RustEdition::E2021));
177 assert_eq!(RustEdition::parse("2024"), Some(RustEdition::E2024));
178 assert_eq!(RustEdition::parse("invalid"), None);
179 }
180}