Skip to main content

amql_engine/extractor/
test.rs

1//! Test framework extractor (vitest, jest, mocha).
2//!
3//! Detects `describe()`, `it()`, `test()`, lifecycle hooks, and their
4//! `.only` / `.skip` modifiers. Preserves the describe > test hierarchy
5//! via nested `children`.
6
7use super::BuiltinExtractor;
8use crate::store::Annotation;
9use crate::types::{AttrName, Binding, RelativePath, TagName};
10use rustc_hash::FxHashMap;
11use serde_json::Value as JsonValue;
12use std::cell::RefCell;
13
14/// Block names recognized as test framework calls.
15const TEST_BLOCKS: &[&str] = &[
16    "describe",
17    "it",
18    "test",
19    "beforeEach",
20    "afterEach",
21    "beforeAll",
22    "afterAll",
23];
24
25/// Modified block names (`.only`, `.skip`).
26const MODIFIED_BLOCKS: &[&str] = &[
27    "describe.only",
28    "describe.skip",
29    "it.only",
30    "it.skip",
31    "test.only",
32    "test.skip",
33];
34
35/// Built-in test framework extractor.
36pub struct TestExtractor;
37
38impl BuiltinExtractor for TestExtractor {
39    fn name(&self) -> &str {
40        "test"
41    }
42
43    fn extensions(&self) -> &[&str] {
44        &[".ts", ".tsx", ".js", ".jsx", ".mts", ".mjs"]
45    }
46
47    fn extract(&self, source: &str, file: &RelativePath) -> Vec<Annotation> {
48        let tree = match parse_ts(source, file) {
49            Some(t) => t,
50            None => return vec![],
51        };
52        let mut annotations = Vec::new();
53        visit_node(tree.root_node(), source.as_bytes(), file, &mut annotations);
54        annotations
55    }
56}
57
58fn visit_node(
59    node: tree_sitter::Node,
60    src: &[u8],
61    file: &RelativePath,
62    annotations: &mut Vec<Annotation>,
63) {
64    if node.kind() == "expression_statement" {
65        if let Some(ann) = extract_test_block(node, src, file) {
66            annotations.push(ann);
67            return;
68        }
69    }
70
71    let mut cursor = node.walk();
72    for child in node.named_children(&mut cursor) {
73        visit_node(child, src, file, annotations);
74    }
75}
76
77/// Try to extract a test block (describe/it/test/hook) from an expression statement.
78fn extract_test_block(
79    node: tree_sitter::Node,
80    src: &[u8],
81    file: &RelativePath,
82) -> Option<Annotation> {
83    let call = node.named_child(0)?;
84    if call.kind() != "call_expression" {
85        return None;
86    }
87
88    let callee_node = call.child_by_field_name("function")?;
89    let callee = node_text(callee_node, src);
90
91    if !TEST_BLOCKS.contains(&callee.as_str()) && !MODIFIED_BLOCKS.contains(&callee.as_str()) {
92        return None;
93    }
94
95    let base_name = callee.split('.').next().unwrap_or(&callee);
96    let modifier = if callee.contains('.') {
97        callee.split('.').nth(1)
98    } else {
99        None
100    };
101
102    let args = call.child_by_field_name("arguments")?;
103    let args_children: Vec<_> = {
104        let mut cursor = args.walk();
105        args.named_children(&mut cursor).collect()
106    };
107    let first_arg = args_children.first();
108    let label = match first_arg {
109        Some(n) if n.kind() == "string" || n.kind() == "template_string" => {
110            string_literal_value(*n, src)
111        }
112        _ => callee.clone(),
113    };
114
115    let tag = match base_name {
116        "describe" => "describe",
117        "it" | "test" => "test",
118        _ => "hook",
119    };
120
121    let mut attrs = FxHashMap::default();
122    if let Some(mod_name) = &modifier {
123        attrs.insert(AttrName::from(*mod_name), JsonValue::Bool(true));
124    }
125    if tag == "hook" {
126        attrs.insert(
127            AttrName::from("kind"),
128            JsonValue::String(callee.to_string()),
129        );
130    }
131
132    // Extract children from callback body (second argument)
133    let mut children = Vec::new();
134    if args_children.len() >= 2 {
135        let callback = args_children[1];
136        if callback.kind() == "arrow_function" || callback.kind() == "function_expression" {
137            if let Some(body) = callback.child_by_field_name("body") {
138                if body.kind() == "statement_block" {
139                    let mut body_cursor = body.walk();
140                    for stmt in body.named_children(&mut body_cursor) {
141                        visit_node(stmt, src, file, &mut children);
142                    }
143                }
144            }
145        }
146    }
147
148    Some(Annotation {
149        tag: TagName::from(tag),
150        attrs,
151        binding: Binding::from(label),
152        file: file.clone(),
153        children,
154    })
155}
156
157// ---------------------------------------------------------------------------
158// tree-sitter helpers
159// ---------------------------------------------------------------------------
160
161fn node_text<'a>(node: tree_sitter::Node<'a>, src: &'a [u8]) -> String {
162    node.utf8_text(src).unwrap_or("").to_string()
163}
164
165fn string_literal_value(node: tree_sitter::Node, src: &[u8]) -> String {
166    let text = node.utf8_text(src).unwrap_or("");
167    if text.len() >= 2 {
168        text[1..text.len() - 1].to_string()
169    } else {
170        text.to_string()
171    }
172}
173
174// ---------------------------------------------------------------------------
175// Parser cache (thread-local)
176// ---------------------------------------------------------------------------
177
178thread_local! {
179    static TS_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
180    static TSX_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
181}
182
183fn parse_ts(source: &str, file: &RelativePath) -> Option<tree_sitter::Tree> {
184    let path: &str = file.as_ref();
185    let is_tsx = path.ends_with(".tsx") || path.ends_with(".jsx");
186
187    if is_tsx {
188        TSX_PARSER.with(|cell| {
189            let mut opt = cell.borrow_mut();
190            let parser = opt.get_or_insert_with(|| {
191                let mut p = tree_sitter::Parser::new();
192                p.set_language(&tree_sitter_typescript::LANGUAGE_TSX.into())
193                    .expect("Failed to set TSX language");
194                p
195            });
196            parser.parse(source, None)
197        })
198    } else {
199        TS_PARSER.with(|cell| {
200            let mut opt = cell.borrow_mut();
201            let parser = opt.get_or_insert_with(|| {
202                let mut p = tree_sitter::Parser::new();
203                p.set_language(&tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into())
204                    .expect("Failed to set TypeScript language");
205                p
206            });
207            parser.parse(source, None)
208        })
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    fn run(source: &str) -> Vec<Annotation> {
217        let file = RelativePath::from("src/app.test.ts");
218        TestExtractor.extract(source, &file)
219    }
220
221    #[test]
222    fn detects_describe_and_tests() {
223        // Arrange
224        let source = r#"
225describe("UserService", () => {
226    it("creates a user", () => {
227        expect(true).toBe(true);
228    });
229
230    test("deletes a user", () => {
231        expect(true).toBe(true);
232    });
233});
234"#;
235
236        // Act
237        let anns = run(source);
238
239        // Assert
240        assert_eq!(anns.len(), 1, "should find 1 top-level describe");
241        assert_eq!(anns[0].tag, "describe", "top-level should be describe");
242        assert_eq!(anns[0].binding, "UserService", "describe label");
243        assert_eq!(anns[0].children.len(), 2, "describe should have 2 children");
244        assert_eq!(
245            anns[0].children[0].tag, "test",
246            "first child should be test"
247        );
248        assert_eq!(
249            anns[0].children[0].binding, "creates a user",
250            "first test label"
251        );
252        assert_eq!(
253            anns[0].children[1].tag, "test",
254            "second child should be test"
255        );
256        assert_eq!(
257            anns[0].children[1].binding, "deletes a user",
258            "second test label"
259        );
260    }
261
262    #[test]
263    fn detects_hooks() {
264        // Arrange
265        let source = r#"
266describe("setup", () => {
267    beforeEach(() => {
268        reset();
269    });
270
271    afterAll(() => {
272        cleanup();
273    });
274
275    it("works", () => {});
276});
277"#;
278
279        // Act
280        let anns = run(source);
281
282        // Assert
283        let desc = &anns[0];
284        let hooks: Vec<_> = desc.children.iter().filter(|a| a.tag == "hook").collect();
285        assert_eq!(hooks.len(), 2, "should find 2 hooks");
286        assert_eq!(
287            hooks[0].attrs.get("kind"),
288            Some(&JsonValue::String("beforeEach".to_string())),
289            "first hook kind"
290        );
291        assert_eq!(
292            hooks[1].attrs.get("kind"),
293            Some(&JsonValue::String("afterAll".to_string())),
294            "second hook kind"
295        );
296    }
297
298    #[test]
299    fn detects_modifiers() {
300        // Arrange
301        let source = r#"
302describe.only("focused", () => {
303    it.skip("skipped test", () => {});
304});
305"#;
306
307        // Act
308        let anns = run(source);
309
310        // Assert
311        assert_eq!(anns.len(), 1, "should find 1 describe");
312        assert_eq!(
313            anns[0].attrs.get("only"),
314            Some(&JsonValue::Bool(true)),
315            "describe should have only modifier"
316        );
317        assert_eq!(
318            anns[0].children[0].attrs.get("skip"),
319            Some(&JsonValue::Bool(true)),
320            "test should have skip modifier"
321        );
322    }
323
324    #[test]
325    fn nested_describes() {
326        // Arrange
327        let source = r#"
328describe("outer", () => {
329    describe("inner", () => {
330        it("deep test", () => {});
331    });
332});
333"#;
334
335        // Act
336        let anns = run(source);
337
338        // Assert
339        assert_eq!(anns.len(), 1, "should find 1 top-level describe");
340        assert_eq!(
341            anns[0].children.len(),
342            1,
343            "outer should have 1 child describe"
344        );
345        assert_eq!(
346            anns[0].children[0].tag, "describe",
347            "child should be describe"
348        );
349        assert_eq!(
350            anns[0].children[0].children.len(),
351            1,
352            "inner should have 1 test"
353        );
354        assert_eq!(
355            anns[0].children[0].children[0].binding, "deep test",
356            "deep test label"
357        );
358    }
359
360    #[test]
361    fn detects_top_level_it() {
362        // Arrange — top-level it() without describe wrapper
363        let source = r#"
364it("standalone test", () => {
365    expect(true).toBe(true);
366});
367"#;
368
369        // Act
370        let anns = run(source);
371
372        // Assert
373        assert_eq!(anns.len(), 1, "should find top-level it");
374        assert_eq!(anns[0].tag, "test", "should be tagged as test");
375        assert_eq!(anns[0].binding, "standalone test", "test label");
376    }
377
378    #[test]
379    fn detects_test_each() {
380        // Arrange — test.each is not in our modified blocks, should be ignored
381        let source = r#"
382describe("parameterized", () => {
383    test("simple test", () => {});
384});
385"#;
386
387        // Act
388        let anns = run(source);
389
390        // Assert
391        assert_eq!(anns.len(), 1, "should find 1 describe");
392        assert_eq!(anns[0].children.len(), 1, "should have 1 test child");
393    }
394
395    #[test]
396    fn detects_function_expression_callback() {
397        // Arrange — using function() instead of arrow function
398        let source = r#"
399describe("legacy", function() {
400    it("old style test", function() {
401        expect(true).toBe(true);
402    });
403});
404"#;
405
406        // Act
407        let anns = run(source);
408
409        // Assert
410        assert_eq!(anns.len(), 1, "should find 1 describe");
411        assert_eq!(
412            anns[0].children.len(),
413            1,
414            "should find test inside function expression"
415        );
416    }
417}