Skip to main content

amql_engine/resolver/
go.rs

1//! Go source file resolver using tree-sitter.
2//!
3//! Parses `.go` files into `CodeElement` trees, extracting functions, methods,
4//! structs, interfaces, type aliases, constants, and variables.
5
6use super::{CodeElement, SourceLocation};
7use crate::error::AqlError;
8use crate::types::{AttrName, CodeElementName, RelativePath, TagName};
9use rustc_hash::FxHashMap;
10use std::cell::RefCell;
11use std::path::Path;
12
13/// Go source file resolver using tree-sitter.
14pub struct GoResolver;
15
16impl super::CodeResolver for GoResolver {
17    fn resolve(&self, file_path: &Path) -> Result<CodeElement, AqlError> {
18        let source =
19            std::fs::read_to_string(file_path).map_err(|e| format!("Failed to read file: {e}"))?;
20        let root = parse_go_source(&source, file_path)?;
21        Ok(root)
22    }
23
24    fn extensions(&self) -> &[&str] {
25        &[".go"]
26    }
27
28    fn code_tags(&self) -> &[&str] {
29        &[
30            "function",
31            "method",
32            "struct",
33            "interface",
34            "type",
35            "const",
36            "var",
37            "module",
38        ]
39    }
40}
41
42// Thread-local cached tree-sitter parser to avoid re-creating on each call.
43thread_local! {
44    static GO_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
45}
46
47fn with_go_parser<F, R>(f: F) -> Result<R, String>
48where
49    F: FnOnce(&mut tree_sitter::Parser) -> Result<R, String>,
50{
51    GO_PARSER.with(|cell| {
52        let mut opt = cell.borrow_mut();
53        let parser = opt.get_or_insert_with(|| {
54            let mut p = tree_sitter::Parser::new();
55            p.set_language(&tree_sitter_go::LANGUAGE.into())
56                .expect("Failed to set Go language for tree-sitter");
57            p
58        });
59        f(parser)
60    })
61}
62
63/// Parse a Go source string into a CodeElement tree.
64fn parse_go_source(source: &str, file_path: &Path) -> Result<CodeElement, String> {
65    let tree = with_go_parser(|parser| {
66        parser
67            .parse(source, None)
68            .ok_or_else(|| "Failed to parse source".to_string())
69    })?;
70
71    let root_node = tree.root_node();
72    let src = source.as_bytes();
73    let file_str = file_path.to_string_lossy().to_string();
74
75    let mut children = Vec::new();
76    let mut cursor = root_node.walk();
77    for child in root_node.named_children(&mut cursor) {
78        extract_elements(&child, src, &file_str, &mut children);
79    }
80
81    let filename = file_path
82        .file_name()
83        .map(|f| f.to_string_lossy().to_string())
84        .unwrap_or_else(|| file_str.clone());
85
86    Ok(CodeElement {
87        tag: TagName::from("module"),
88        name: CodeElementName::from(filename),
89        attrs: FxHashMap::default(),
90        children,
91        source: SourceLocation {
92            file: RelativePath::from(file_str),
93            line: 1,
94            column: 0,
95            end_line: Some(root_node.end_position().row + 1),
96            end_column: Some(root_node.end_position().column),
97            start_byte: root_node.start_byte(),
98            end_byte: root_node.end_byte(),
99        },
100    })
101}
102
103/// Extract CodeElements from a tree-sitter node and its children.
104fn extract_elements(
105    node: &tree_sitter::Node,
106    src: &[u8],
107    file: &str,
108    result: &mut Vec<CodeElement>,
109) {
110    match node.kind() {
111        "function_declaration" => {
112            result.push(extract_function(node, src, file));
113        }
114        "method_declaration" => {
115            result.push(extract_method(node, src, file));
116        }
117        "type_declaration" => {
118            // type_declaration wraps one or more type_spec children
119            let mut cursor = node.walk();
120            for child in node.named_children(&mut cursor) {
121                if child.kind() == "type_spec" {
122                    if let Some(el) = extract_type_spec(&child, src, file) {
123                        result.push(el);
124                    }
125                }
126            }
127        }
128        "const_declaration" => {
129            let mut cursor = node.walk();
130            for child in node.named_children(&mut cursor) {
131                if child.kind() == "const_spec" {
132                    result.push(extract_const_spec(&child, src, file));
133                }
134            }
135        }
136        "var_declaration" => {
137            let mut cursor = node.walk();
138            for child in node.named_children(&mut cursor) {
139                if child.kind() == "var_spec" {
140                    result.push(extract_var_spec(&child, src, file));
141                }
142            }
143        }
144        _ => {}
145    }
146}
147
148fn node_text<'a>(node: &tree_sitter::Node, src: &'a [u8]) -> &'a str {
149    node.utf8_text(src).unwrap_or("")
150}
151
152fn get_name(node: &tree_sitter::Node, src: &[u8]) -> CodeElementName {
153    CodeElementName::from(
154        node.child_by_field_name("name")
155            .map(|n| node_text(&n, src).to_string())
156            .unwrap_or_default(),
157    )
158}
159
160fn make_source_location(node: &tree_sitter::Node, file: &str) -> SourceLocation {
161    let start = node.start_position();
162    let end = node.end_position();
163    SourceLocation {
164        file: RelativePath::from(file),
165        line: start.row + 1,
166        column: start.column,
167        end_line: Some(end.row + 1),
168        end_column: Some(end.column),
169        start_byte: node.start_byte(),
170        end_byte: node.end_byte(),
171    }
172}
173
174/// Extract a function declaration.
175fn extract_function(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
176    let name = get_name(node, src);
177    let mut attrs = FxHashMap::default();
178    attrs.insert(
179        AttrName::from("name"),
180        serde_json::Value::String(name.to_string()),
181    );
182
183    // Check if exported (starts with uppercase)
184    let name_str = name.to_string();
185    if name_str.starts_with(|c: char| c.is_uppercase()) {
186        attrs.insert(AttrName::from("export"), serde_json::Value::Bool(true));
187    }
188
189    // Extract return type
190    if let Some(result_node) = node.child_by_field_name("result") {
191        attrs.insert(
192            AttrName::from("returnType"),
193            serde_json::Value::String(node_text(&result_node, src).to_string()),
194        );
195    }
196
197    CodeElement {
198        tag: TagName::from("function"),
199        name,
200        attrs,
201        children: vec![],
202        source: make_source_location(node, file),
203    }
204}
205
206/// Extract a method declaration (function with receiver).
207fn extract_method(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
208    let name = get_name(node, src);
209    let mut attrs = FxHashMap::default();
210    attrs.insert(
211        AttrName::from("name"),
212        serde_json::Value::String(name.to_string()),
213    );
214
215    // Extract receiver type
216    if let Some(receiver) = node.child_by_field_name("receiver") {
217        let receiver_text = node_text(&receiver, src);
218        // Strip parens and pointer: (s *Server) -> Server
219        let cleaned = receiver_text
220            .trim_matches(|c: char| c == '(' || c == ')')
221            .trim();
222        let type_name = cleaned
223            .split_whitespace()
224            .last()
225            .unwrap_or(cleaned)
226            .trim_start_matches('*');
227        attrs.insert(
228            AttrName::from("receiver"),
229            serde_json::Value::String(type_name.to_string()),
230        );
231
232        // Check for pointer receiver
233        if cleaned.contains('*') {
234            attrs.insert(AttrName::from("pointer"), serde_json::Value::Bool(true));
235        }
236    }
237
238    // Check if exported
239    let name_str = name.to_string();
240    if name_str.starts_with(|c: char| c.is_uppercase()) {
241        attrs.insert(AttrName::from("export"), serde_json::Value::Bool(true));
242    }
243
244    // Extract return type
245    if let Some(result_node) = node.child_by_field_name("result") {
246        attrs.insert(
247            AttrName::from("returnType"),
248            serde_json::Value::String(node_text(&result_node, src).to_string()),
249        );
250    }
251
252    CodeElement {
253        tag: TagName::from("method"),
254        name,
255        attrs,
256        children: vec![],
257        source: make_source_location(node, file),
258    }
259}
260
261/// Extract a type_spec — could be struct, interface, or type alias.
262fn extract_type_spec(node: &tree_sitter::Node, src: &[u8], file: &str) -> Option<CodeElement> {
263    let name = get_name(node, src);
264    let type_node = node.child_by_field_name("type")?;
265    let mut attrs = FxHashMap::default();
266    attrs.insert(
267        AttrName::from("name"),
268        serde_json::Value::String(name.to_string()),
269    );
270
271    // Check if exported
272    let name_str = name.to_string();
273    if name_str.starts_with(|c: char| c.is_uppercase()) {
274        attrs.insert(AttrName::from("export"), serde_json::Value::Bool(true));
275    }
276
277    let (tag, children) = match type_node.kind() {
278        "struct_type" => {
279            let children = extract_struct_fields(&type_node, src, file);
280            ("struct", children)
281        }
282        "interface_type" => {
283            let children = extract_interface_methods(&type_node, src, file);
284            ("interface", children)
285        }
286        _ => ("type", vec![]),
287    };
288
289    Some(CodeElement {
290        tag: TagName::from(tag),
291        name,
292        attrs,
293        children,
294        source: make_source_location(node, file),
295    })
296}
297
298/// Extract struct field declarations as children.
299fn extract_struct_fields(node: &tree_sitter::Node, src: &[u8], file: &str) -> Vec<CodeElement> {
300    let mut fields = Vec::new();
301    // field_declaration_list is a direct child of struct_type
302    let mut cursor = node.walk();
303    for child in node.named_children(&mut cursor) {
304        if child.kind() == "field_declaration_list" {
305            let mut inner_cursor = child.walk();
306            for field in child.named_children(&mut inner_cursor) {
307                if field.kind() == "field_declaration" {
308                    // Field name is the first identifier child
309                    let field_name = field
310                        .child_by_field_name("name")
311                        .map(|n| node_text(&n, src).to_string())
312                        .unwrap_or_default();
313                    if field_name.is_empty() {
314                        continue; // embedded type
315                    }
316                    let mut attrs = FxHashMap::default();
317                    attrs.insert(
318                        AttrName::from("name"),
319                        serde_json::Value::String(field_name.clone()),
320                    );
321                    if let Some(type_node) = field.child_by_field_name("type") {
322                        attrs.insert(
323                            AttrName::from("fieldType"),
324                            serde_json::Value::String(node_text(&type_node, src).to_string()),
325                        );
326                    }
327                    fields.push(CodeElement {
328                        tag: TagName::from("field"),
329                        name: CodeElementName::from(field_name),
330                        attrs,
331                        children: vec![],
332                        source: make_source_location(&field, file),
333                    });
334                }
335            }
336        }
337    }
338    fields
339}
340
341/// Extract interface method signatures as children.
342fn extract_interface_methods(node: &tree_sitter::Node, src: &[u8], file: &str) -> Vec<CodeElement> {
343    let mut methods = Vec::new();
344    let mut cursor = node.walk();
345    for child in node.named_children(&mut cursor) {
346        if child.kind() == "method_elem" {
347            let method_name = child
348                .child_by_field_name("name")
349                .map(|n| node_text(&n, src).to_string())
350                .unwrap_or_default();
351            if method_name.is_empty() {
352                continue;
353            }
354            let mut attrs = FxHashMap::default();
355            attrs.insert(
356                AttrName::from("name"),
357                serde_json::Value::String(method_name.clone()),
358            );
359            methods.push(CodeElement {
360                tag: TagName::from("method"),
361                name: CodeElementName::from(method_name),
362                attrs,
363                children: vec![],
364                source: make_source_location(&child, file),
365            });
366        }
367    }
368    methods
369}
370
371/// Extract a const_spec.
372fn extract_const_spec(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
373    let name = get_name(node, src);
374    let mut attrs = FxHashMap::default();
375    attrs.insert(
376        AttrName::from("name"),
377        serde_json::Value::String(name.to_string()),
378    );
379    let name_str = name.to_string();
380    if name_str.starts_with(|c: char| c.is_uppercase()) {
381        attrs.insert(AttrName::from("export"), serde_json::Value::Bool(true));
382    }
383    CodeElement {
384        tag: TagName::from("const"),
385        name,
386        attrs,
387        children: vec![],
388        source: make_source_location(node, file),
389    }
390}
391
392/// Extract a var_spec.
393fn extract_var_spec(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
394    let name = get_name(node, src);
395    let mut attrs = FxHashMap::default();
396    attrs.insert(
397        AttrName::from("name"),
398        serde_json::Value::String(name.to_string()),
399    );
400    let name_str = name.to_string();
401    if name_str.starts_with(|c: char| c.is_uppercase()) {
402        attrs.insert(AttrName::from("export"), serde_json::Value::Bool(true));
403    }
404    CodeElement {
405        tag: TagName::from("var"),
406        name,
407        attrs,
408        children: vec![],
409        source: make_source_location(node, file),
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416
417    fn parse_snippet(source: &str) -> CodeElement {
418        parse_go_source(source, Path::new("test.go")).unwrap()
419    }
420
421    #[test]
422    fn parses_functions() {
423        // Arrange
424        let root =
425            parse_snippet("package main\n\nfunc hello() {}\nfunc World() string { return \"\" }");
426
427        // Act
428        let hello = &root.children[0];
429        let world = &root.children[1];
430
431        // Assert
432        assert_eq!(hello.tag, "function", "hello tag");
433        assert_eq!(hello.name, "hello", "hello name");
434        assert_eq!(hello.attrs.get("export"), None, "hello not exported");
435
436        assert_eq!(world.tag, "function", "World tag");
437        assert_eq!(world.name, "World", "World name");
438        assert_eq!(
439            world.attrs.get("export"),
440            Some(&serde_json::Value::Bool(true)),
441            "World is exported"
442        );
443    }
444
445    #[test]
446    fn parses_methods() {
447        // Arrange
448        let root =
449            parse_snippet("package main\n\ntype Server struct{}\n\nfunc (s *Server) Handle() {}");
450
451        // Act — type_declaration + method_declaration
452        let method = root
453            .children
454            .iter()
455            .find(|c| c.tag.as_ref() == "method")
456            .unwrap();
457
458        // Assert
459        assert_eq!(method.name, "Handle", "method name");
460        assert_eq!(
461            method.attrs.get("receiver"),
462            Some(&serde_json::Value::String("Server".to_string())),
463            "receiver type"
464        );
465        assert_eq!(
466            method.attrs.get("pointer"),
467            Some(&serde_json::Value::Bool(true)),
468            "pointer receiver"
469        );
470    }
471
472    #[test]
473    fn parses_structs_and_interfaces() {
474        // Arrange
475        let root = parse_snippet(
476            "package main\n\ntype Config struct {\n\tHost string\n\tPort int\n}\n\ntype Handler interface {\n\tServeHTTP()\n}",
477        );
478
479        // Act
480        let config = &root.children[0];
481        let handler = &root.children[1];
482
483        // Assert
484        assert_eq!(config.tag, "struct", "Config tag");
485        assert_eq!(config.name, "Config", "Config name");
486        assert_eq!(config.children.len(), 2, "Config has 2 fields");
487        assert_eq!(config.children[0].name, "Host", "first field");
488
489        assert_eq!(handler.tag, "interface", "Handler tag");
490        assert_eq!(handler.name, "Handler", "Handler name");
491        assert_eq!(handler.children.len(), 1, "Handler has 1 method");
492        assert_eq!(handler.children[0].name, "ServeHTTP", "interface method");
493    }
494
495    #[test]
496    fn parses_consts_and_vars() {
497        // Arrange
498        let root = parse_snippet("package main\n\nconst MaxRetries = 3\n\nvar defaultTimeout = 30");
499
500        // Act
501        let c = &root.children[0];
502        let v = &root.children[1];
503
504        // Assert
505        assert_eq!(c.tag, "const", "const tag");
506        assert_eq!(c.name, "MaxRetries", "const name");
507        assert_eq!(
508            c.attrs.get("export"),
509            Some(&serde_json::Value::Bool(true)),
510            "const exported"
511        );
512
513        assert_eq!(v.tag, "var", "var tag");
514        assert_eq!(v.name, "defaultTimeout", "var name");
515        assert_eq!(v.attrs.get("export"), None, "var not exported");
516    }
517}