openapi_from_source/extractor/
actix.rs1use crate::extractor::{
2 HttpMethod, Parameter, ParameterLocation, RouteExtractor, RouteInfo, TypeInfo,
3};
4use crate::parser::ParsedFile;
5use syn::{visit::Visit, Attribute, Expr, Lit, Meta};
6
7pub struct ActixExtractor;
9
10impl RouteExtractor for ActixExtractor {
11 fn extract_routes(&self, parsed_files: &[ParsedFile]) -> Vec<RouteInfo> {
12 let mut visitor = ActixVisitor::new();
13
14 for parsed_file in parsed_files {
16 visitor.visit_file(&parsed_file.syntax_tree);
17 }
18
19 visitor.analyze_handlers();
21
22 visitor.routes
23 }
24}
25
26struct ActixVisitor {
28 routes: Vec<RouteInfo>,
29 current_scope: String,
30 functions: std::collections::HashMap<String, syn::Signature>,
31}
32
33impl ActixVisitor {
34 fn new() -> Self {
35 Self {
36 routes: Vec::new(),
37 current_scope: String::new(),
38 functions: std::collections::HashMap::new(),
39 }
40 }
41
42 fn analyze_handlers(&mut self) {
44 let routes_to_update: Vec<_> = self
46 .routes
47 .iter()
48 .enumerate()
49 .map(|(idx, route)| (idx, route.handler_name.clone()))
50 .collect();
51
52 for (idx, handler_name) in routes_to_update {
53 if let Some(fn_sig) = self.functions.get(&handler_name) {
54 let (params, request_body) = self.parse_extractors(fn_sig);
55
56 let mut all_params = self.routes[idx].parameters.clone();
58 all_params.extend(params);
59
60 self.routes[idx].parameters = all_params;
61 self.routes[idx].request_body = request_body;
62 }
63 }
64 }
65
66 fn find_route_macros(&mut self, item_fn: &syn::ItemFn) {
68 let fn_name = item_fn.sig.ident.to_string();
69
70 for attr in &item_fn.attrs {
71 if let Some((method, path)) = self.parse_route_macro(attr) {
72 let full_path = self.combine_paths(&self.current_scope, &path);
73 let mut route = RouteInfo::new(full_path.clone(), method, fn_name.clone());
74 route.parameters = self.extract_path_parameters(&full_path);
75 self.routes.push(route);
76 }
77 }
78 }
79
80 fn parse_route_macro(&self, attr: &Attribute) -> Option<(HttpMethod, String)> {
82 let attr_name = attr.path().segments.last()?.ident.to_string();
84
85 let method = self.parse_http_method(&attr_name)?;
87
88 let path = match &attr.meta {
91 Meta::List(meta_list) => {
92 self.extract_path_from_tokens(&meta_list.tokens.to_string())
94 }
95 _ => None,
96 }?;
97
98 Some((method, path))
99 }
100
101 fn extract_path_from_tokens(&self, tokens: &str) -> Option<String> {
103 let cleaned = tokens.trim().trim_matches('"');
105 if cleaned.is_empty() {
106 None
107 } else {
108 Some(cleaned.to_string())
109 }
110 }
111
112 fn parse_http_method(&self, method: &str) -> Option<HttpMethod> {
114 match method.to_lowercase().as_str() {
115 "get" => Some(HttpMethod::Get),
116 "post" => Some(HttpMethod::Post),
117 "put" => Some(HttpMethod::Put),
118 "delete" => Some(HttpMethod::Delete),
119 "patch" => Some(HttpMethod::Patch),
120 "head" => Some(HttpMethod::Head),
121 "options" => Some(HttpMethod::Options),
122 _ => None,
123 }
124 }
125
126 fn combine_paths(&self, scope: &str, path: &str) -> String {
128 if scope.is_empty() {
129 return path.to_string();
130 }
131
132 let scope = scope.trim_end_matches('/');
133 let path = path.trim_start_matches('/');
134
135 if path.is_empty() {
136 scope.to_string()
137 } else {
138 format!("{}/{}", scope, path)
139 }
140 }
141
142 fn extract_path_parameters(&self, path: &str) -> Vec<Parameter> {
144 let mut parameters = Vec::new();
145
146 for segment in path.split('/') {
147 if segment.starts_with('{') && segment.ends_with('}') {
148 let param_name = segment
149 .trim_start_matches('{')
150 .trim_end_matches('}')
151 .to_string();
152 parameters.push(Parameter::new(
153 param_name,
154 ParameterLocation::Path,
155 TypeInfo::new("String".to_string()),
156 true,
157 ));
158 }
159 }
160
161 parameters
162 }
163
164 fn parse_extractors(&self, fn_sig: &syn::Signature) -> (Vec<Parameter>, Option<TypeInfo>) {
166 let mut parameters = Vec::new();
167 let mut request_body = None;
168
169 for input in &fn_sig.inputs {
170 if let syn::FnArg::Typed(pat_type) = input {
171 if let Some((extractor_type, inner_type)) = self.parse_extractor_type(&pat_type.ty)
173 {
174 match extractor_type.as_str() {
175 "Json" => {
176 request_body = Some(inner_type);
178 }
179 "Path" => {
180 parameters.push(Parameter::new(
182 "path_params".to_string(),
183 ParameterLocation::Path,
184 inner_type,
185 true,
186 ));
187 }
188 "Query" => {
189 parameters.push(Parameter::new(
191 "query_params".to_string(),
192 ParameterLocation::Query,
193 inner_type,
194 false,
195 ));
196 }
197 _ => {}
198 }
199 }
200 }
201 }
202
203 (parameters, request_body)
204 }
205
206 fn parse_extractor_type(&self, ty: &syn::Type) -> Option<(String, TypeInfo)> {
208 if let syn::Type::Path(type_path) = ty {
209 if let Some(segment) = type_path.path.segments.last() {
210 let extractor_name = segment.ident.to_string();
211
212 if matches!(extractor_name.as_str(), "Json" | "Path" | "Query") {
214 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
216 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
217 let type_info = self.extract_type_info(inner_ty);
218 return Some((extractor_name, type_info));
219 }
220 }
221 }
222 }
223 }
224 None
225 }
226
227 fn extract_type_info(&self, ty: &syn::Type) -> TypeInfo {
229 match ty {
230 syn::Type::Path(type_path) => {
231 if let Some(segment) = type_path.path.segments.last() {
232 let type_name = segment.ident.to_string();
233
234 if type_name == "Option" {
236 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
237 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
238 let inner_type_info = self.extract_type_info(inner_ty);
239 return TypeInfo::option(inner_type_info);
240 }
241 }
242 }
243
244 if type_name == "Vec" {
246 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
247 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
248 let inner_type_info = self.extract_type_info(inner_ty);
249 return TypeInfo::vec(inner_type_info);
250 }
251 }
252 }
253
254 TypeInfo::new(type_name)
256 } else {
257 TypeInfo::new("unknown".to_string())
258 }
259 }
260 _ => TypeInfo::new("unknown".to_string()),
261 }
262 }
263}
264
265impl<'ast> Visit<'ast> for ActixVisitor {
266 fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) {
267 let fn_name = node.sig.ident.to_string();
269 self.functions.insert(fn_name, node.sig.clone());
270
271 self.find_route_macros(node);
273
274 syn::visit::visit_item_fn(self, node);
276 }
277
278 fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) {
279 let method_name = node.method.to_string();
280
281 if method_name == "scope" {
283 if let Some(scope_path) = self.extract_scope_path(node) {
284 let old_scope = self.current_scope.clone();
285 self.current_scope = self.combine_paths(&old_scope, &scope_path);
286
287 syn::visit::visit_expr_method_call(self, node);
289
290 self.current_scope = old_scope;
292 return;
293 }
294 }
295
296 syn::visit::visit_expr_method_call(self, node);
298 }
299}
300
301impl ActixVisitor {
302 fn extract_scope_path(&self, expr: &syn::ExprMethodCall) -> Option<String> {
304 if expr.args.is_empty() {
306 return None;
307 }
308
309 self.extract_string_literal(&expr.args[0])
310 }
311
312 fn extract_string_literal(&self, expr: &Expr) -> Option<String> {
314 match expr {
315 Expr::Lit(expr_lit) => {
316 if let Lit::Str(lit_str) = &expr_lit.lit {
317 Some(lit_str.value())
318 } else {
319 None
320 }
321 }
322 _ => None,
323 }
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use std::path::PathBuf;
331
332 fn parse_code(code: &str) -> ParsedFile {
333 let syntax_tree = syn::parse_file(code).expect("Failed to parse test code");
334 ParsedFile {
335 path: PathBuf::from("test.rs"),
336 syntax_tree,
337 }
338 }
339
340 #[test]
341 fn test_simple_get_route() {
342 let code = r#"
343 use actix_web::{get, HttpResponse};
344
345 #[get("/hello")]
346 async fn hello() -> HttpResponse {
347 HttpResponse::Ok().body("Hello, World!")
348 }
349 "#;
350
351 let parsed = parse_code(code);
352 let extractor = ActixExtractor;
353 let routes = extractor.extract_routes(&[parsed]);
354
355 assert_eq!(routes.len(), 1);
356 assert_eq!(routes[0].path, "/hello");
357 assert_eq!(routes[0].method, HttpMethod::Get);
358 assert_eq!(routes[0].handler_name, "hello");
359 }
360
361 #[test]
362 fn test_multiple_http_methods() {
363 let code = r#"
364 use actix_web::{get, post, put, delete, patch, HttpResponse};
365
366 #[get("/resource")]
367 async fn get_resource() -> HttpResponse {
368 HttpResponse::Ok().finish()
369 }
370
371 #[post("/resource")]
372 async fn create_resource() -> HttpResponse {
373 HttpResponse::Created().finish()
374 }
375
376 #[put("/resource")]
377 async fn update_resource() -> HttpResponse {
378 HttpResponse::Ok().finish()
379 }
380
381 #[delete("/resource")]
382 async fn delete_resource() -> HttpResponse {
383 HttpResponse::NoContent().finish()
384 }
385
386 #[patch("/resource")]
387 async fn patch_resource() -> HttpResponse {
388 HttpResponse::Ok().finish()
389 }
390 "#;
391
392 let parsed = parse_code(code);
393 let extractor = ActixExtractor;
394 let routes = extractor.extract_routes(&[parsed]);
395
396 assert_eq!(routes.len(), 5);
397
398 let methods: Vec<_> = routes.iter().map(|r| &r.method).collect();
399 assert!(methods.contains(&&HttpMethod::Get));
400 assert!(methods.contains(&&HttpMethod::Post));
401 assert!(methods.contains(&&HttpMethod::Put));
402 assert!(methods.contains(&&HttpMethod::Delete));
403 assert!(methods.contains(&&HttpMethod::Patch));
404 }
405
406 #[test]
407 fn test_path_parameters() {
408 let code = r#"
409 use actix_web::{get, HttpResponse};
410
411 #[get("/users/{id}")]
412 async fn get_user() -> HttpResponse {
413 HttpResponse::Ok().finish()
414 }
415 "#;
416
417 let parsed = parse_code(code);
418 let extractor = ActixExtractor;
419 let routes = extractor.extract_routes(&[parsed]);
420
421 assert_eq!(routes.len(), 1);
422 assert_eq!(routes[0].path, "/users/{id}");
423 assert_eq!(routes[0].parameters.len(), 1);
424 assert_eq!(routes[0].parameters[0].name, "id");
425 assert_eq!(routes[0].parameters[0].location, ParameterLocation::Path);
426 assert!(routes[0].parameters[0].required);
427 }
428
429 #[test]
430 fn test_multiple_path_parameters() {
431 let code = r#"
432 use actix_web::{get, HttpResponse};
433
434 #[get("/posts/{post_id}/comments/{comment_id}")]
435 async fn get_comment() -> HttpResponse {
436 HttpResponse::Ok().finish()
437 }
438 "#;
439
440 let parsed = parse_code(code);
441 let extractor = ActixExtractor;
442 let routes = extractor.extract_routes(&[parsed]);
443
444 assert_eq!(routes.len(), 1);
445 assert_eq!(routes[0].path, "/posts/{post_id}/comments/{comment_id}");
446 assert_eq!(routes[0].parameters.len(), 2);
447
448 let param_names: Vec<_> = routes[0]
449 .parameters
450 .iter()
451 .map(|p| p.name.as_str())
452 .collect();
453 assert!(param_names.contains(&"post_id"));
454 assert!(param_names.contains(&"comment_id"));
455 }
456
457 #[test]
458 fn test_scope_handling() {
459 let code = r#"
460 use actix_web::{web, get, HttpResponse, App};
461
462 #[get("/users")]
463 async fn list_users() -> HttpResponse {
464 HttpResponse::Ok().finish()
465 }
466
467 #[get("/users/{id}")]
468 async fn get_user() -> HttpResponse {
469 HttpResponse::Ok().finish()
470 }
471
472 fn config(cfg: &mut web::ServiceConfig) {
473 cfg.service(
474 web::scope("/api")
475 .service(list_users)
476 .service(get_user)
477 );
478 }
479 "#;
480
481 let parsed = parse_code(code);
482 let extractor = ActixExtractor;
483 let routes = extractor.extract_routes(&[parsed]);
484
485 assert_eq!(routes.len(), 2);
489
490 let paths: Vec<_> = routes.iter().map(|r| r.path.as_str()).collect();
492 assert!(paths.contains(&"/users"));
493 assert!(paths.contains(&"/users/{id}"));
494 }
495
496 #[test]
497 fn test_json_extractor() {
498 let code = r#"
499 use actix_web::{post, web, HttpResponse};
500 use serde::Deserialize;
501
502 #[derive(Deserialize)]
503 struct CreateUser {
504 name: String,
505 email: String,
506 }
507
508 #[post("/users")]
509 async fn create_user(user: web::Json<CreateUser>) -> HttpResponse {
510 HttpResponse::Created().finish()
511 }
512 "#;
513
514 let parsed = parse_code(code);
515 let extractor = ActixExtractor;
516 let routes = extractor.extract_routes(&[parsed]);
517
518 assert_eq!(routes.len(), 1);
519 assert_eq!(routes[0].handler_name, "create_user");
520
521 assert!(routes[0].request_body.is_some());
523 if let Some(ref body) = routes[0].request_body {
524 assert_eq!(body.name, "CreateUser");
525 }
526 }
527
528 #[test]
529 fn test_path_extractor() {
530 let code = r#"
531 use actix_web::{get, web, HttpResponse};
532
533 #[get("/users/{id}")]
534 async fn get_user(path: web::Path<u32>) -> HttpResponse {
535 HttpResponse::Ok().finish()
536 }
537 "#;
538
539 let parsed = parse_code(code);
540 let extractor = ActixExtractor;
541 let routes = extractor.extract_routes(&[parsed]);
542
543 assert_eq!(routes.len(), 1);
544
545 let path_params: Vec<_> = routes[0]
547 .parameters
548 .iter()
549 .filter(|p| p.location == ParameterLocation::Path)
550 .collect();
551 assert!(!path_params.is_empty());
552 }
553
554 #[test]
555 fn test_query_extractor() {
556 let code = r#"
557 use actix_web::{get, web, HttpResponse};
558 use serde::Deserialize;
559
560 #[derive(Deserialize)]
561 struct Pagination {
562 page: u32,
563 limit: u32,
564 }
565
566 #[get("/users")]
567 async fn list_users(query: web::Query<Pagination>) -> HttpResponse {
568 HttpResponse::Ok().finish()
569 }
570 "#;
571
572 let parsed = parse_code(code);
573 let extractor = ActixExtractor;
574 let routes = extractor.extract_routes(&[parsed]);
575
576 assert_eq!(routes.len(), 1);
577
578 let query_params: Vec<_> = routes[0]
580 .parameters
581 .iter()
582 .filter(|p| p.location == ParameterLocation::Query)
583 .collect();
584 assert!(!query_params.is_empty());
585
586 if let Some(param) = query_params.first() {
587 assert_eq!(param.type_info.name, "Pagination");
588 }
589 }
590
591 #[test]
592 fn test_multiple_extractors() {
593 let code = r#"
594 use actix_web::{post, web, HttpResponse};
595 use serde::Deserialize;
596
597 #[derive(Deserialize)]
598 struct CreateComment {
599 text: String,
600 }
601
602 #[post("/posts/{id}/comments")]
603 async fn create_comment(
604 path: web::Path<u32>,
605 comment: web::Json<CreateComment>,
606 ) -> HttpResponse {
607 HttpResponse::Created().finish()
608 }
609 "#;
610
611 let parsed = parse_code(code);
612 let extractor = ActixExtractor;
613 let routes = extractor.extract_routes(&[parsed]);
614
615 assert_eq!(routes.len(), 1);
616
617 let path_params: Vec<_> = routes[0]
619 .parameters
620 .iter()
621 .filter(|p| p.location == ParameterLocation::Path)
622 .collect();
623 assert!(!path_params.is_empty());
624
625 assert!(routes[0].request_body.is_some());
627 if let Some(ref body) = routes[0].request_body {
628 assert_eq!(body.name, "CreateComment");
629 }
630 }
631
632 #[test]
633 fn test_nested_scope() {
634 let code = r#"
635 use actix_web::{web, get, HttpResponse};
636
637 #[get("/profile")]
638 async fn get_profile() -> HttpResponse {
639 HttpResponse::Ok().finish()
640 }
641
642 fn config(cfg: &mut web::ServiceConfig) {
643 cfg.service(
644 web::scope("/api")
645 .service(
646 web::scope("/v1")
647 .service(get_profile)
648 )
649 );
650 }
651 "#;
652
653 let parsed = parse_code(code);
654 let extractor = ActixExtractor;
655 let routes = extractor.extract_routes(&[parsed]);
656
657 assert_eq!(routes.len(), 1);
658 assert_eq!(routes[0].path, "/profile");
659 }
660
661 #[test]
662 fn test_route_without_parameters() {
663 let code = r#"
664 use actix_web::{get, HttpResponse};
665
666 #[get("/health")]
667 async fn health_check() -> HttpResponse {
668 HttpResponse::Ok().body("OK")
669 }
670 "#;
671
672 let parsed = parse_code(code);
673 let extractor = ActixExtractor;
674 let routes = extractor.extract_routes(&[parsed]);
675
676 assert_eq!(routes.len(), 1);
677 assert_eq!(routes[0].path, "/health");
678 assert_eq!(routes[0].method, HttpMethod::Get);
679 assert_eq!(routes[0].handler_name, "health_check");
680 assert!(routes[0].parameters.is_empty());
681 }
682
683 #[test]
684 fn test_complex_path() {
685 let code = r#"
686 use actix_web::{get, HttpResponse};
687
688 #[get("/api/v1/organizations/{org_id}/projects/{project_id}/tasks/{task_id}")]
689 async fn get_task() -> HttpResponse {
690 HttpResponse::Ok().finish()
691 }
692 "#;
693
694 let parsed = parse_code(code);
695 let extractor = ActixExtractor;
696 let routes = extractor.extract_routes(&[parsed]);
697
698 assert_eq!(routes.len(), 1);
699 assert_eq!(
700 routes[0].path,
701 "/api/v1/organizations/{org_id}/projects/{project_id}/tasks/{task_id}"
702 );
703 assert_eq!(routes[0].parameters.len(), 3);
704
705 let param_names: Vec<_> = routes[0]
706 .parameters
707 .iter()
708 .map(|p| p.name.as_str())
709 .collect();
710 assert!(param_names.contains(&"org_id"));
711 assert!(param_names.contains(&"project_id"));
712 assert!(param_names.contains(&"task_id"));
713 }
714}