axum_openapi3_derive/
lib.rs

1//! Derive macro for defining an endpoint.
2//! See [`axum_openapi3`](https://crates.io/crates/axum-openapi3) for more information.
3
4use handler_signature::{parse_handler_arguments, parse_handler_ret_type, HandlerArgument};
5use macro_arguments::MacroArgs;
6use quote::quote;
7use std::fmt::Write;
8use syn::{parse_macro_input, spanned::Spanned, ItemFn};
9
10mod handler_signature;
11mod macro_arguments;
12mod util;
13
14/// Derive macro for defining an endpoint.
15#[proc_macro_attribute]
16pub fn endpoint(
17    args: proc_macro::TokenStream,
18    input: proc_macro::TokenStream,
19) -> proc_macro::TokenStream {
20    let macro_args = parse_macro_input!(args as MacroArgs);
21    let input_fn = parse_macro_input!(input as ItemFn);
22
23    let fn_args = match parse_handler_arguments(&input_fn.sig) {
24        Ok(args) => args,
25        Err(err) => return err.to_compile_error().into(),
26    };
27
28    let ret_type = match parse_handler_ret_type(&input_fn.sig) {
29        Ok(ty) => ty,
30        Err(err) => return err.to_compile_error().into(),
31    };
32
33    let fn_name = input_fn.sig.ident.clone();
34    let fn_name_str = fn_name.to_string();
35
36    let path = macro_args.path;
37    let method = macro_args.method.to_string();
38    let description = macro_args.description.unwrap_or_default();
39
40    let method: http::Method = method.parse().unwrap(); //The HTTP method parsing fails before
41
42    let (utoipa_method_name, axum_method) = match get_method_tokens(method) {
43        Ok(value) => value,
44        Err(_) => {
45            return syn::Error::new(input_fn.sig.span(), "Unsupported HTTP method")
46                .to_compile_error()
47                .into()
48        }
49    };
50
51    let ret_type = get_ret_type_token(ret_type);
52
53    let request_body = get_request_body_token(&fn_args);
54
55    let query_params = get_query_params_token(&fn_args);
56
57    let path_param_names = extract_params(&path);
58    let path_params = get_path_params_token(&fn_args, path_param_names);
59
60    let state = get_state_token(fn_args);
61
62    let path_for_openapi = transform_route(&path);
63
64    let output = quote! {
65        fn #fn_name() -> (&'static str, axum::routing::MethodRouter < #state , std::convert::Infallible >)
66        {
67            #input_fn
68
69            let handler = axum::routing:: #axum_method (#fn_name);
70
71            let op_builder = axum_openapi3::utoipa::openapi::path::OperationBuilder::new()
72                .description(Some(#description));
73
74            #ret_type
75
76            #request_body
77
78            #query_params
79
80            #path_params
81
82            let op_builder = op_builder.operation_id(Some(#fn_name_str));
83
84            let paths = axum_openapi3::utoipa::openapi::PathsBuilder::new()
85                .path(#path_for_openapi, axum_openapi3::utoipa::openapi::path::PathItemBuilder::new()
86                    .operation(
87                        axum_openapi3::utoipa::openapi::HttpMethod:: #utoipa_method_name,
88                        op_builder.build()
89                    )
90                    .build())
91                .build();
92
93            axum_openapi3::ENDPOINTS.lock().unwrap().push(paths);
94
95            (#path, handler)
96        }
97
98    };
99
100    output.into()
101}
102
103fn get_ret_type_token(ret_type: Option<String>) -> proc_macro2::TokenStream {
104    let ret_type = if let Some(ret_type) = ret_type {
105        format!(
106            r#"
107let response_schema = < {ret_type} as axum_openapi3::utoipa::PartialSchema > :: schema();
108let op_builder = op_builder.response(
109    "200", 
110    axum_openapi3::utoipa::openapi::ResponseBuilder::new()
111        .content(
112            "application/json", 
113            axum_openapi3::utoipa::openapi::ContentBuilder::new()
114                .schema(Some(response_schema))
115                .build()
116        )
117        .build()
118);
119            "#
120        )
121    } else {
122        "let op_builder = op_builder;".to_string()
123    };
124    let ret_type: proc_macro2::TokenStream = ret_type.parse().unwrap();
125    ret_type
126}
127
128fn get_request_body_token(fn_args: &[HandlerArgument]) -> proc_macro2::TokenStream {
129    let request_body = fn_args.iter().find_map(|arg| match arg {
130        HandlerArgument::RequestBody(ty) => Some(format!(
131            r#"
132let request_body = < {ty} as axum_openapi3::utoipa::PartialSchema > :: schema();
133let op_builder = op_builder
134        .request_body(Some(
135            axum_openapi3::utoipa::openapi::request_body::RequestBodyBuilder::new()
136                .content(
137                    "application/json",
138                    axum_openapi3::utoipa::openapi::ContentBuilder::new()
139                        .schema(Some(request_body))
140                        .build()
141                )
142                .build()
143        ));
144            "#
145        )),
146        _ => None,
147    });
148    let request_body: proc_macro2::TokenStream = if let Some(request_body) = request_body {
149        request_body.parse().unwrap()
150    } else {
151        "let op_builder = op_builder;".parse().unwrap()
152    };
153    request_body
154}
155
156fn get_state_token(fn_args: Vec<HandlerArgument>) -> proc_macro2::TokenStream {
157    let state = fn_args
158        .iter()
159        .find_map(|arg| match arg {
160            HandlerArgument::State(ty) => Some(ty.as_str()),
161            _ => None,
162        })
163        .unwrap_or("()");
164    let state: proc_macro2::TokenStream = state.parse().unwrap();
165    state
166}
167
168fn get_path_params_token(
169    fn_args: &[HandlerArgument],
170    path_param_names: Vec<String>,
171) -> proc_macro2::TokenStream {
172    let path_params: String = fn_args
173        .iter()
174        .filter_map(|arg| match arg {
175            HandlerArgument::Path(ty) => Some(ty),
176            _ => None,
177        })
178        .zip(path_param_names.iter())
179        .fold(String::new(), |mut acc, (ty, name)| {
180            let _ = write!(
181                acc,
182                r#"
183let schema = < {ty} as axum_openapi3::utoipa::PartialSchema > :: schema();
184let path_param = axum_openapi3::utoipa::openapi::path::ParameterBuilder::new()
185    .parameter_in(axum_openapi3::utoipa::openapi::path::ParameterIn::Path)
186    .name("{name}")
187    .required(axum_openapi3::utoipa::openapi::Required::True)
188    .schema(Some(schema))
189    .build();
190
191let op_builder = op_builder
192    .parameter(path_param);
193"#
194            );
195            acc
196        });
197    let path_params: proc_macro2::TokenStream = if path_params.is_empty() {
198        "let op_builder = op_builder;".parse().unwrap()
199    } else {
200        path_params.parse().unwrap()
201    };
202    path_params
203}
204
205fn get_query_params_token(fn_args: &[HandlerArgument]) -> proc_macro2::TokenStream {
206    let query_params = fn_args.iter().find_map(|arg| {
207        match arg {
208            HandlerArgument::Query(ty) => Some(format!(r#"
209let query_params = < {ty} as axum_openapi3::utoipa::IntoParams > :: into_params(|| Some(axum_openapi3::utoipa::openapi::path::ParameterIn::Query));
210let op_builder = op_builder
211    .parameters(Some(query_params));
212            "#)),
213            _ => None,
214        }
215    });
216    let query_params: proc_macro2::TokenStream = if let Some(query_params) = query_params {
217        query_params.parse().unwrap()
218    } else {
219        "let op_builder = op_builder;".parse().unwrap()
220    };
221    query_params
222}
223
224fn get_method_tokens(
225    method: http::Method,
226) -> Result<(proc_macro2::TokenStream, proc_macro2::TokenStream), ()> {
227    let (utoipa_method_name, axum_method): (proc_macro2::TokenStream, proc_macro2::TokenStream) =
228        match method {
229            http::Method::GET => ("Get".parse().unwrap(), "get".parse().unwrap()),
230            http::Method::POST => ("Post".parse().unwrap(), "post".parse().unwrap()),
231            http::Method::PUT => ("Put".parse().unwrap(), "put".parse().unwrap()),
232            http::Method::DELETE => ("Delete".parse().unwrap(), "delete".parse().unwrap()),
233            http::Method::HEAD => ("Head".parse().unwrap(), "head".parse().unwrap()),
234            http::Method::OPTIONS => ("Options".parse().unwrap(), "options".parse().unwrap()),
235            http::Method::CONNECT => ("Connect".parse().unwrap(), "connect".parse().unwrap()),
236            http::Method::PATCH => ("Patch".parse().unwrap(), "patch".parse().unwrap()),
237            // Ensure the HTTP method is valid
238            _ => return Err(()),
239        };
240    Ok((utoipa_method_name, axum_method))
241}
242
243fn extract_params(input: &str) -> Vec<String> {
244    input
245        .split('/')
246        .filter_map(|segment| {
247            if segment.starts_with('{') && segment.ends_with("}") {
248                Some(
249                    segment
250                        .trim_start_matches('{')
251                        .trim_end_matches("}")
252                        .to_string(),
253                )
254            } else {
255                None
256            }
257        })
258        .collect()
259}
260
261fn transform_route(route: &str) -> String {
262    route
263        .split('/') // Split the route by '/'
264        .map(|segment| {
265            if let Some(stripped) = segment.strip_prefix(':') {
266                format!("{{{}}}", stripped) // Replace ':id' with '{id}'
267            } else {
268                segment.to_string() // Keep other segments unchanged
269            }
270        })
271        .collect::<Vec<_>>() // Collect transformed segments
272        .join("/") // Rejoin segments into a single string
273}
274
275#[cfg(test)]
276mod tests {
277    #[test]
278    fn test_extract_params() {
279        assert_eq!(super::extract_params("/foo/{id}/bar"), vec!["id"]);
280        assert_eq!(
281            super::extract_params("/foo/{id}/bar/{baz}"),
282            vec!["id", "baz"]
283        );
284        assert_eq!(
285            super::extract_params("/foo/{id}/bar/{baz}/"),
286            vec!["id", "baz"]
287        );
288        assert_eq!(
289            super::extract_params("/foo/{id}/bar/{baz}/{qux}"),
290            vec!["id", "baz", "qux"]
291        );
292    }
293
294    #[test]
295    fn test_transform_route() {
296        let routes = vec![
297            ("/todos", "/todos"),
298            ("/todos/:id", "/todos/{id}"),
299            ("/todos/:id/foo", "/todos/{id}/foo"),
300            ("/bar/:bar_id/foo/:foo_id", "/bar/{bar_id}/foo/{foo_id}"),
301            (
302                "/bar/{bar_id}/foo/{foo_id}/baz",
303                "/bar/{bar_id}/foo/{foo_id}/baz",
304            ),
305        ];
306
307        for (input, expected) in routes {
308            assert_eq!(super::transform_route(input), expected);
309        }
310    }
311}