project_rag/indexer/
ast_parser.rs

1use anyhow::{Context, Result};
2use tree_sitter::{Language, Node, Parser};
3
4/// AST node information for chunking
5#[derive(Debug, Clone)]
6pub struct AstNode {
7    pub kind: String,
8    pub start_byte: usize,
9    pub end_byte: usize,
10    pub start_line: usize,
11    pub end_line: usize,
12}
13
14/// AST parser for extracting semantic code units
15pub struct AstParser {
16    parser: Parser,
17    _language: Language,
18    language_name: String,
19}
20
21impl AstParser {
22    /// Create a new AST parser for the given language
23    pub fn new(extension: &str) -> Result<Self> {
24        let (language, language_name) = match extension.to_lowercase().as_str() {
25            "rs" => (tree_sitter_rust::LANGUAGE.into(), "Rust"),
26            "py" => (tree_sitter_python::LANGUAGE.into(), "Python"),
27            "js" | "mjs" | "cjs" | "jsx" => (tree_sitter_javascript::LANGUAGE.into(), "JavaScript"),
28            "ts" | "tsx" => (
29                tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
30                "TypeScript",
31            ),
32            "go" => (tree_sitter_go::LANGUAGE.into(), "Go"),
33            "java" => (tree_sitter_java::LANGUAGE.into(), "Java"),
34            "swift" => (tree_sitter_swift::LANGUAGE.into(), "Swift"),
35            "c" | "h" => (tree_sitter_c::LANGUAGE.into(), "C"),
36            "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "hh" => {
37                (tree_sitter_cpp::LANGUAGE.into(), "C++")
38            }
39            "cs" => (tree_sitter_c_sharp::LANGUAGE.into(), "C#"),
40            "rb" => (tree_sitter_ruby::LANGUAGE.into(), "Ruby"),
41            "php" => (tree_sitter_php::LANGUAGE_PHP.into(), "PHP"),
42            _ => anyhow::bail!("Unsupported language for AST parsing: {}", extension),
43        };
44
45        let mut parser = Parser::new();
46        parser
47            .set_language(&language)
48            .context("Failed to set parser language")?;
49
50        Ok(Self {
51            parser,
52            _language: language,
53            language_name: language_name.to_string(),
54        })
55    }
56
57    /// Parse source code and extract semantic units (functions, classes, etc.)
58    pub fn parse(&mut self, source_code: &str) -> Result<Vec<AstNode>> {
59        let tree = self
60            .parser
61            .parse(source_code, None)
62            .context("Failed to parse source code")?;
63
64        let root_node = tree.root_node();
65        let mut nodes = Vec::new();
66
67        // Extract semantic units based on language
68        self.extract_semantic_units(root_node, source_code, &mut nodes);
69
70        Ok(nodes)
71    }
72
73    /// Extract semantic units (functions, classes, methods) from the AST
74    fn extract_semantic_units(&self, node: Node, _source_code: &str, result: &mut Vec<AstNode>) {
75        // Define node types we want to chunk by language
76        let target_kinds = match self.language_name.as_str() {
77            "Rust" => vec![
78                "function_item",
79                "impl_item",
80                "trait_item",
81                "struct_item",
82                "enum_item",
83                "mod_item",
84            ],
85            "Python" => vec![
86                "function_definition",
87                "class_definition",
88                "decorated_definition",
89            ],
90            "JavaScript" | "TypeScript" => vec![
91                "function_declaration",
92                "function_expression",
93                "arrow_function",
94                "method_definition",
95                "class_declaration",
96            ],
97            "Go" => vec![
98                "function_declaration",
99                "method_declaration",
100                "type_declaration",
101            ],
102            "Java" => vec![
103                "method_declaration",
104                "class_declaration",
105                "interface_declaration",
106                "constructor_declaration",
107            ],
108            "Swift" => vec![
109                "function_declaration",
110                "class_declaration",
111                "protocol_declaration",
112                "struct_declaration",
113                "enum_declaration",
114                "extension_declaration",
115                "deinit_declaration",
116                "initializer_declaration",
117                "subscript_declaration",
118            ],
119            "C" => vec![
120                "function_definition",
121                "struct_specifier",
122                "enum_specifier",
123                "union_specifier",
124                "type_definition",
125            ],
126            "C++" => vec![
127                "function_definition",
128                "class_specifier",
129                "struct_specifier",
130                "enum_specifier",
131                "union_specifier",
132                "namespace_definition",
133                "template_declaration",
134            ],
135            "C#" => vec![
136                "method_declaration",
137                "class_declaration",
138                "struct_declaration",
139                "interface_declaration",
140                "enum_declaration",
141                "namespace_declaration",
142                "constructor_declaration",
143                "property_declaration",
144            ],
145            "Ruby" => vec![
146                "method",
147                "singleton_method",
148                "class",
149                "singleton_class",
150                "module",
151            ],
152            "PHP" => vec![
153                "function_definition",
154                "method_declaration",
155                "class_declaration",
156                "interface_declaration",
157                "trait_declaration",
158                "namespace_definition",
159            ],
160            _ => vec![],
161        };
162
163        // Check if current node is a target kind
164        let kind = node.kind();
165        if target_kinds.contains(&kind) {
166            let start_position = node.start_position();
167            let end_position = node.end_position();
168
169            result.push(AstNode {
170                kind: kind.to_string(),
171                start_byte: node.start_byte(),
172                end_byte: node.end_byte(),
173                start_line: start_position.row + 1, // Tree-sitter uses 0-indexed rows
174                end_line: end_position.row + 1,
175            });
176        }
177
178        // Recursively process children
179        let mut cursor = node.walk();
180        for child in node.children(&mut cursor) {
181            self.extract_semantic_units(child, _source_code, result);
182        }
183    }
184
185    /// Get the language name
186    pub fn language_name(&self) -> &str {
187        &self.language_name
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn test_rust_parsing() {
197        let source = r#"
198fn main() {
199    println!("Hello, world!");
200}
201
202struct MyStruct {
203    field: i32,
204}
205
206impl MyStruct {
207    fn new() -> Self {
208        MyStruct { field: 0 }
209    }
210}
211"#;
212
213        let mut parser = AstParser::new("rs").unwrap();
214        let nodes = parser.parse(source).unwrap();
215
216        assert!(nodes.len() >= 3); // function, struct, impl
217        assert!(nodes.iter().any(|n| n.kind == "function_item"));
218        assert!(nodes.iter().any(|n| n.kind == "struct_item"));
219        assert!(nodes.iter().any(|n| n.kind == "impl_item"));
220    }
221
222    #[test]
223    fn test_python_parsing() {
224        let source = r#"
225def hello():
226    print("Hello")
227
228class MyClass:
229    def __init__(self):
230        self.value = 0
231
232    def method(self):
233        return self.value
234"#;
235
236        let mut parser = AstParser::new("py").unwrap();
237        let nodes = parser.parse(source).unwrap();
238
239        assert!(nodes.len() >= 2); // function and class
240        assert!(nodes.iter().any(|n| n.kind == "function_definition"));
241        assert!(nodes.iter().any(|n| n.kind == "class_definition"));
242    }
243
244    #[test]
245    fn test_javascript_parsing() {
246        let source = r#"
247function hello() {
248    console.log("Hello");
249}
250
251const arrow = () => {
252    return 42;
253};
254
255class MyClass {
256    constructor() {
257        this.value = 0;
258    }
259
260    method() {
261        return this.value;
262    }
263}
264"#;
265
266        let mut parser = AstParser::new("js").unwrap();
267        let nodes = parser.parse(source).unwrap();
268
269        assert!(nodes.len() >= 2); // At least function and class
270    }
271
272    #[test]
273    fn test_swift_parsing() {
274        let source = r#"
275func greet(name: String) {
276    print("Hello, \(name)!")
277}
278
279class MyClass {
280    var value: Int
281
282    init(value: Int) {
283        self.value = value
284    }
285
286    func method() -> Int {
287        return value
288    }
289}
290"#;
291
292        let mut parser = AstParser::new("swift").unwrap();
293        let nodes = parser.parse(source).unwrap();
294
295        // Swift parser should extract function and class declarations
296        assert!(!nodes.is_empty()); // At least some declarations found
297        // Check we can parse Swift without errors
298        assert!(parser.language_name() == "Swift");
299    }
300
301    #[test]
302    fn test_unsupported_language() {
303        let result = AstParser::new("xyz");
304        assert!(result.is_err());
305    }
306
307    #[test]
308    fn test_c_parsing() {
309        let source = r#"
310int add(int a, int b) {
311    return a + b;
312}
313
314struct Point {
315    int x;
316    int y;
317};
318"#;
319
320        let mut parser = AstParser::new("c").unwrap();
321        let nodes = parser.parse(source).unwrap();
322
323        assert!(!nodes.is_empty());
324        assert!(parser.language_name() == "C");
325    }
326
327    #[test]
328    fn test_cpp_parsing() {
329        let source = r#"
330class MyClass {
331public:
332    int value;
333    MyClass() : value(0) {}
334    int getValue() { return value; }
335};
336
337namespace MyNamespace {
338    void function() {}
339}
340"#;
341
342        let mut parser = AstParser::new("cpp").unwrap();
343        let nodes = parser.parse(source).unwrap();
344
345        assert!(!nodes.is_empty());
346        assert!(parser.language_name() == "C++");
347    }
348
349    #[test]
350    fn test_csharp_parsing() {
351        let source = r#"
352class MyClass {
353    private int value;
354
355    public MyClass() {
356        value = 0;
357    }
358
359    public int GetValue() {
360        return value;
361    }
362}
363"#;
364
365        let mut parser = AstParser::new("cs").unwrap();
366        let nodes = parser.parse(source).unwrap();
367
368        assert!(!nodes.is_empty());
369        assert!(parser.language_name() == "C#");
370    }
371
372    #[test]
373    fn test_ruby_parsing() {
374        let source = r#"
375def hello(name)
376  puts "Hello, #{name}!"
377end
378
379class MyClass
380  def initialize(value)
381    @value = value
382  end
383
384  def method
385    @value
386  end
387end
388"#;
389
390        let mut parser = AstParser::new("rb").unwrap();
391        let nodes = parser.parse(source).unwrap();
392
393        assert!(!nodes.is_empty());
394        assert!(parser.language_name() == "Ruby");
395    }
396
397    #[test]
398    fn test_php_parsing() {
399        let source = r#"
400<?php
401function hello($name) {
402    echo "Hello, $name!";
403}
404
405class MyClass {
406    private $value;
407
408    public function __construct($value) {
409        $this->value = $value;
410    }
411
412    public function getValue() {
413        return $this->value;
414    }
415}
416?>
417"#;
418
419        let mut parser = AstParser::new("php").unwrap();
420        let nodes = parser.parse(source).unwrap();
421
422        assert!(!nodes.is_empty());
423        assert!(parser.language_name() == "PHP");
424    }
425}