Skip to main content

amql_engine/extractor/
express.rs

1//! Express route and middleware extractor.
2//!
3//! Detects `Router()` instances, route calls (`router.get("/path", ...)`),
4//! and middleware (`app.use("/path", ...)`). Uses tree-sitter for parsing.
5
6use super::BuiltinExtractor;
7use crate::store::Annotation;
8use crate::types::{AttrName, Binding, RelativePath, TagName};
9use rustc_hash::FxHashMap;
10use serde_json::Value as JsonValue;
11use std::cell::RefCell;
12
13/// HTTP methods recognized as Express route handlers.
14const HTTP_METHODS: &[&str] = &[
15    "get", "post", "put", "delete", "patch", "options", "head", "all",
16];
17
18/// Built-in Express route/middleware extractor.
19pub struct ExpressExtractor;
20
21impl BuiltinExtractor for ExpressExtractor {
22    fn name(&self) -> &str {
23        "express"
24    }
25
26    fn extensions(&self) -> &[&str] {
27        &[".ts", ".tsx", ".js", ".jsx", ".mts", ".mjs"]
28    }
29
30    fn extract(&self, source: &str, file: &RelativePath) -> Vec<Annotation> {
31        let tree = match parse_ts(source, file) {
32            Some(t) => t,
33            None => return vec![],
34        };
35        let mut annotations = Vec::new();
36        visit_node(tree.root_node(), source.as_bytes(), file, &mut annotations);
37        annotations
38    }
39}
40
41fn visit_node(
42    node: tree_sitter::Node,
43    src: &[u8],
44    file: &RelativePath,
45    annotations: &mut Vec<Annotation>,
46) {
47    // Router instance: const router = Router() or const router = express.Router()
48    if node.kind() == "lexical_declaration" || node.kind() == "variable_declaration" {
49        extract_router(node, src, file, annotations);
50    }
51
52    // Route/middleware calls: router.get("/path", ...) or app.use("/path", ...)
53    if node.kind() == "expression_statement" {
54        extract_route_call(node, src, file, annotations);
55    }
56
57    let mut cursor = node.walk();
58    for child in node.named_children(&mut cursor) {
59        visit_node(child, src, file, annotations);
60    }
61}
62
63/// Detect `const router = Router()` or `const router = express.Router()`.
64fn extract_router(
65    node: tree_sitter::Node,
66    src: &[u8],
67    file: &RelativePath,
68    annotations: &mut Vec<Annotation>,
69) {
70    let mut cursor = node.walk();
71    for declarator in node.named_children(&mut cursor) {
72        if declarator.kind() != "variable_declarator" {
73            continue;
74        }
75        let name_node = match declarator.child_by_field_name("name") {
76            Some(n) if n.kind() == "identifier" => n,
77            _ => continue,
78        };
79        let init = match declarator.child_by_field_name("value") {
80            Some(n) if n.kind() == "call_expression" => n,
81            _ => continue,
82        };
83        let callee = match init.child_by_field_name("function") {
84            Some(n) => node_text(n, src),
85            None => continue,
86        };
87        if callee == "Router" || callee == "express.Router" {
88            let binding = node_text(name_node, src);
89            let attrs = collect_export_attrs(node, src);
90            annotations.push(Annotation {
91                tag: TagName::from("router"),
92                attrs,
93                binding: Binding::from(binding),
94                file: file.clone(),
95                children: vec![],
96            });
97        }
98    }
99}
100
101/// Detect `router.get("/path", handler)` or `app.use("/path", middleware)`.
102fn extract_route_call(
103    node: tree_sitter::Node,
104    src: &[u8],
105    file: &RelativePath,
106    annotations: &mut Vec<Annotation>,
107) {
108    // expression_statement > call_expression > member_expression
109    let call = match node.named_child(0) {
110        Some(n) if n.kind() == "call_expression" => n,
111        _ => return,
112    };
113    let callee = match call.child_by_field_name("function") {
114        Some(n) if n.kind() == "member_expression" => n,
115        _ => return,
116    };
117    let object = match callee.child_by_field_name("object") {
118        Some(n) => node_text(n, src),
119        None => return,
120    };
121    let method = match callee.child_by_field_name("property") {
122        Some(n) => node_text(n, src),
123        None => return,
124    };
125    let args = match call.child_by_field_name("arguments") {
126        Some(n) => n,
127        None => return,
128    };
129    let first_arg = match first_named_child(&args) {
130        Some(n) if n.kind() == "string" || n.kind() == "template_string" => n,
131        _ => return,
132    };
133    let path = string_literal_value(first_arg, src);
134
135    let arg_count = named_child_count(&args);
136
137    if method == "use" {
138        let bind = format!("USE {path}");
139        let mut attrs = FxHashMap::default();
140        attrs.insert(AttrName::from("path"), JsonValue::String(path.to_string()));
141        attrs.insert(
142            AttrName::from("handler"),
143            JsonValue::String(object.to_string()),
144        );
145        annotations.push(Annotation {
146            tag: TagName::from("middleware"),
147            attrs,
148            binding: Binding::from(bind),
149            file: file.clone(),
150            children: vec![],
151        });
152    } else if HTTP_METHODS.contains(&method.as_str()) {
153        let bind = format!("{} {path}", method.to_uppercase());
154        let mut attrs = FxHashMap::default();
155        attrs.insert(
156            AttrName::from("method"),
157            JsonValue::String(method.to_uppercase()),
158        );
159        attrs.insert(AttrName::from("path"), JsonValue::String(path.to_string()));
160        attrs.insert(
161            AttrName::from("handler"),
162            JsonValue::String(object.to_string()),
163        );
164        if arg_count > 2 {
165            attrs.insert(AttrName::from("hasMiddleware"), JsonValue::Bool(true));
166        }
167        annotations.push(Annotation {
168            tag: TagName::from("route"),
169            attrs,
170            binding: Binding::from(bind),
171            file: file.clone(),
172            children: vec![],
173        });
174    }
175}
176
177/// Check for export/default modifiers on a statement node.
178fn collect_export_attrs(node: tree_sitter::Node, src: &[u8]) -> FxHashMap<AttrName, JsonValue> {
179    let mut attrs = FxHashMap::default();
180    if let Some(parent) = node.parent() {
181        if parent.kind() == "export_statement" {
182            attrs.insert(AttrName::from("export"), JsonValue::Bool(true));
183            // Check for `export default`
184            let text = node_text(parent, src);
185            if text.starts_with("export default") {
186                attrs.insert(AttrName::from("default"), JsonValue::Bool(true));
187            }
188        }
189    }
190    attrs
191}
192
193// ---------------------------------------------------------------------------
194// tree-sitter helpers
195// ---------------------------------------------------------------------------
196
197fn node_text<'a>(node: tree_sitter::Node<'a>, src: &'a [u8]) -> String {
198    node.utf8_text(src).unwrap_or("").to_string()
199}
200
201/// Extract the string value from a string literal or template string node.
202fn string_literal_value(node: tree_sitter::Node, src: &[u8]) -> String {
203    let text = node.utf8_text(src).unwrap_or("");
204    // Strip quotes: "...", '...', `...`
205    if text.len() >= 2 {
206        let inner = &text[1..text.len() - 1];
207        inner.to_string()
208    } else {
209        text.to_string()
210    }
211}
212
213fn first_named_child<'a>(node: &'a tree_sitter::Node<'a>) -> Option<tree_sitter::Node<'a>> {
214    let mut cursor = node.walk();
215    let result = node.named_children(&mut cursor).next();
216    result
217}
218
219fn named_child_count(node: &tree_sitter::Node) -> usize {
220    let mut cursor = node.walk();
221    let count = node.named_children(&mut cursor).count();
222    count
223}
224
225// ---------------------------------------------------------------------------
226// Parser cache (thread-local)
227// ---------------------------------------------------------------------------
228
229thread_local! {
230    static TS_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
231    static TSX_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
232}
233
234fn parse_ts(source: &str, file: &RelativePath) -> Option<tree_sitter::Tree> {
235    let path: &str = file.as_ref();
236    let is_tsx = path.ends_with(".tsx") || path.ends_with(".jsx");
237
238    if is_tsx {
239        TSX_PARSER.with(|cell| {
240            let mut opt = cell.borrow_mut();
241            let parser = opt.get_or_insert_with(|| {
242                let mut p = tree_sitter::Parser::new();
243                p.set_language(&tree_sitter_typescript::LANGUAGE_TSX.into())
244                    .expect("Failed to set TSX language");
245                p
246            });
247            parser.parse(source, None)
248        })
249    } else {
250        TS_PARSER.with(|cell| {
251            let mut opt = cell.borrow_mut();
252            let parser = opt.get_or_insert_with(|| {
253                let mut p = tree_sitter::Parser::new();
254                p.set_language(&tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into())
255                    .expect("Failed to set TypeScript language");
256                p
257            });
258            parser.parse(source, None)
259        })
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    fn run(source: &str) -> Vec<Annotation> {
268        let file = RelativePath::from("src/routes.ts");
269        ExpressExtractor.extract(source, &file)
270    }
271
272    #[test]
273    fn detects_router_instance() {
274        // Arrange
275        let source = r#"import { Router } from "express";
276const router = Router();
277export default router;"#;
278
279        // Act
280        let anns = run(source);
281
282        // Assert
283        let routers: Vec<_> = anns.iter().filter(|a| a.tag == "router").collect();
284        assert_eq!(routers.len(), 1, "should find 1 router instance");
285        assert_eq!(routers[0].binding, "router", "binding should be router");
286    }
287
288    #[test]
289    fn detects_route_calls() {
290        // Arrange
291        let source = r#"
292const router = Router();
293router.get("/users", getUsers);
294router.post("/users", createUser);
295router.delete("/users/:id", deleteUser);
296"#;
297
298        // Act
299        let anns = run(source);
300
301        // Assert
302        let routes: Vec<_> = anns.iter().filter(|a| a.tag == "route").collect();
303        assert_eq!(routes.len(), 3, "should find 3 routes");
304        assert_eq!(routes[0].binding, "GET /users", "first route binding");
305        assert_eq!(routes[1].binding, "POST /users", "second route binding");
306        assert_eq!(
307            routes[2].binding, "DELETE /users/:id",
308            "third route binding"
309        );
310        assert_eq!(
311            routes[0].attrs.get("method"),
312            Some(&JsonValue::String("GET".to_string())),
313            "method attr should be GET"
314        );
315    }
316
317    #[test]
318    fn detects_middleware() {
319        // Arrange
320        let source = r#"
321const app = express();
322app.use("/api", cors());
323"#;
324
325        // Act
326        let anns = run(source);
327
328        // Assert
329        let mw: Vec<_> = anns.iter().filter(|a| a.tag == "middleware").collect();
330        assert_eq!(mw.len(), 1, "should find 1 middleware");
331        assert_eq!(mw[0].binding, "USE /api", "middleware binding");
332        assert_eq!(
333            mw[0].attrs.get("path"),
334            Some(&JsonValue::String("/api".to_string())),
335            "path attr"
336        );
337    }
338
339    #[test]
340    fn detects_route_with_middleware() {
341        // Arrange
342        let source = r#"router.get("/admin", authMiddleware, adminHandler, responseHandler);"#;
343
344        // Act
345        let anns = run(source);
346
347        // Assert
348        let routes: Vec<_> = anns.iter().filter(|a| a.tag == "route").collect();
349        assert_eq!(routes.len(), 1, "should find 1 route");
350        assert_eq!(
351            routes[0].attrs.get("hasMiddleware"),
352            Some(&JsonValue::Bool(true)),
353            "should mark hasMiddleware"
354        );
355    }
356
357    #[test]
358    fn detects_chained_routes() {
359        // Arrange
360        let source = r#"
361const router = Router();
362router.get("/items", listItems);
363router.post("/items", createItem);
364router.put("/items/:id", updateItem);
365router.patch("/items/:id", patchItem);
366router.delete("/items/:id", deleteItem);
367"#;
368
369        // Act
370        let anns = run(source);
371
372        // Assert
373        let routes: Vec<_> = anns.iter().filter(|a| a.tag == "route").collect();
374        assert_eq!(routes.len(), 5, "should find all 5 HTTP method routes");
375        assert_eq!(routes[2].binding, "PUT /items/:id", "PUT route binding");
376        assert_eq!(routes[3].binding, "PATCH /items/:id", "PATCH route binding");
377    }
378
379    #[test]
380    fn detects_pathless_middleware() {
381        // Arrange — middleware with template string path
382        let source = r#"
383app.use(`/api/v1`, rateLimiter);
384"#;
385
386        // Act
387        let anns = run(source);
388
389        // Assert
390        let mw: Vec<_> = anns.iter().filter(|a| a.tag == "middleware").collect();
391        assert_eq!(mw.len(), 1, "should find middleware with template path");
392        assert_eq!(mw[0].binding, "USE /api/v1", "template string path");
393    }
394}