Skip to main content

aptu_coder_core/languages/
go.rs

1// SPDX-FileCopyrightText: 2026 aptu-coder contributors
2// SPDX-License-Identifier: Apache-2.0
3use tree_sitter::Node;
4
5/// Tree-sitter query for extracting Go elements (functions, methods, and types).
6pub const ELEMENT_QUERY: &str = r"
7(function_declaration
8  name: (identifier) @func_name) @function
9(method_declaration
10  name: (field_identifier) @method_name) @function
11(type_spec
12  name: (type_identifier) @type_name
13  type: (struct_type)) @class
14(type_spec
15  name: (type_identifier) @type_name
16  type: (interface_type)) @class
17";
18
19/// Tree-sitter query for extracting function calls.
20pub const CALL_QUERY: &str = r"
21(call_expression
22  function: (identifier) @call)
23(call_expression
24  function: (selector_expression field: (field_identifier) @call))
25";
26
27/// Tree-sitter query for extracting type references.
28pub const REFERENCE_QUERY: &str = r"
29(type_identifier) @type_ref
30";
31
32/// Tree-sitter query for extracting Go imports.
33pub const IMPORT_QUERY: &str = r"
34(import_declaration) @import_path
35";
36
37/// Tree-sitter query for extracting definition and use sites.
38pub const DEFUSE_QUERY: &str = r"
39(short_var_declaration left: (expression_list (identifier) @write.short))
40(assignment_statement left: (expression_list (identifier) @write.assign))
41(var_declaration (var_spec (identifier) @write.var))
42(inc_statement (identifier) @writeread.inc)
43(dec_statement (identifier) @writeread.dec)
44(identifier) @read.usage
45";
46
47/// Extract function or method name from a Go function or method declaration.
48#[must_use]
49pub fn extract_function_name(node: &Node, source: &str, _lang: &str) -> Option<String> {
50    if node.kind() != "function_declaration" && node.kind() != "method_declaration" {
51        return None;
52    }
53    node.child_by_field_name("name").and_then(|n| {
54        let end = n.end_byte();
55        if end <= source.len() {
56            Some(source[n.start_byte()..end].to_string())
57        } else {
58            None
59        }
60    })
61}
62
63/// Find receiver type for a Go method declaration.
64/// Walks the method_declaration.receiver field to find the type.
65#[must_use]
66pub fn find_receiver_type(node: &Node, source: &str) -> Option<String> {
67    if node.kind() != "method_declaration" {
68        return None;
69    }
70
71    // Get the receiver field
72    let receiver = node.child_by_field_name("receiver")?;
73
74    // Iterate through receiver's children to find parameter_declaration
75    for i in 0..receiver.named_child_count() {
76        if let Some(param) = receiver.named_child(i as u32)
77            && param.kind() == "parameter_declaration"
78        {
79            // Get the type field from parameter_declaration
80            if let Some(type_node) = param.child_by_field_name("type") {
81                match type_node.kind() {
82                    "type_identifier" => {
83                        let end = type_node.end_byte();
84                        if end <= source.len() {
85                            return Some(source[type_node.start_byte()..end].to_string());
86                        }
87                    }
88                    "pointer_type" => {
89                        // pointer_type wraps the actual type_identifier
90                        if let Some(inner) = (0..type_node.named_child_count())
91                            .filter_map(|j| type_node.named_child(j as u32))
92                            .find(|n| n.kind() == "type_identifier")
93                        {
94                            let end = inner.end_byte();
95                            if end <= source.len() {
96                                return Some(source[inner.start_byte()..end].to_string());
97                            }
98                        }
99                    }
100                    _ => {}
101                }
102            }
103        }
104    }
105
106    None
107}
108
109/// Find method name for a receiver type.
110#[must_use]
111pub fn find_method_for_receiver(
112    node: &Node,
113    source: &str,
114    _depth: Option<usize>,
115) -> Option<String> {
116    if node.kind() != "method_declaration" && node.kind() != "function_declaration" {
117        return None;
118    }
119    node.child_by_field_name("name").and_then(|n| {
120        let start = n.start_byte();
121        let end = n.end_byte();
122        if end <= source.len() {
123            Some(source[start..end].to_string())
124        } else {
125            None
126        }
127    })
128}
129
130/// Extract inheritance information from a Go type node.
131#[must_use]
132pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
133    let mut inherits = Vec::new();
134
135    // Get the type field from type_spec
136    if let Some(type_field) = node.child_by_field_name("type") {
137        match type_field.kind() {
138            "struct_type" => {
139                // For struct embedding, walk children for field_declaration_list
140                for i in 0..type_field.named_child_count() {
141                    if let Some(field_list) = type_field.named_child(i as u32)
142                        && field_list.kind() == "field_declaration_list"
143                    {
144                        // Walk field_declaration_list for field_declaration without name
145                        for j in 0..field_list.named_child_count() {
146                            if let Some(field) = field_list.named_child(j as u32)
147                                && field.kind() == "field_declaration"
148                                && field.child_by_field_name("name").is_none()
149                            {
150                                // Embedded type has no name field
151                                if let Some(type_node) = field.child_by_field_name("type") {
152                                    let text =
153                                        &source[type_node.start_byte()..type_node.end_byte()];
154                                    inherits.push(text.to_string());
155                                }
156                            }
157                        }
158                    }
159                }
160            }
161            "interface_type" => {
162                // For interface embedding, walk children for type_elem
163                for i in 0..type_field.named_child_count() {
164                    if let Some(elem) = type_field.named_child(i as u32)
165                        && elem.kind() == "type_elem"
166                    {
167                        let text = &source[elem.start_byte()..elem.end_byte()];
168                        inherits.push(text.to_string());
169                    }
170                }
171            }
172            _ => {}
173        }
174    }
175
176    inherits
177}
178
179#[cfg(all(test, feature = "lang-go"))]
180mod tests {
181    use super::*;
182    use crate::DefUseKind;
183    use crate::parser::SemanticExtractor;
184    use tree_sitter::Parser;
185
186    fn parse_go(source: &str) -> tree_sitter::Tree {
187        let mut parser = Parser::new();
188        parser
189            .set_language(&tree_sitter_go::LANGUAGE.into())
190            .expect("failed to set Go language");
191        parser.parse(source, None).expect("failed to parse source")
192    }
193
194    #[test]
195    fn test_extract_inheritance_struct_no_embeds() {
196        // Arrange: struct with no embedded types
197        let source = "package p\ntype Foo struct { x int }";
198        let tree = parse_go(source);
199        let root = tree.root_node();
200        // find the type_spec node
201        let type_spec = (0..root.named_child_count())
202            .filter_map(|i| root.named_child(i as u32))
203            .find_map(|n| {
204                if n.kind() == "type_declaration" {
205                    (0..n.named_child_count())
206                        .filter_map(|j| n.named_child(j as u32))
207                        .find(|c| c.kind() == "type_spec")
208                } else {
209                    None
210                }
211            })
212            .expect("expected type_spec node");
213        // Act
214        let result = extract_inheritance(&type_spec, source);
215        // Assert
216        assert!(
217            result.is_empty(),
218            "expected no inherited types, got {:?}",
219            result
220        );
221    }
222
223    #[test]
224    fn test_find_method_for_receiver_wrong_kind() {
225        // Arrange: use a struct node (not a method or function declaration)
226        let source = "package p\ntype Bar struct {}";
227        let tree = parse_go(source);
228        let root = tree.root_node();
229        let node = root.named_child(0).expect("expected child");
230        // Act
231        let result = find_method_for_receiver(&node, source, None);
232        // Assert
233        assert_eq!(result, None);
234    }
235
236    #[test]
237    fn test_extract_function_name() {
238        // Arrange: free function declaration
239        let source = "package p\nfunc Foo() {}";
240        let tree = parse_go(source);
241        let root = tree.root_node();
242        // Find function_declaration node
243        let func_node = (0..root.named_child_count())
244            .filter_map(|i| root.named_child(i as u32))
245            .find(|n| n.kind() == "function_declaration")
246            .expect("expected function_declaration");
247        // Act
248        let result = extract_function_name(&func_node, source, "go");
249        // Assert
250        assert_eq!(result, Some("Foo".to_string()));
251    }
252
253    #[test]
254    fn test_extract_method_name() {
255        // Arrange: method declaration
256        let source = "package p\nfunc (r *Receiver) Bar() {}";
257        let tree = parse_go(source);
258        let root = tree.root_node();
259        // Find method_declaration node
260        let method_node = (0..root.named_child_count())
261            .filter_map(|i| root.named_child(i as u32))
262            .find(|n| n.kind() == "method_declaration")
263            .expect("expected method_declaration");
264        // Act
265        let result = extract_function_name(&method_node, source, "go");
266        // Assert
267        assert_eq!(result, Some("Bar".to_string()));
268    }
269
270    #[test]
271    fn test_extract_function_name_wrong_kind() {
272        // Arrange: use a struct node (not a function or method declaration)
273        let source = "package p\ntype Baz struct {}";
274        let tree = parse_go(source);
275        let root = tree.root_node();
276        let node = root.named_child(0).expect("expected child");
277        // Act
278        let result = extract_function_name(&node, source, "go");
279        // Assert
280        assert_eq!(result, None);
281    }
282
283    #[test]
284    fn test_find_receiver_type() {
285        // Arrange: method with value receiver
286        let source = "package p\nfunc (r Receiver) Foo() {}";
287        let tree = parse_go(source);
288        let root = tree.root_node();
289        // Find method_declaration node
290        let method_node = (0..root.named_child_count())
291            .filter_map(|i| root.named_child(i as u32))
292            .find(|n| n.kind() == "method_declaration")
293            .expect("expected method_declaration");
294        // Act
295        let result = find_receiver_type(&method_node, source);
296        // Assert
297        assert_eq!(result, Some("Receiver".to_string()));
298    }
299
300    #[test]
301    fn test_find_receiver_type_pointer() {
302        // Arrange: method with pointer receiver
303        let source = "package p\nfunc (r *Receiver) Foo() {}";
304        let tree = parse_go(source);
305        let root = tree.root_node();
306        // Find method_declaration node
307        let method_node = (0..root.named_child_count())
308            .filter_map(|i| root.named_child(i as u32))
309            .find(|n| n.kind() == "method_declaration")
310            .expect("expected method_declaration");
311        // Act
312        let result = find_receiver_type(&method_node, source);
313        // Assert
314        assert_eq!(result, Some("Receiver".to_string()));
315    }
316
317    #[test]
318    fn test_defuse_query_write_site() {
319        // Arrange
320        let src = "package p\nfunc main() { x := 1 }\n";
321        let sites = SemanticExtractor::extract_def_use_for_file(src, "go", "x", "test.go", None);
322        assert!(!sites.is_empty(), "defuse sites should not be empty");
323        let has_write = sites.iter().any(|s| matches!(s.kind, DefUseKind::Write));
324        assert!(has_write, "should contain a Write DefUseSite");
325    }
326
327    #[test]
328    fn test_defuse_go_short_var_decl() {
329        // Arrange: short var declaration := is Write
330        let src = "package p\nfunc main() { x := 42 }\n";
331        // Act
332        let sites = SemanticExtractor::extract_def_use_for_file(src, "go", "x", "test.go", None);
333        // Assert
334        assert!(
335            !sites.is_empty(),
336            "short var decl should produce defuse sites"
337        );
338        let has_write = sites.iter().any(|s| matches!(s.kind, DefUseKind::Write));
339        assert!(has_write, "short var decl should be Write");
340    }
341
342    #[test]
343    fn test_defuse_go_multi_return() {
344        // Arrange: multi-return := captures all LHS identifiers as Write
345        let src = "package p\nfunc main() { a, b := f() }\nfunc f() (int, int) { return 1, 2 }\n";
346        // Act
347        let sites_a = SemanticExtractor::extract_def_use_for_file(src, "go", "a", "test.go", None);
348        let sites_b = SemanticExtractor::extract_def_use_for_file(src, "go", "b", "test.go", None);
349        // Assert
350        assert!(
351            !sites_a.is_empty(),
352            "multi-return a should produce defuse sites"
353        );
354        assert!(
355            !sites_b.is_empty(),
356            "multi-return b should produce defuse sites"
357        );
358        let a_write = sites_a.iter().any(|s| matches!(s.kind, DefUseKind::Write));
359        let b_write = sites_b.iter().any(|s| matches!(s.kind, DefUseKind::Write));
360        assert!(a_write, "multi-return a should be Write");
361        assert!(b_write, "multi-return b should be Write");
362    }
363
364    #[test]
365    fn test_defuse_go_blank_identifier() {
366        // Arrange: blank identifier _ in multi-return
367        let src =
368            "package p\nfunc main() { _, err := f() }\nfunc f() (int, error) { return 1, nil }\n";
369        // Act
370        let sites = SemanticExtractor::extract_def_use_for_file(src, "go", "_", "test.go", None);
371        // Assert: blank identifier may be captured or excluded; test documents behavior
372        // If captured, it should be Write; if not captured, sites will be empty
373        if !sites.is_empty() {
374            let has_write = sites.iter().any(|s| matches!(s.kind, DefUseKind::Write));
375            assert!(has_write, "blank identifier if captured should be Write");
376        }
377    }
378}