axum_distributed_routing_macros/
lib.rs

1use std::{collections::HashMap, str::FromStr};
2
3use syn::{
4    Block, Field, Ident, LitStr, PatType, Token, Type,
5    ext::IdentExt,
6    parenthesized,
7    parse::{Parse, ParseStream},
8    parse_macro_input,
9    punctuated::Punctuated,
10};
11
12/// Either a block with members or a type name
13enum TypeNameOrDef {
14    Type(Type),
15    Def(Punctuated<Field, Token![,]>),
16}
17
18enum Method {
19    Get,
20    Post,
21    Put,
22    Patch,
23    Delete,
24    Head,
25    Options,
26    Trace,
27    Connect,
28}
29
30struct Args {
31    path: String,
32    path_params: HashMap<Ident, Type>,
33    query_params: Option<TypeNameOrDef>,
34    body_params: Option<TypeNameOrDef>,
35    parameters: Punctuated<PatType, Token![,]>,
36    name: Ident,
37    group: Type,
38    return_type: Type,
39    method: Method,
40    is_async: bool,
41    handler: Block,
42}
43
44impl Parse for TypeNameOrDef {
45    fn parse(input: ParseStream) -> syn::Result<Self> {
46        if input.peek(syn::token::Brace) {
47            let content;
48            let _ = syn::braced!(content in input);
49            Ok(TypeNameOrDef::Def(Punctuated::parse_terminated_with(
50                &content,
51                Field::parse_named,
52            )?))
53        } else {
54            Ok(TypeNameOrDef::Type(input.parse()?))
55        }
56    }
57}
58
59impl Parse for Args {
60    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
61        let mut path = None;
62        let mut path_params = HashMap::new();
63        let mut query_params = None;
64        let mut body_params = None;
65        let mut parameters: Punctuated<PatType, Token![,]> = Punctuated::new();
66        let mut name = None;
67        let mut return_type = None;
68        let mut handler = None;
69        let mut is_async = false;
70        let mut method = None;
71        let mut group = None;
72
73        while !input.is_empty() {
74            let ident: Ident = input.call(Ident::parse_any)?;
75
76            match ident.to_string().as_str() {
77                "method" => {
78                    // Expects equal sign
79                    input.parse::<syn::Token![=]>()?;
80
81                    let method_ident = input.parse::<Ident>()?;
82                    match method_ident.to_string().as_str() {
83                        "GET" => method = Some(Method::Get),
84                        "POST" => method = Some(Method::Post),
85                        "PUT" => method = Some(Method::Put),
86                        "PATCH" => method = Some(Method::Patch),
87                        "DELETE" => method = Some(Method::Delete),
88                        "HEAD" => method = Some(Method::Head),
89                        "OPTIONS" => method = Some(Method::Options),
90                        "TRACE" => method = Some(Method::Trace),
91                        "CONNECT" => method = Some(Method::Connect),
92                        m => {
93                            return Err(syn::Error::new(
94                                proc_macro2::Span::call_site(),
95                                format!("Unknown method {}", m),
96                            ));
97                        }
98                    }
99                }
100                "group" => {
101                    // Expects equal sign
102                    input.parse::<syn::Token![=]>()?;
103
104                    group = Some(input.parse()?);
105                }
106                "path" => {
107                    // Expects equal sign
108                    input.parse::<syn::Token![=]>()?;
109
110                    let path_str: LitStr = input.parse()?;
111                    let (path_, path_params_) = Self::parse_path(path_str.value())?;
112                    path = Some(path_);
113                    path_params = path_params_;
114                }
115                "query" => {
116                    // Expects equal sign
117                    input.parse::<syn::Token![=]>()?;
118
119                    query_params = Some(input.parse::<TypeNameOrDef>()?);
120                }
121                "body" => {
122                    // Expects equal sign
123                    input.parse::<syn::Token![=]>()?;
124
125                    body_params = Some(input.parse::<TypeNameOrDef>()?);
126                }
127                _ => {
128                    if ident.to_string().as_str() == "async" {
129                        is_async = true;
130                        name = Some(input.parse()?);
131                    } else {
132                        name = Some(ident);
133                    }
134
135                    if input.peek(syn::token::Paren) {
136                        let content;
137                        let _ = parenthesized!(content in input);
138                        parameters = Punctuated::parse_terminated(&content)?;
139                    }
140
141                    if input.peek(Token![->]) {
142                        input.parse::<Token![->]>()?;
143                        return_type = Some(input.parse()?);
144                    }
145                    handler = Some(input.parse()?);
146                }
147            }
148
149            if !input.is_empty() {
150                input.parse::<syn::Token![,]>()?;
151            }
152        }
153
154        if path.is_none() {
155            return Err(syn::Error::new(
156                proc_macro2::Span::call_site(),
157                "Missing path",
158            ));
159        }
160
161        if name.is_none() {
162            return Err(syn::Error::new(
163                proc_macro2::Span::call_site(),
164                "Missing name",
165            ));
166        }
167
168        if return_type.is_none() {
169            return Err(syn::Error::new(
170                proc_macro2::Span::call_site(),
171                "Missing return type",
172            ));
173        }
174
175        if handler.is_none() {
176            return Err(syn::Error::new(
177                proc_macro2::Span::call_site(),
178                "Missing handler",
179            ));
180        }
181
182        if method.is_none() {
183            return Err(syn::Error::new(
184                proc_macro2::Span::call_site(),
185                "Missing method",
186            ));
187        }
188
189        if group.is_none() {
190            return Err(syn::Error::new(
191                proc_macro2::Span::call_site(),
192                "Missing group",
193            ));
194        }
195
196        Ok(Args {
197            name: name.unwrap(),
198            return_type: return_type.unwrap(),
199            is_async,
200            group: group.unwrap(),
201            method: method.unwrap(),
202            handler: handler.unwrap(),
203            path: path.unwrap(),
204            path_params,
205            query_params,
206            body_params,
207            parameters,
208        })
209    }
210}
211
212#[derive(PartialEq)]
213enum ParsePathState {
214    Path,
215    PathParamName,
216    PathParamType,
217}
218
219impl Args {
220    fn parse_path(path: String) -> syn::Result<(String, HashMap<Ident, Type>)> {
221        let mut real_path = String::new();
222        let mut path_params = HashMap::new();
223        let mut state = ParsePathState::Path;
224        let mut current_name = String::new();
225        let mut current_type = String::new();
226
227        for c in path.chars() {
228            match c {
229                '{' => {
230                    if state == ParsePathState::Path {
231                        state = ParsePathState::PathParamName;
232                    } else {
233                        return Err(syn::Error::new(
234                            proc_macro2::Span::call_site(),
235                            "Invalid path",
236                        ));
237                    }
238                }
239                '}' => {
240                    if state == ParsePathState::PathParamType {
241                        let param_name = proc_macro2::TokenStream::from_str(&current_name)
242                            .map_err(|_| {
243                                syn::Error::new(proc_macro2::Span::call_site(), "Invalid path")
244                            })?;
245                        let param_type = proc_macro2::TokenStream::from_str(&current_type)
246                            .map_err(|_| {
247                                syn::Error::new(proc_macro2::Span::call_site(), "Invalid path")
248                            })?;
249                        path_params.insert(syn::parse2(param_name)?, syn::parse2(param_type)?);
250
251                        real_path.push(':');
252                        real_path.push_str(&current_name);
253
254                        current_name = String::new();
255                        current_type = String::new();
256                        state = ParsePathState::Path;
257                    } else {
258                        return Err(syn::Error::new(
259                            proc_macro2::Span::call_site(),
260                            "Invalid path",
261                        ));
262                    }
263                }
264                ':' => {
265                    if state == ParsePathState::PathParamName {
266                        state = ParsePathState::PathParamType;
267                    } else {
268                        return Err(syn::Error::new(
269                            proc_macro2::Span::call_site(),
270                            "Invalid path",
271                        ));
272                    }
273                }
274                _ => match state {
275                    ParsePathState::Path => {
276                        real_path.push(c);
277                    }
278                    ParsePathState::PathParamName => {
279                        current_name.push(c);
280                    }
281                    ParsePathState::PathParamType => {
282                        current_type.push(c);
283                    }
284                },
285            }
286        }
287
288        if state != ParsePathState::Path {
289            return Err(syn::Error::new(
290                proc_macro2::Span::call_site(),
291                "Invalid path",
292            ));
293        }
294
295        Ok((path, path_params))
296    }
297}
298
299/// Creates a route and add it to the group
300///
301/// # Example
302/// ```
303/// route!(
304///     group = Routes,
305///     path = "/echo/{str:String}",
306///     method = GET,
307///     async test_fn -> String { str }
308/// );
309/// ```
310#[proc_macro]
311pub fn route(attr: proc_macro::TokenStream) -> proc_macro::TokenStream {
312    // TODO: cleanup
313    let args = parse_macro_input!(attr as Args);
314
315    let path_params = args.path_params;
316    let path_idents = path_params.keys().collect::<Vec<_>>();
317    let path_types = path_params.values().collect::<Vec<_>>();
318
319    let path_params = if path_params.is_empty() {
320        quote::quote! {}
321    } else {
322        quote::quote! {
323            axum::extract::Path((#(#path_idents),*)): axum::extract::Path<(#(#path_types),*)>,
324        }
325    };
326
327    let (query_params_def, query_params) = if let Some(q) = args.query_params {
328        match q {
329            TypeNameOrDef::Type(t) => (
330                quote::quote! {},
331                quote::quote! { axum::extract::Query(query): axum::extract::Query<#t>, },
332            ),
333            TypeNameOrDef::Def(d) => {
334                let def_name = Ident::new(
335                    format!(
336                        "{}QueryParams",
337                        stringcase::pascal_case(args.name.to_string().as_str())
338                    )
339                    .as_str(),
340                    proc_macro2::Span::call_site(),
341                );
342                (
343                    quote::quote! {
344                        #[derive(serde::Deserialize)]
345                        struct #def_name #d
346                    },
347                    quote::quote! { axum::extract::Query(query): axum::extract::Query<#def_name>, },
348                )
349            }
350        }
351    } else {
352        (quote::quote! {}, quote::quote! {})
353    };
354
355    let (body_params_def, body_params) = if let Some(b) = args.body_params {
356        match b {
357            TypeNameOrDef::Type(t) => (
358                quote::quote! {},
359                quote::quote! { axum::extract::Form(body): axum::extract::Form<#t>, },
360            ),
361            TypeNameOrDef::Def(d) => {
362                let def_name = Ident::new(
363                    format!(
364                        "{}BodyParams",
365                        stringcase::pascal_case(args.name.to_string().as_str())
366                    )
367                    .as_str(),
368                    proc_macro2::Span::call_site(),
369                );
370                (
371                    quote::quote! {
372                        #[derive(serde::Deserialize)]
373                        struct #def_name #d
374                    },
375                    quote::quote! { axum::extract::Form(body): axum::extract::Form<#def_name>, },
376                )
377            }
378        }
379    } else {
380        (quote::quote! {}, quote::quote! {})
381    };
382
383    let route_name = Ident::new(
384        &format!(
385            "ROUTE_{}",
386            stringcase::macro_case(args.name.to_string().as_str())
387        ),
388        proc_macro2::Span::call_site(),
389    );
390    let path = args.path;
391    let parameters = args.parameters;
392    let return_type = args.return_type;
393    let block = args.handler;
394    let group = args.group;
395
396    let async_keyword = if args.is_async {
397        quote::quote! { async }
398    } else {
399        quote::quote! {}
400    };
401
402    let handler = quote::quote! {
403        #async_keyword |#path_params #query_params #body_params #parameters| -> #return_type #block
404    };
405
406    let handler = match args.method {
407        Method::Get => quote::quote! { axum::routing::get(#handler) },
408        Method::Post => quote::quote! { axum::routing::post(#handler) },
409        Method::Put => quote::quote! { axum::routing::put(#handler) },
410        Method::Patch => quote::quote! { axum::routing::patch(#handler) },
411        Method::Delete => quote::quote! { axum::routing::delete(#handler) },
412        Method::Head => quote::quote! { axum::routing::head(#handler) },
413        Method::Options => quote::quote! { axum::routing::options(#handler) },
414        Method::Trace => quote::quote! { axum::routing::trace(#handler) },
415        Method::Connect => quote::quote! { axum::routing::connect(#handler) },
416    };
417
418    let result = quote::quote! {
419        #query_params_def
420        #body_params_def
421
422        pub static #route_name: #group = #group::new(#path, |r, _| r.route(#path, #handler));
423
424        axum_distributed_routing::inventory::submit! {
425            #route_name
426        }
427    };
428
429    result.into()
430}