axum_distributed_routing_macros/
lib.rs

1use std::{collections::HashMap, str::FromStr};
2
3use syn::{
4    Attribute, Block, Ident, LitStr, PatType, Token, Type, ext::IdentExt, parenthesized,
5    parse::Parse, parse_macro_input, punctuated::Punctuated,
6};
7
8enum Method {
9    Get,
10    Post,
11    Put,
12    Patch,
13    Delete,
14    Head,
15    Options,
16    Trace,
17    Connect,
18}
19
20struct Args {
21    path: String,
22    path_params: HashMap<Ident, Type>,
23    query_params: Option<Type>,
24    body_params: Option<Type>,
25    parameters: Punctuated<PatType, Token![,]>,
26    name: Ident,
27    group: Type,
28    return_type: Type,
29    method: Method,
30    handler_attributes: Vec<Attribute>,
31    handler: Block,
32}
33
34impl Parse for Args {
35    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
36        let mut path = None;
37        let mut path_params = HashMap::new();
38        let mut query_params = None;
39        let mut body_params = None;
40        let mut parameters: Punctuated<PatType, Token![,]> = Punctuated::new();
41        let mut name = None;
42        let mut return_type = None;
43        let mut handler = None;
44        let mut method = None;
45        let mut group = None;
46        let mut handler_attributes = Vec::new();
47
48        while !input.is_empty() {
49            if input.peek(Token![#]) || input.peek(Token![async]) {
50                if handler.is_some() {
51                    return Err(syn::Error::new(input.span(), "Handler is already defined"));
52                }
53
54                handler_attributes = input.call(Attribute::parse_outer)?;
55                input.parse::<Token![async]>()?;
56
57                name = Some(input.parse()?);
58
59                if input.peek(syn::token::Paren) {
60                    let content;
61                    let _ = parenthesized!(content in input);
62                    parameters = Punctuated::parse_terminated(&content)?;
63                }
64
65                if input.peek(Token![->]) {
66                    input.parse::<Token![->]>()?;
67                    return_type = Some(input.parse()?);
68                }
69                handler = Some(input.parse()?);
70            } else {
71                let ident: Ident = input.call(Ident::parse_any)?;
72
73                match ident.to_string().as_str() {
74                    "method" => {
75                        // Expects equal sign
76                        input.parse::<syn::Token![=]>()?;
77
78                        let method_ident = input.parse::<Ident>()?;
79                        match method_ident.to_string().as_str() {
80                            "GET" => method = Some(Method::Get),
81                            "POST" => method = Some(Method::Post),
82                            "PUT" => method = Some(Method::Put),
83                            "PATCH" => method = Some(Method::Patch),
84                            "DELETE" => method = Some(Method::Delete),
85                            "HEAD" => method = Some(Method::Head),
86                            "OPTIONS" => method = Some(Method::Options),
87                            "TRACE" => method = Some(Method::Trace),
88                            "CONNECT" => method = Some(Method::Connect),
89                            m => {
90                                return Err(syn::Error::new(
91                                    ident.span(),
92                                    format!("Unknown method {m}"),
93                                ));
94                            }
95                        }
96                    }
97                    "group" => {
98                        // Expects equal sign
99                        input.parse::<syn::Token![=]>()?;
100
101                        group = Some(input.parse()?);
102                    }
103                    "path" => {
104                        // Expects equal sign
105                        input.parse::<syn::Token![=]>()?;
106
107                        let path_str: LitStr = input.parse()?;
108                        let (path_, path_params_) = Self::parse_path(path_str)?;
109                        path = Some(path_);
110                        path_params = path_params_;
111                    }
112                    "query" => {
113                        // Expects equal sign
114                        input.parse::<syn::Token![=]>()?;
115
116                        query_params = Some(input.parse()?);
117                    }
118                    "body" => {
119                        // Expects equal sign
120                        input.parse::<syn::Token![=]>()?;
121
122                        body_params = Some(input.parse()?);
123                    }
124                    _ => {
125                        return Err(syn::Error::new(
126                            ident.span(),
127                            format!(
128                                "Unknown attribute '{ident}'. Allowed attributes are: 'method', 'group', 'path', 'query', 'body'.",
129                            ),
130                        ));
131                    }
132                }
133            }
134
135            if !input.is_empty() {
136                input.parse::<syn::Token![,]>()?;
137            }
138        }
139
140        if path.is_none() {
141            return Err(syn::Error::new(
142                proc_macro2::Span::call_site(),
143                "Missing path",
144            ));
145        }
146
147        if name.is_none() {
148            return Err(syn::Error::new(
149                proc_macro2::Span::call_site(),
150                "Missing name",
151            ));
152        }
153
154        if return_type.is_none() {
155            return Err(syn::Error::new(
156                proc_macro2::Span::call_site(),
157                "Missing return type",
158            ));
159        }
160
161        if handler.is_none() {
162            return Err(syn::Error::new(
163                proc_macro2::Span::call_site(),
164                "Missing handler",
165            ));
166        }
167
168        if method.is_none() {
169            return Err(syn::Error::new(
170                proc_macro2::Span::call_site(),
171                "Missing method",
172            ));
173        }
174
175        if group.is_none() {
176            return Err(syn::Error::new(
177                proc_macro2::Span::call_site(),
178                "Missing group",
179            ));
180        }
181
182        Ok(Args {
183            name: name.unwrap(),
184            return_type: return_type.unwrap(),
185            group: group.unwrap(),
186            method: method.unwrap(),
187            handler_attributes,
188            handler: handler.unwrap(),
189            path: path.unwrap(),
190            path_params,
191            query_params,
192            body_params,
193            parameters,
194        })
195    }
196}
197
198#[derive(PartialEq)]
199enum ParsePathState {
200    Path,
201    PathParamName,
202    PathParamType,
203}
204
205impl Args {
206    fn parse_path(literal: LitStr) -> syn::Result<(String, HashMap<Ident, Type>)> {
207        let path = literal.value();
208        let mut real_path = String::new();
209        let mut path_params = HashMap::new();
210        let mut state = ParsePathState::Path;
211        let mut current_name = String::new();
212        let mut current_type = String::new();
213
214        for c in path.chars() {
215            match c {
216                '{' => {
217                    if state == ParsePathState::Path {
218                        state = ParsePathState::PathParamName;
219                    } else {
220                        return Err(syn::Error::new(
221                            literal.span(),
222                            "Expected one of character, `:` or `}`, found `{`",
223                        ));
224                    }
225                }
226                '}' => {
227                    if state == ParsePathState::PathParamType {
228                        let param_name = proc_macro2::TokenStream::from_str(&current_name)
229                            .map_err(|_| {
230                                syn::Error::new(literal.span(), "Invalid path parameter name")
231                            })?;
232                        let param_type = proc_macro2::TokenStream::from_str(&current_type)
233                            .map_err(|_| {
234                                syn::Error::new(literal.span(), "Invalid path parameter type")
235                            })?;
236                        path_params.insert(syn::parse2(param_name)?, syn::parse2(param_type)?);
237
238                        real_path.push(':');
239                        real_path.push_str(&current_name);
240
241                        current_name = String::new();
242                        current_type = String::new();
243                        state = ParsePathState::Path;
244                    } else if state == ParsePathState::PathParamName {
245                        return Err(syn::Error::new(
246                            literal.span(),
247                            "Expected one of character or `:`, found `}`",
248                        ));
249                    } else {
250                        return Err(syn::Error::new(
251                            literal.span(),
252                            "Expected one of character or `{`, found `}`",
253                        ));
254                    }
255                }
256                ':' => {
257                    if state == ParsePathState::PathParamName {
258                        state = ParsePathState::PathParamType;
259                    } else if state != ParsePathState::Path {
260                        return Err(syn::Error::new(
261                            literal.span(),
262                            "Expected one of character or `{`, found `:`",
263                        ));
264                    } else {
265                        return Err(syn::Error::new(
266                            literal.span(),
267                            "Expected one of character or `}`, found `:`",
268                        ));
269                    }
270                }
271                _ => match state {
272                    ParsePathState::Path => {
273                        real_path.push(c);
274                    }
275                    ParsePathState::PathParamName => {
276                        current_name.push(c);
277                    }
278                    ParsePathState::PathParamType => {
279                        current_type.push(c);
280                    }
281                },
282            }
283        }
284
285        if state != ParsePathState::Path {
286            return Err(syn::Error::new(
287                literal.span(),
288                "Expected one of character or `}`, found EOF",
289            ));
290        }
291
292        Ok((path, path_params))
293    }
294}
295
296/// Creates a route and add it to the group
297///
298/// # Example
299/// ```
300/// route!(
301///     group = Routes,
302///     path = "/echo/{str:String}",
303///     method = GET,
304///     async test_fn -> String { str }
305/// );
306/// ```
307#[proc_macro]
308pub fn route(attr: proc_macro::TokenStream) -> proc_macro::TokenStream {
309    // TODO: cleanup
310    let args = parse_macro_input!(attr as Args);
311
312    let path_params = args.path_params;
313    let path_idents = path_params.keys().collect::<Vec<_>>();
314    let path_types = path_params.values().collect::<Vec<_>>();
315
316    let path_params = if path_params.is_empty() {
317        quote::quote! {}
318    } else {
319        quote::quote! {
320            axum::extract::Path((#(#path_idents),*)): axum::extract::Path<(#(#path_types),*)>,
321        }
322    };
323
324    let query_params = if let Some(q) = args.query_params {
325        quote::quote! { query: #q, }
326    } else {
327        quote::quote! {}
328    };
329
330    let body_params = if let Some(b) = args.body_params {
331        quote::quote! { body: #b, }
332    } else {
333        quote::quote! {}
334    };
335
336    let route_name = Ident::new(
337        &format!(
338            "ROUTE_{}",
339            stringcase::macro_case(args.name.to_string().as_str())
340        ),
341        proc_macro2::Span::call_site(),
342    );
343    let name = args.name;
344    let path = args.path;
345    let parameters = if !args.parameters.trailing_punct() && !args.parameters.is_empty() {
346        let parameters = args.parameters;
347        quote::quote! { #parameters, }
348    } else {
349        let parameters = args.parameters;
350        quote::quote! { #parameters }
351    };
352    let return_type = args.return_type;
353    let block = args.handler;
354    let group = args.group;
355    let handler_attributes = args.handler_attributes;
356
357    let handler_def = quote::quote! {
358        #(#handler_attributes)*
359        async fn #name(#path_params #query_params #parameters #body_params) -> #return_type #block
360    };
361
362    let handler = match args.method {
363        Method::Get => quote::quote! { axum::routing::get(#name) },
364        Method::Post => quote::quote! { axum::routing::post(#name) },
365        Method::Put => quote::quote! { axum::routing::put(#name) },
366        Method::Patch => quote::quote! { axum::routing::patch(#name) },
367        Method::Delete => quote::quote! { axum::routing::delete(#name) },
368        Method::Head => quote::quote! { axum::routing::head(#name) },
369        Method::Options => quote::quote! { axum::routing::options(#name) },
370        Method::Trace => quote::quote! { axum::routing::trace(#name) },
371        Method::Connect => quote::quote! { axum::routing::connect(#name) },
372    };
373
374    let result = quote::quote! {
375        #handler_def
376
377        pub static #route_name: #group = #group::new(#path, |r, _| r.route(#path, #handler));
378
379        axum_distributed_routing::inventory::submit! {
380            #route_name
381        }
382    };
383
384    result.into()
385}