Skip to main content

pecto_python/extractors/
controller.rs

1use super::common::*;
2use crate::context::ParsedFile;
3use pecto_core::model::*;
4use std::collections::BTreeMap;
5
6/// Extract endpoints from a Python file containing FastAPI, Flask, or Django routes.
7pub fn extract(file: &ParsedFile) -> Option<Capability> {
8    let root = file.tree.root_node();
9    let source = file.source.as_bytes();
10
11    let mut endpoints = Vec::new();
12
13    for i in 0..root.named_child_count() {
14        let node = root.named_child(i).unwrap();
15
16        if node.kind() == "decorated_definition" {
17            let decorators = collect_decorators(&node, source);
18            let inner = match get_inner_definition(&node) {
19                Some(n) => n,
20                None => continue,
21            };
22
23            if inner.kind() == "function_definition" {
24                // FastAPI/Flask route decorators
25                for dec in &decorators {
26                    if let Some(endpoint) = extract_route_endpoint(&inner, source, dec) {
27                        endpoints.push(endpoint);
28                    }
29                }
30
31                // Django REST Framework @api_view
32                if let Some(endpoint) = extract_drf_api_view(&inner, source, &decorators) {
33                    endpoints.push(endpoint);
34                }
35            }
36
37            if inner.kind() == "class_definition" {
38                // Django REST Framework ViewSets
39                extract_drf_viewset(&inner, source, &decorators, &mut endpoints);
40            }
41        }
42    }
43
44    if endpoints.is_empty() {
45        return None;
46    }
47
48    // Derive capability name from file
49    let file_stem = file
50        .path
51        .rsplit('/')
52        .next()
53        .unwrap_or(&file.path)
54        .trim_end_matches(".py");
55    let capability_name = to_kebab_case(
56        &file_stem
57            .replace("_routes", "")
58            .replace("_views", "")
59            .replace("_router", "")
60            .replace("_api", ""),
61    );
62
63    let mut capability = Capability::new(capability_name, file.path.clone());
64    capability.endpoints = endpoints;
65    Some(capability)
66}
67
68/// Extract endpoint from FastAPI/Flask route decorator.
69fn extract_route_endpoint(
70    func_node: &tree_sitter::Node,
71    source: &[u8],
72    decorator: &DecoratorInfo,
73) -> Option<Endpoint> {
74    let http_method = match decorator.name.as_str() {
75        "get" | "GET" => HttpMethod::Get,
76        "post" | "POST" => HttpMethod::Post,
77        "put" | "PUT" => HttpMethod::Put,
78        "delete" | "DELETE" => HttpMethod::Delete,
79        "patch" | "PATCH" => HttpMethod::Patch,
80        "route" => {
81            // Flask @app.route — check methods kwarg
82            extract_flask_method(&decorator.args)
83        }
84        _ => return None,
85    };
86
87    // Don't match standalone decorators that aren't route-like
88    if !decorator.full_name.contains('.')
89        && !matches!(
90            decorator.name.as_str(),
91            "get" | "post" | "put" | "delete" | "patch"
92        )
93    {
94        return None;
95    }
96
97    // Extract path from first argument
98    let path = decorator
99        .args
100        .first()
101        .map(|a| clean_string_literal(a))
102        .unwrap_or_default();
103
104    if path.is_empty() && decorator.name != "route" {
105        return None;
106    }
107
108    let _func_name = get_def_name(func_node, source);
109
110    // Extract parameters from function signature
111    let input = extract_function_params(func_node, source);
112
113    // Check for security dependencies (FastAPI Depends)
114    let security = extract_security(func_node, source);
115
116    let behaviors = vec![Behavior {
117        name: "success".to_string(),
118        condition: None,
119        returns: ResponseSpec {
120            status: default_status(&http_method),
121            body: extract_return_type(func_node, source),
122        },
123        side_effects: Vec::new(),
124    }];
125
126    Some(Endpoint {
127        method: http_method,
128        path,
129        input,
130        validation: Vec::new(),
131        behaviors,
132        security,
133    })
134}
135
136/// Extract HTTP method from Flask @app.route(methods=["GET", "POST"])
137fn extract_flask_method(args: &[String]) -> HttpMethod {
138    for arg in args {
139        if arg.contains("POST") {
140            return HttpMethod::Post;
141        }
142        if arg.contains("PUT") {
143            return HttpMethod::Put;
144        }
145        if arg.contains("DELETE") {
146            return HttpMethod::Delete;
147        }
148        if arg.contains("PATCH") {
149            return HttpMethod::Patch;
150        }
151    }
152    HttpMethod::Get
153}
154
155/// Extract endpoint from DRF @api_view decorator.
156fn extract_drf_api_view(
157    func_node: &tree_sitter::Node,
158    source: &[u8],
159    decorators: &[DecoratorInfo],
160) -> Option<Endpoint> {
161    let api_view = decorators.iter().find(|d| d.name == "api_view")?;
162
163    let method = if !api_view.args.is_empty() {
164        extract_flask_method(&api_view.args)
165    } else {
166        HttpMethod::Get
167    };
168
169    let func_name = get_def_name(func_node, source);
170    let path = format!("/{}", func_name.replace('_', "-"));
171
172    Some(Endpoint {
173        method,
174        path,
175        input: None,
176        validation: Vec::new(),
177        behaviors: vec![Behavior {
178            name: "success".to_string(),
179            condition: None,
180            returns: ResponseSpec {
181                status: 200,
182                body: None,
183            },
184            side_effects: Vec::new(),
185        }],
186        security: None,
187    })
188}
189
190/// Extract CRUD endpoints from DRF ViewSet class.
191fn extract_drf_viewset(
192    class_node: &tree_sitter::Node,
193    source: &[u8],
194    _decorators: &[DecoratorInfo],
195    endpoints: &mut Vec<Endpoint>,
196) {
197    // Check if class inherits from known ViewSet bases
198    let bases = get_class_bases(class_node, source);
199    let is_viewset = bases.iter().any(|b| {
200        b.contains("ViewSet")
201            || b.contains("ModelViewSet")
202            || b.contains("GenericAPIView")
203            || b.contains("APIView")
204    });
205
206    if !is_viewset {
207        return;
208    }
209
210    let class_name = get_def_name(class_node, source);
211    let base_path = format!(
212        "/{}",
213        to_kebab_case(&class_name.replace("ViewSet", "").replace("View", ""))
214    );
215
216    let is_model_viewset = bases.iter().any(|b| b.contains("ModelViewSet"));
217
218    if is_model_viewset {
219        // ModelViewSet provides standard CRUD
220        let crud = [
221            (HttpMethod::Get, format!("{}/", base_path), "list"),
222            (HttpMethod::Post, format!("{}/", base_path), "create"),
223            (HttpMethod::Get, format!("{}/:id/", base_path), "retrieve"),
224            (HttpMethod::Put, format!("{}/:id/", base_path), "update"),
225            (HttpMethod::Delete, format!("{}/:id/", base_path), "destroy"),
226        ];
227        for (method, path, _name) in crud {
228            endpoints.push(Endpoint {
229                method,
230                path,
231                input: None,
232                validation: Vec::new(),
233                behaviors: vec![Behavior {
234                    name: "success".to_string(),
235                    condition: None,
236                    returns: ResponseSpec {
237                        status: 200,
238                        body: None,
239                    },
240                    side_effects: Vec::new(),
241                }],
242                security: None,
243            });
244        }
245    }
246}
247
248fn get_class_bases(class_node: &tree_sitter::Node, source: &[u8]) -> Vec<String> {
249    let mut bases = Vec::new();
250    if let Some(arg_list) = class_node.child_by_field_name("superclasses") {
251        for i in 0..arg_list.named_child_count() {
252            let arg = arg_list.named_child(i).unwrap();
253            bases.push(node_text(&arg, source));
254        }
255    }
256    bases
257}
258
259fn extract_function_params(func_node: &tree_sitter::Node, source: &[u8]) -> Option<EndpointInput> {
260    let params = func_node.child_by_field_name("parameters")?;
261    let mut path_params = Vec::new();
262    let mut body = None;
263
264    for i in 0..params.named_child_count() {
265        let param = params.named_child(i).unwrap();
266
267        let param_name = match param.kind() {
268            "typed_parameter" | "typed_default_parameter" => param
269                .child_by_field_name("name")
270                .map(|n| node_text(&n, source))
271                .unwrap_or_default(),
272            "identifier" => node_text(&param, source),
273            _ => continue,
274        };
275
276        // Skip self, request, db, response
277        if matches!(
278            param_name.as_str(),
279            "self" | "request" | "db" | "response" | "session"
280        ) {
281            continue;
282        }
283
284        // Check type annotation
285        let type_ann = param
286            .child_by_field_name("type")
287            .map(|t| node_text(&t, source));
288
289        if let Some(ref t) = type_ann {
290            // If type looks like a model (PascalCase), it's a body
291            if t.chars().next().is_some_and(|c| c.is_uppercase())
292                && !t.starts_with("Optional")
293                && !t.starts_with("int")
294                && !t.starts_with("str")
295                && !t.starts_with("float")
296                && !t.starts_with("bool")
297            {
298                body = Some(TypeRef {
299                    name: t.clone(),
300                    fields: BTreeMap::new(),
301                });
302                continue;
303            }
304        }
305
306        // Simple types are path params
307        if !param_name.is_empty() {
308            path_params.push(Param {
309                name: param_name,
310                param_type: type_ann.unwrap_or_else(|| "str".to_string()),
311                required: true,
312            });
313        }
314    }
315
316    if body.is_none() && path_params.is_empty() {
317        return None;
318    }
319
320    Some(EndpointInput {
321        body,
322        path_params,
323        query_params: Vec::new(),
324    })
325}
326
327fn extract_security(func_node: &tree_sitter::Node, source: &[u8]) -> Option<SecurityConfig> {
328    let params = func_node.child_by_field_name("parameters")?;
329    let text = node_text(&params, source);
330
331    // FastAPI: Depends(get_current_user)
332    if text.contains("Depends") && (text.contains("current_user") || text.contains("auth")) {
333        return Some(SecurityConfig {
334            authentication: Some("required".to_string()),
335            roles: Vec::new(),
336            rate_limit: None,
337            cors: None,
338        });
339    }
340
341    None
342}
343
344fn extract_return_type(func_node: &tree_sitter::Node, source: &[u8]) -> Option<TypeRef> {
345    let return_type = func_node.child_by_field_name("return_type")?;
346    let type_text = node_text(&return_type, source);
347
348    if type_text == "None" || type_text.is_empty() {
349        return None;
350    }
351
352    Some(TypeRef {
353        name: type_text,
354        fields: BTreeMap::new(),
355    })
356}
357
358fn default_status(method: &HttpMethod) -> u16 {
359    match method {
360        HttpMethod::Post => 201,
361        HttpMethod::Delete => 204,
362        _ => 200,
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use crate::context::ParsedFile;
370
371    fn parse_file(source: &str, path: &str) -> ParsedFile {
372        ParsedFile::parse(source.to_string(), path.to_string()).unwrap()
373    }
374
375    #[test]
376    fn test_fastapi_routes() {
377        let source = r#"
378from fastapi import APIRouter, Depends
379
380router = APIRouter()
381
382@router.get("/users/{user_id}")
383async def get_user(user_id: int) -> User:
384    return db.get(user_id)
385
386@router.post("/users")
387async def create_user(user: UserCreate, current_user: User = Depends(get_current_user)):
388    return db.create(user)
389
390@router.delete("/users/{user_id}")
391async def delete_user(user_id: int):
392    db.delete(user_id)
393"#;
394
395        let file = parse_file(source, "routes/users.py");
396        let capability = extract(&file).unwrap();
397
398        assert_eq!(capability.name, "users");
399        assert_eq!(capability.endpoints.len(), 3);
400
401        let get = &capability.endpoints[0];
402        assert!(matches!(get.method, HttpMethod::Get));
403        assert_eq!(get.path, "/users/{user_id}");
404
405        let post = &capability.endpoints[1];
406        assert!(matches!(post.method, HttpMethod::Post));
407
408        let delete = &capability.endpoints[2];
409        assert!(matches!(delete.method, HttpMethod::Delete));
410    }
411
412    #[test]
413    fn test_flask_routes() {
414        let source = r#"
415from flask import Blueprint
416
417bp = Blueprint('items', __name__)
418
419@bp.route("/items", methods=["GET"])
420def list_items():
421    return jsonify(items)
422
423@bp.route("/items", methods=["POST"])
424def create_item():
425    return jsonify(item), 201
426"#;
427
428        let file = parse_file(source, "views/items.py");
429        let capability = extract(&file).unwrap();
430
431        assert_eq!(capability.name, "items");
432        assert_eq!(capability.endpoints.len(), 2);
433
434        assert!(matches!(capability.endpoints[0].method, HttpMethod::Get));
435        assert!(matches!(capability.endpoints[1].method, HttpMethod::Post));
436    }
437
438    #[test]
439    fn test_drf_api_view() {
440        let source = r#"
441from rest_framework.decorators import api_view
442
443@api_view(['GET', 'POST'])
444def user_list(request):
445    pass
446"#;
447
448        let file = parse_file(source, "views/users.py");
449        let capability = extract(&file).unwrap();
450
451        assert_eq!(capability.endpoints.len(), 1);
452    }
453
454    #[test]
455    fn test_no_routes() {
456        let source = r#"
457def helper_function():
458    return 42
459
460class Calculator:
461    def add(self, a, b):
462        return a + b
463"#;
464
465        let file = parse_file(source, "utils.py");
466        assert!(extract(&file).is_none());
467    }
468}