Skip to main content

amql_engine/extractor/
react.rs

1//! React component and hook extractor.
2//!
3//! Detects function components (capitalized name returning JSX), hooks
4//! (use* calls), memo/forwardRef wrappers, and Suspense boundaries.
5//! Uses tree-sitter for TSX/JSX parsing.
6
7use super::BuiltinExtractor;
8use crate::store::Annotation;
9use crate::types::{AttrName, Binding, RelativePath, TagName};
10use rustc_hash::FxHashMap;
11use serde_json::Value as JsonValue;
12use std::cell::RefCell;
13
14/// Built-in names recognized as React hooks.
15const BUILTIN_HOOKS: &[&str] = &[
16    "useState",
17    "useEffect",
18    "useContext",
19    "useReducer",
20    "useCallback",
21    "useMemo",
22    "useRef",
23    "useImperativeHandle",
24    "useLayoutEffect",
25    "useDebugValue",
26    "useDeferredValue",
27    "useTransition",
28    "useId",
29    "useSyncExternalStore",
30    "useInsertionEffect",
31    "useOptimistic",
32    "useFormStatus",
33    "useActionState",
34    "use",
35];
36
37/// Higher-order component wrappers.
38const HOC_WRAPPERS: &[&str] = &["memo", "React.memo", "forwardRef", "React.forwardRef"];
39
40/// Built-in React component/hook extractor.
41pub struct ReactExtractor;
42
43impl BuiltinExtractor for ReactExtractor {
44    fn name(&self) -> &str {
45        "react"
46    }
47
48    fn extensions(&self) -> &[&str] {
49        &[".tsx", ".jsx", ".ts", ".js"]
50    }
51
52    fn extract(&self, source: &str, file: &RelativePath) -> Vec<Annotation> {
53        let tree = match parse_tsx(source, file) {
54            Some(t) => t,
55            None => return vec![],
56        };
57        let mut annotations = Vec::new();
58        visit_node(tree.root_node(), source.as_bytes(), file, &mut annotations);
59        annotations
60    }
61}
62
63fn visit_node(
64    node: tree_sitter::Node,
65    src: &[u8],
66    file: &RelativePath,
67    annotations: &mut Vec<Annotation>,
68) {
69    // Function declarations: function Foo() { return <div/>; }
70    if node.kind() == "function_declaration" {
71        if let Some(ann) = extract_function_component(node, src, file) {
72            annotations.push(ann);
73        }
74    }
75
76    // Variable declarations: const Foo = () => <div/>; or const Foo = memo(...)
77    if node.kind() == "lexical_declaration" || node.kind() == "variable_declaration" {
78        extract_variable_declarations(node, src, file, annotations);
79    }
80
81    // Export statements: export default function Foo() {}
82    // Return early to avoid double-visiting the child declaration.
83    if node.kind() == "export_statement" {
84        extract_export_statement(node, src, file, annotations);
85        return;
86    }
87
88    // Expression statements: hook calls at top level
89    if node.kind() == "expression_statement" {
90        extract_hook_call_statement(node, src, file, annotations);
91    }
92
93    let mut cursor = node.walk();
94    for child in node.named_children(&mut cursor) {
95        visit_node(child, src, file, annotations);
96    }
97}
98
99/// Extract a function component from `function Foo() { ... }`.
100fn extract_function_component(
101    node: tree_sitter::Node,
102    src: &[u8],
103    file: &RelativePath,
104) -> Option<Annotation> {
105    let name_node = node.child_by_field_name("name")?;
106    let name = node_text(name_node, src);
107
108    if !is_component_name(&name) {
109        return None;
110    }
111
112    if !body_contains_jsx(node, src) {
113        return None;
114    }
115
116    let mut attrs = collect_export_attrs(node);
117    attrs.insert(AttrName::from("name"), JsonValue::String(name.clone()));
118
119    Some(Annotation {
120        tag: TagName::from("component"),
121        attrs,
122        binding: Binding::from(name),
123        file: file.clone(),
124        children: vec![],
125    })
126}
127
128/// Extract components/hooks from variable declarations.
129fn extract_variable_declarations(
130    node: tree_sitter::Node,
131    src: &[u8],
132    file: &RelativePath,
133    annotations: &mut Vec<Annotation>,
134) {
135    let mut cursor = node.walk();
136    for declarator in node.named_children(&mut cursor) {
137        if declarator.kind() != "variable_declarator" {
138            continue;
139        }
140        let name_node = match declarator.child_by_field_name("name") {
141            Some(n) if n.kind() == "identifier" => n,
142            _ => continue,
143        };
144        let name = node_text(name_node, src);
145        let init = match declarator.child_by_field_name("value") {
146            Some(n) => n,
147            None => continue,
148        };
149
150        // Check for HOC wrappers: memo(...), forwardRef(...)
151        if init.kind() == "call_expression" {
152            if let Some(ann) = extract_hoc_component(&name, init, src, file, node) {
153                annotations.push(ann);
154                continue;
155            }
156
157            // Hook assignment: const value = useCustomHook()
158            if let Some(ann) = extract_hook_assignment(&name, init, src, file) {
159                annotations.push(ann);
160                continue;
161            }
162        }
163
164        // Arrow/function component: const Foo = () => <div/>
165        if is_component_name(&name)
166            && (init.kind() == "arrow_function" || init.kind() == "function_expression")
167            && body_contains_jsx(init, src)
168        {
169            let mut attrs = collect_export_attrs(node);
170            attrs.insert(AttrName::from("name"), JsonValue::String(name.clone()));
171
172            annotations.push(Annotation {
173                tag: TagName::from("component"),
174                attrs,
175                binding: Binding::from(name),
176                file: file.clone(),
177                children: vec![],
178            });
179        }
180    }
181}
182
183/// Extract component from `memo(...)` or `forwardRef(...)`.
184fn extract_hoc_component(
185    name: &str,
186    call_node: tree_sitter::Node,
187    src: &[u8],
188    file: &RelativePath,
189    parent: tree_sitter::Node,
190) -> Option<Annotation> {
191    let callee = call_node.child_by_field_name("function")?;
192    let callee_text = node_text(callee, src);
193
194    if !HOC_WRAPPERS.contains(&callee_text.as_str()) {
195        return None;
196    }
197
198    if !is_component_name(name) {
199        return None;
200    }
201
202    let wrapper = if callee_text.contains("memo") {
203        "memo"
204    } else {
205        "forwardRef"
206    };
207
208    let mut attrs = collect_export_attrs(parent);
209    attrs.insert(AttrName::from("name"), JsonValue::String(name.to_string()));
210    attrs.insert(AttrName::from(wrapper), JsonValue::Bool(true));
211
212    Some(Annotation {
213        tag: TagName::from("component"),
214        attrs,
215        binding: Binding::from(name.to_string()),
216        file: file.clone(),
217        children: vec![],
218    })
219}
220
221/// Extract hook call from `const x = useFoo()`.
222fn extract_hook_assignment(
223    _name: &str,
224    call_node: tree_sitter::Node,
225    src: &[u8],
226    file: &RelativePath,
227) -> Option<Annotation> {
228    let callee = call_node.child_by_field_name("function")?;
229    let callee_text = node_text(callee, src);
230
231    if !is_hook_name(&callee_text) {
232        return None;
233    }
234
235    let custom = !BUILTIN_HOOKS.contains(&callee_text.as_str());
236
237    let mut attrs = FxHashMap::default();
238    attrs.insert(
239        AttrName::from("name"),
240        JsonValue::String(callee_text.clone()),
241    );
242    if custom {
243        attrs.insert(AttrName::from("custom"), JsonValue::Bool(true));
244    }
245
246    Some(Annotation {
247        tag: TagName::from("hook"),
248        attrs,
249        binding: Binding::from(callee_text),
250        file: file.clone(),
251        children: vec![],
252    })
253}
254
255/// Extract hook calls from expression statements (no assignment).
256fn extract_hook_call_statement(
257    node: tree_sitter::Node,
258    src: &[u8],
259    file: &RelativePath,
260    annotations: &mut Vec<Annotation>,
261) {
262    let call = match node.named_child(0) {
263        Some(n) if n.kind() == "call_expression" => n,
264        _ => return,
265    };
266    let callee = match call.child_by_field_name("function") {
267        Some(n) => node_text(n, src),
268        None => return,
269    };
270
271    if !is_hook_name(&callee) {
272        return;
273    }
274
275    let custom = !BUILTIN_HOOKS.contains(&callee.as_str());
276
277    let mut attrs = FxHashMap::default();
278    attrs.insert(AttrName::from("name"), JsonValue::String(callee.clone()));
279    if custom {
280        attrs.insert(AttrName::from("custom"), JsonValue::Bool(true));
281    }
282
283    annotations.push(Annotation {
284        tag: TagName::from("hook"),
285        attrs,
286        binding: Binding::from(callee),
287        file: file.clone(),
288        children: vec![],
289    });
290}
291
292/// Extract components from export statements.
293fn extract_export_statement(
294    node: tree_sitter::Node,
295    src: &[u8],
296    file: &RelativePath,
297    annotations: &mut Vec<Annotation>,
298) {
299    let mut cursor = node.walk();
300    for child in node.named_children(&mut cursor) {
301        if child.kind() == "function_declaration" {
302            if let Some(mut ann) = extract_function_component(child, src, file) {
303                ann.attrs
304                    .insert(AttrName::from("export"), JsonValue::Bool(true));
305                let text = node_text(node, src);
306                if text.starts_with("export default") {
307                    ann.attrs
308                        .insert(AttrName::from("default"), JsonValue::Bool(true));
309                }
310                annotations.push(ann);
311                return;
312            }
313        }
314        if child.kind() == "lexical_declaration" || child.kind() == "variable_declaration" {
315            extract_variable_declarations(child, src, file, annotations);
316        }
317    }
318}
319
320// ---------------------------------------------------------------------------
321// Helpers
322// ---------------------------------------------------------------------------
323
324/// Component names start with uppercase.
325fn is_component_name(name: &str) -> bool {
326    name.starts_with(|c: char| c.is_ascii_uppercase())
327}
328
329/// Hook names start with `use` followed by uppercase.
330fn is_hook_name(name: &str) -> bool {
331    if name == "use" {
332        return true;
333    }
334    name.starts_with("use") && name.chars().nth(3).is_some_and(|c| c.is_ascii_uppercase())
335}
336
337/// Check if a function/arrow body contains JSX elements.
338fn body_contains_jsx(node: tree_sitter::Node, src: &[u8]) -> bool {
339    let text = node.utf8_text(src).unwrap_or("");
340    // Quick heuristic: check for JSX-like patterns in the source text
341    text.contains('<') && (text.contains("/>") || text.contains("</"))
342}
343
344fn node_text(node: tree_sitter::Node, src: &[u8]) -> String {
345    node.utf8_text(src).unwrap_or("").to_string()
346}
347
348/// Check for export/default modifiers on a parent statement node.
349fn collect_export_attrs(node: tree_sitter::Node) -> FxHashMap<AttrName, JsonValue> {
350    let mut attrs = FxHashMap::default();
351    if let Some(parent) = node.parent() {
352        if parent.kind() == "export_statement" {
353            attrs.insert(AttrName::from("export"), JsonValue::Bool(true));
354            let mut cursor = parent.walk();
355            for child in parent.children(&mut cursor) {
356                if child.kind() == "default" {
357                    attrs.insert(AttrName::from("default"), JsonValue::Bool(true));
358                    break;
359                }
360            }
361        }
362    }
363    attrs
364}
365
366// ---------------------------------------------------------------------------
367// Parser cache (thread-local)
368// ---------------------------------------------------------------------------
369
370thread_local! {
371    static TSX_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
372}
373
374fn parse_tsx(source: &str, _file: &RelativePath) -> Option<tree_sitter::Tree> {
375    TSX_PARSER.with(|cell| {
376        let mut opt = cell.borrow_mut();
377        let parser = opt.get_or_insert_with(|| {
378            let mut p = tree_sitter::Parser::new();
379            p.set_language(&tree_sitter_typescript::LANGUAGE_TSX.into())
380                .expect("Failed to set TSX language");
381            p
382        });
383        parser.parse(source, None)
384    })
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    fn run(source: &str) -> Vec<Annotation> {
392        let file = RelativePath::from("src/App.tsx");
393        ReactExtractor.extract(source, &file)
394    }
395
396    #[test]
397    fn detects_function_component() {
398        // Arrange
399        let source = r#"
400function App() {
401    return <div>Hello</div>;
402}
403"#;
404
405        // Act
406        let anns = run(source);
407
408        // Assert
409        let components: Vec<_> = anns.iter().filter(|a| a.tag == "component").collect();
410        assert_eq!(components.len(), 1, "should find 1 component");
411        assert_eq!(components[0].binding, "App", "component name");
412        assert_eq!(
413            components[0].attrs.get("name"),
414            Some(&JsonValue::String("App".to_string())),
415            "name attr"
416        );
417    }
418
419    #[test]
420    fn detects_arrow_component() {
421        // Arrange
422        let source = r#"
423const Header = () => {
424    return <header>Title</header>;
425};
426"#;
427
428        // Act
429        let anns = run(source);
430
431        // Assert
432        let components: Vec<_> = anns.iter().filter(|a| a.tag == "component").collect();
433        assert_eq!(components.len(), 1, "should find 1 arrow component");
434        assert_eq!(components[0].binding, "Header", "component name");
435    }
436
437    #[test]
438    fn detects_hook_call() {
439        // Arrange
440        let source = r#"
441function App() {
442    const [count, setCount] = useState(0);
443    useEffect(() => {}, []);
444    return <div>{count}</div>;
445}
446"#;
447
448        // Act
449        let anns = run(source);
450
451        // Assert
452        let hooks: Vec<_> = anns.iter().filter(|a| a.tag == "hook").collect();
453        assert_eq!(hooks.len(), 1, "should find useEffect hook statement");
454        assert_eq!(hooks[0].binding, "useEffect", "hook name");
455    }
456
457    #[test]
458    fn detects_memo_wrapper() {
459        // Arrange
460        let source = r#"
461const Card = memo(({ title }) => {
462    return <div>{title}</div>;
463});
464"#;
465
466        // Act
467        let anns = run(source);
468
469        // Assert
470        let components: Vec<_> = anns.iter().filter(|a| a.tag == "component").collect();
471        assert_eq!(components.len(), 1, "should find 1 memo component");
472        assert_eq!(components[0].binding, "Card", "component name");
473        assert_eq!(
474            components[0].attrs.get("memo"),
475            Some(&JsonValue::Bool(true)),
476            "should have memo attr"
477        );
478    }
479
480    #[test]
481    fn detects_forward_ref() {
482        // Arrange
483        let source = r#"
484const Input = forwardRef((props, ref) => {
485    return <input ref={ref} />;
486});
487"#;
488
489        // Act
490        let anns = run(source);
491
492        // Assert
493        let components: Vec<_> = anns.iter().filter(|a| a.tag == "component").collect();
494        assert_eq!(components.len(), 1, "should find 1 forwardRef component");
495        assert_eq!(
496            components[0].attrs.get("forwardRef"),
497            Some(&JsonValue::Bool(true)),
498            "should have forwardRef attr"
499        );
500    }
501
502    #[test]
503    fn detects_custom_hook() {
504        // Arrange
505        let source = r#"
506function App() {
507    const data = useCustomData();
508    return <div />;
509}
510"#;
511
512        // Act
513        let anns = run(source);
514
515        // Assert
516        let hooks: Vec<_> = anns.iter().filter(|a| a.tag == "hook").collect();
517        assert_eq!(hooks.len(), 1, "should find 1 custom hook");
518        assert_eq!(hooks[0].binding, "useCustomData", "hook name");
519        assert_eq!(
520            hooks[0].attrs.get("custom"),
521            Some(&JsonValue::Bool(true)),
522            "should mark as custom"
523        );
524    }
525
526    #[test]
527    fn ignores_non_component_functions() {
528        // Arrange
529        let source = r#"
530function helper() {
531    return 42;
532}
533
534const utils = () => {
535    return "hello";
536};
537"#;
538
539        // Act
540        let anns = run(source);
541
542        // Assert
543        assert!(anns.is_empty(), "should not detect non-component functions");
544    }
545
546    #[test]
547    fn detects_exported_component() {
548        // Arrange
549        let source = r#"
550export default function App() {
551    return <div>Hello</div>;
552}
553"#;
554
555        // Act
556        let anns = run(source);
557
558        // Assert
559        let components: Vec<_> = anns.iter().filter(|a| a.tag == "component").collect();
560        assert_eq!(components.len(), 1, "should find 1 exported component");
561        assert_eq!(
562            components[0].attrs.get("export"),
563            Some(&JsonValue::Bool(true)),
564            "should have export attr"
565        );
566        assert_eq!(
567            components[0].attrs.get("default"),
568            Some(&JsonValue::Bool(true)),
569            "should have default attr"
570        );
571    }
572}