axum_typed_routing_macros/
lib.rs

1use compilation::CompiledRoute;
2use parsing::{Method, Route};
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
5use std::collections::HashMap;
6use syn::{
7    parse::{Parse, ParseStream},
8    punctuated::Punctuated,
9    token::{Colon, Comma, Slash},
10    FnArg, GenericArgument, ItemFn, LitStr, Meta, PathArguments, Signature, Type,
11};
12#[macro_use]
13extern crate quote;
14#[macro_use]
15extern crate syn;
16
17mod compilation;
18mod parsing;
19
20/// A macro that generates statically-typed routes for axum handlers.
21///
22/// # Syntax
23/// ```ignore
24/// #[route(<METHOD> "<PATH>" [with <STATE>])]
25/// ```
26/// - `METHOD` is the HTTP method, such as `GET`, `POST`, `PUT`, etc.
27/// - `PATH` is the path of the route, with optional path parameters and query parameters,
28///     e.g. `/item/:id?amount&offset`.
29/// - `STATE` is the type of axum-state, passed to the handler. This is optional, and if not
30///    specified, the state type is guessed based on the parameters of the handler.
31///
32/// # Example
33/// ```
34/// use axum::extract::{State, Json};
35/// use axum_typed_routing_macros::route;
36///
37/// #[route(GET "/item/:id?amount&offset")]
38/// async fn item_handler(
39///     id: u32,
40///     amount: Option<u32>,
41///     offset: Option<u32>,
42///     State(state): State<String>,
43///     Json(json): Json<u32>,
44/// ) -> String {
45///     todo!("handle request")
46/// }
47/// ```
48///
49/// # State type
50/// Normally, the state-type is guessed based on the parameters of the function:
51/// If the function has a parameter of type `[..]::State<T>`, then `T` is used as the state type.
52/// This should work for most cases, however when not sufficient, the state type can be specified
53/// explicitly using the `with` keyword:
54/// ```ignore
55/// #[route(GET "/item/:id?amount&offset" with String)]
56/// ```
57///
58/// # Internals
59/// The macro expands to a function with signature `fn() -> (&'static str, axum::routing::MethodRouter<S>)`.
60/// The first element of the tuple is the path, and the second is axum's `MethodRouter`.
61///
62/// The path and query are extracted using axum's `extract::Path` and `extract::Query` extractors, as the first
63/// and second parameters of the function. The remaining parameters are the parameters of the handler.
64#[proc_macro_attribute]
65pub fn route(attr: TokenStream, mut item: TokenStream) -> TokenStream {
66    match _route(attr, item.clone(), false) {
67        Ok(tokens) => tokens.into(),
68        Err(err) => {
69            let err: TokenStream = err.to_compile_error().into();
70            item.extend(err);
71            item
72        }
73    }
74}
75
76/// Same as [`macro@route`], but with support for OpenApi using `aide`. See [`macro@route`] for more
77/// information and examples.
78///
79/// # Syntax
80/// ```ignore
81/// #[api_route(<METHOD> "<PATH>" [with <STATE>] [{
82///     summary: "<SUMMARY>",
83///     description: "<DESCRIPTION>",
84///     id: "<ID>",
85///     tags: ["<TAG>", ..],
86///     hidden: <bool>,
87///     security: { <SCHEME>: ["<SCOPE>", ..], .. },
88///     responses: { <CODE>: <TYPE>, .. },
89///     transform: |op| { .. },
90/// }])]
91/// ```
92/// - `summary` is the OpenApi summary. If not specified, the first line of the function's doc-comments
93/// - `description` is the OpenApi description. If not specified, the rest of the function's doc-comments
94/// - `id` is the OpenApi operationId. If not specified, the function's name is used.
95/// - `tags` are the OpenApi tags.
96/// - `hidden` sets whether docs should be hidden for this route.
97/// - `security` is the OpenApi security requirements.
98/// - `responses` are the OpenApi responses.
99/// - `transform` is a closure that takes an `TransformOperation` and returns an `TransformOperation`.
100/// This may override the other options. (see the crate `aide` for more information).
101///
102/// # Example
103/// ```
104/// use axum::extract::{State, Json};
105/// use axum_typed_routing_macros::api_route;
106///
107/// #[api_route(GET "/item/:id?amount&offset" with String {
108///     summary: "Get an item",
109///     description: "Get an item by id",
110///     id: "get-item",
111///     tags: ["items"],
112///     hidden: false,
113///     security: { "bearer": ["read:items"] },
114///     responses: { 200: String },
115///     transform: |op| op.tag("private"),
116/// })]
117/// async fn item_handler(
118///     id: u32,
119///     amount: Option<u32>,
120///     offset: Option<u32>,
121///     State(state): State<String>,
122/// ) -> String {
123///     todo!("handle request")
124/// }
125/// ```
126#[proc_macro_attribute]
127pub fn api_route(attr: TokenStream, mut item: TokenStream) -> TokenStream {
128    match _route(attr, item.clone(), true) {
129        Ok(tokens) => tokens.into(),
130        Err(err) => {
131            let err: TokenStream = err.to_compile_error().into();
132            item.extend(err);
133            item
134        }
135    }
136}
137
138fn _route(attr: TokenStream, item: TokenStream, with_aide: bool) -> syn::Result<TokenStream2> {
139    // Parse the route and function
140    let route = syn::parse::<Route>(attr)?;
141    let function = syn::parse::<ItemFn>(item)?;
142
143    // Now we can compile the route
144    let route = CompiledRoute::from_route(route, &function, with_aide)?;
145    let path_extractor = route.path_extractor();
146    let query_extractor = route.query_extractor();
147    let query_params_struct = route.query_params_struct(with_aide);
148    let state_type = &route.state;
149    let axum_path = route.to_axum_path_string();
150    let http_method = route.method.to_axum_method_name();
151    let remaining_numbered_pats = route.remaining_pattypes_numbered(&function.sig.inputs);
152    let extracted_idents = route.extracted_idents();
153    let remaining_numbered_idents = remaining_numbered_pats.iter().map(|pat_type| &pat_type.pat);
154    let route_docs = route.to_doc_comments();
155
156    // Get the variables we need for code generation
157    let fn_name = &function.sig.ident;
158    let fn_output = &function.sig.output;
159    let vis = &function.vis;
160    let asyncness = &function.sig.asyncness;
161    let (impl_generics, ty_generics, where_clause) = &function.sig.generics.split_for_impl();
162    let ty_generics = ty_generics.as_turbofish();
163    let fn_docs = function
164        .attrs
165        .iter()
166        .filter(|attr| attr.path().is_ident("doc"));
167
168    let (aide_ident_docs, inner_fn_call, method_router_ty) = if with_aide {
169        let http_method = format_ident!("{}_with", http_method);
170        let summary = route
171            .get_oapi_summary()
172            .map(|summary| quote! { .summary(#summary) });
173        let description = route
174            .get_oapi_description()
175            .map(|description| quote! { .description(#description) });
176        let hidden = route
177            .get_oapi_hidden()
178            .map(|hidden| quote! { .hidden(#hidden) });
179        let tags = route.get_oapi_tags();
180        let id = route
181            .get_oapi_id(&function.sig)
182            .map(|id| quote! { .id(#id) });
183        let transform = route.get_oapi_transform()?;
184        let responses = route.get_oapi_responses();
185        let response_code = responses.iter().map(|response| &response.0);
186        let response_type = responses.iter().map(|response| &response.1);
187        let security = route.get_oapi_security();
188        let schemes = security.iter().map(|sec| &sec.0);
189        let scopes = security.iter().map(|sec| &sec.1);
190
191        (
192            route.ide_documentation_for_aide_methods(),
193            quote! {
194                ::aide::axum::routing::#http_method(
195                    __inner__function__ #ty_generics,
196                    |__op__| {
197                        let __op__ = __op__
198                            #summary
199                            #description
200                            #hidden
201                            #id
202                            #(.tag(#tags))*
203                            #(.security_requirement_scopes::<Vec<&'static str>, _>(#schemes, vec![#(#scopes),*]))*
204                            #(.response::<#response_code, #response_type>())*
205                            ;
206                        #transform
207                        __op__
208                    }
209                )
210            },
211            quote! { ::aide::axum::routing::ApiMethodRouter },
212        )
213    } else {
214        (
215            quote!(),
216            quote! { ::axum::routing::#http_method(__inner__function__ #ty_generics) },
217            quote! { ::axum::routing::MethodRouter },
218        )
219    };
220
221    // Generate the code
222    Ok(quote! {
223        #(#fn_docs)*
224        #route_docs
225        #vis fn #fn_name #impl_generics() -> (&'static str, #method_router_ty<#state_type>) #where_clause {
226
227            #query_params_struct
228
229            #aide_ident_docs
230            #asyncness fn __inner__function__ #impl_generics(
231                #path_extractor
232                #query_extractor
233                #remaining_numbered_pats
234            ) #fn_output #where_clause {
235                #function
236
237                #fn_name #ty_generics(#(#extracted_idents,)* #(#remaining_numbered_idents,)* ).await
238            }
239
240            (#axum_path, #inner_fn_call)
241        }
242    })
243}