blaze_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{format_ident, quote, quote_spanned, ToTokens};
4
5use syn::{
6    parse::Parse,
7    parse::ParseStream,
8    parse_macro_input,
9    punctuated::Punctuated,
10    token::{Comma, Eq},
11    Expr, ExprMethodCall, ExprTry, Ident, ItemFn, Pat, ReturnType, Signature, Type,
12};
13
14struct RouteArgs {
15    method: syn::Ident,
16    path: syn::LitStr,
17    default: syn::LitBool,
18}
19
20struct RoutesInput {
21    routes: Punctuated<Ident, Comma>,
22}
23
24impl Parse for RoutesInput {
25    fn parse(input: ParseStream) -> syn::Result<Self> {
26        let routes = Punctuated::<Ident, Comma>::parse_terminated(input)?;
27        Ok(RoutesInput { routes })
28    }
29}
30
31impl Parse for RouteArgs {
32    fn parse(input: ParseStream) -> syn::Result<Self> {
33        let mut method: Option<syn::Ident> = None;
34        let mut path: Option<syn::LitStr> = None;
35        let mut default: syn::LitBool = syn::LitBool::new(false, input.span());
36
37        while !input.is_empty() {
38            if input.peek(syn::Ident) && input.peek2(Eq) {
39                let ident: syn::Ident = input.parse()?;
40                if ident == "default" {
41                    input.parse::<Eq>()?;
42                    default = input.parse()?;
43                }
44            } else if method.is_none() && input.peek(syn::Ident) {
45                let ident: syn::Ident = input.parse()?;
46                let method_str = ident.to_string().to_uppercase();
47                method = Some(syn::Ident::new(&method_str, ident.span()));
48            } else if path.is_none() && input.peek(syn::LitStr) {
49                path = Some(input.parse()?);
50            }
51
52            if input.peek(Comma) {
53                input.parse::<Comma>()?;
54            } else {
55                break;
56            }
57        }
58
59        let method = method.unwrap_or_else(|| syn::Ident::new("ALL", input.span()));
60        let path = path.unwrap_or_else(|| syn::LitStr::new("/", input.span()));
61
62        Ok(RouteArgs { method, path, default })
63    }
64}
65
66fn transform_serve_call(expr: &Expr) -> Option<quote::__private::TokenStream> {
67    match expr {
68        Expr::MethodCall(ExprMethodCall { receiver, method, args, .. }) => {
69            if method.to_string() == "serve" {
70                Some(quote! { #receiver.#method(#args).await })
71            } else if method.to_string() == "service" {
72                if let Expr::Path(path) = &args[0] {
73                    let ident = &path.path.segments.last().unwrap().ident;
74                    let route_fn_ident = format_ident!("__ROUTE_{}", ident.to_string().to_uppercase());
75                    Some(quote! { #route_fn_ident(&mut #receiver); })
76                } else {
77                    None
78                }
79            } else {
80                None
81            }
82        }
83        Expr::Try(ExprTry { expr, .. }) => {
84            if let Expr::MethodCall(ExprMethodCall { receiver, method, args, .. }) = expr.as_ref() {
85                if method.to_string() == "serve" {
86                    Some(quote! { #receiver.#method(#args).await? })
87                } else {
88                    None
89                }
90            } else {
91                None
92            }
93        }
94        _ => None,
95    }
96}
97
98#[proc_macro_attribute]
99pub fn main(_attr: TokenStream, item: TokenStream) -> TokenStream {
100    let ItemFn { attrs, vis, sig, block } = parse_macro_input!(item);
101    let Signature { ident, generics, inputs, output, .. } = sig;
102
103    let return_type = match output {
104        ReturnType::Default => quote! { ::std::io::Result<()> },
105        ReturnType::Type(_, ty) => quote! { #ty },
106    };
107
108    let mut new_body = Vec::new();
109
110    for stmt in block.stmts.iter() {
111        match stmt {
112            syn::Stmt::Expr(expr, _) => {
113                if let Some(new_expr) = transform_serve_call(expr) {
114                    new_body.push(new_expr);
115                } else {
116                    new_body.push(stmt.to_token_stream());
117                }
118            }
119            _ => new_body.push(stmt.to_token_stream()),
120        }
121    }
122
123    let gen = quote! {
124        #(#attrs)*
125        #vis fn #ident #generics(#inputs) -> #return_type {
126            let rt = ::tokio::runtime::Runtime::new().unwrap();
127            rt.block_on(async {
128                #(#new_body)*;
129                Ok(())
130            })
131        }
132    };
133
134    gen.into()
135}
136
137#[proc_macro_attribute]
138pub fn route(attr: TokenStream, item: TokenStream) -> TokenStream {
139    let args = parse_macro_input!(attr as RouteArgs);
140
141    let ItemFn { attrs, vis, sig, block } = parse_macro_input!(item as ItemFn);
142    let Signature { ident, inputs, output, .. } = sig.clone();
143
144    let method = &args.method;
145    let path = &args.path;
146    let is_default = &args.default;
147
148    let is_result = match &output {
149        ReturnType::Default => false,
150        ReturnType::Type(_, ty) => matches!(&**ty, Type::Path(type_path) if type_path.path.segments.last().map_or(false, |s| s.ident == "Result" || s.ident == "HttpResponse")),
151    };
152
153    let parameters: Vec<_> = path
154        .value()
155        .split('/')
156        .filter_map(|segment| {
157            if segment.starts_with('{') && segment.ends_with('}') {
158                Some(segment[1..segment.len() - 1].to_string())
159            } else {
160                None
161            }
162        })
163        .collect();
164
165    if inputs.len() != parameters.len() + 1 {
166        return syn::Error::new_spanned(sig, format!("Route handler must have {} arguments", parameters.len() + 1))
167            .to_compile_error()
168            .into();
169    }
170
171    if sig.asyncness.is_none() {
172        return syn::Error::new_spanned(sig, "Route handler must be async").to_compile_error().into();
173    }
174
175    let route_fn_ident = quote::format_ident!("__ROUTE_{}", ident.to_string().to_uppercase());
176
177    // Generate parameter extraction
178    let mut param_extractions = Vec::new();
179    let mut function_args = vec![quote!(req.clone())];
180    for (_, input) in inputs.iter().enumerate().skip(1) {
181        if let syn::FnArg::Typed(pat_type) = input {
182            if let Pat::Ident(pat_ident) = &*pat_type.pat {
183                let param_name = &pat_ident.ident;
184                let param_str = param_name.to_string();
185                param_extractions.push(quote! {
186                    let #param_name = req.route_param(#param_str)
187                        .or_else(|| req.query_param(#param_str))
188                        .cloned()
189                        .unwrap_or_default();
190                });
191                function_args.push(quote!(#param_name));
192            }
193        }
194    }
195
196    let param_extraction = quote! { #(#param_extractions)* };
197    let function_call = quote! { #ident(#(#function_args),*) };
198
199    let handler_body = if is_result {
200        quote_spanned! {Span::call_site()=>
201            #param_extraction
202            match #function_call.await {
203                Ok(responder) => Ok(Box::new(responder) as Box<dyn crate::Responder>),
204                Err(err) => Err(err),
205            }
206        }
207    } else {
208        quote_spanned! {Span::call_site()=>
209            #param_extraction
210            Ok(Box::new(#function_call.await) as Box<dyn crate::Responder>)
211        }
212    };
213
214    let gen = quote! {
215        #(#attrs)*
216        #vis #sig #block
217
218        pub fn #route_fn_ident(router: &mut crate::Router) {
219            if #is_default {
220                router.add_default(|req: crate::Request| Box::pin(async move {
221                    #handler_body
222                }));
223            } else {
224                router.add(crate::Method::#method, #path.to_string(),
225                |req: crate::Request| Box::pin(async move {
226                    #handler_body
227                }));
228            }
229        }
230    };
231
232    gen.into()
233}
234
235#[proc_macro]
236pub fn routes(input: TokenStream) -> TokenStream {
237    let input = parse_macro_input!(input as RoutesInput);
238    let routes = input.routes;
239
240    let route_services = routes.iter().map(|route| {
241        let route_fn_ident = format_ident!("__ROUTE_{}", route.to_string().to_uppercase());
242        quote! {
243            #route_fn_ident(&mut router);
244        }
245    });
246
247    let gen = quote! {
248        {
249            let mut router = crate::Router::new();
250            #(#route_services)*
251            router
252        }
253    };
254
255    gen.into()
256}