axum_distributed_routing_macros/
lib.rs1use std::{collections::HashMap, str::FromStr};
2
3use syn::{
4 Attribute, Block, Ident, LitStr, PatType, Token, Type, ext::IdentExt, parenthesized,
5 parse::Parse, parse_macro_input, punctuated::Punctuated,
6};
7
8enum Method {
9 Get,
10 Post,
11 Put,
12 Patch,
13 Delete,
14 Head,
15 Options,
16 Trace,
17 Connect,
18}
19
20struct Args {
21 path: String,
22 path_params: HashMap<Ident, Type>,
23 query_params: Option<Type>,
24 body_params: Option<Type>,
25 parameters: Punctuated<PatType, Token![,]>,
26 name: Ident,
27 group: Type,
28 return_type: Type,
29 method: Method,
30 handler_attributes: Vec<Attribute>,
31 handler: Block,
32}
33
34impl Parse for Args {
35 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
36 let mut path = None;
37 let mut path_params = HashMap::new();
38 let mut query_params = None;
39 let mut body_params = None;
40 let mut parameters: Punctuated<PatType, Token![,]> = Punctuated::new();
41 let mut name = None;
42 let mut return_type = None;
43 let mut handler = None;
44 let mut method = None;
45 let mut group = None;
46 let mut handler_attributes = Vec::new();
47
48 while !input.is_empty() {
49 if input.peek(Token![#]) || input.peek(Token![async]) {
50 if handler.is_some() {
51 return Err(syn::Error::new(input.span(), "Handler is already defined"));
52 }
53
54 handler_attributes = input.call(Attribute::parse_outer)?;
55 input.parse::<Token![async]>()?;
56
57 name = Some(input.parse()?);
58
59 if input.peek(syn::token::Paren) {
60 let content;
61 let _ = parenthesized!(content in input);
62 parameters = Punctuated::parse_terminated(&content)?;
63 }
64
65 if input.peek(Token![->]) {
66 input.parse::<Token![->]>()?;
67 return_type = Some(input.parse()?);
68 }
69 handler = Some(input.parse()?);
70 } else {
71 let ident: Ident = input.call(Ident::parse_any)?;
72
73 match ident.to_string().as_str() {
74 "method" => {
75 input.parse::<syn::Token![=]>()?;
77
78 let method_ident = input.parse::<Ident>()?;
79 match method_ident.to_string().as_str() {
80 "GET" => method = Some(Method::Get),
81 "POST" => method = Some(Method::Post),
82 "PUT" => method = Some(Method::Put),
83 "PATCH" => method = Some(Method::Patch),
84 "DELETE" => method = Some(Method::Delete),
85 "HEAD" => method = Some(Method::Head),
86 "OPTIONS" => method = Some(Method::Options),
87 "TRACE" => method = Some(Method::Trace),
88 "CONNECT" => method = Some(Method::Connect),
89 m => {
90 return Err(syn::Error::new(
91 ident.span(),
92 format!("Unknown method {m}"),
93 ));
94 }
95 }
96 }
97 "group" => {
98 input.parse::<syn::Token![=]>()?;
100
101 group = Some(input.parse()?);
102 }
103 "path" => {
104 input.parse::<syn::Token![=]>()?;
106
107 let path_str: LitStr = input.parse()?;
108 let (path_, path_params_) = Self::parse_path(path_str)?;
109 path = Some(path_);
110 path_params = path_params_;
111 }
112 "query" => {
113 input.parse::<syn::Token![=]>()?;
115
116 query_params = Some(input.parse()?);
117 }
118 "body" => {
119 input.parse::<syn::Token![=]>()?;
121
122 body_params = Some(input.parse()?);
123 }
124 _ => {
125 return Err(syn::Error::new(
126 ident.span(),
127 format!(
128 "Unknown attribute '{ident}'. Allowed attributes are: 'method', 'group', 'path', 'query', 'body'.",
129 ),
130 ));
131 }
132 }
133 }
134
135 if !input.is_empty() {
136 input.parse::<syn::Token![,]>()?;
137 }
138 }
139
140 if path.is_none() {
141 return Err(syn::Error::new(
142 proc_macro2::Span::call_site(),
143 "Missing path",
144 ));
145 }
146
147 if name.is_none() {
148 return Err(syn::Error::new(
149 proc_macro2::Span::call_site(),
150 "Missing name",
151 ));
152 }
153
154 if return_type.is_none() {
155 return Err(syn::Error::new(
156 proc_macro2::Span::call_site(),
157 "Missing return type",
158 ));
159 }
160
161 if handler.is_none() {
162 return Err(syn::Error::new(
163 proc_macro2::Span::call_site(),
164 "Missing handler",
165 ));
166 }
167
168 if method.is_none() {
169 return Err(syn::Error::new(
170 proc_macro2::Span::call_site(),
171 "Missing method",
172 ));
173 }
174
175 if group.is_none() {
176 return Err(syn::Error::new(
177 proc_macro2::Span::call_site(),
178 "Missing group",
179 ));
180 }
181
182 Ok(Args {
183 name: name.unwrap(),
184 return_type: return_type.unwrap(),
185 group: group.unwrap(),
186 method: method.unwrap(),
187 handler_attributes,
188 handler: handler.unwrap(),
189 path: path.unwrap(),
190 path_params,
191 query_params,
192 body_params,
193 parameters,
194 })
195 }
196}
197
198#[derive(PartialEq)]
199enum ParsePathState {
200 Path,
201 PathParamName,
202 PathParamType,
203}
204
205impl Args {
206 fn parse_path(literal: LitStr) -> syn::Result<(String, HashMap<Ident, Type>)> {
207 let path = literal.value();
208 let mut real_path = String::new();
209 let mut path_params = HashMap::new();
210 let mut state = ParsePathState::Path;
211 let mut current_name = String::new();
212 let mut current_type = String::new();
213
214 for c in path.chars() {
215 match c {
216 '{' => {
217 if state == ParsePathState::Path {
218 state = ParsePathState::PathParamName;
219 } else {
220 return Err(syn::Error::new(
221 literal.span(),
222 "Expected one of character, `:` or `}`, found `{`",
223 ));
224 }
225 }
226 '}' => {
227 if state == ParsePathState::PathParamType {
228 let param_name = proc_macro2::TokenStream::from_str(¤t_name)
229 .map_err(|_| {
230 syn::Error::new(literal.span(), "Invalid path parameter name")
231 })?;
232 let param_type = proc_macro2::TokenStream::from_str(¤t_type)
233 .map_err(|_| {
234 syn::Error::new(literal.span(), "Invalid path parameter type")
235 })?;
236 path_params.insert(syn::parse2(param_name)?, syn::parse2(param_type)?);
237
238 real_path.push(':');
239 real_path.push_str(¤t_name);
240
241 current_name = String::new();
242 current_type = String::new();
243 state = ParsePathState::Path;
244 } else if state == ParsePathState::PathParamName {
245 return Err(syn::Error::new(
246 literal.span(),
247 "Expected one of character or `:`, found `}`",
248 ));
249 } else {
250 return Err(syn::Error::new(
251 literal.span(),
252 "Expected one of character or `{`, found `}`",
253 ));
254 }
255 }
256 ':' => {
257 if state == ParsePathState::PathParamName {
258 state = ParsePathState::PathParamType;
259 } else if state != ParsePathState::Path {
260 return Err(syn::Error::new(
261 literal.span(),
262 "Expected one of character or `{`, found `:`",
263 ));
264 } else {
265 return Err(syn::Error::new(
266 literal.span(),
267 "Expected one of character or `}`, found `:`",
268 ));
269 }
270 }
271 _ => match state {
272 ParsePathState::Path => {
273 real_path.push(c);
274 }
275 ParsePathState::PathParamName => {
276 current_name.push(c);
277 }
278 ParsePathState::PathParamType => {
279 current_type.push(c);
280 }
281 },
282 }
283 }
284
285 if state != ParsePathState::Path {
286 return Err(syn::Error::new(
287 literal.span(),
288 "Expected one of character or `}`, found EOF",
289 ));
290 }
291
292 Ok((path, path_params))
293 }
294}
295
296#[proc_macro]
308pub fn route(attr: proc_macro::TokenStream) -> proc_macro::TokenStream {
309 let args = parse_macro_input!(attr as Args);
311
312 let path_params = args.path_params;
313 let path_idents = path_params.keys().collect::<Vec<_>>();
314 let path_types = path_params.values().collect::<Vec<_>>();
315
316 let path_params = if path_params.is_empty() {
317 quote::quote! {}
318 } else {
319 quote::quote! {
320 axum::extract::Path((#(#path_idents),*)): axum::extract::Path<(#(#path_types),*)>,
321 }
322 };
323
324 let query_params = if let Some(q) = args.query_params {
325 quote::quote! { query: #q, }
326 } else {
327 quote::quote! {}
328 };
329
330 let body_params = if let Some(b) = args.body_params {
331 quote::quote! { body: #b, }
332 } else {
333 quote::quote! {}
334 };
335
336 let route_name = Ident::new(
337 &format!(
338 "ROUTE_{}",
339 stringcase::macro_case(args.name.to_string().as_str())
340 ),
341 proc_macro2::Span::call_site(),
342 );
343 let name = args.name;
344 let path = args.path;
345 let parameters = if !args.parameters.trailing_punct() && !args.parameters.is_empty() {
346 let parameters = args.parameters;
347 quote::quote! { #parameters, }
348 } else {
349 let parameters = args.parameters;
350 quote::quote! { #parameters }
351 };
352 let return_type = args.return_type;
353 let block = args.handler;
354 let group = args.group;
355 let handler_attributes = args.handler_attributes;
356
357 let handler_def = quote::quote! {
358 #(#handler_attributes)*
359 async fn #name(#path_params #query_params #parameters #body_params) -> #return_type #block
360 };
361
362 let handler = match args.method {
363 Method::Get => quote::quote! { axum::routing::get(#name) },
364 Method::Post => quote::quote! { axum::routing::post(#name) },
365 Method::Put => quote::quote! { axum::routing::put(#name) },
366 Method::Patch => quote::quote! { axum::routing::patch(#name) },
367 Method::Delete => quote::quote! { axum::routing::delete(#name) },
368 Method::Head => quote::quote! { axum::routing::head(#name) },
369 Method::Options => quote::quote! { axum::routing::options(#name) },
370 Method::Trace => quote::quote! { axum::routing::trace(#name) },
371 Method::Connect => quote::quote! { axum::routing::connect(#name) },
372 };
373
374 let result = quote::quote! {
375 #handler_def
376
377 pub static #route_name: #group = #group::new(#path, |r, _| r.route(#path, #handler));
378
379 axum_distributed_routing::inventory::submit! {
380 #route_name
381 }
382 };
383
384 result.into()
385}