1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{format_ident, quote, quote_spanned, ToTokens};
4
5use syn::{
6 parse::Parse,
7 parse::ParseStream,
8 parse_macro_input,
9 punctuated::Punctuated,
10 token::{Comma, Eq},
11 Expr, ExprMethodCall, ExprTry, Ident, ItemFn, Pat, ReturnType, Signature, Type,
12};
13
14struct RouteArgs {
15 method: syn::Ident,
16 path: syn::LitStr,
17 default: syn::LitBool,
18}
19
20struct RoutesInput {
21 routes: Punctuated<Ident, Comma>,
22}
23
24impl Parse for RoutesInput {
25 fn parse(input: ParseStream) -> syn::Result<Self> {
26 let routes = Punctuated::<Ident, Comma>::parse_terminated(input)?;
27 Ok(RoutesInput { routes })
28 }
29}
30
31impl Parse for RouteArgs {
32 fn parse(input: ParseStream) -> syn::Result<Self> {
33 let mut method: Option<syn::Ident> = None;
34 let mut path: Option<syn::LitStr> = None;
35 let mut default: syn::LitBool = syn::LitBool::new(false, input.span());
36
37 while !input.is_empty() {
38 if input.peek(syn::Ident) && input.peek2(Eq) {
39 let ident: syn::Ident = input.parse()?;
40 if ident == "default" {
41 input.parse::<Eq>()?;
42 default = input.parse()?;
43 }
44 } else if method.is_none() && input.peek(syn::Ident) {
45 let ident: syn::Ident = input.parse()?;
46 let method_str = ident.to_string().to_uppercase();
47 method = Some(syn::Ident::new(&method_str, ident.span()));
48 } else if path.is_none() && input.peek(syn::LitStr) {
49 path = Some(input.parse()?);
50 }
51
52 if input.peek(Comma) {
53 input.parse::<Comma>()?;
54 } else {
55 break;
56 }
57 }
58
59 let method = method.unwrap_or_else(|| syn::Ident::new("ALL", input.span()));
60 let path = path.unwrap_or_else(|| syn::LitStr::new("/", input.span()));
61
62 Ok(RouteArgs { method, path, default })
63 }
64}
65
66fn transform_serve_call(expr: &Expr) -> Option<quote::__private::TokenStream> {
67 match expr {
68 Expr::MethodCall(ExprMethodCall { receiver, method, args, .. }) => {
69 if method.to_string() == "serve" {
70 Some(quote! { #receiver.#method(#args).await })
71 } else if method.to_string() == "service" {
72 if let Expr::Path(path) = &args[0] {
73 let ident = &path.path.segments.last().unwrap().ident;
74 let route_fn_ident = format_ident!("__ROUTE_{}", ident.to_string().to_uppercase());
75 Some(quote! { #route_fn_ident(&mut #receiver); })
76 } else {
77 None
78 }
79 } else {
80 None
81 }
82 }
83 Expr::Try(ExprTry { expr, .. }) => {
84 if let Expr::MethodCall(ExprMethodCall { receiver, method, args, .. }) = expr.as_ref() {
85 if method.to_string() == "serve" {
86 Some(quote! { #receiver.#method(#args).await? })
87 } else {
88 None
89 }
90 } else {
91 None
92 }
93 }
94 _ => None,
95 }
96}
97
98#[proc_macro_attribute]
99pub fn main(_attr: TokenStream, item: TokenStream) -> TokenStream {
100 let ItemFn { attrs, vis, sig, block } = parse_macro_input!(item);
101 let Signature { ident, generics, inputs, output, .. } = sig;
102
103 let return_type = match output {
104 ReturnType::Default => quote! { ::std::io::Result<()> },
105 ReturnType::Type(_, ty) => quote! { #ty },
106 };
107
108 let mut new_body = Vec::new();
109
110 for stmt in block.stmts.iter() {
111 match stmt {
112 syn::Stmt::Expr(expr, _) => {
113 if let Some(new_expr) = transform_serve_call(expr) {
114 new_body.push(new_expr);
115 } else {
116 new_body.push(stmt.to_token_stream());
117 }
118 }
119 _ => new_body.push(stmt.to_token_stream()),
120 }
121 }
122
123 let gen = quote! {
124 #(#attrs)*
125 #vis fn #ident #generics(#inputs) -> #return_type {
126 let rt = ::tokio::runtime::Runtime::new().unwrap();
127 rt.block_on(async {
128 #(#new_body)*;
129 Ok(())
130 })
131 }
132 };
133
134 gen.into()
135}
136
137#[proc_macro_attribute]
138pub fn route(attr: TokenStream, item: TokenStream) -> TokenStream {
139 let args = parse_macro_input!(attr as RouteArgs);
140
141 let ItemFn { attrs, vis, sig, block } = parse_macro_input!(item as ItemFn);
142 let Signature { ident, inputs, output, .. } = sig.clone();
143
144 let method = &args.method;
145 let path = &args.path;
146 let is_default = &args.default;
147
148 let is_result = match &output {
149 ReturnType::Default => false,
150 ReturnType::Type(_, ty) => matches!(&**ty, Type::Path(type_path) if type_path.path.segments.last().map_or(false, |s| s.ident == "Result" || s.ident == "HttpResponse")),
151 };
152
153 let parameters: Vec<_> = path
154 .value()
155 .split('/')
156 .filter_map(|segment| {
157 if segment.starts_with('{') && segment.ends_with('}') {
158 Some(segment[1..segment.len() - 1].to_string())
159 } else {
160 None
161 }
162 })
163 .collect();
164
165 if inputs.len() != parameters.len() + 1 {
166 return syn::Error::new_spanned(sig, format!("Route handler must have {} arguments", parameters.len() + 1))
167 .to_compile_error()
168 .into();
169 }
170
171 if sig.asyncness.is_none() {
172 return syn::Error::new_spanned(sig, "Route handler must be async").to_compile_error().into();
173 }
174
175 let route_fn_ident = quote::format_ident!("__ROUTE_{}", ident.to_string().to_uppercase());
176
177 let mut param_extractions = Vec::new();
179 let mut function_args = vec![quote!(req.clone())];
180 for (_, input) in inputs.iter().enumerate().skip(1) {
181 if let syn::FnArg::Typed(pat_type) = input {
182 if let Pat::Ident(pat_ident) = &*pat_type.pat {
183 let param_name = &pat_ident.ident;
184 let param_str = param_name.to_string();
185 param_extractions.push(quote! {
186 let #param_name = req.route_param(#param_str)
187 .or_else(|| req.query_param(#param_str))
188 .cloned()
189 .unwrap_or_default();
190 });
191 function_args.push(quote!(#param_name));
192 }
193 }
194 }
195
196 let param_extraction = quote! { #(#param_extractions)* };
197 let function_call = quote! { #ident(#(#function_args),*) };
198
199 let handler_body = if is_result {
200 quote_spanned! {Span::call_site()=>
201 #param_extraction
202 match #function_call.await {
203 Ok(responder) => Ok(Box::new(responder) as Box<dyn crate::Responder>),
204 Err(err) => Err(err),
205 }
206 }
207 } else {
208 quote_spanned! {Span::call_site()=>
209 #param_extraction
210 Ok(Box::new(#function_call.await) as Box<dyn crate::Responder>)
211 }
212 };
213
214 let gen = quote! {
215 #(#attrs)*
216 #vis #sig #block
217
218 pub fn #route_fn_ident(router: &mut crate::Router) {
219 if #is_default {
220 router.add_default(|req: crate::Request| Box::pin(async move {
221 #handler_body
222 }));
223 } else {
224 router.add(crate::Method::#method, #path.to_string(),
225 |req: crate::Request| Box::pin(async move {
226 #handler_body
227 }));
228 }
229 }
230 };
231
232 gen.into()
233}
234
235#[proc_macro]
236pub fn routes(input: TokenStream) -> TokenStream {
237 let input = parse_macro_input!(input as RoutesInput);
238 let routes = input.routes;
239
240 let route_services = routes.iter().map(|route| {
241 let route_fn_ident = format_ident!("__ROUTE_{}", route.to_string().to_uppercase());
242 quote! {
243 #route_fn_ident(&mut router);
244 }
245 });
246
247 let gen = quote! {
248 {
249 let mut router = crate::Router::new();
250 #(#route_services)*
251 router
252 }
253 };
254
255 gen.into()
256}