register_actix_routes/
lib.rs

1extern crate proc_macro;
2extern crate tabled;
3use once_cell::sync::Lazy;
4use proc_macro::TokenStream;
5use quote::quote;
6use std::sync::RwLock;
7use syn::{parse_macro_input, ItemFn, LitStr};
8
9#[derive(Debug, Clone)]
10struct RouteInfo {
11    prefix: String,       // The scope or module key (e.g., "/events")
12    handler_name: String, // The name of the handler function
13    path: String,         // The route path (e.g., "/search")
14    verb: String,         // The HTTP method (e.g., "GET")
15}
16
17// Use a global RwLock map for storing registrations per unique module key
18static REGISTRATION_MAP: Lazy<RwLock<std::collections::HashMap<String, Vec<RouteInfo>>>> =
19    Lazy::new(|| RwLock::new(std::collections::HashMap::new()));
20
21#[proc_macro_attribute]
22pub fn auto_register(attr: TokenStream, item: TokenStream) -> TokenStream {
23    // Parse the input function
24    let input_fn = parse_macro_input!(item as ItemFn);
25    let fn_name = input_fn.sig.ident.to_string();
26
27    // Parse the prefix as a string literal
28    let prefix = if !attr.is_empty() {
29        let parsed_attr = parse_macro_input!(attr as syn::LitStr);
30        parsed_attr.value()
31    } else {
32        panic!("Expected a prefix (e.g., \"/scope\") as the argument to auto_register");
33    };
34
35    // Extract the route path and HTTP verb from the function attributes
36    let mut route_path = None;
37    let mut verb = None;
38
39    for attr in &input_fn.attrs {
40        if let Some(segment) = attr.path().segments.last() {
41            if ["get", "post", "put", "delete", "patch"]
42                .contains(&segment.ident.to_string().as_str())
43            {
44                verb = Some(segment.ident.to_string().to_uppercase());
45                if let Ok(route_literal) = attr.parse_args::<LitStr>() {
46                    route_path = Some(route_literal.value());
47                }
48            }
49        }
50    }
51
52    // Validate the extracted route path and HTTP verb
53    if route_path.is_none() || verb.is_none() {
54        panic!(
55            "Could not extract the route path or verb from attributes on function '{}'. Ensure it has a valid Actix route macro like \
56            #[get(\"/path\")].",
57            fn_name
58        );
59    }
60
61    // Use empty route path if valid (e.g., `""`)
62    let route_info = RouteInfo {
63        prefix: prefix.clone(),
64        handler_name: fn_name.clone(),
65        path: route_path.unwrap_or_else(|| "".to_string()),
66        verb: verb.unwrap(),
67    };
68
69    // Safely store the route information
70    let mut map = REGISTRATION_MAP
71        .write()
72        .expect("Failed to acquire write lock");
73    map.entry(prefix.clone()).or_default().push(route_info);
74
75    // Generate the original function definition
76    let expanded = quote! {
77        #input_fn
78    };
79
80    TokenStream::from(expanded)
81}
82
83#[proc_macro]
84pub fn generate_register_service(input: TokenStream) -> TokenStream {
85    // Parse the macro arguments (prefix and optional use_scope flag)
86    let args = parse_macro_input!(input as syn::ExprArray);
87    let module_key: Option<String>;
88    let mut use_scope = false; // Default to not using the prefix as the scope
89
90    // Parse the arguments
91    if let Some(syn::Expr::Lit(syn::ExprLit {
92        lit: syn::Lit::Str(lit_str),
93        ..
94    })) = args.elems.iter().next()
95    {
96        module_key = Some(lit_str.value());
97    } else {
98        panic!("Expected the first argument to be a string literal representing the module key.");
99    }
100
101    if let Some(syn::Expr::Assign(syn::ExprAssign { left, right, .. })) = args.elems.iter().nth(1) {
102        if let syn::Expr::Path(path) = &**left {
103            if path.path.is_ident("use_scope") {
104                if let syn::Expr::Lit(syn::ExprLit {
105                    lit: syn::Lit::Bool(lit_bool),
106                    ..
107                }) = &**right
108                {
109                    use_scope = lit_bool.value();
110                } else {
111                    panic!("The value of `use_scope` must be a boolean.");
112                }
113            }
114        }
115    }
116
117    if module_key.is_none() {
118        panic!("Expected a module key as the first argument.");
119    }
120
121    // Safely read handler registrations for the specified module key
122    let map = REGISTRATION_MAP
123        .read()
124        .expect("Failed to acquire read lock");
125    let registrations = map.get(&module_key.unwrap()).cloned().unwrap_or_default();
126
127    // Group functions by their prefixes
128    let mut grouped_by_prefix: std::collections::HashMap<String, Vec<String>> =
129        std::collections::HashMap::new();
130    for RouteInfo {
131        prefix,
132        handler_name,
133        ..
134    } in registrations
135    {
136        grouped_by_prefix
137            .entry(prefix.clone())
138            .or_default()
139            .push(handler_name);
140    }
141
142    // Generate the registration function code
143    let mut registration_functions = Vec::new();
144    for (prefix, functions) in grouped_by_prefix {
145        let fn_calls = functions.iter().map(|fn_name| {
146            let fn_ident = syn::Ident::new(fn_name, proc_macro2::Span::call_site());
147            quote! {
148                .service(#fn_ident)
149            }
150        });
151
152        let scope_block = if use_scope {
153            quote! {
154                cfg.service(
155                    actix_web::web::scope(#prefix)
156                        #(#fn_calls)*
157                );
158            }
159        } else {
160            quote! {
161                cfg.service(
162                    actix_web::web::scope("")
163                        #(#fn_calls)*
164                );
165            }
166        };
167
168        registration_functions.push(scope_block);
169    }
170
171    let expanded = quote! {
172        pub fn register_service(cfg: &mut actix_web::web::ServiceConfig) {
173            #(#registration_functions)*
174        }
175    };
176
177    TokenStream::from(expanded)
178}
179
180#[proc_macro]
181pub fn generate_list_routes(_input: TokenStream) -> TokenStream {
182    // Safely read all handler registrations from the REGISTRATION_MAP
183    let map = REGISTRATION_MAP
184        .read()
185        .expect("Failed to acquire read lock");
186
187    // Collect all routes into a vector for table display
188    let mut rows = Vec::new();
189    for (scope, routes) in map.iter() {
190        for route in routes {
191            let scope_literal = syn::LitStr::new(scope, proc_macro2::Span::call_site());
192            let path_literal = syn::LitStr::new(&route.path, proc_macro2::Span::call_site());
193            let handler_literal =
194                syn::LitStr::new(&route.handler_name, proc_macro2::Span::call_site());
195            let verb_literal = syn::LitStr::new(&route.verb, proc_macro2::Span::call_site());
196
197            rows.push(quote! {
198                Route {
199                    scope: #scope_literal.to_string(),
200                    path: #path_literal.to_string(),
201                    handler: #handler_literal.to_string(),
202                    verb: #verb_literal.to_string(),
203                }
204            });
205        }
206    }
207
208    // Generate code for the `list_routes` function
209    let expanded = quote! {
210        pub fn list_routes() {
211            use tabled::{Table, Tabled};
212
213            #[derive(Tabled)]
214            struct Route {
215                #[tabled(rename = "Scope")]
216                scope: String,
217                #[tabled(rename = "Path")]
218                path: String,
219                #[tabled(rename = "Handler")]
220                handler: String,
221                #[tabled(rename = "Verb")]
222                verb: String,
223            }
224
225            let routes = vec![
226                #(#rows),*
227            ];
228
229            let table = Table::new(routes)
230                .with(tabled::settings::Style::modern())
231                .to_string();
232
233            println!("List of the automatically registered routes:");
234            println!("{}", table);
235        }
236    };
237
238    TokenStream::from(expanded)
239}