Skip to main content

amql_engine/resolver/
rust.rs

1//! Rust source file resolver using tree-sitter.
2//!
3//! Parses `.rs` files into `CodeElement` trees, extracting functions, structs,
4//! enums, traits, impl blocks, modules, constants, statics, type aliases,
5//! and macro definitions.
6
7use super::{CodeElement, SourceLocation};
8use crate::error::AqlError;
9use crate::types::{AttrName, CodeElementName, RelativePath, TagName};
10use rustc_hash::FxHashMap;
11use std::cell::RefCell;
12use std::path::Path;
13
14/// Rust source file resolver using tree-sitter.
15pub struct RustResolver;
16
17impl super::CodeResolver for RustResolver {
18    fn resolve(&self, file_path: &Path) -> Result<CodeElement, AqlError> {
19        let source =
20            std::fs::read_to_string(file_path).map_err(|e| format!("Failed to read file: {e}"))?;
21        let root = parse_rust_source(&source, file_path)?;
22        Ok(root)
23    }
24
25    fn extensions(&self) -> &[&str] {
26        &[".rs"]
27    }
28
29    fn code_tags(&self) -> &[&str] {
30        &[
31            "function", "struct", "enum", "trait", "impl", "module", "const", "static", "type",
32            "macro",
33        ]
34    }
35}
36
37// Thread-local cached tree-sitter parser to avoid re-creating on each call.
38thread_local! {
39    static RUST_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
40}
41
42fn with_rust_parser<F, R>(f: F) -> Result<R, String>
43where
44    F: FnOnce(&mut tree_sitter::Parser) -> Result<R, String>,
45{
46    RUST_PARSER.with(|cell| {
47        let mut opt = cell.borrow_mut();
48        let parser = opt.get_or_insert_with(|| {
49            let mut p = tree_sitter::Parser::new();
50            p.set_language(&tree_sitter_rust::LANGUAGE.into())
51                .expect("Failed to set Rust language for tree-sitter");
52            p
53        });
54        f(parser)
55    })
56}
57
58/// Parse a Rust source string into a CodeElement tree.
59fn parse_rust_source(source: &str, file_path: &Path) -> Result<CodeElement, String> {
60    let tree = with_rust_parser(|parser| {
61        parser
62            .parse(source, None)
63            .ok_or_else(|| "Failed to parse source".to_string())
64    })?;
65
66    let root_node = tree.root_node();
67    let src = source.as_bytes();
68    let file_str = file_path.to_string_lossy().to_string();
69
70    let mut children = Vec::new();
71    let mut cursor = root_node.walk();
72    for child in root_node.named_children(&mut cursor) {
73        if let Some(element) = extract_element(&child, src, &file_str) {
74            children.push(element);
75        }
76    }
77
78    let filename = file_path
79        .file_name()
80        .map(|f| f.to_string_lossy().to_string())
81        .unwrap_or_else(|| file_str.clone());
82
83    Ok(CodeElement {
84        tag: TagName::from("module"),
85        name: CodeElementName::from(filename),
86        attrs: FxHashMap::default(),
87        children,
88        source: SourceLocation {
89            file: RelativePath::from(file_str),
90            line: 1,
91            column: 0,
92            end_line: Some(root_node.end_position().row + 1),
93            end_column: Some(root_node.end_position().column),
94            start_byte: root_node.start_byte(),
95            end_byte: root_node.end_byte(),
96        },
97    })
98}
99
100/// Extract a CodeElement from a tree-sitter node, if it's a recognized kind.
101fn extract_element(node: &tree_sitter::Node, src: &[u8], file: &str) -> Option<CodeElement> {
102    match node.kind() {
103        "function_item" => Some(extract_function(node, src, file)),
104        "struct_item" => Some(extract_named_element("struct", node, src, file)),
105        "enum_item" => Some(extract_named_element("enum", node, src, file)),
106        "trait_item" => Some(extract_trait(node, src, file)),
107        "impl_item" => Some(extract_impl(node, src, file)),
108        "mod_item" => Some(extract_named_element("module", node, src, file)),
109        "const_item" => Some(extract_named_element("const", node, src, file)),
110        "static_item" => Some(extract_static(node, src, file)),
111        "type_item" => Some(extract_named_element("type", node, src, file)),
112        "macro_definition" => Some(extract_named_element("macro", node, src, file)),
113        _ => None,
114    }
115}
116
117fn node_text<'a>(node: &tree_sitter::Node, src: &'a [u8]) -> &'a str {
118    node.utf8_text(src).unwrap_or("")
119}
120
121fn get_name(node: &tree_sitter::Node, src: &[u8]) -> CodeElementName {
122    CodeElementName::from(
123        node.child_by_field_name("name")
124            .map(|n| node_text(&n, src).to_string())
125            .unwrap_or_default(),
126    )
127}
128
129fn get_visibility(node: &tree_sitter::Node, src: &[u8]) -> Option<String> {
130    let mut cursor = node.walk();
131    for child in node.named_children(&mut cursor) {
132        if child.kind() == "visibility_modifier" {
133            return Some(node_text(&child, src).to_string());
134        }
135    }
136    None
137}
138
139/// Extracted modifier flags from a single pass over node children.
140struct Modifiers {
141    is_async: bool,
142    is_unsafe: bool,
143    is_const: bool,
144}
145
146/// Walk children once and extract all keyword modifiers in a single pass.
147fn extract_modifiers(node: &tree_sitter::Node, src: &[u8]) -> Modifiers {
148    let mut mods = Modifiers {
149        is_async: false,
150        is_unsafe: false,
151        is_const: false,
152    };
153
154    let mut cursor = node.walk();
155    for child in node.children(&mut cursor) {
156        let text = node_text(&child, src);
157        match text {
158            "async" => mods.is_async = true,
159            "unsafe" => mods.is_unsafe = true,
160            "const" => mods.is_const = true,
161            _ => {}
162        }
163        // Also check inside function_modifiers or similar wrapper nodes
164        if child.is_named() {
165            let mut inner_cursor = child.walk();
166            for inner in child.children(&mut inner_cursor) {
167                let inner_text = node_text(&inner, src);
168                match inner_text {
169                    "async" => mods.is_async = true,
170                    "unsafe" => mods.is_unsafe = true,
171                    "const" => mods.is_const = true,
172                    _ => {}
173                }
174            }
175        }
176    }
177
178    mods
179}
180
181fn make_source_location(node: &tree_sitter::Node, file: &str) -> SourceLocation {
182    let start = node.start_position();
183    let end = node.end_position();
184    SourceLocation {
185        file: RelativePath::from(file),
186        line: start.row + 1,
187        column: start.column,
188        end_line: Some(end.row + 1),
189        end_column: Some(end.column),
190        start_byte: node.start_byte(),
191        end_byte: node.end_byte(),
192    }
193}
194
195/// Build a CodeElement with name and visibility extracted from the node.
196/// Covers the common pattern shared by most element kinds.
197fn extract_named_element(
198    tag: &str,
199    node: &tree_sitter::Node,
200    src: &[u8],
201    file: &str,
202) -> CodeElement {
203    let name = get_name(node, src);
204    let mut attrs = FxHashMap::default();
205    attrs.insert(
206        AttrName::from("name"),
207        serde_json::Value::String(name.to_string()),
208    );
209    if let Some(vis) = get_visibility(node, src) {
210        attrs.insert(AttrName::from("visibility"), serde_json::Value::String(vis));
211    }
212    CodeElement {
213        tag: TagName::from(tag),
214        name,
215        attrs,
216        children: vec![],
217        source: make_source_location(node, file),
218    }
219}
220
221fn extract_function(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
222    let mut el = extract_named_element("function", node, src, file);
223    let mods = extract_modifiers(node, src);
224    if mods.is_async {
225        el.attrs
226            .insert(AttrName::from("async"), serde_json::Value::Bool(true));
227    }
228    if mods.is_unsafe {
229        el.attrs
230            .insert(AttrName::from("unsafe"), serde_json::Value::Bool(true));
231    }
232    if mods.is_const {
233        el.attrs
234            .insert(AttrName::from("const"), serde_json::Value::Bool(true));
235    }
236    el
237}
238
239fn extract_trait(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
240    let mut el = extract_named_element("trait", node, src, file);
241    if extract_modifiers(node, src).is_unsafe {
242        el.attrs
243            .insert(AttrName::from("unsafe"), serde_json::Value::Bool(true));
244    }
245    el
246}
247
248fn extract_impl(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
249    let type_name = node
250        .child_by_field_name("type")
251        .map(|n| node_text(&n, src).to_string())
252        .unwrap_or_default();
253
254    let trait_name = node
255        .child_by_field_name("trait")
256        .map(|n| node_text(&n, src).to_string());
257
258    let mut attrs = FxHashMap::default();
259    attrs.insert(
260        AttrName::from("type"),
261        serde_json::Value::String(type_name.clone()),
262    );
263    if let Some(ref t) = trait_name {
264        attrs.insert(
265            AttrName::from("trait"),
266            serde_json::Value::String(t.clone()),
267        );
268    }
269
270    // Extract method children from the declaration_list (body)
271    let mut children = Vec::new();
272    if let Some(body) = node.child_by_field_name("body") {
273        let mut cursor = body.walk();
274        for child in body.named_children(&mut cursor) {
275            if child.kind() == "function_item" {
276                children.push(extract_function(&child, src, file));
277            }
278        }
279    }
280
281    let name = if let Some(ref t) = trait_name {
282        format!("{t} for {type_name}")
283    } else {
284        type_name
285    };
286
287    CodeElement {
288        tag: TagName::from("impl"),
289        name: CodeElementName::from(name),
290        attrs,
291        children,
292        source: make_source_location(node, file),
293    }
294}
295
296fn extract_static(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
297    let mut el = extract_named_element("static", node, src, file);
298    let mut cursor = node.walk();
299    for child in node.named_children(&mut cursor) {
300        if child.kind() == "mutable_specifier" {
301            el.attrs
302                .insert(AttrName::from("mutable"), serde_json::Value::Bool(true));
303            break;
304        }
305    }
306    el
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    fn parse_snippet(source: &str) -> CodeElement {
314        parse_rust_source(source, Path::new("test.rs")).unwrap()
315    }
316
317    #[test]
318    fn parses_simple_named_elements() {
319        // Arrange
320        let struct_root = parse_snippet("struct Bar { x: i32 }");
321        let pub_struct_root = parse_snippet("pub struct Baz;");
322        let enum_root = parse_snippet("pub enum Color { Red, Green, Blue }");
323        let const_root = parse_snippet("pub const MAX: usize = 100;");
324        let type_root = parse_snippet("pub type Result<T> = std::result::Result<T, Error>;");
325        let macro_root = parse_snippet("macro_rules! my_macro { () => {} }");
326        let module_root = parse_snippet("pub mod inner {}");
327
328        // Act
329        let s = &struct_root.children[0];
330        let ps = &pub_struct_root.children[0];
331        let e = &enum_root.children[0];
332        let c = &const_root.children[0];
333        let t = &type_root.children[0];
334        let m = &macro_root.children[0];
335        let md = &module_root.children[0];
336
337        // Assert
338        assert_eq!(struct_root.children.len(), 1);
339        assert_eq!(s.tag, "struct");
340        assert_eq!(s.name, "Bar");
341
342        assert_eq!(ps.tag, "struct");
343        assert_eq!(ps.name, "Baz");
344        assert_eq!(
345            ps.attrs.get("visibility"),
346            Some(&serde_json::Value::String("pub".to_string()))
347        );
348
349        assert_eq!(e.tag, "enum");
350        assert_eq!(e.name, "Color");
351        assert_eq!(
352            e.attrs.get("visibility"),
353            Some(&serde_json::Value::String("pub".to_string()))
354        );
355
356        assert_eq!(c.tag, "const");
357        assert_eq!(c.name, "MAX");
358        assert_eq!(
359            c.attrs.get("visibility"),
360            Some(&serde_json::Value::String("pub".to_string()))
361        );
362
363        assert_eq!(t.tag, "type");
364        assert_eq!(t.name, "Result");
365
366        assert_eq!(m.tag, "macro");
367        assert_eq!(m.name, "my_macro");
368
369        assert_eq!(md.tag, "module");
370        assert_eq!(md.name, "inner");
371        assert_eq!(
372            md.attrs.get("visibility"),
373            Some(&serde_json::Value::String("pub".to_string()))
374        );
375    }
376
377    #[test]
378    fn parses_function_variants() {
379        // Arrange
380        let pub_async_root = parse_snippet("pub async fn foo() {}");
381        let unsafe_root = parse_snippet("unsafe fn danger() {}");
382
383        // Act
384        let func = &pub_async_root.children[0];
385        let uf = &unsafe_root.children[0];
386
387        // Assert
388        assert_eq!(pub_async_root.tag, "module");
389        assert_eq!(pub_async_root.children.len(), 1);
390        assert_eq!(func.tag, "function");
391        assert_eq!(func.name, "foo");
392        assert_eq!(
393            func.attrs.get("async"),
394            Some(&serde_json::Value::Bool(true))
395        );
396        assert_eq!(
397            func.attrs.get("visibility"),
398            Some(&serde_json::Value::String("pub".to_string()))
399        );
400
401        assert_eq!(uf.tag, "function");
402        assert_eq!(uf.name, "danger");
403        assert_eq!(uf.attrs.get("unsafe"), Some(&serde_json::Value::Bool(true)));
404    }
405
406    #[test]
407    fn parses_compound_elements() {
408        // Arrange
409        let impl_root = parse_snippet("impl Foo { fn bar() {} fn baz(&self) {} }");
410        let trait_impl_root = parse_snippet("impl Display for Foo { fn fmt(&self) {} }");
411        let trait_root = parse_snippet("pub trait Resolver { fn resolve(&self); }");
412        let static_root = parse_snippet("static mut COUNTER: u32 = 0;");
413
414        // Act
415        let imp = &impl_root.children[0];
416        let ti = &trait_impl_root.children[0];
417        let tr = &trait_root.children[0];
418        let st = &static_root.children[0];
419
420        // Assert
421        assert_eq!(impl_root.children.len(), 1);
422        assert_eq!(imp.tag, "impl");
423        assert_eq!(
424            imp.attrs.get("type"),
425            Some(&serde_json::Value::String("Foo".to_string()))
426        );
427        assert_eq!(imp.children.len(), 2);
428        assert_eq!(imp.children[0].tag, "function");
429        assert_eq!(imp.children[0].name, "bar");
430        assert_eq!(imp.children[1].name, "baz");
431
432        assert_eq!(ti.tag, "impl");
433        assert_eq!(
434            ti.attrs.get("trait"),
435            Some(&serde_json::Value::String("Display".to_string()))
436        );
437        assert_eq!(
438            ti.attrs.get("type"),
439            Some(&serde_json::Value::String("Foo".to_string()))
440        );
441
442        assert_eq!(tr.tag, "trait");
443        assert_eq!(tr.name, "Resolver");
444
445        assert_eq!(st.tag, "static");
446        assert_eq!(st.name, "COUNTER");
447        assert_eq!(
448            st.attrs.get("mutable"),
449            Some(&serde_json::Value::Bool(true))
450        );
451    }
452}