1use super::common::*;
2use crate::context::ParsedFile;
3use pecto_core::model::*;
4use std::collections::BTreeMap;
5
6pub fn extract(file: &ParsedFile) -> Option<Capability> {
8 let root = file.tree.root_node();
9 let source = file.source.as_bytes();
10
11 let mut endpoints = Vec::new();
12
13 for i in 0..root.named_child_count() {
14 let node = root.named_child(i).unwrap();
15
16 if node.kind() == "decorated_definition" {
17 let decorators = collect_decorators(&node, source);
18 let inner = match get_inner_definition(&node) {
19 Some(n) => n,
20 None => continue,
21 };
22
23 if inner.kind() == "function_definition" {
24 for dec in &decorators {
26 if let Some(endpoint) = extract_route_endpoint(&inner, source, dec) {
27 endpoints.push(endpoint);
28 }
29 }
30
31 if let Some(endpoint) = extract_drf_api_view(&inner, source, &decorators) {
33 endpoints.push(endpoint);
34 }
35 }
36
37 if inner.kind() == "class_definition" {
38 extract_drf_viewset(&inner, source, &decorators, &mut endpoints);
40 }
41 }
42 }
43
44 if endpoints.is_empty() {
45 return None;
46 }
47
48 let file_stem = file
50 .path
51 .rsplit('/')
52 .next()
53 .unwrap_or(&file.path)
54 .trim_end_matches(".py");
55 let capability_name = to_kebab_case(
56 &file_stem
57 .replace("_routes", "")
58 .replace("_views", "")
59 .replace("_router", "")
60 .replace("_api", ""),
61 );
62
63 let mut capability = Capability::new(capability_name, file.path.clone());
64 capability.endpoints = endpoints;
65 Some(capability)
66}
67
68fn extract_route_endpoint(
70 func_node: &tree_sitter::Node,
71 source: &[u8],
72 decorator: &DecoratorInfo,
73) -> Option<Endpoint> {
74 let http_method = match decorator.name.as_str() {
75 "get" | "GET" => HttpMethod::Get,
76 "post" | "POST" => HttpMethod::Post,
77 "put" | "PUT" => HttpMethod::Put,
78 "delete" | "DELETE" => HttpMethod::Delete,
79 "patch" | "PATCH" => HttpMethod::Patch,
80 "route" => {
81 extract_flask_method(&decorator.args)
83 }
84 _ => return None,
85 };
86
87 if !decorator.full_name.contains('.')
89 && !matches!(
90 decorator.name.as_str(),
91 "get" | "post" | "put" | "delete" | "patch"
92 )
93 {
94 return None;
95 }
96
97 let path = decorator
99 .args
100 .first()
101 .map(|a| clean_string_literal(a))
102 .unwrap_or_default();
103
104 if path.is_empty() && decorator.name != "route" {
105 return None;
106 }
107
108 let _func_name = get_def_name(func_node, source);
109
110 let input = extract_function_params(func_node, source);
112
113 let security = extract_security(func_node, source);
115
116 let behaviors = vec![Behavior {
117 name: "success".to_string(),
118 condition: None,
119 returns: ResponseSpec {
120 status: default_status(&http_method),
121 body: extract_return_type(func_node, source),
122 },
123 side_effects: Vec::new(),
124 }];
125
126 Some(Endpoint {
127 method: http_method,
128 path,
129 input,
130 validation: Vec::new(),
131 behaviors,
132 security,
133 })
134}
135
136fn extract_flask_method(args: &[String]) -> HttpMethod {
138 for arg in args {
139 if arg.contains("POST") {
140 return HttpMethod::Post;
141 }
142 if arg.contains("PUT") {
143 return HttpMethod::Put;
144 }
145 if arg.contains("DELETE") {
146 return HttpMethod::Delete;
147 }
148 if arg.contains("PATCH") {
149 return HttpMethod::Patch;
150 }
151 }
152 HttpMethod::Get
153}
154
155fn extract_drf_api_view(
157 func_node: &tree_sitter::Node,
158 source: &[u8],
159 decorators: &[DecoratorInfo],
160) -> Option<Endpoint> {
161 let api_view = decorators.iter().find(|d| d.name == "api_view")?;
162
163 let method = if !api_view.args.is_empty() {
164 extract_flask_method(&api_view.args)
165 } else {
166 HttpMethod::Get
167 };
168
169 let func_name = get_def_name(func_node, source);
170 let path = format!("/{}", func_name.replace('_', "-"));
171
172 Some(Endpoint {
173 method,
174 path,
175 input: None,
176 validation: Vec::new(),
177 behaviors: vec![Behavior {
178 name: "success".to_string(),
179 condition: None,
180 returns: ResponseSpec {
181 status: 200,
182 body: None,
183 },
184 side_effects: Vec::new(),
185 }],
186 security: None,
187 })
188}
189
190fn extract_drf_viewset(
192 class_node: &tree_sitter::Node,
193 source: &[u8],
194 _decorators: &[DecoratorInfo],
195 endpoints: &mut Vec<Endpoint>,
196) {
197 let bases = get_class_bases(class_node, source);
199 let is_viewset = bases.iter().any(|b| {
200 b.contains("ViewSet")
201 || b.contains("ModelViewSet")
202 || b.contains("GenericAPIView")
203 || b.contains("APIView")
204 });
205
206 if !is_viewset {
207 return;
208 }
209
210 let class_name = get_def_name(class_node, source);
211 let base_path = format!(
212 "/{}",
213 to_kebab_case(&class_name.replace("ViewSet", "").replace("View", ""))
214 );
215
216 let is_model_viewset = bases.iter().any(|b| b.contains("ModelViewSet"));
217
218 if is_model_viewset {
219 let crud = [
221 (HttpMethod::Get, format!("{}/", base_path), "list"),
222 (HttpMethod::Post, format!("{}/", base_path), "create"),
223 (HttpMethod::Get, format!("{}/:id/", base_path), "retrieve"),
224 (HttpMethod::Put, format!("{}/:id/", base_path), "update"),
225 (HttpMethod::Delete, format!("{}/:id/", base_path), "destroy"),
226 ];
227 for (method, path, _name) in crud {
228 endpoints.push(Endpoint {
229 method,
230 path,
231 input: None,
232 validation: Vec::new(),
233 behaviors: vec![Behavior {
234 name: "success".to_string(),
235 condition: None,
236 returns: ResponseSpec {
237 status: 200,
238 body: None,
239 },
240 side_effects: Vec::new(),
241 }],
242 security: None,
243 });
244 }
245 }
246}
247
248fn get_class_bases(class_node: &tree_sitter::Node, source: &[u8]) -> Vec<String> {
249 let mut bases = Vec::new();
250 if let Some(arg_list) = class_node.child_by_field_name("superclasses") {
251 for i in 0..arg_list.named_child_count() {
252 let arg = arg_list.named_child(i).unwrap();
253 bases.push(node_text(&arg, source));
254 }
255 }
256 bases
257}
258
259fn extract_function_params(func_node: &tree_sitter::Node, source: &[u8]) -> Option<EndpointInput> {
260 let params = func_node.child_by_field_name("parameters")?;
261 let mut path_params = Vec::new();
262 let mut body = None;
263
264 for i in 0..params.named_child_count() {
265 let param = params.named_child(i).unwrap();
266
267 let param_name = match param.kind() {
268 "typed_parameter" | "typed_default_parameter" => param
269 .child_by_field_name("name")
270 .map(|n| node_text(&n, source))
271 .unwrap_or_default(),
272 "identifier" => node_text(¶m, source),
273 _ => continue,
274 };
275
276 if matches!(
278 param_name.as_str(),
279 "self" | "request" | "db" | "response" | "session"
280 ) {
281 continue;
282 }
283
284 let type_ann = param
286 .child_by_field_name("type")
287 .map(|t| node_text(&t, source));
288
289 if let Some(ref t) = type_ann {
290 if t.chars().next().is_some_and(|c| c.is_uppercase())
292 && !t.starts_with("Optional")
293 && !t.starts_with("int")
294 && !t.starts_with("str")
295 && !t.starts_with("float")
296 && !t.starts_with("bool")
297 {
298 body = Some(TypeRef {
299 name: t.clone(),
300 fields: BTreeMap::new(),
301 });
302 continue;
303 }
304 }
305
306 if !param_name.is_empty() {
308 path_params.push(Param {
309 name: param_name,
310 param_type: type_ann.unwrap_or_else(|| "str".to_string()),
311 required: true,
312 });
313 }
314 }
315
316 if body.is_none() && path_params.is_empty() {
317 return None;
318 }
319
320 Some(EndpointInput {
321 body,
322 path_params,
323 query_params: Vec::new(),
324 })
325}
326
327fn extract_security(func_node: &tree_sitter::Node, source: &[u8]) -> Option<SecurityConfig> {
328 let params = func_node.child_by_field_name("parameters")?;
329 let text = node_text(¶ms, source);
330
331 if text.contains("Depends") && (text.contains("current_user") || text.contains("auth")) {
333 return Some(SecurityConfig {
334 authentication: Some("required".to_string()),
335 roles: Vec::new(),
336 rate_limit: None,
337 cors: None,
338 });
339 }
340
341 None
342}
343
344fn extract_return_type(func_node: &tree_sitter::Node, source: &[u8]) -> Option<TypeRef> {
345 let return_type = func_node.child_by_field_name("return_type")?;
346 let type_text = node_text(&return_type, source);
347
348 if type_text == "None" || type_text.is_empty() {
349 return None;
350 }
351
352 Some(TypeRef {
353 name: type_text,
354 fields: BTreeMap::new(),
355 })
356}
357
358fn default_status(method: &HttpMethod) -> u16 {
359 match method {
360 HttpMethod::Post => 201,
361 HttpMethod::Delete => 204,
362 _ => 200,
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use crate::context::ParsedFile;
370
371 fn parse_file(source: &str, path: &str) -> ParsedFile {
372 ParsedFile::parse(source.to_string(), path.to_string()).unwrap()
373 }
374
375 #[test]
376 fn test_fastapi_routes() {
377 let source = r#"
378from fastapi import APIRouter, Depends
379
380router = APIRouter()
381
382@router.get("/users/{user_id}")
383async def get_user(user_id: int) -> User:
384 return db.get(user_id)
385
386@router.post("/users")
387async def create_user(user: UserCreate, current_user: User = Depends(get_current_user)):
388 return db.create(user)
389
390@router.delete("/users/{user_id}")
391async def delete_user(user_id: int):
392 db.delete(user_id)
393"#;
394
395 let file = parse_file(source, "routes/users.py");
396 let capability = extract(&file).unwrap();
397
398 assert_eq!(capability.name, "users");
399 assert_eq!(capability.endpoints.len(), 3);
400
401 let get = &capability.endpoints[0];
402 assert!(matches!(get.method, HttpMethod::Get));
403 assert_eq!(get.path, "/users/{user_id}");
404
405 let post = &capability.endpoints[1];
406 assert!(matches!(post.method, HttpMethod::Post));
407
408 let delete = &capability.endpoints[2];
409 assert!(matches!(delete.method, HttpMethod::Delete));
410 }
411
412 #[test]
413 fn test_flask_routes() {
414 let source = r#"
415from flask import Blueprint
416
417bp = Blueprint('items', __name__)
418
419@bp.route("/items", methods=["GET"])
420def list_items():
421 return jsonify(items)
422
423@bp.route("/items", methods=["POST"])
424def create_item():
425 return jsonify(item), 201
426"#;
427
428 let file = parse_file(source, "views/items.py");
429 let capability = extract(&file).unwrap();
430
431 assert_eq!(capability.name, "items");
432 assert_eq!(capability.endpoints.len(), 2);
433
434 assert!(matches!(capability.endpoints[0].method, HttpMethod::Get));
435 assert!(matches!(capability.endpoints[1].method, HttpMethod::Post));
436 }
437
438 #[test]
439 fn test_drf_api_view() {
440 let source = r#"
441from rest_framework.decorators import api_view
442
443@api_view(['GET', 'POST'])
444def user_list(request):
445 pass
446"#;
447
448 let file = parse_file(source, "views/users.py");
449 let capability = extract(&file).unwrap();
450
451 assert_eq!(capability.endpoints.len(), 1);
452 }
453
454 #[test]
455 fn test_no_routes() {
456 let source = r#"
457def helper_function():
458 return 42
459
460class Calculator:
461 def add(self, a, b):
462 return a + b
463"#;
464
465 let file = parse_file(source, "utils.py");
466 assert!(extract(&file).is_none());
467 }
468}