Skip to main content

amql_engine/extractor/
go_test.rs

1//! Go test and benchmark extractor.
2//!
3//! Detects `func Test*(t *testing.T)` and `func Benchmark*(b *testing.B)`
4//! patterns in Go source files. Uses tree-sitter-go for parsing.
5
6use super::BuiltinExtractor;
7use crate::store::Annotation;
8use crate::types::{AttrName, Binding, RelativePath, TagName};
9use rustc_hash::FxHashMap;
10use serde_json::Value as JsonValue;
11use std::cell::RefCell;
12
13/// Built-in Go test/benchmark extractor.
14pub struct GoTestExtractor;
15
16impl BuiltinExtractor for GoTestExtractor {
17    fn name(&self) -> &str {
18        "go_test"
19    }
20
21    fn extensions(&self) -> &[&str] {
22        &[".go"]
23    }
24
25    fn extract(&self, source: &str, file: &RelativePath) -> Vec<Annotation> {
26        // Only process _test.go files
27        let path: &str = file.as_ref();
28        if !path.ends_with("_test.go") {
29            return vec![];
30        }
31
32        let tree = match parse_go(source) {
33            Some(t) => t,
34            None => return vec![],
35        };
36        let mut annotations = Vec::new();
37        visit_node(tree.root_node(), source.as_bytes(), file, &mut annotations);
38        annotations
39    }
40}
41
42fn visit_node(
43    node: tree_sitter::Node,
44    src: &[u8],
45    file: &RelativePath,
46    annotations: &mut Vec<Annotation>,
47) {
48    if node.kind() == "function_declaration" {
49        if let Some(ann) = extract_test_func(node, src, file) {
50            annotations.push(ann);
51        }
52    }
53
54    let mut cursor = node.walk();
55    for child in node.named_children(&mut cursor) {
56        visit_node(child, src, file, annotations);
57    }
58}
59
60/// Detect `func TestFoo(t *testing.T)` or `func BenchmarkFoo(b *testing.B)`.
61fn extract_test_func(
62    node: tree_sitter::Node,
63    src: &[u8],
64    file: &RelativePath,
65) -> Option<Annotation> {
66    let name_node = node.child_by_field_name("name")?;
67    let name = node_text(&name_node, src);
68
69    let (tag, prefix) = if name.starts_with("Test") {
70        ("test", "Test")
71    } else if name.starts_with("Benchmark") {
72        ("benchmark", "Benchmark")
73    } else {
74        return None;
75    };
76
77    // Validate the parameter type: *testing.T or *testing.B
78    let params = node.child_by_field_name("parameters")?;
79    let param_text = node_text(&params, src);
80    let expected_type = if tag == "test" {
81        "testing.T"
82    } else {
83        "testing.B"
84    };
85    if !param_text.contains(expected_type) {
86        return None;
87    }
88
89    let test_name = name.strip_prefix(prefix).unwrap_or(&name);
90
91    let mut attrs = FxHashMap::default();
92    attrs.insert(
93        AttrName::from("name"),
94        JsonValue::String(test_name.to_string()),
95    );
96
97    Some(Annotation {
98        tag: TagName::from(tag),
99        attrs,
100        binding: Binding::from(name),
101        file: file.clone(),
102        children: vec![],
103    })
104}
105
106fn node_text(node: &tree_sitter::Node, src: &[u8]) -> String {
107    node.utf8_text(src).unwrap_or("").to_string()
108}
109
110// ---------------------------------------------------------------------------
111// Parser cache (thread-local)
112// ---------------------------------------------------------------------------
113
114thread_local! {
115    static GO_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
116}
117
118fn parse_go(source: &str) -> Option<tree_sitter::Tree> {
119    GO_PARSER.with(|cell| {
120        let mut opt = cell.borrow_mut();
121        let parser = opt.get_or_insert_with(|| {
122            let mut p = tree_sitter::Parser::new();
123            p.set_language(&tree_sitter_go::LANGUAGE.into())
124                .expect("Failed to set Go language for tree-sitter");
125            p
126        });
127        parser.parse(source, None)
128    })
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    fn run(source: &str) -> Vec<Annotation> {
136        let file = RelativePath::from("pkg/handler_test.go");
137        GoTestExtractor.extract(source, &file)
138    }
139
140    #[test]
141    fn detects_test_functions() {
142        // Arrange
143        let source = r#"package handler
144
145import "testing"
146
147func TestCreateUser(t *testing.T) {
148	// test body
149}
150
151func TestDeleteUser(t *testing.T) {
152	// test body
153}
154"#;
155
156        // Act
157        let anns = run(source);
158
159        // Assert
160        let tests: Vec<_> = anns.iter().filter(|a| a.tag == "test").collect();
161        assert_eq!(tests.len(), 2, "should find 2 test functions");
162        assert_eq!(tests[0].binding, "TestCreateUser", "first test binding");
163        assert_eq!(tests[1].binding, "TestDeleteUser", "second test binding");
164        assert_eq!(
165            tests[0].attrs.get("name"),
166            Some(&JsonValue::String("CreateUser".to_string())),
167            "test name without prefix"
168        );
169    }
170
171    #[test]
172    fn detects_benchmark_functions() {
173        // Arrange
174        let source = r#"package handler
175
176import "testing"
177
178func BenchmarkSerialize(b *testing.B) {
179	for i := 0; i < b.N; i++ {
180		serialize()
181	}
182}
183"#;
184
185        // Act
186        let anns = run(source);
187
188        // Assert
189        let benchmarks: Vec<_> = anns.iter().filter(|a| a.tag == "benchmark").collect();
190        assert_eq!(benchmarks.len(), 1, "should find 1 benchmark");
191        assert_eq!(
192            benchmarks[0].binding, "BenchmarkSerialize",
193            "benchmark binding"
194        );
195        assert_eq!(
196            benchmarks[0].attrs.get("name"),
197            Some(&JsonValue::String("Serialize".to_string())),
198            "benchmark name without prefix"
199        );
200    }
201
202    #[test]
203    fn ignores_non_test_files() {
204        // Arrange
205        let source = r#"package handler
206
207import "testing"
208
209func TestCreateUser(t *testing.T) {}
210"#;
211        let file = RelativePath::from("pkg/handler.go");
212
213        // Act
214        let anns = GoTestExtractor.extract(source, &file);
215
216        // Assert
217        assert!(anns.is_empty(), "should ignore non-_test.go files");
218    }
219
220    #[test]
221    fn ignores_helper_functions() {
222        // Arrange
223        let source = r#"package handler
224
225import "testing"
226
227func TestMain(m *testing.M) {
228	os.Exit(m.Run())
229}
230
231func helperSetup(t *testing.T) {
232	// not a test
233}
234"#;
235
236        // Act
237        let anns = run(source);
238
239        // Assert
240        assert!(
241            anns.is_empty(),
242            "should ignore TestMain (testing.M) and helpers"
243        );
244    }
245
246    #[test]
247    fn detects_mixed_tests_and_benchmarks() {
248        // Arrange
249        let source = r#"package handler
250
251import "testing"
252
253func TestParse(t *testing.T) {}
254func BenchmarkParse(b *testing.B) {}
255func TestFormat(t *testing.T) {}
256"#;
257
258        // Act
259        let anns = run(source);
260
261        // Assert
262        let tests: Vec<_> = anns.iter().filter(|a| a.tag == "test").collect();
263        let benchmarks: Vec<_> = anns.iter().filter(|a| a.tag == "benchmark").collect();
264        assert_eq!(tests.len(), 2, "should find 2 tests");
265        assert_eq!(benchmarks.len(), 1, "should find 1 benchmark");
266    }
267}