1use anyhow::{Context, Result};
2use tree_sitter::{Language, Node, Parser};
3
4#[derive(Debug, Clone)]
6pub struct AstNode {
7 pub kind: String,
8 pub start_byte: usize,
9 pub end_byte: usize,
10 pub start_line: usize,
11 pub end_line: usize,
12}
13
14pub struct AstParser {
16 parser: Parser,
17 _language: Language,
18 language_name: String,
19}
20
21impl AstParser {
22 pub fn new(extension: &str) -> Result<Self> {
24 let (language, language_name) = match extension.to_lowercase().as_str() {
25 "rs" => (tree_sitter_rust::LANGUAGE.into(), "Rust"),
26 "py" => (tree_sitter_python::LANGUAGE.into(), "Python"),
27 "js" | "mjs" | "cjs" | "jsx" => (tree_sitter_javascript::LANGUAGE.into(), "JavaScript"),
28 "ts" | "tsx" => (
29 tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
30 "TypeScript",
31 ),
32 "go" => (tree_sitter_go::LANGUAGE.into(), "Go"),
33 "java" => (tree_sitter_java::LANGUAGE.into(), "Java"),
34 "swift" => (tree_sitter_swift::LANGUAGE.into(), "Swift"),
35 "c" | "h" => (tree_sitter_c::LANGUAGE.into(), "C"),
36 "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "hh" => {
37 (tree_sitter_cpp::LANGUAGE.into(), "C++")
38 }
39 "cs" => (tree_sitter_c_sharp::LANGUAGE.into(), "C#"),
40 "rb" => (tree_sitter_ruby::LANGUAGE.into(), "Ruby"),
41 "php" => (tree_sitter_php::LANGUAGE_PHP.into(), "PHP"),
42 _ => anyhow::bail!("Unsupported language for AST parsing: {}", extension),
43 };
44
45 let mut parser = Parser::new();
46 parser
47 .set_language(&language)
48 .context("Failed to set parser language")?;
49
50 Ok(Self {
51 parser,
52 _language: language,
53 language_name: language_name.to_string(),
54 })
55 }
56
57 pub fn parse(&mut self, source_code: &str) -> Result<Vec<AstNode>> {
59 let tree = self
60 .parser
61 .parse(source_code, None)
62 .context("Failed to parse source code")?;
63
64 let root_node = tree.root_node();
65 let mut nodes = Vec::new();
66
67 self.extract_semantic_units(root_node, source_code, &mut nodes);
69
70 Ok(nodes)
71 }
72
73 fn extract_semantic_units(&self, node: Node, _source_code: &str, result: &mut Vec<AstNode>) {
75 let target_kinds = match self.language_name.as_str() {
77 "Rust" => vec![
78 "function_item",
79 "impl_item",
80 "trait_item",
81 "struct_item",
82 "enum_item",
83 "mod_item",
84 ],
85 "Python" => vec![
86 "function_definition",
87 "class_definition",
88 "decorated_definition",
89 ],
90 "JavaScript" | "TypeScript" => vec![
91 "function_declaration",
92 "function_expression",
93 "arrow_function",
94 "method_definition",
95 "class_declaration",
96 ],
97 "Go" => vec![
98 "function_declaration",
99 "method_declaration",
100 "type_declaration",
101 ],
102 "Java" => vec![
103 "method_declaration",
104 "class_declaration",
105 "interface_declaration",
106 "constructor_declaration",
107 ],
108 "Swift" => vec![
109 "function_declaration",
110 "class_declaration",
111 "protocol_declaration",
112 "struct_declaration",
113 "enum_declaration",
114 "extension_declaration",
115 "deinit_declaration",
116 "initializer_declaration",
117 "subscript_declaration",
118 ],
119 "C" => vec![
120 "function_definition",
121 "struct_specifier",
122 "enum_specifier",
123 "union_specifier",
124 "type_definition",
125 ],
126 "C++" => vec![
127 "function_definition",
128 "class_specifier",
129 "struct_specifier",
130 "enum_specifier",
131 "union_specifier",
132 "namespace_definition",
133 "template_declaration",
134 ],
135 "C#" => vec![
136 "method_declaration",
137 "class_declaration",
138 "struct_declaration",
139 "interface_declaration",
140 "enum_declaration",
141 "namespace_declaration",
142 "constructor_declaration",
143 "property_declaration",
144 ],
145 "Ruby" => vec![
146 "method",
147 "singleton_method",
148 "class",
149 "singleton_class",
150 "module",
151 ],
152 "PHP" => vec![
153 "function_definition",
154 "method_declaration",
155 "class_declaration",
156 "interface_declaration",
157 "trait_declaration",
158 "namespace_definition",
159 ],
160 _ => vec![],
161 };
162
163 let kind = node.kind();
165 if target_kinds.contains(&kind) {
166 let start_position = node.start_position();
167 let end_position = node.end_position();
168
169 result.push(AstNode {
170 kind: kind.to_string(),
171 start_byte: node.start_byte(),
172 end_byte: node.end_byte(),
173 start_line: start_position.row + 1, end_line: end_position.row + 1,
175 });
176 }
177
178 let mut cursor = node.walk();
180 for child in node.children(&mut cursor) {
181 self.extract_semantic_units(child, _source_code, result);
182 }
183 }
184
185 pub fn language_name(&self) -> &str {
187 &self.language_name
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn test_rust_parsing() {
197 let source = r#"
198fn main() {
199 println!("Hello, world!");
200}
201
202struct MyStruct {
203 field: i32,
204}
205
206impl MyStruct {
207 fn new() -> Self {
208 MyStruct { field: 0 }
209 }
210}
211"#;
212
213 let mut parser = AstParser::new("rs").unwrap();
214 let nodes = parser.parse(source).unwrap();
215
216 assert!(nodes.len() >= 3); assert!(nodes.iter().any(|n| n.kind == "function_item"));
218 assert!(nodes.iter().any(|n| n.kind == "struct_item"));
219 assert!(nodes.iter().any(|n| n.kind == "impl_item"));
220 }
221
222 #[test]
223 fn test_python_parsing() {
224 let source = r#"
225def hello():
226 print("Hello")
227
228class MyClass:
229 def __init__(self):
230 self.value = 0
231
232 def method(self):
233 return self.value
234"#;
235
236 let mut parser = AstParser::new("py").unwrap();
237 let nodes = parser.parse(source).unwrap();
238
239 assert!(nodes.len() >= 2); assert!(nodes.iter().any(|n| n.kind == "function_definition"));
241 assert!(nodes.iter().any(|n| n.kind == "class_definition"));
242 }
243
244 #[test]
245 fn test_javascript_parsing() {
246 let source = r#"
247function hello() {
248 console.log("Hello");
249}
250
251const arrow = () => {
252 return 42;
253};
254
255class MyClass {
256 constructor() {
257 this.value = 0;
258 }
259
260 method() {
261 return this.value;
262 }
263}
264"#;
265
266 let mut parser = AstParser::new("js").unwrap();
267 let nodes = parser.parse(source).unwrap();
268
269 assert!(nodes.len() >= 2); }
271
272 #[test]
273 fn test_swift_parsing() {
274 let source = r#"
275func greet(name: String) {
276 print("Hello, \(name)!")
277}
278
279class MyClass {
280 var value: Int
281
282 init(value: Int) {
283 self.value = value
284 }
285
286 func method() -> Int {
287 return value
288 }
289}
290"#;
291
292 let mut parser = AstParser::new("swift").unwrap();
293 let nodes = parser.parse(source).unwrap();
294
295 assert!(!nodes.is_empty()); assert!(parser.language_name() == "Swift");
299 }
300
301 #[test]
302 fn test_unsupported_language() {
303 let result = AstParser::new("xyz");
304 assert!(result.is_err());
305 }
306
307 #[test]
308 fn test_c_parsing() {
309 let source = r#"
310int add(int a, int b) {
311 return a + b;
312}
313
314struct Point {
315 int x;
316 int y;
317};
318"#;
319
320 let mut parser = AstParser::new("c").unwrap();
321 let nodes = parser.parse(source).unwrap();
322
323 assert!(!nodes.is_empty());
324 assert!(parser.language_name() == "C");
325 }
326
327 #[test]
328 fn test_cpp_parsing() {
329 let source = r#"
330class MyClass {
331public:
332 int value;
333 MyClass() : value(0) {}
334 int getValue() { return value; }
335};
336
337namespace MyNamespace {
338 void function() {}
339}
340"#;
341
342 let mut parser = AstParser::new("cpp").unwrap();
343 let nodes = parser.parse(source).unwrap();
344
345 assert!(!nodes.is_empty());
346 assert!(parser.language_name() == "C++");
347 }
348
349 #[test]
350 fn test_csharp_parsing() {
351 let source = r#"
352class MyClass {
353 private int value;
354
355 public MyClass() {
356 value = 0;
357 }
358
359 public int GetValue() {
360 return value;
361 }
362}
363"#;
364
365 let mut parser = AstParser::new("cs").unwrap();
366 let nodes = parser.parse(source).unwrap();
367
368 assert!(!nodes.is_empty());
369 assert!(parser.language_name() == "C#");
370 }
371
372 #[test]
373 fn test_ruby_parsing() {
374 let source = r#"
375def hello(name)
376 puts "Hello, #{name}!"
377end
378
379class MyClass
380 def initialize(value)
381 @value = value
382 end
383
384 def method
385 @value
386 end
387end
388"#;
389
390 let mut parser = AstParser::new("rb").unwrap();
391 let nodes = parser.parse(source).unwrap();
392
393 assert!(!nodes.is_empty());
394 assert!(parser.language_name() == "Ruby");
395 }
396
397 #[test]
398 fn test_php_parsing() {
399 let source = r#"
400<?php
401function hello($name) {
402 echo "Hello, $name!";
403}
404
405class MyClass {
406 private $value;
407
408 public function __construct($value) {
409 $this->value = $value;
410 }
411
412 public function getValue() {
413 return $this->value;
414 }
415}
416?>
417"#;
418
419 let mut parser = AstParser::new("php").unwrap();
420 let nodes = parser.parse(source).unwrap();
421
422 assert!(!nodes.is_empty());
423 assert!(parser.language_name() == "PHP");
424 }
425}