Skip to main content

axum_openapi3_derive/
lib.rs

1//! Derive macro for defining an endpoint.
2//! See [`axum_openapi3`](https://crates.io/crates/axum-openapi3) for more information.
3
4use handler_signature::{parse_handler_arguments, parse_handler_ret_type, HandlerArgument};
5use macro_arguments::MacroArgs;
6use quote::quote;
7use std::fmt::Write;
8use syn::{parse_macro_input, spanned::Spanned, ItemFn};
9
10mod handler_signature;
11mod macro_arguments;
12mod util;
13
14/// Derive macro for defining an endpoint.
15#[proc_macro_attribute]
16pub fn endpoint(
17    args: proc_macro::TokenStream,
18    input: proc_macro::TokenStream,
19) -> proc_macro::TokenStream {
20    let macro_args = parse_macro_input!(args as MacroArgs);
21    let input_fn = parse_macro_input!(input as ItemFn);
22
23    let fn_args = match parse_handler_arguments(&input_fn.sig) {
24        Ok(args) => args,
25        Err(err) => return err.to_compile_error().into(),
26    };
27
28    let ret_type = match parse_handler_ret_type(&input_fn.sig) {
29        Ok(ty) => ty,
30        Err(err) => return err.to_compile_error().into(),
31    };
32
33    let fn_name = input_fn.sig.ident.clone();
34    let fn_name_str = fn_name.to_string();
35
36    let path = macro_args.path;
37    let method = macro_args.method.to_string();
38    let description = macro_args.description.unwrap_or_default();
39
40    let method: http::Method = method.parse().unwrap(); //The HTTP method parsing fails before
41
42    let (utoipa_method_name, axum_method) = match get_method_tokens(method) {
43        Ok(value) => value,
44        Err(_) => {
45            return syn::Error::new(input_fn.sig.span(), "Unsupported HTTP method")
46                .to_compile_error()
47                .into()
48        }
49    };
50
51    let ret_type = get_ret_type_token(ret_type);
52
53    let request_body = get_request_body_token(&fn_args);
54
55    let query_params = get_query_params_token(&fn_args);
56
57    let path_param_names = extract_params(&path);
58    let path_params = get_path_params_token(&fn_args, path_param_names);
59
60    let state = get_state_token(fn_args);
61
62    let path_for_openapi = transform_route(&path);
63
64    let public = get_public_token(&input_fn.vis);
65
66    let output = quote! {
67        #public fn #fn_name() -> (&'static str, axum::routing::MethodRouter < #state , std::convert::Infallible >)
68        {
69            #input_fn
70
71            let handler = axum::routing:: #axum_method (#fn_name);
72
73            let op_builder = axum_openapi3::utoipa::openapi::path::OperationBuilder::new()
74                .description(Some(#description));
75
76            #ret_type
77
78            #request_body
79
80            #query_params
81
82            #path_params
83
84            let op_builder = op_builder.operation_id(Some(#fn_name_str));
85
86            let paths = axum_openapi3::utoipa::openapi::PathsBuilder::new()
87                .path(#path_for_openapi, axum_openapi3::utoipa::openapi::path::PathItemBuilder::new()
88                    .operation(
89                        axum_openapi3::utoipa::openapi::HttpMethod:: #utoipa_method_name,
90                        op_builder.build()
91                    )
92                    .build())
93                .build();
94
95            axum_openapi3::ENDPOINTS.lock().unwrap().push(paths);
96
97            (#path, handler)
98        }
99
100    };
101
102    output.into()
103}
104
105fn get_ret_type_token(ret_type: Option<String>) -> proc_macro2::TokenStream {
106    let ret_type = if let Some(ret_type) = ret_type {
107        format!(
108            r#"
109let response_schema = < {ret_type} as axum_openapi3::utoipa::PartialSchema > :: schema();
110let op_builder = op_builder.response(
111    "200", 
112    axum_openapi3::utoipa::openapi::ResponseBuilder::new()
113        .content(
114            "application/json", 
115            axum_openapi3::utoipa::openapi::ContentBuilder::new()
116                .schema(Some(response_schema))
117                .build()
118        )
119        .build()
120);
121            "#
122        )
123    } else {
124        "let op_builder = op_builder;".to_string()
125    };
126    let ret_type: proc_macro2::TokenStream = ret_type.parse().unwrap();
127    ret_type
128}
129
130fn get_request_body_token(fn_args: &[HandlerArgument]) -> proc_macro2::TokenStream {
131    let request_body = fn_args.iter().find_map(|arg| match arg {
132        HandlerArgument::RequestBody(ty) => Some(format!(
133            r#"
134let request_body = < {ty} as axum_openapi3::utoipa::PartialSchema > :: schema();
135let op_builder = op_builder
136        .request_body(Some(
137            axum_openapi3::utoipa::openapi::request_body::RequestBodyBuilder::new()
138                .content(
139                    "application/json",
140                    axum_openapi3::utoipa::openapi::ContentBuilder::new()
141                        .schema(Some(request_body))
142                        .build()
143                )
144                .build()
145        ));
146            "#
147        )),
148        _ => None,
149    });
150    let request_body: proc_macro2::TokenStream = if let Some(request_body) = request_body {
151        request_body.parse().unwrap()
152    } else {
153        "let op_builder = op_builder;".parse().unwrap()
154    };
155    request_body
156}
157
158fn get_state_token(fn_args: Vec<HandlerArgument>) -> proc_macro2::TokenStream {
159    let state = fn_args
160        .iter()
161        .find_map(|arg| match arg {
162            HandlerArgument::State(ty) => Some(ty.as_str()),
163            _ => None,
164        })
165        .unwrap_or("()");
166    let state: proc_macro2::TokenStream = state.parse().unwrap();
167    state
168}
169
170fn get_public_token(public: &syn::Visibility) -> proc_macro2::TokenStream {
171    let public: proc_macro2::TokenStream = match public {
172        syn::Visibility::Public(_) => "pub ".parse().unwrap(),
173        _ => "".parse().unwrap(),
174    };
175    public
176}
177
178fn get_path_params_token(
179    fn_args: &[HandlerArgument],
180    path_param_names: Vec<String>,
181) -> proc_macro2::TokenStream {
182    let path_params: String = fn_args
183        .iter()
184        .filter_map(|arg| match arg {
185            HandlerArgument::Path(ty) => Some(ty),
186            _ => None,
187        })
188        .zip(path_param_names.iter())
189        .fold(String::new(), |mut acc, (ty, name)| {
190            let _ = write!(
191                acc,
192                r#"
193let schema = < {ty} as axum_openapi3::utoipa::PartialSchema > :: schema();
194let path_param = axum_openapi3::utoipa::openapi::path::ParameterBuilder::new()
195    .parameter_in(axum_openapi3::utoipa::openapi::path::ParameterIn::Path)
196    .name("{name}")
197    .required(axum_openapi3::utoipa::openapi::Required::True)
198    .schema(Some(schema))
199    .build();
200
201let op_builder = op_builder
202    .parameter(path_param);
203"#
204            );
205            acc
206        });
207    let path_params: proc_macro2::TokenStream = if path_params.is_empty() {
208        "let op_builder = op_builder;".parse().unwrap()
209    } else {
210        path_params.parse().unwrap()
211    };
212    path_params
213}
214
215fn get_query_params_token(fn_args: &[HandlerArgument]) -> proc_macro2::TokenStream {
216    let query_params = fn_args.iter().find_map(|arg| {
217        match arg {
218            HandlerArgument::Query(ty) => Some(format!(r#"
219let query_params = < {ty} as axum_openapi3::utoipa::IntoParams > :: into_params(|| Some(axum_openapi3::utoipa::openapi::path::ParameterIn::Query));
220let op_builder = op_builder
221    .parameters(Some(query_params));
222            "#)),
223            _ => None,
224        }
225    });
226    let query_params: proc_macro2::TokenStream = if let Some(query_params) = query_params {
227        query_params.parse().unwrap()
228    } else {
229        "let op_builder = op_builder;".parse().unwrap()
230    };
231    query_params
232}
233
234fn get_method_tokens(
235    method: http::Method,
236) -> Result<(proc_macro2::TokenStream, proc_macro2::TokenStream), ()> {
237    let (utoipa_method_name, axum_method): (proc_macro2::TokenStream, proc_macro2::TokenStream) =
238        match method {
239            http::Method::GET => ("Get".parse().unwrap(), "get".parse().unwrap()),
240            http::Method::POST => ("Post".parse().unwrap(), "post".parse().unwrap()),
241            http::Method::PUT => ("Put".parse().unwrap(), "put".parse().unwrap()),
242            http::Method::DELETE => ("Delete".parse().unwrap(), "delete".parse().unwrap()),
243            http::Method::HEAD => ("Head".parse().unwrap(), "head".parse().unwrap()),
244            http::Method::OPTIONS => ("Options".parse().unwrap(), "options".parse().unwrap()),
245            http::Method::CONNECT => ("Connect".parse().unwrap(), "connect".parse().unwrap()),
246            http::Method::PATCH => ("Patch".parse().unwrap(), "patch".parse().unwrap()),
247            // Ensure the HTTP method is valid
248            _ => return Err(()),
249        };
250    Ok((utoipa_method_name, axum_method))
251}
252
253fn extract_params(input: &str) -> Vec<String> {
254    input
255        .split('/')
256        .filter_map(|segment| {
257            if segment.starts_with('{') && segment.ends_with("}") {
258                Some(
259                    segment
260                        .trim_start_matches('{')
261                        .trim_end_matches("}")
262                        .to_string(),
263                )
264            } else {
265                None
266            }
267        })
268        .collect()
269}
270
271fn transform_route(route: &str) -> String {
272    route
273        .split('/') // Split the route by '/'
274        .map(|segment| {
275            if let Some(stripped) = segment.strip_prefix(':') {
276                format!("{{{}}}", stripped) // Replace ':id' with '{id}'
277            } else {
278                segment.to_string() // Keep other segments unchanged
279            }
280        })
281        .collect::<Vec<_>>() // Collect transformed segments
282        .join("/") // Rejoin segments into a single string
283}
284
285#[cfg(test)]
286mod tests {
287    #[test]
288    fn test_extract_params() {
289        assert_eq!(super::extract_params("/foo/{id}/bar"), vec!["id"]);
290        assert_eq!(
291            super::extract_params("/foo/{id}/bar/{baz}"),
292            vec!["id", "baz"]
293        );
294        assert_eq!(
295            super::extract_params("/foo/{id}/bar/{baz}/"),
296            vec!["id", "baz"]
297        );
298        assert_eq!(
299            super::extract_params("/foo/{id}/bar/{baz}/{qux}"),
300            vec!["id", "baz", "qux"]
301        );
302    }
303
304    #[test]
305    fn test_transform_route() {
306        let routes = vec![
307            ("/todos", "/todos"),
308            ("/todos/:id", "/todos/{id}"),
309            ("/todos/:id/foo", "/todos/{id}/foo"),
310            ("/bar/:bar_id/foo/:foo_id", "/bar/{bar_id}/foo/{foo_id}"),
311            (
312                "/bar/{bar_id}/foo/{foo_id}/baz",
313                "/bar/{bar_id}/foo/{foo_id}/baz",
314            ),
315        ];
316
317        for (input, expected) in routes {
318            assert_eq!(super::transform_route(input), expected);
319        }
320    }
321}