Skip to main content

amql_engine/extractor/
go_http.rs

1//! Go HTTP route and middleware extractor.
2//!
3//! Detects HTTP handler registrations across common Go HTTP frameworks:
4//! net/http (`HandleFunc`, `Handle`), chi/gorilla (`Get`, `Post`, ...),
5//! and gin (`GET`, `POST`, ...). Uses tree-sitter-go for parsing.
6
7use super::BuiltinExtractor;
8use crate::store::Annotation;
9use crate::types::{AttrName, Binding, RelativePath, TagName};
10use rustc_hash::FxHashMap;
11use serde_json::Value as JsonValue;
12use std::cell::RefCell;
13
14/// HTTP method names recognized as route registrations (lowercase).
15const HTTP_METHODS: &[&str] = &["get", "post", "put", "delete", "patch", "options", "head"];
16
17/// Method names that register routes in net/http / gorilla style.
18const HANDLE_FUNCS: &[&str] = &["HandleFunc", "Handle"];
19
20/// Built-in Go HTTP route/middleware extractor.
21pub struct GoHttpExtractor;
22
23impl BuiltinExtractor for GoHttpExtractor {
24    fn name(&self) -> &str {
25        "go-http"
26    }
27
28    fn extensions(&self) -> &[&str] {
29        &[".go"]
30    }
31
32    fn extract(&self, source: &str, file: &RelativePath) -> Vec<Annotation> {
33        let tree = match parse_go(source) {
34            Some(t) => t,
35            None => return vec![],
36        };
37        let mut annotations = Vec::new();
38        visit_node(tree.root_node(), source.as_bytes(), file, &mut annotations);
39        annotations
40    }
41}
42
43/// Recursively visit tree-sitter nodes looking for route registrations.
44fn visit_node(
45    node: tree_sitter::Node,
46    src: &[u8],
47    file: &RelativePath,
48    annotations: &mut Vec<Annotation>,
49) {
50    if node.kind() == "call_expression" {
51        extract_route_call(node, src, file, annotations);
52    }
53
54    let mut cursor = node.walk();
55    for child in node.named_children(&mut cursor) {
56        visit_node(child, src, file, annotations);
57    }
58}
59
60/// Detect route/middleware registrations from a call expression.
61///
62/// Patterns:
63///   - `r.Get("/path", handler)` (chi style, case-insensitive method match)
64///   - `r.GET("/path", handler)` (gin style)
65///   - `http.HandleFunc("/path", handler)` (net/http)
66///   - `mux.HandleFunc("/path", handler)` (gorilla/mux)
67///   - `r.Use(middleware)` (middleware registration)
68///   - `r.Route("/path", func(r chi.Router) { ... })` (chi group)
69fn extract_route_call(
70    node: tree_sitter::Node,
71    src: &[u8],
72    file: &RelativePath,
73    annotations: &mut Vec<Annotation>,
74) {
75    // call_expression > selector_expression (object.method)
76    let callee = match node.child_by_field_name("function") {
77        Some(n) if n.kind() == "selector_expression" => n,
78        _ => return,
79    };
80    let method = match callee.child_by_field_name("field") {
81        Some(n) => node_text(&n, src),
82        None => return,
83    };
84    let args = match node.child_by_field_name("arguments") {
85        Some(n) => n,
86        None => return,
87    };
88
89    let method_lower = method.to_lowercase();
90
91    // Middleware: r.Use(...)
92    if method_lower == "use" {
93        let handler = first_named_child_text(&args, src).unwrap_or_default();
94        let mut attrs = FxHashMap::default();
95        attrs.insert(
96            AttrName::from("handler"),
97            JsonValue::String(handler.clone()),
98        );
99        annotations.push(Annotation {
100            tag: TagName::from("middleware"),
101            attrs,
102            binding: Binding::from(format!("USE {handler}")),
103            file: file.clone(),
104            children: vec![],
105        });
106        return;
107    }
108
109    // Route group: r.Route("/path", func(...) { ... })
110    if method == "Route" || method == "Group" {
111        if let Some(path) = first_string_arg(&args, src) {
112            let mut attrs = FxHashMap::default();
113            attrs.insert(AttrName::from("path"), JsonValue::String(path.clone()));
114            annotations.push(Annotation {
115                tag: TagName::from("route-group"),
116                attrs,
117                binding: Binding::from(format!("GROUP {path}")),
118                file: file.clone(),
119                children: vec![],
120            });
121        }
122        return;
123    }
124
125    // HTTP method routes: r.Get("/path", handler) or r.GET("/path", handler)
126    let is_http_method = HTTP_METHODS.contains(&method_lower.as_str());
127    let is_handle_func = HANDLE_FUNCS.contains(&method.as_str());
128
129    if !is_http_method && !is_handle_func {
130        return;
131    }
132
133    let path = match first_string_arg(&args, src) {
134        Some(p) => p,
135        None => return,
136    };
137
138    let handler_text = second_named_child_text(&args, src).unwrap_or_default();
139
140    let mut attrs = FxHashMap::default();
141    attrs.insert(AttrName::from("path"), JsonValue::String(path.clone()));
142    attrs.insert(AttrName::from("handler"), JsonValue::String(handler_text));
143
144    if is_http_method {
145        let upper = method_lower.to_uppercase();
146        attrs.insert(AttrName::from("method"), JsonValue::String(upper.clone()));
147        let binding = format!("{upper} {path}");
148
149        // Check for middleware args (> 2 named children in arg list)
150        if named_child_count(&args) > 2 {
151            attrs.insert(AttrName::from("hasMiddleware"), JsonValue::Bool(true));
152        }
153
154        annotations.push(Annotation {
155            tag: TagName::from("route"),
156            attrs,
157            binding: Binding::from(binding),
158            file: file.clone(),
159            children: vec![],
160        });
161    } else {
162        // HandleFunc / Handle — method unknown at registration time
163        let binding = format!("HANDLE {path}");
164        annotations.push(Annotation {
165            tag: TagName::from("route"),
166            attrs,
167            binding: Binding::from(binding),
168            file: file.clone(),
169            children: vec![],
170        });
171    }
172}
173
174// ---------------------------------------------------------------------------
175// tree-sitter helpers
176// ---------------------------------------------------------------------------
177
178fn node_text(node: &tree_sitter::Node, src: &[u8]) -> String {
179    node.utf8_text(src).unwrap_or("").to_string()
180}
181
182/// Extract the string value from an `interpreted_string_literal` or `raw_string_literal`.
183fn string_value(node: &tree_sitter::Node, src: &[u8]) -> String {
184    let text = node.utf8_text(src).unwrap_or("");
185    if text.len() >= 2 {
186        // Strip surrounding quotes: "..." or `...`
187        text[1..text.len() - 1].to_string()
188    } else {
189        text.to_string()
190    }
191}
192
193/// Return the first string literal argument from an argument list.
194fn first_string_arg(args: &tree_sitter::Node, src: &[u8]) -> Option<String> {
195    let mut cursor = args.walk();
196    for child in args.named_children(&mut cursor) {
197        if child.kind() == "interpreted_string_literal" || child.kind() == "raw_string_literal" {
198            return Some(string_value(&child, src));
199        }
200    }
201    None
202}
203
204/// Return the text of the first named child in an argument list.
205fn first_named_child_text(args: &tree_sitter::Node, src: &[u8]) -> Option<String> {
206    let mut cursor = args.walk();
207    let result = args
208        .named_children(&mut cursor)
209        .next()
210        .map(|n| node_text(&n, src));
211    result
212}
213
214/// Return the text of the second named child in an argument list.
215fn second_named_child_text(args: &tree_sitter::Node, src: &[u8]) -> Option<String> {
216    let mut cursor = args.walk();
217    let result = args
218        .named_children(&mut cursor)
219        .nth(1)
220        .map(|n| node_text(&n, src));
221    result
222}
223
224fn named_child_count(node: &tree_sitter::Node) -> usize {
225    let mut cursor = node.walk();
226    node.named_children(&mut cursor).count()
227}
228
229// ---------------------------------------------------------------------------
230// Parser cache (thread-local)
231// ---------------------------------------------------------------------------
232
233thread_local! {
234    static GO_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
235}
236
237fn parse_go(source: &str) -> Option<tree_sitter::Tree> {
238    GO_PARSER.with(|cell| {
239        let mut opt = cell.borrow_mut();
240        let parser = opt.get_or_insert_with(|| {
241            let mut p = tree_sitter::Parser::new();
242            p.set_language(&tree_sitter_go::LANGUAGE.into())
243                .expect("Failed to set Go language for tree-sitter");
244            p
245        });
246        parser.parse(source, None)
247    })
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    fn run(source: &str) -> Vec<Annotation> {
255        let file = RelativePath::from("server/routes.go");
256        GoHttpExtractor.extract(source, &file)
257    }
258
259    #[test]
260    fn detects_chi_style_routes() {
261        // Arrange
262        let source = r#"package main
263
264func main() {
265	r := chi.NewRouter()
266	r.Get("/users", listUsers)
267	r.Post("/users", createUser)
268	r.Delete("/users/{id}", deleteUser)
269}
270"#;
271
272        // Act
273        let anns = run(source);
274
275        // Assert
276        let routes: Vec<_> = anns.iter().filter(|a| a.tag == "route").collect();
277        assert_eq!(routes.len(), 3, "should find 3 routes");
278        assert_eq!(routes[0].binding, "GET /users", "first route binding");
279        assert_eq!(routes[1].binding, "POST /users", "second route binding");
280        assert_eq!(
281            routes[2].binding, "DELETE /users/{id}",
282            "third route binding"
283        );
284        assert_eq!(
285            routes[0].attrs.get("method"),
286            Some(&JsonValue::String("GET".to_string())),
287            "method attr"
288        );
289    }
290
291    #[test]
292    fn detects_gin_style_routes() {
293        // Arrange
294        let source = r#"package main
295
296func main() {
297	r := gin.Default()
298	r.GET("/health", healthCheck)
299	r.POST("/api/generate", generate)
300}
301"#;
302
303        // Act
304        let anns = run(source);
305
306        // Assert
307        let routes: Vec<_> = anns.iter().filter(|a| a.tag == "route").collect();
308        assert_eq!(routes.len(), 2, "should find 2 routes");
309        assert_eq!(routes[0].binding, "GET /health", "health route");
310        assert_eq!(routes[1].binding, "POST /api/generate", "generate route");
311    }
312
313    #[test]
314    fn detects_http_handle_func() {
315        // Arrange
316        let source = r#"package main
317
318func main() {
319	http.HandleFunc("/api/tags", TagsHandler)
320	http.HandleFunc("/api/chat", ChatHandler)
321	mux.Handle("/static/", fileServer)
322}
323"#;
324
325        // Act
326        let anns = run(source);
327
328        // Assert
329        let routes: Vec<_> = anns.iter().filter(|a| a.tag == "route").collect();
330        assert_eq!(routes.len(), 3, "should find 3 routes");
331        assert_eq!(routes[0].binding, "HANDLE /api/tags", "HandleFunc binding");
332        assert_eq!(routes[1].binding, "HANDLE /api/chat", "HandleFunc binding");
333        assert_eq!(routes[2].binding, "HANDLE /static/", "Handle binding");
334    }
335
336    #[test]
337    fn detects_middleware() {
338        // Arrange
339        let source = r#"package main
340
341func main() {
342	r := chi.NewRouter()
343	r.Use(middleware.Logger)
344	r.Use(middleware.Recoverer)
345}
346"#;
347
348        // Act
349        let anns = run(source);
350
351        // Assert
352        let mw: Vec<_> = anns.iter().filter(|a| a.tag == "middleware").collect();
353        assert_eq!(mw.len(), 2, "should find 2 middleware");
354        assert_eq!(mw[0].binding, "USE middleware.Logger", "logger middleware");
355    }
356
357    #[test]
358    fn detects_route_groups() {
359        // Arrange
360        let source = r#"package main
361
362func main() {
363	r := chi.NewRouter()
364	r.Route("/api/v1", func(r chi.Router) {
365		r.Get("/users", listUsers)
366	})
367}
368"#;
369
370        // Act
371        let anns = run(source);
372
373        // Assert
374        let groups: Vec<_> = anns.iter().filter(|a| a.tag == "route-group").collect();
375        assert_eq!(groups.len(), 1, "should find 1 route group");
376        assert_eq!(groups[0].binding, "GROUP /api/v1", "group binding");
377
378        let routes: Vec<_> = anns.iter().filter(|a| a.tag == "route").collect();
379        assert_eq!(routes.len(), 1, "should find nested route");
380    }
381
382    #[test]
383    fn detects_method_routes_with_receiver() {
384        // Arrange — routes registered on a struct method receiver
385        let source = r#"package server
386
387func (s *Server) setupRoutes() {
388	s.router.Get("/api/items", s.listItems)
389	s.router.Post("/api/items", s.createItem)
390}
391"#;
392
393        // Act
394        let anns = run(source);
395
396        // Assert
397        let routes: Vec<_> = anns.iter().filter(|a| a.tag == "route").collect();
398        assert_eq!(routes.len(), 2, "should find routes from method receiver");
399        assert_eq!(routes[0].binding, "GET /api/items", "first route");
400        assert_eq!(routes[1].binding, "POST /api/items", "second route");
401    }
402
403    #[test]
404    fn detects_routes_with_middleware_args() {
405        // Arrange — route with inline middleware
406        let source = r#"package main
407
408func main() {
409	r := chi.NewRouter()
410	r.Get("/admin", authMiddleware, rateLimiter, adminHandler)
411}
412"#;
413
414        // Act
415        let anns = run(source);
416
417        // Assert
418        let routes: Vec<_> = anns.iter().filter(|a| a.tag == "route").collect();
419        assert_eq!(routes.len(), 1, "should find 1 route");
420        assert_eq!(
421            routes[0].attrs.get("hasMiddleware"),
422            Some(&JsonValue::Bool(true)),
423            "should detect middleware args"
424        );
425    }
426}