Skip to main content

amql_engine/extractor/
rust_structure.rs

1//! General-purpose Rust structure extractor.
2//!
3//! Extracts functions, structs, enums, traits, impl blocks, modules,
4//! constants, statics, type aliases, and macro definitions as annotations.
5//! Mirrors the Rust resolver logic but produces `Annotation` metadata
6//! instead of `CodeElement` trees, discarding implementation bodies.
7
8use super::BuiltinExtractor;
9use crate::store::Annotation;
10use crate::types::{AttrName, Binding, RelativePath, TagName};
11use rustc_hash::FxHashMap;
12use serde_json::Value as JsonValue;
13use std::cell::RefCell;
14
15/// General-purpose Rust structure extractor.
16pub struct RustStructureExtractor;
17
18impl BuiltinExtractor for RustStructureExtractor {
19    fn name(&self) -> &str {
20        "rust-structure"
21    }
22
23    fn extensions(&self) -> &[&str] {
24        &[".rs"]
25    }
26
27    fn extract(&self, source: &str, file: &RelativePath) -> Vec<Annotation> {
28        let tree = match parse_rust(source) {
29            Some(t) => t,
30            None => return vec![],
31        };
32        let mut annotations = Vec::new();
33        let root = tree.root_node();
34        let src = source.as_bytes();
35        let mut cursor = root.walk();
36        for child in root.named_children(&mut cursor) {
37            if let Some(ann) = extract_element(&child, src, file) {
38                annotations.push(ann);
39            }
40        }
41        annotations
42    }
43}
44
45// ---------------------------------------------------------------------------
46// Tree-sitter helpers
47// ---------------------------------------------------------------------------
48
49fn node_text<'a>(node: &tree_sitter::Node, src: &'a [u8]) -> &'a str {
50    node.utf8_text(src).unwrap_or("")
51}
52
53fn get_name(node: &tree_sitter::Node, src: &[u8]) -> String {
54    node.child_by_field_name("name")
55        .map(|n| node_text(&n, src).to_string())
56        .unwrap_or_default()
57}
58
59fn get_visibility(node: &tree_sitter::Node, src: &[u8]) -> Option<String> {
60    let mut cursor = node.walk();
61    for child in node.named_children(&mut cursor) {
62        if child.kind() == "visibility_modifier" {
63            return Some(node_text(&child, src).to_string());
64        }
65    }
66    None
67}
68
69/// Extracted modifier flags from a single pass over node children.
70struct Modifiers {
71    is_async: bool,
72    is_unsafe: bool,
73    is_const: bool,
74}
75
76/// Walk children once and extract all keyword modifiers in a single pass.
77fn extract_modifiers(node: &tree_sitter::Node, src: &[u8]) -> Modifiers {
78    let mut mods = Modifiers {
79        is_async: false,
80        is_unsafe: false,
81        is_const: false,
82    };
83
84    let mut cursor = node.walk();
85    for child in node.children(&mut cursor) {
86        let text = node_text(&child, src);
87        match text {
88            "async" => mods.is_async = true,
89            "unsafe" => mods.is_unsafe = true,
90            "const" => mods.is_const = true,
91            _ => {}
92        }
93        if child.is_named() {
94            let mut inner_cursor = child.walk();
95            for inner in child.children(&mut inner_cursor) {
96                let inner_text = node_text(&inner, src);
97                match inner_text {
98                    "async" => mods.is_async = true,
99                    "unsafe" => mods.is_unsafe = true,
100                    "const" => mods.is_const = true,
101                    _ => {}
102                }
103            }
104        }
105    }
106
107    mods
108}
109
110fn make_annotation(
111    tag: &str,
112    binding: String,
113    attrs: FxHashMap<AttrName, JsonValue>,
114    file: &RelativePath,
115    children: Vec<Annotation>,
116) -> Annotation {
117    Annotation {
118        tag: TagName::from(tag),
119        attrs,
120        binding: Binding::from(binding),
121        file: file.clone(),
122        children,
123    }
124}
125
126// ---------------------------------------------------------------------------
127// Element extraction
128// ---------------------------------------------------------------------------
129
130/// Extract an annotation from a tree-sitter node, if it's a recognized kind.
131fn extract_element(
132    node: &tree_sitter::Node,
133    src: &[u8],
134    file: &RelativePath,
135) -> Option<Annotation> {
136    match node.kind() {
137        "function_item" => extract_function(node, src, file),
138        "struct_item" => extract_named_element("struct", node, src, file),
139        "enum_item" => extract_enum(node, src, file),
140        "trait_item" => extract_trait(node, src, file),
141        "impl_item" => extract_impl(node, src, file),
142        "mod_item" => extract_named_element("module", node, src, file),
143        "const_item" => extract_named_element("const", node, src, file),
144        "static_item" => extract_static(node, src, file),
145        "type_item" => extract_named_element("type", node, src, file),
146        "macro_definition" => extract_named_element("macro", node, src, file),
147        _ => None,
148    }
149}
150
151/// Build an annotation with name and visibility extracted from the node.
152fn extract_named_element(
153    tag: &str,
154    node: &tree_sitter::Node,
155    src: &[u8],
156    file: &RelativePath,
157) -> Option<Annotation> {
158    let name = get_name(node, src);
159    if name.is_empty() {
160        return None;
161    }
162    let mut attrs = FxHashMap::default();
163    attrs.insert(AttrName::from("name"), JsonValue::String(name.clone()));
164    if let Some(vis) = get_visibility(node, src) {
165        attrs.insert(AttrName::from("visibility"), JsonValue::String(vis));
166    }
167    Some(make_annotation(tag, name, attrs, file, vec![]))
168}
169
170fn extract_function(
171    node: &tree_sitter::Node,
172    src: &[u8],
173    file: &RelativePath,
174) -> Option<Annotation> {
175    let name = get_name(node, src);
176    if name.is_empty() {
177        return None;
178    }
179    let mut attrs = FxHashMap::default();
180    attrs.insert(AttrName::from("name"), JsonValue::String(name.clone()));
181    if let Some(vis) = get_visibility(node, src) {
182        attrs.insert(AttrName::from("visibility"), JsonValue::String(vis));
183    }
184    let mods = extract_modifiers(node, src);
185    if mods.is_async {
186        attrs.insert(AttrName::from("async"), JsonValue::Bool(true));
187    }
188    if mods.is_unsafe {
189        attrs.insert(AttrName::from("unsafe"), JsonValue::Bool(true));
190    }
191    if mods.is_const {
192        attrs.insert(AttrName::from("const"), JsonValue::Bool(true));
193    }
194    Some(make_annotation("function", name, attrs, file, vec![]))
195}
196
197fn extract_trait(node: &tree_sitter::Node, src: &[u8], file: &RelativePath) -> Option<Annotation> {
198    let name = get_name(node, src);
199    if name.is_empty() {
200        return None;
201    }
202    let mut attrs = FxHashMap::default();
203    attrs.insert(AttrName::from("name"), JsonValue::String(name.clone()));
204    if let Some(vis) = get_visibility(node, src) {
205        attrs.insert(AttrName::from("visibility"), JsonValue::String(vis));
206    }
207    if extract_modifiers(node, src).is_unsafe {
208        attrs.insert(AttrName::from("unsafe"), JsonValue::Bool(true));
209    }
210
211    // Extract method children from trait body
212    let mut children = Vec::new();
213    if let Some(body) = node.child_by_field_name("body") {
214        let mut cursor = body.walk();
215        for child in body.named_children(&mut cursor) {
216            if child.kind() == "function_item" || child.kind() == "function_signature_item" {
217                let method_name = get_name(&child, src);
218                if !method_name.is_empty() {
219                    let mut method_attrs = FxHashMap::default();
220                    method_attrs.insert(
221                        AttrName::from("name"),
222                        JsonValue::String(method_name.clone()),
223                    );
224                    children.push(make_annotation(
225                        "method",
226                        method_name,
227                        method_attrs,
228                        file,
229                        vec![],
230                    ));
231                }
232            }
233        }
234    }
235
236    Some(make_annotation("trait", name, attrs, file, children))
237}
238
239fn extract_impl(node: &tree_sitter::Node, src: &[u8], file: &RelativePath) -> Option<Annotation> {
240    let type_name = node
241        .child_by_field_name("type")
242        .map(|n| node_text(&n, src).to_string())
243        .unwrap_or_default();
244
245    if type_name.is_empty() {
246        return None;
247    }
248
249    let trait_name = node
250        .child_by_field_name("trait")
251        .map(|n| node_text(&n, src).to_string());
252
253    let mut attrs = FxHashMap::default();
254    attrs.insert(AttrName::from("type"), JsonValue::String(type_name.clone()));
255    if let Some(ref t) = trait_name {
256        attrs.insert(AttrName::from("trait"), JsonValue::String(t.clone()));
257    }
258
259    // Extract method children from the declaration_list (body)
260    let mut children = Vec::new();
261    if let Some(body) = node.child_by_field_name("body") {
262        let mut cursor = body.walk();
263        for child in body.named_children(&mut cursor) {
264            if child.kind() == "function_item" {
265                if let Some(ann) = extract_function(&child, src, file) {
266                    // Re-tag as method for impl context
267                    let method = Annotation {
268                        tag: TagName::from("method"),
269                        ..ann
270                    };
271                    children.push(method);
272                }
273            }
274        }
275    }
276
277    let binding = if let Some(ref t) = trait_name {
278        format!("{t} for {type_name}")
279    } else {
280        type_name
281    };
282
283    Some(make_annotation("impl", binding, attrs, file, children))
284}
285
286fn extract_enum(node: &tree_sitter::Node, src: &[u8], file: &RelativePath) -> Option<Annotation> {
287    let name = get_name(node, src);
288    if name.is_empty() {
289        return None;
290    }
291    let mut attrs = FxHashMap::default();
292    attrs.insert(AttrName::from("name"), JsonValue::String(name.clone()));
293    if let Some(vis) = get_visibility(node, src) {
294        attrs.insert(AttrName::from("visibility"), JsonValue::String(vis));
295    }
296
297    // Extract variants as children
298    let mut children = Vec::new();
299    if let Some(body) = node.child_by_field_name("body") {
300        let mut cursor = body.walk();
301        for child in body.named_children(&mut cursor) {
302            if child.kind() == "enum_variant" {
303                let variant_name = get_name(&child, src);
304                if !variant_name.is_empty() {
305                    let mut variant_attrs = FxHashMap::default();
306                    variant_attrs.insert(
307                        AttrName::from("name"),
308                        JsonValue::String(variant_name.clone()),
309                    );
310                    children.push(make_annotation(
311                        "variant",
312                        variant_name,
313                        variant_attrs,
314                        file,
315                        vec![],
316                    ));
317                }
318            }
319        }
320    }
321
322    Some(make_annotation("enum", name, attrs, file, children))
323}
324
325fn extract_static(node: &tree_sitter::Node, src: &[u8], file: &RelativePath) -> Option<Annotation> {
326    let name = get_name(node, src);
327    if name.is_empty() {
328        return None;
329    }
330    let mut attrs = FxHashMap::default();
331    attrs.insert(AttrName::from("name"), JsonValue::String(name.clone()));
332    if let Some(vis) = get_visibility(node, src) {
333        attrs.insert(AttrName::from("visibility"), JsonValue::String(vis));
334    }
335    let mut cursor = node.walk();
336    for child in node.named_children(&mut cursor) {
337        if child.kind() == "mutable_specifier" {
338            attrs.insert(AttrName::from("mutable"), JsonValue::Bool(true));
339            break;
340        }
341    }
342    Some(make_annotation("static", name, attrs, file, vec![]))
343}
344
345// ---------------------------------------------------------------------------
346// Parser cache (thread-local)
347// ---------------------------------------------------------------------------
348
349thread_local! {
350    static RUST_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
351}
352
353fn parse_rust(source: &str) -> Option<tree_sitter::Tree> {
354    RUST_PARSER.with(|cell| {
355        let mut opt = cell.borrow_mut();
356        let parser = opt.get_or_insert_with(|| {
357            let mut p = tree_sitter::Parser::new();
358            p.set_language(&tree_sitter_rust::LANGUAGE.into())
359                .expect("Failed to set Rust language");
360            p
361        });
362        parser.parse(source, None)
363    })
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    fn run(source: &str) -> Vec<Annotation> {
371        let file = RelativePath::from("src/lib.rs");
372        RustStructureExtractor.extract(source, &file)
373    }
374
375    #[test]
376    fn extracts_functions() {
377        // Arrange
378        let source = "pub async fn foo() {}\nunsafe fn danger() {}";
379
380        // Act
381        let anns = run(source);
382
383        // Assert
384        assert_eq!(anns.len(), 2, "should find 2 functions");
385        assert_eq!(anns[0].tag.as_ref(), "function", "should be function");
386        assert_eq!(anns[0].binding.as_ref(), "foo", "function name");
387        assert_eq!(
388            anns[0].attrs.get(&AttrName::from("async")),
389            Some(&JsonValue::Bool(true)),
390            "foo should be async"
391        );
392        assert_eq!(
393            anns[0].attrs.get(&AttrName::from("visibility")),
394            Some(&JsonValue::String("pub".to_string())),
395            "foo should be pub"
396        );
397        assert_eq!(
398            anns[1].attrs.get(&AttrName::from("unsafe")),
399            Some(&JsonValue::Bool(true)),
400            "danger should be unsafe"
401        );
402    }
403
404    #[test]
405    fn extracts_structs_and_enums() {
406        // Arrange
407        let source = "pub struct Foo { x: i32 }\npub enum Color { Red, Green, Blue }";
408
409        // Act
410        let anns = run(source);
411
412        // Assert
413        assert_eq!(anns.len(), 2, "should find struct + enum");
414        assert_eq!(anns[0].tag.as_ref(), "struct", "should be struct");
415        assert_eq!(anns[0].binding.as_ref(), "Foo", "struct name");
416        assert_eq!(anns[1].tag.as_ref(), "enum", "should be enum");
417        assert_eq!(anns[1].binding.as_ref(), "Color", "enum name");
418        assert_eq!(anns[1].children.len(), 3, "should have 3 variants");
419    }
420
421    #[test]
422    fn extracts_impl_with_methods() {
423        // Arrange
424        let source = "impl Foo { fn bar() {} fn baz(&self) {} }";
425
426        // Act
427        let anns = run(source);
428
429        // Assert
430        assert_eq!(anns.len(), 1, "should find 1 impl");
431        assert_eq!(anns[0].tag.as_ref(), "impl", "should be impl");
432        assert_eq!(anns[0].binding.as_ref(), "Foo", "impl type");
433        assert_eq!(anns[0].children.len(), 2, "should have 2 methods");
434        assert_eq!(
435            anns[0].children[0].tag.as_ref(),
436            "method",
437            "child should be method"
438        );
439        assert_eq!(anns[0].children[0].binding.as_ref(), "bar", "method name");
440    }
441
442    #[test]
443    fn extracts_trait_impl() {
444        // Arrange
445        let source = "impl Display for Foo { fn fmt(&self) {} }";
446
447        // Act
448        let anns = run(source);
449
450        // Assert
451        assert_eq!(anns.len(), 1, "should find 1 impl");
452        assert_eq!(
453            anns[0].attrs.get(&AttrName::from("trait")),
454            Some(&JsonValue::String("Display".to_string())),
455            "should have trait attr"
456        );
457        assert_eq!(anns[0].binding.as_ref(), "Display for Foo", "impl binding");
458    }
459
460    #[test]
461    fn extracts_traits() {
462        // Arrange
463        let source = "pub trait Resolver { fn resolve(&self); }";
464
465        // Act
466        let anns = run(source);
467
468        // Assert
469        assert_eq!(anns.len(), 1, "should find 1 trait");
470        assert_eq!(anns[0].tag.as_ref(), "trait", "should be trait");
471        assert_eq!(anns[0].binding.as_ref(), "Resolver", "trait name");
472    }
473
474    #[test]
475    fn extracts_statics() {
476        // Arrange
477        let source = "static mut COUNTER: u32 = 0;";
478
479        // Act
480        let anns = run(source);
481
482        // Assert
483        assert_eq!(anns.len(), 1, "should find 1 static");
484        assert_eq!(anns[0].tag.as_ref(), "static", "should be static");
485        assert_eq!(
486            anns[0].attrs.get(&AttrName::from("mutable")),
487            Some(&JsonValue::Bool(true)),
488            "should be mutable"
489        );
490    }
491
492    #[test]
493    fn extracts_other_items() {
494        // Arrange
495        let source = "pub const MAX: usize = 100;\npub type Result<T> = std::result::Result<T, Error>;\nmacro_rules! my_macro { () => {} }\npub mod inner {}";
496
497        // Act
498        let anns = run(source);
499
500        // Assert
501        assert_eq!(anns.len(), 4, "should find const + type + macro + mod");
502        assert_eq!(anns[0].tag.as_ref(), "const", "should be const");
503        assert_eq!(anns[0].binding.as_ref(), "MAX", "const name");
504        assert_eq!(anns[1].tag.as_ref(), "type", "should be type");
505        assert_eq!(anns[2].tag.as_ref(), "macro", "should be macro");
506        assert_eq!(anns[3].tag.as_ref(), "module", "should be module");
507    }
508}