amql_engine/extractor/
go_test.rs1use super::BuiltinExtractor;
7use crate::store::Annotation;
8use crate::types::{AttrName, Binding, RelativePath, TagName};
9use rustc_hash::FxHashMap;
10use serde_json::Value as JsonValue;
11use std::cell::RefCell;
12
13pub struct GoTestExtractor;
15
16impl BuiltinExtractor for GoTestExtractor {
17 fn name(&self) -> &str {
18 "go_test"
19 }
20
21 fn extensions(&self) -> &[&str] {
22 &[".go"]
23 }
24
25 fn extract(&self, source: &str, file: &RelativePath) -> Vec<Annotation> {
26 let path: &str = file.as_ref();
28 if !path.ends_with("_test.go") {
29 return vec![];
30 }
31
32 let tree = match parse_go(source) {
33 Some(t) => t,
34 None => return vec![],
35 };
36 let mut annotations = Vec::new();
37 visit_node(tree.root_node(), source.as_bytes(), file, &mut annotations);
38 annotations
39 }
40}
41
42fn visit_node(
43 node: tree_sitter::Node,
44 src: &[u8],
45 file: &RelativePath,
46 annotations: &mut Vec<Annotation>,
47) {
48 if node.kind() == "function_declaration" {
49 if let Some(ann) = extract_test_func(node, src, file) {
50 annotations.push(ann);
51 }
52 }
53
54 let mut cursor = node.walk();
55 for child in node.named_children(&mut cursor) {
56 visit_node(child, src, file, annotations);
57 }
58}
59
60fn extract_test_func(
62 node: tree_sitter::Node,
63 src: &[u8],
64 file: &RelativePath,
65) -> Option<Annotation> {
66 let name_node = node.child_by_field_name("name")?;
67 let name = node_text(&name_node, src);
68
69 let (tag, prefix) = if name.starts_with("Test") {
70 ("test", "Test")
71 } else if name.starts_with("Benchmark") {
72 ("benchmark", "Benchmark")
73 } else {
74 return None;
75 };
76
77 let params = node.child_by_field_name("parameters")?;
79 let param_text = node_text(¶ms, src);
80 let expected_type = if tag == "test" {
81 "testing.T"
82 } else {
83 "testing.B"
84 };
85 if !param_text.contains(expected_type) {
86 return None;
87 }
88
89 let test_name = name.strip_prefix(prefix).unwrap_or(&name);
90
91 let mut attrs = FxHashMap::default();
92 attrs.insert(
93 AttrName::from("name"),
94 JsonValue::String(test_name.to_string()),
95 );
96
97 Some(Annotation {
98 tag: TagName::from(tag),
99 attrs,
100 binding: Binding::from(name),
101 file: file.clone(),
102 children: vec![],
103 })
104}
105
106fn node_text(node: &tree_sitter::Node, src: &[u8]) -> String {
107 node.utf8_text(src).unwrap_or("").to_string()
108}
109
110thread_local! {
115 static GO_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
116}
117
118fn parse_go(source: &str) -> Option<tree_sitter::Tree> {
119 GO_PARSER.with(|cell| {
120 let mut opt = cell.borrow_mut();
121 let parser = opt.get_or_insert_with(|| {
122 let mut p = tree_sitter::Parser::new();
123 p.set_language(&tree_sitter_go::LANGUAGE.into())
124 .expect("Failed to set Go language for tree-sitter");
125 p
126 });
127 parser.parse(source, None)
128 })
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 fn run(source: &str) -> Vec<Annotation> {
136 let file = RelativePath::from("pkg/handler_test.go");
137 GoTestExtractor.extract(source, &file)
138 }
139
140 #[test]
141 fn detects_test_functions() {
142 let source = r#"package handler
144
145import "testing"
146
147func TestCreateUser(t *testing.T) {
148 // test body
149}
150
151func TestDeleteUser(t *testing.T) {
152 // test body
153}
154"#;
155
156 let anns = run(source);
158
159 let tests: Vec<_> = anns.iter().filter(|a| a.tag == "test").collect();
161 assert_eq!(tests.len(), 2, "should find 2 test functions");
162 assert_eq!(tests[0].binding, "TestCreateUser", "first test binding");
163 assert_eq!(tests[1].binding, "TestDeleteUser", "second test binding");
164 assert_eq!(
165 tests[0].attrs.get("name"),
166 Some(&JsonValue::String("CreateUser".to_string())),
167 "test name without prefix"
168 );
169 }
170
171 #[test]
172 fn detects_benchmark_functions() {
173 let source = r#"package handler
175
176import "testing"
177
178func BenchmarkSerialize(b *testing.B) {
179 for i := 0; i < b.N; i++ {
180 serialize()
181 }
182}
183"#;
184
185 let anns = run(source);
187
188 let benchmarks: Vec<_> = anns.iter().filter(|a| a.tag == "benchmark").collect();
190 assert_eq!(benchmarks.len(), 1, "should find 1 benchmark");
191 assert_eq!(
192 benchmarks[0].binding, "BenchmarkSerialize",
193 "benchmark binding"
194 );
195 assert_eq!(
196 benchmarks[0].attrs.get("name"),
197 Some(&JsonValue::String("Serialize".to_string())),
198 "benchmark name without prefix"
199 );
200 }
201
202 #[test]
203 fn ignores_non_test_files() {
204 let source = r#"package handler
206
207import "testing"
208
209func TestCreateUser(t *testing.T) {}
210"#;
211 let file = RelativePath::from("pkg/handler.go");
212
213 let anns = GoTestExtractor.extract(source, &file);
215
216 assert!(anns.is_empty(), "should ignore non-_test.go files");
218 }
219
220 #[test]
221 fn ignores_helper_functions() {
222 let source = r#"package handler
224
225import "testing"
226
227func TestMain(m *testing.M) {
228 os.Exit(m.Run())
229}
230
231func helperSetup(t *testing.T) {
232 // not a test
233}
234"#;
235
236 let anns = run(source);
238
239 assert!(
241 anns.is_empty(),
242 "should ignore TestMain (testing.M) and helpers"
243 );
244 }
245
246 #[test]
247 fn detects_mixed_tests_and_benchmarks() {
248 let source = r#"package handler
250
251import "testing"
252
253func TestParse(t *testing.T) {}
254func BenchmarkParse(b *testing.B) {}
255func TestFormat(t *testing.T) {}
256"#;
257
258 let anns = run(source);
260
261 let tests: Vec<_> = anns.iter().filter(|a| a.tag == "test").collect();
263 let benchmarks: Vec<_> = anns.iter().filter(|a| a.tag == "benchmark").collect();
264 assert_eq!(tests.len(), 2, "should find 2 tests");
265 assert_eq!(benchmarks.len(), 1, "should find 1 benchmark");
266 }
267}