axum_distributed_routing_macros/
lib.rs1use std::{collections::HashMap, str::FromStr};
2
3use syn::{
4 Block, Field, Ident, LitStr, PatType, Token, Type,
5 ext::IdentExt,
6 parenthesized,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10};
11
12enum TypeNameOrDef {
14 Type(Type),
15 Def(Punctuated<Field, Token![,]>),
16}
17
18enum Method {
19 Get,
20 Post,
21 Put,
22 Patch,
23 Delete,
24 Head,
25 Options,
26 Trace,
27 Connect,
28}
29
30struct Args {
31 path: String,
32 path_params: HashMap<Ident, Type>,
33 query_params: Option<TypeNameOrDef>,
34 body_params: Option<TypeNameOrDef>,
35 parameters: Punctuated<PatType, Token![,]>,
36 name: Ident,
37 group: Type,
38 return_type: Type,
39 method: Method,
40 is_async: bool,
41 handler: Block,
42}
43
44impl Parse for TypeNameOrDef {
45 fn parse(input: ParseStream) -> syn::Result<Self> {
46 if input.peek(syn::token::Brace) {
47 let content;
48 let _ = syn::braced!(content in input);
49 Ok(TypeNameOrDef::Def(Punctuated::parse_terminated_with(
50 &content,
51 Field::parse_named,
52 )?))
53 } else {
54 Ok(TypeNameOrDef::Type(input.parse()?))
55 }
56 }
57}
58
59impl Parse for Args {
60 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
61 let mut path = None;
62 let mut path_params = HashMap::new();
63 let mut query_params = None;
64 let mut body_params = None;
65 let mut parameters: Punctuated<PatType, Token![,]> = Punctuated::new();
66 let mut name = None;
67 let mut return_type = None;
68 let mut handler = None;
69 let mut is_async = false;
70 let mut method = None;
71 let mut group = None;
72
73 while !input.is_empty() {
74 let ident: Ident = input.call(Ident::parse_any)?;
75
76 match ident.to_string().as_str() {
77 "method" => {
78 input.parse::<syn::Token![=]>()?;
80
81 let method_ident = input.parse::<Ident>()?;
82 match method_ident.to_string().as_str() {
83 "GET" => method = Some(Method::Get),
84 "POST" => method = Some(Method::Post),
85 "PUT" => method = Some(Method::Put),
86 "PATCH" => method = Some(Method::Patch),
87 "DELETE" => method = Some(Method::Delete),
88 "HEAD" => method = Some(Method::Head),
89 "OPTIONS" => method = Some(Method::Options),
90 "TRACE" => method = Some(Method::Trace),
91 "CONNECT" => method = Some(Method::Connect),
92 m => {
93 return Err(syn::Error::new(
94 proc_macro2::Span::call_site(),
95 format!("Unknown method {}", m),
96 ));
97 }
98 }
99 }
100 "group" => {
101 input.parse::<syn::Token![=]>()?;
103
104 group = Some(input.parse()?);
105 }
106 "path" => {
107 input.parse::<syn::Token![=]>()?;
109
110 let path_str: LitStr = input.parse()?;
111 let (path_, path_params_) = Self::parse_path(path_str.value())?;
112 path = Some(path_);
113 path_params = path_params_;
114 }
115 "query" => {
116 input.parse::<syn::Token![=]>()?;
118
119 query_params = Some(input.parse::<TypeNameOrDef>()?);
120 }
121 "body" => {
122 input.parse::<syn::Token![=]>()?;
124
125 body_params = Some(input.parse::<TypeNameOrDef>()?);
126 }
127 _ => {
128 if ident.to_string().as_str() == "async" {
129 is_async = true;
130 name = Some(input.parse()?);
131 } else {
132 name = Some(ident);
133 }
134
135 if input.peek(syn::token::Paren) {
136 let content;
137 let _ = parenthesized!(content in input);
138 parameters = Punctuated::parse_terminated(&content)?;
139 }
140
141 if input.peek(Token![->]) {
142 input.parse::<Token![->]>()?;
143 return_type = Some(input.parse()?);
144 }
145 handler = Some(input.parse()?);
146 }
147 }
148
149 if !input.is_empty() {
150 input.parse::<syn::Token![,]>()?;
151 }
152 }
153
154 if path.is_none() {
155 return Err(syn::Error::new(
156 proc_macro2::Span::call_site(),
157 "Missing path",
158 ));
159 }
160
161 if name.is_none() {
162 return Err(syn::Error::new(
163 proc_macro2::Span::call_site(),
164 "Missing name",
165 ));
166 }
167
168 if return_type.is_none() {
169 return Err(syn::Error::new(
170 proc_macro2::Span::call_site(),
171 "Missing return type",
172 ));
173 }
174
175 if handler.is_none() {
176 return Err(syn::Error::new(
177 proc_macro2::Span::call_site(),
178 "Missing handler",
179 ));
180 }
181
182 if method.is_none() {
183 return Err(syn::Error::new(
184 proc_macro2::Span::call_site(),
185 "Missing method",
186 ));
187 }
188
189 if group.is_none() {
190 return Err(syn::Error::new(
191 proc_macro2::Span::call_site(),
192 "Missing group",
193 ));
194 }
195
196 Ok(Args {
197 name: name.unwrap(),
198 return_type: return_type.unwrap(),
199 is_async,
200 group: group.unwrap(),
201 method: method.unwrap(),
202 handler: handler.unwrap(),
203 path: path.unwrap(),
204 path_params,
205 query_params,
206 body_params,
207 parameters,
208 })
209 }
210}
211
212#[derive(PartialEq)]
213enum ParsePathState {
214 Path,
215 PathParamName,
216 PathParamType,
217}
218
219impl Args {
220 fn parse_path(path: String) -> syn::Result<(String, HashMap<Ident, Type>)> {
221 let mut real_path = String::new();
222 let mut path_params = HashMap::new();
223 let mut state = ParsePathState::Path;
224 let mut current_name = String::new();
225 let mut current_type = String::new();
226
227 for c in path.chars() {
228 match c {
229 '{' => {
230 if state == ParsePathState::Path {
231 state = ParsePathState::PathParamName;
232 } else {
233 return Err(syn::Error::new(
234 proc_macro2::Span::call_site(),
235 "Invalid path",
236 ));
237 }
238 }
239 '}' => {
240 if state == ParsePathState::PathParamType {
241 let param_name = proc_macro2::TokenStream::from_str(¤t_name)
242 .map_err(|_| {
243 syn::Error::new(proc_macro2::Span::call_site(), "Invalid path")
244 })?;
245 let param_type = proc_macro2::TokenStream::from_str(¤t_type)
246 .map_err(|_| {
247 syn::Error::new(proc_macro2::Span::call_site(), "Invalid path")
248 })?;
249 path_params.insert(syn::parse2(param_name)?, syn::parse2(param_type)?);
250
251 real_path.push(':');
252 real_path.push_str(¤t_name);
253
254 current_name = String::new();
255 current_type = String::new();
256 state = ParsePathState::Path;
257 } else {
258 return Err(syn::Error::new(
259 proc_macro2::Span::call_site(),
260 "Invalid path",
261 ));
262 }
263 }
264 ':' => {
265 if state == ParsePathState::PathParamName {
266 state = ParsePathState::PathParamType;
267 } else {
268 return Err(syn::Error::new(
269 proc_macro2::Span::call_site(),
270 "Invalid path",
271 ));
272 }
273 }
274 _ => match state {
275 ParsePathState::Path => {
276 real_path.push(c);
277 }
278 ParsePathState::PathParamName => {
279 current_name.push(c);
280 }
281 ParsePathState::PathParamType => {
282 current_type.push(c);
283 }
284 },
285 }
286 }
287
288 if state != ParsePathState::Path {
289 return Err(syn::Error::new(
290 proc_macro2::Span::call_site(),
291 "Invalid path",
292 ));
293 }
294
295 Ok((path, path_params))
296 }
297}
298
299#[proc_macro]
311pub fn route(attr: proc_macro::TokenStream) -> proc_macro::TokenStream {
312 let args = parse_macro_input!(attr as Args);
314
315 let path_params = args.path_params;
316 let path_idents = path_params.keys().collect::<Vec<_>>();
317 let path_types = path_params.values().collect::<Vec<_>>();
318
319 let path_params = if path_params.is_empty() {
320 quote::quote! {}
321 } else {
322 quote::quote! {
323 axum::extract::Path((#(#path_idents),*)): axum::extract::Path<(#(#path_types),*)>,
324 }
325 };
326
327 let (query_params_def, query_params) = if let Some(q) = args.query_params {
328 match q {
329 TypeNameOrDef::Type(t) => (
330 quote::quote! {},
331 quote::quote! { axum::extract::Query(query): axum::extract::Query<#t>, },
332 ),
333 TypeNameOrDef::Def(d) => {
334 let def_name = Ident::new(
335 format!(
336 "{}QueryParams",
337 stringcase::pascal_case(args.name.to_string().as_str())
338 )
339 .as_str(),
340 proc_macro2::Span::call_site(),
341 );
342 (
343 quote::quote! {
344 #[derive(serde::Deserialize)]
345 struct #def_name #d
346 },
347 quote::quote! { axum::extract::Query(query): axum::extract::Query<#def_name>, },
348 )
349 }
350 }
351 } else {
352 (quote::quote! {}, quote::quote! {})
353 };
354
355 let (body_params_def, body_params) = if let Some(b) = args.body_params {
356 match b {
357 TypeNameOrDef::Type(t) => (
358 quote::quote! {},
359 quote::quote! { axum::extract::Form(body): axum::extract::Form<#t>, },
360 ),
361 TypeNameOrDef::Def(d) => {
362 let def_name = Ident::new(
363 format!(
364 "{}BodyParams",
365 stringcase::pascal_case(args.name.to_string().as_str())
366 )
367 .as_str(),
368 proc_macro2::Span::call_site(),
369 );
370 (
371 quote::quote! {
372 #[derive(serde::Deserialize)]
373 struct #def_name #d
374 },
375 quote::quote! { axum::extract::Form(body): axum::extract::Form<#def_name>, },
376 )
377 }
378 }
379 } else {
380 (quote::quote! {}, quote::quote! {})
381 };
382
383 let route_name = Ident::new(
384 &format!(
385 "ROUTE_{}",
386 stringcase::macro_case(args.name.to_string().as_str())
387 ),
388 proc_macro2::Span::call_site(),
389 );
390 let path = args.path;
391 let parameters = args.parameters;
392 let return_type = args.return_type;
393 let block = args.handler;
394 let group = args.group;
395
396 let async_keyword = if args.is_async {
397 quote::quote! { async }
398 } else {
399 quote::quote! {}
400 };
401
402 let handler = quote::quote! {
403 #async_keyword |#path_params #query_params #body_params #parameters| -> #return_type #block
404 };
405
406 let handler = match args.method {
407 Method::Get => quote::quote! { axum::routing::get(#handler) },
408 Method::Post => quote::quote! { axum::routing::post(#handler) },
409 Method::Put => quote::quote! { axum::routing::put(#handler) },
410 Method::Patch => quote::quote! { axum::routing::patch(#handler) },
411 Method::Delete => quote::quote! { axum::routing::delete(#handler) },
412 Method::Head => quote::quote! { axum::routing::head(#handler) },
413 Method::Options => quote::quote! { axum::routing::options(#handler) },
414 Method::Trace => quote::quote! { axum::routing::trace(#handler) },
415 Method::Connect => quote::quote! { axum::routing::connect(#handler) },
416 };
417
418 let result = quote::quote! {
419 #query_params_def
420 #body_params_def
421
422 pub static #route_name: #group = #group::new(#path, |r, _| r.route(#path, #handler));
423
424 axum_distributed_routing::inventory::submit! {
425 #route_name
426 }
427 };
428
429 result.into()
430}