axum_typed_routing_macros/lib.rs
1use compilation::CompiledRoute;
2use parsing::{Method, Route};
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
5use std::collections::HashMap;
6use syn::{
7 parse::{Parse, ParseStream},
8 punctuated::Punctuated,
9 token::{Colon, Comma, Slash},
10 FnArg, GenericArgument, ItemFn, LitStr, Meta, PathArguments, Signature, Type,
11};
12#[macro_use]
13extern crate quote;
14#[macro_use]
15extern crate syn;
16
17mod compilation;
18mod parsing;
19
20/// A macro that generates statically-typed routes for axum handlers.
21///
22/// # Syntax
23/// ```ignore
24/// #[route(<METHOD> "<PATH>" [with <STATE>])]
25/// ```
26/// - `METHOD` is the HTTP method, such as `GET`, `POST`, `PUT`, etc.
27/// - `PATH` is the path of the route, with optional path parameters and query parameters,
28/// e.g. `/item/:id?amount&offset`.
29/// - `STATE` is the type of axum-state, passed to the handler. This is optional, and if not
30/// specified, the state type is guessed based on the parameters of the handler.
31///
32/// # Example
33/// ```
34/// use axum::extract::{State, Json};
35/// use axum_typed_routing_macros::route;
36///
37/// #[route(GET "/item/:id?amount&offset")]
38/// async fn item_handler(
39/// id: u32,
40/// amount: Option<u32>,
41/// offset: Option<u32>,
42/// State(state): State<String>,
43/// Json(json): Json<u32>,
44/// ) -> String {
45/// todo!("handle request")
46/// }
47/// ```
48///
49/// # State type
50/// Normally, the state-type is guessed based on the parameters of the function:
51/// If the function has a parameter of type `[..]::State<T>`, then `T` is used as the state type.
52/// This should work for most cases, however when not sufficient, the state type can be specified
53/// explicitly using the `with` keyword:
54/// ```ignore
55/// #[route(GET "/item/:id?amount&offset" with String)]
56/// ```
57///
58/// # Internals
59/// The macro expands to a function with signature `fn() -> (&'static str, axum::routing::MethodRouter<S>)`.
60/// The first element of the tuple is the path, and the second is axum's `MethodRouter`.
61///
62/// The path and query are extracted using axum's `extract::Path` and `extract::Query` extractors, as the first
63/// and second parameters of the function. The remaining parameters are the parameters of the handler.
64#[proc_macro_attribute]
65pub fn route(attr: TokenStream, mut item: TokenStream) -> TokenStream {
66 match _route(attr, item.clone(), false) {
67 Ok(tokens) => tokens.into(),
68 Err(err) => {
69 let err: TokenStream = err.to_compile_error().into();
70 item.extend(err);
71 item
72 }
73 }
74}
75
76/// Same as [`macro@route`], but with support for OpenApi using `aide`. See [`macro@route`] for more
77/// information and examples.
78///
79/// # Syntax
80/// ```ignore
81/// #[api_route(<METHOD> "<PATH>" [with <STATE>] [{
82/// summary: "<SUMMARY>",
83/// description: "<DESCRIPTION>",
84/// id: "<ID>",
85/// tags: ["<TAG>", ..],
86/// hidden: <bool>,
87/// security: { <SCHEME>: ["<SCOPE>", ..], .. },
88/// responses: { <CODE>: <TYPE>, .. },
89/// transform: |op| { .. },
90/// }])]
91/// ```
92/// - `summary` is the OpenApi summary. If not specified, the first line of the function's doc-comments
93/// - `description` is the OpenApi description. If not specified, the rest of the function's doc-comments
94/// - `id` is the OpenApi operationId. If not specified, the function's name is used.
95/// - `tags` are the OpenApi tags.
96/// - `hidden` sets whether docs should be hidden for this route.
97/// - `security` is the OpenApi security requirements.
98/// - `responses` are the OpenApi responses.
99/// - `transform` is a closure that takes an `TransformOperation` and returns an `TransformOperation`.
100/// This may override the other options. (see the crate `aide` for more information).
101///
102/// # Example
103/// ```
104/// use axum::extract::{State, Json};
105/// use axum_typed_routing_macros::api_route;
106///
107/// #[api_route(GET "/item/:id?amount&offset" with String {
108/// summary: "Get an item",
109/// description: "Get an item by id",
110/// id: "get-item",
111/// tags: ["items"],
112/// hidden: false,
113/// security: { "bearer": ["read:items"] },
114/// responses: { 200: String },
115/// transform: |op| op.tag("private"),
116/// })]
117/// async fn item_handler(
118/// id: u32,
119/// amount: Option<u32>,
120/// offset: Option<u32>,
121/// State(state): State<String>,
122/// ) -> String {
123/// todo!("handle request")
124/// }
125/// ```
126#[proc_macro_attribute]
127pub fn api_route(attr: TokenStream, mut item: TokenStream) -> TokenStream {
128 match _route(attr, item.clone(), true) {
129 Ok(tokens) => tokens.into(),
130 Err(err) => {
131 let err: TokenStream = err.to_compile_error().into();
132 item.extend(err);
133 item
134 }
135 }
136}
137
138fn _route(attr: TokenStream, item: TokenStream, with_aide: bool) -> syn::Result<TokenStream2> {
139 // Parse the route and function
140 let route = syn::parse::<Route>(attr)?;
141 let function = syn::parse::<ItemFn>(item)?;
142
143 // Now we can compile the route
144 let route = CompiledRoute::from_route(route, &function, with_aide)?;
145 let path_extractor = route.path_extractor();
146 let query_extractor = route.query_extractor();
147 let query_params_struct = route.query_params_struct(with_aide);
148 let state_type = &route.state;
149 let axum_path = route.to_axum_path_string();
150 let http_method = route.method.to_axum_method_name();
151 let remaining_numbered_pats = route.remaining_pattypes_numbered(&function.sig.inputs);
152 let extracted_idents = route.extracted_idents();
153 let remaining_numbered_idents = remaining_numbered_pats.iter().map(|pat_type| &pat_type.pat);
154 let route_docs = route.to_doc_comments();
155
156 // Get the variables we need for code generation
157 let fn_name = &function.sig.ident;
158 let fn_output = &function.sig.output;
159 let vis = &function.vis;
160 let asyncness = &function.sig.asyncness;
161 let (impl_generics, ty_generics, where_clause) = &function.sig.generics.split_for_impl();
162 let ty_generics = ty_generics.as_turbofish();
163 let fn_docs = function
164 .attrs
165 .iter()
166 .filter(|attr| attr.path().is_ident("doc"));
167
168 let (aide_ident_docs, inner_fn_call, method_router_ty) = if with_aide {
169 let http_method = format_ident!("{}_with", http_method);
170 let summary = route
171 .get_oapi_summary()
172 .map(|summary| quote! { .summary(#summary) });
173 let description = route
174 .get_oapi_description()
175 .map(|description| quote! { .description(#description) });
176 let hidden = route
177 .get_oapi_hidden()
178 .map(|hidden| quote! { .hidden(#hidden) });
179 let tags = route.get_oapi_tags();
180 let id = route
181 .get_oapi_id(&function.sig)
182 .map(|id| quote! { .id(#id) });
183 let transform = route.get_oapi_transform()?;
184 let responses = route.get_oapi_responses();
185 let response_code = responses.iter().map(|response| &response.0);
186 let response_type = responses.iter().map(|response| &response.1);
187 let security = route.get_oapi_security();
188 let schemes = security.iter().map(|sec| &sec.0);
189 let scopes = security.iter().map(|sec| &sec.1);
190
191 (
192 route.ide_documentation_for_aide_methods(),
193 quote! {
194 ::aide::axum::routing::#http_method(
195 __inner__function__ #ty_generics,
196 |__op__| {
197 let __op__ = __op__
198 #summary
199 #description
200 #hidden
201 #id
202 #(.tag(#tags))*
203 #(.security_requirement_scopes::<Vec<&'static str>, _>(#schemes, vec![#(#scopes),*]))*
204 #(.response::<#response_code, #response_type>())*
205 ;
206 #transform
207 __op__
208 }
209 )
210 },
211 quote! { ::aide::axum::routing::ApiMethodRouter },
212 )
213 } else {
214 (
215 quote!(),
216 quote! { ::axum::routing::#http_method(__inner__function__ #ty_generics) },
217 quote! { ::axum::routing::MethodRouter },
218 )
219 };
220
221 // Generate the code
222 Ok(quote! {
223 #(#fn_docs)*
224 #route_docs
225 #vis fn #fn_name #impl_generics() -> (&'static str, #method_router_ty<#state_type>) #where_clause {
226
227 #query_params_struct
228
229 #aide_ident_docs
230 #asyncness fn __inner__function__ #impl_generics(
231 #path_extractor
232 #query_extractor
233 #remaining_numbered_pats
234 ) #fn_output #where_clause {
235 #function
236
237 #fn_name #ty_generics(#(#extracted_idents,)* #(#remaining_numbered_idents,)* ).await
238 }
239
240 (#axum_path, #inner_fn_call)
241 }
242 })
243}