axum_distributed_routing_macros/
lib.rs

1use std::{collections::HashMap, str::FromStr};
2
3use syn::{
4    ext::IdentExt, parenthesized, parse::Parse, parse_macro_input, punctuated::Punctuated, Attribute, Block, Ident, LitStr, PatType, Token, Type
5};
6
7enum Method {
8    Get,
9    Post,
10    Put,
11    Patch,
12    Delete,
13    Head,
14    Options,
15    Trace,
16    Connect,
17}
18
19struct Args {
20    path: String,
21    path_params: HashMap<Ident, Type>,
22    query_params: Option<Type>,
23    body_params: Option<Type>,
24    parameters: Punctuated<PatType, Token![,]>,
25    name: Ident,
26    group: Type,
27    return_type: Type,
28    method: Method,
29    handler_attributes: Vec<Attribute>,
30    handler: Block,
31}
32
33impl Parse for Args {
34    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
35        let mut path = None;
36        let mut path_params = HashMap::new();
37        let mut query_params = None;
38        let mut body_params = None;
39        let mut parameters: Punctuated<PatType, Token![,]> = Punctuated::new();
40        let mut name = None;
41        let mut return_type = None;
42        let mut handler = None;
43        let mut method = None;
44        let mut group = None;
45        let mut handler_attributes = Vec::new();
46
47        while !input.is_empty() {
48            if input.peek(Token![#]) || input.peek(Token![async]) {
49                if handler.is_some() {
50                    return Err(syn::Error::new(
51                        input.span(),
52                        "Handler is already defined",
53                    ));
54                }
55
56                handler_attributes = input.call(Attribute::parse_outer)?;
57                input.parse::<Token![async]>()?;
58
59                name = Some(input.parse()?);
60                
61                if input.peek(syn::token::Paren) {
62                    let content;
63                    let _ = parenthesized!(content in input);
64                    parameters = Punctuated::parse_terminated(&content)?;
65                }
66
67                if input.peek(Token![->]) {
68                    input.parse::<Token![->]>()?;
69                    return_type = Some(input.parse()?);
70                }
71                handler = Some(input.parse()?);
72            } else {
73                let ident: Ident = input.call(Ident::parse_any)?;
74
75                match ident.to_string().as_str() {
76                    "method" => {
77                        // Expects equal sign
78                        input.parse::<syn::Token![=]>()?;
79
80                        let method_ident = input.parse::<Ident>()?;
81                        match method_ident.to_string().as_str() {
82                            "GET" => method = Some(Method::Get),
83                            "POST" => method = Some(Method::Post),
84                            "PUT" => method = Some(Method::Put),
85                            "PATCH" => method = Some(Method::Patch),
86                            "DELETE" => method = Some(Method::Delete),
87                            "HEAD" => method = Some(Method::Head),
88                            "OPTIONS" => method = Some(Method::Options),
89                            "TRACE" => method = Some(Method::Trace),
90                            "CONNECT" => method = Some(Method::Connect),
91                            m => {
92                                return Err(syn::Error::new(
93                                    ident.span(),
94                                    format!("Unknown method {}", m),
95                                ));
96                            }
97                        }
98                    }
99                    "group" => {
100                        // Expects equal sign
101                        input.parse::<syn::Token![=]>()?;
102
103                        group = Some(input.parse()?);
104                    }
105                    "path" => {
106                        // Expects equal sign
107                        input.parse::<syn::Token![=]>()?;
108
109                        let path_str: LitStr = input.parse()?;
110                        let (path_, path_params_) = Self::parse_path(path_str)?;
111                        path = Some(path_);
112                        path_params = path_params_;
113                    }
114                    "query" => {
115                        // Expects equal sign
116                        input.parse::<syn::Token![=]>()?;
117
118                        query_params = Some(input.parse()?);
119                    }
120                    "body" => {
121                        // Expects equal sign
122                        input.parse::<syn::Token![=]>()?;
123
124                        body_params = Some(input.parse()?);
125                    }
126                    _ => {
127                        return Err(syn::Error::new(
128                            ident.span(),
129                            format!(
130                                "Unknown attribute '{}'. Allowed attributes are: 'method', 'group', 'path', 'query', 'body'.",
131                                ident
132                            ),
133                        ));
134                    }
135                }
136            }
137
138            if !input.is_empty() {
139                input.parse::<syn::Token![,]>()?;
140            }
141        }
142
143        if path.is_none() {
144            return Err(syn::Error::new(
145                proc_macro2::Span::call_site(),
146                "Missing path",
147            ));
148        }
149
150        if name.is_none() {
151            return Err(syn::Error::new(
152                proc_macro2::Span::call_site(),
153                "Missing name",
154            ));
155        }
156
157        if return_type.is_none() {
158            return Err(syn::Error::new(
159                proc_macro2::Span::call_site(),
160                "Missing return type",
161            ));
162        }
163
164        if handler.is_none() {
165            return Err(syn::Error::new(
166                proc_macro2::Span::call_site(),
167                "Missing handler",
168            ));
169        }
170
171        if method.is_none() {
172            return Err(syn::Error::new(
173                proc_macro2::Span::call_site(),
174                "Missing method",
175            ));
176        }
177
178        if group.is_none() {
179            return Err(syn::Error::new(
180                proc_macro2::Span::call_site(),
181                "Missing group",
182            ));
183        }
184
185        Ok(Args {
186            name: name.unwrap(),
187            return_type: return_type.unwrap(),
188            group: group.unwrap(),
189            method: method.unwrap(),
190            handler_attributes,
191            handler: handler.unwrap(),
192            path: path.unwrap(),
193            path_params,
194            query_params,
195            body_params,
196            parameters,
197        })
198    }
199}
200
201#[derive(PartialEq)]
202enum ParsePathState {
203    Path,
204    PathParamName,
205    PathParamType,
206}
207
208impl Args {
209    fn parse_path(literal: LitStr) -> syn::Result<(String, HashMap<Ident, Type>)> {
210        let path = literal.value();
211        let mut real_path = String::new();
212        let mut path_params = HashMap::new();
213        let mut state = ParsePathState::Path;
214        let mut current_name = String::new();
215        let mut current_type = String::new();
216
217        for c in path.chars() {
218            match c {
219                '{' => {
220                    if state == ParsePathState::Path {
221                        state = ParsePathState::PathParamName;
222                    } else {
223                        return Err(syn::Error::new(
224                            literal.span(),
225                            "Expected one of character, `:` or `}`, found `{`",
226                        ));
227                    }
228                }
229                '}' => {
230                    if state == ParsePathState::PathParamType {
231                        let param_name = proc_macro2::TokenStream::from_str(&current_name)
232                            .map_err(|_| {
233                                syn::Error::new(literal.span(), "Invalid path parameter name")
234                            })?;
235                        let param_type = proc_macro2::TokenStream::from_str(&current_type)
236                            .map_err(|_| {
237                                syn::Error::new(literal.span(), "Invalid path parameter type")
238                            })?;
239                        path_params.insert(syn::parse2(param_name)?, syn::parse2(param_type)?);
240
241                        real_path.push(':');
242                        real_path.push_str(&current_name);
243
244                        current_name = String::new();
245                        current_type = String::new();
246                        state = ParsePathState::Path;
247                    } else if state == ParsePathState::PathParamName {
248                        return Err(syn::Error::new(
249                            literal.span(),
250                            "Expected one of character or `:`, found `}`",
251                        ));
252                    } else {
253                        return Err(syn::Error::new(
254                            literal.span(),
255                            "Expected one of character or `{`, found `}`",
256                        ));
257                    }
258                }
259                ':' => {
260                    if state == ParsePathState::PathParamName {
261                        state = ParsePathState::PathParamType;
262                    } else if state != ParsePathState::Path {
263                        return Err(syn::Error::new(
264                            literal.span(),
265                            "Expected one of character or `{`, found `:`",
266                        ));
267                    } else {
268                        return Err(syn::Error::new(
269                            literal.span(),
270                            "Expected one of character or `}`, found `:`",
271                        ));
272                    }
273                }
274                _ => match state {
275                    ParsePathState::Path => {
276                        real_path.push(c);
277                    }
278                    ParsePathState::PathParamName => {
279                        current_name.push(c);
280                    }
281                    ParsePathState::PathParamType => {
282                        current_type.push(c);
283                    }
284                },
285            }
286        }
287
288        if state != ParsePathState::Path {
289            return Err(syn::Error::new(
290                literal.span(),
291                "Expected one of character or `}`, found EOF",
292            ));
293        }
294
295        Ok((path, path_params))
296    }
297}
298
299/// Creates a route and add it to the group
300///
301/// # Example
302/// ```
303/// route!(
304///     group = Routes,
305///     path = "/echo/{str:String}",
306///     method = GET,
307///     async test_fn -> String { str }
308/// );
309/// ```
310#[proc_macro]
311pub fn route(attr: proc_macro::TokenStream) -> proc_macro::TokenStream {
312    // TODO: cleanup
313    let args = parse_macro_input!(attr as Args);
314
315    let path_params = args.path_params;
316    let path_idents = path_params.keys().collect::<Vec<_>>();
317    let path_types = path_params.values().collect::<Vec<_>>();
318
319    let path_params = if path_params.is_empty() {
320        quote::quote! {}
321    } else {
322        quote::quote! {
323            axum::extract::Path((#(#path_idents),*)): axum::extract::Path<(#(#path_types),*)>,
324        }
325    };
326
327    let query_params = if let Some(q) = args.query_params {
328        quote::quote! { axum::extract::Query(query): axum::extract::Query<#q>, }
329    } else {
330        quote::quote! {}
331    };
332
333    let body_params = if let Some(b) = args.body_params {
334        quote::quote! { body: #b, }
335    } else {
336        quote::quote! {}
337    };
338
339    let route_name = Ident::new(
340        &format!(
341            "ROUTE_{}",
342            stringcase::macro_case(args.name.to_string().as_str())
343        ),
344        proc_macro2::Span::call_site(),
345    );
346    let name = args.name;
347    let path = args.path;
348    let parameters = if !args.parameters.trailing_punct() && !args.parameters.is_empty() {
349        let parameters = args.parameters;
350        quote::quote! { #parameters, }
351    } else {
352        let parameters = args.parameters;
353        quote::quote! { #parameters }
354    };
355    let return_type = args.return_type;
356    let block = args.handler;
357    let group = args.group;
358    let handler_attributes = args.handler_attributes;
359
360    let handler_def = quote::quote! {
361        #(#handler_attributes)*
362        async fn #name(#path_params #query_params #parameters #body_params) -> #return_type #block
363    };
364
365    let handler = match args.method {
366        Method::Get => quote::quote! { axum::routing::get(#name) },
367        Method::Post => quote::quote! { axum::routing::post(#name) },
368        Method::Put => quote::quote! { axum::routing::put(#name) },
369        Method::Patch => quote::quote! { axum::routing::patch(#name) },
370        Method::Delete => quote::quote! { axum::routing::delete(#name) },
371        Method::Head => quote::quote! { axum::routing::head(#name) },
372        Method::Options => quote::quote! { axum::routing::options(#name) },
373        Method::Trace => quote::quote! { axum::routing::trace(#name) },
374        Method::Connect => quote::quote! { axum::routing::connect(#name) },
375    };
376
377    let result = quote::quote! {
378        #handler_def
379
380        pub static #route_name: #group = #group::new(#path, |r, _| r.route(#path, #handler));
381
382        axum_distributed_routing::inventory::submit! {
383            #route_name
384        }
385    };
386
387    result.into()
388}