axum_openapi3_derive/
lib.rs1use handler_signature::{parse_handler_arguments, parse_handler_ret_type, HandlerArgument};
5use macro_arguments::MacroArgs;
6use quote::quote;
7use std::fmt::Write;
8use syn::{parse_macro_input, spanned::Spanned, ItemFn};
9
10mod handler_signature;
11mod macro_arguments;
12mod util;
13
14#[proc_macro_attribute]
16pub fn endpoint(
17 args: proc_macro::TokenStream,
18 input: proc_macro::TokenStream,
19) -> proc_macro::TokenStream {
20 let macro_args = parse_macro_input!(args as MacroArgs);
21 let input_fn = parse_macro_input!(input as ItemFn);
22
23 let fn_args = match parse_handler_arguments(&input_fn.sig) {
24 Ok(args) => args,
25 Err(err) => return err.to_compile_error().into(),
26 };
27
28 let ret_type = match parse_handler_ret_type(&input_fn.sig) {
29 Ok(ty) => ty,
30 Err(err) => return err.to_compile_error().into(),
31 };
32
33 let fn_name = input_fn.sig.ident.clone();
34 let fn_name_str = fn_name.to_string();
35
36 let path = macro_args.path;
37 let method = macro_args.method.to_string();
38 let description = macro_args.description.unwrap_or_default();
39
40 let method: http::Method = method.parse().unwrap(); let (utoipa_method_name, axum_method) = match get_method_tokens(method) {
43 Ok(value) => value,
44 Err(_) => {
45 return syn::Error::new(input_fn.sig.span(), "Unsupported HTTP method")
46 .to_compile_error()
47 .into()
48 }
49 };
50
51 let ret_type = get_ret_type_token(ret_type);
52
53 let request_body = get_request_body_token(&fn_args);
54
55 let query_params = get_query_params_token(&fn_args);
56
57 let path_param_names = extract_params(&path);
58 let path_params = get_path_params_token(&fn_args, path_param_names);
59
60 let state = get_state_token(fn_args);
61
62 let path_for_openapi = transform_route(&path);
63
64 let output = quote! {
65 fn #fn_name() -> (&'static str, axum::routing::MethodRouter < #state , std::convert::Infallible >)
66 {
67 #input_fn
68
69 let handler = axum::routing:: #axum_method (#fn_name);
70
71 let op_builder = axum_openapi3::utoipa::openapi::path::OperationBuilder::new()
72 .description(Some(#description));
73
74 #ret_type
75
76 #request_body
77
78 #query_params
79
80 #path_params
81
82 let op_builder = op_builder.operation_id(Some(#fn_name_str));
83
84 let paths = axum_openapi3::utoipa::openapi::PathsBuilder::new()
85 .path(#path_for_openapi, axum_openapi3::utoipa::openapi::path::PathItemBuilder::new()
86 .operation(
87 axum_openapi3::utoipa::openapi::HttpMethod:: #utoipa_method_name,
88 op_builder.build()
89 )
90 .build())
91 .build();
92
93 axum_openapi3::ENDPOINTS.lock().unwrap().push(paths);
94
95 (#path, handler)
96 }
97
98 };
99
100 output.into()
101}
102
103fn get_ret_type_token(ret_type: Option<String>) -> proc_macro2::TokenStream {
104 let ret_type = if let Some(ret_type) = ret_type {
105 format!(
106 r#"
107let response_schema = < {ret_type} as axum_openapi3::utoipa::PartialSchema > :: schema();
108let op_builder = op_builder.response(
109 "200",
110 axum_openapi3::utoipa::openapi::ResponseBuilder::new()
111 .content(
112 "application/json",
113 axum_openapi3::utoipa::openapi::ContentBuilder::new()
114 .schema(Some(response_schema))
115 .build()
116 )
117 .build()
118);
119 "#
120 )
121 } else {
122 "let op_builder = op_builder;".to_string()
123 };
124 let ret_type: proc_macro2::TokenStream = ret_type.parse().unwrap();
125 ret_type
126}
127
128fn get_request_body_token(fn_args: &[HandlerArgument]) -> proc_macro2::TokenStream {
129 let request_body = fn_args.iter().find_map(|arg| match arg {
130 HandlerArgument::RequestBody(ty) => Some(format!(
131 r#"
132let request_body = < {ty} as axum_openapi3::utoipa::PartialSchema > :: schema();
133let op_builder = op_builder
134 .request_body(Some(
135 axum_openapi3::utoipa::openapi::request_body::RequestBodyBuilder::new()
136 .content(
137 "application/json",
138 axum_openapi3::utoipa::openapi::ContentBuilder::new()
139 .schema(Some(request_body))
140 .build()
141 )
142 .build()
143 ));
144 "#
145 )),
146 _ => None,
147 });
148 let request_body: proc_macro2::TokenStream = if let Some(request_body) = request_body {
149 request_body.parse().unwrap()
150 } else {
151 "let op_builder = op_builder;".parse().unwrap()
152 };
153 request_body
154}
155
156fn get_state_token(fn_args: Vec<HandlerArgument>) -> proc_macro2::TokenStream {
157 let state = fn_args
158 .iter()
159 .find_map(|arg| match arg {
160 HandlerArgument::State(ty) => Some(ty.as_str()),
161 _ => None,
162 })
163 .unwrap_or("()");
164 let state: proc_macro2::TokenStream = state.parse().unwrap();
165 state
166}
167
168fn get_path_params_token(
169 fn_args: &[HandlerArgument],
170 path_param_names: Vec<String>,
171) -> proc_macro2::TokenStream {
172 let path_params: String = fn_args
173 .iter()
174 .filter_map(|arg| match arg {
175 HandlerArgument::Path(ty) => Some(ty),
176 _ => None,
177 })
178 .zip(path_param_names.iter())
179 .fold(String::new(), |mut acc, (ty, name)| {
180 let _ = write!(
181 acc,
182 r#"
183let schema = < {ty} as axum_openapi3::utoipa::PartialSchema > :: schema();
184let path_param = axum_openapi3::utoipa::openapi::path::ParameterBuilder::new()
185 .parameter_in(axum_openapi3::utoipa::openapi::path::ParameterIn::Path)
186 .name("{name}")
187 .required(axum_openapi3::utoipa::openapi::Required::True)
188 .schema(Some(schema))
189 .build();
190
191let op_builder = op_builder
192 .parameter(path_param);
193"#
194 );
195 acc
196 });
197 let path_params: proc_macro2::TokenStream = if path_params.is_empty() {
198 "let op_builder = op_builder;".parse().unwrap()
199 } else {
200 path_params.parse().unwrap()
201 };
202 path_params
203}
204
205fn get_query_params_token(fn_args: &[HandlerArgument]) -> proc_macro2::TokenStream {
206 let query_params = fn_args.iter().find_map(|arg| {
207 match arg {
208 HandlerArgument::Query(ty) => Some(format!(r#"
209let query_params = < {ty} as axum_openapi3::utoipa::IntoParams > :: into_params(|| Some(axum_openapi3::utoipa::openapi::path::ParameterIn::Query));
210let op_builder = op_builder
211 .parameters(Some(query_params));
212 "#)),
213 _ => None,
214 }
215 });
216 let query_params: proc_macro2::TokenStream = if let Some(query_params) = query_params {
217 query_params.parse().unwrap()
218 } else {
219 "let op_builder = op_builder;".parse().unwrap()
220 };
221 query_params
222}
223
224fn get_method_tokens(
225 method: http::Method,
226) -> Result<(proc_macro2::TokenStream, proc_macro2::TokenStream), ()> {
227 let (utoipa_method_name, axum_method): (proc_macro2::TokenStream, proc_macro2::TokenStream) =
228 match method {
229 http::Method::GET => ("Get".parse().unwrap(), "get".parse().unwrap()),
230 http::Method::POST => ("Post".parse().unwrap(), "post".parse().unwrap()),
231 http::Method::PUT => ("Put".parse().unwrap(), "put".parse().unwrap()),
232 http::Method::DELETE => ("Delete".parse().unwrap(), "delete".parse().unwrap()),
233 http::Method::HEAD => ("Head".parse().unwrap(), "head".parse().unwrap()),
234 http::Method::OPTIONS => ("Options".parse().unwrap(), "options".parse().unwrap()),
235 http::Method::CONNECT => ("Connect".parse().unwrap(), "connect".parse().unwrap()),
236 http::Method::PATCH => ("Patch".parse().unwrap(), "patch".parse().unwrap()),
237 _ => return Err(()),
239 };
240 Ok((utoipa_method_name, axum_method))
241}
242
243fn extract_params(input: &str) -> Vec<String> {
244 input
245 .split('/')
246 .filter_map(|segment| {
247 if segment.starts_with('{') && segment.ends_with("}") {
248 Some(
249 segment
250 .trim_start_matches('{')
251 .trim_end_matches("}")
252 .to_string(),
253 )
254 } else {
255 None
256 }
257 })
258 .collect()
259}
260
261fn transform_route(route: &str) -> String {
262 route
263 .split('/') .map(|segment| {
265 if let Some(stripped) = segment.strip_prefix(':') {
266 format!("{{{}}}", stripped) } else {
268 segment.to_string() }
270 })
271 .collect::<Vec<_>>() .join("/") }
274
275#[cfg(test)]
276mod tests {
277 #[test]
278 fn test_extract_params() {
279 assert_eq!(super::extract_params("/foo/{id}/bar"), vec!["id"]);
280 assert_eq!(
281 super::extract_params("/foo/{id}/bar/{baz}"),
282 vec!["id", "baz"]
283 );
284 assert_eq!(
285 super::extract_params("/foo/{id}/bar/{baz}/"),
286 vec!["id", "baz"]
287 );
288 assert_eq!(
289 super::extract_params("/foo/{id}/bar/{baz}/{qux}"),
290 vec!["id", "baz", "qux"]
291 );
292 }
293
294 #[test]
295 fn test_transform_route() {
296 let routes = vec![
297 ("/todos", "/todos"),
298 ("/todos/:id", "/todos/{id}"),
299 ("/todos/:id/foo", "/todos/{id}/foo"),
300 ("/bar/:bar_id/foo/:foo_id", "/bar/{bar_id}/foo/{foo_id}"),
301 (
302 "/bar/{bar_id}/foo/{foo_id}/baz",
303 "/bar/{bar_id}/foo/{foo_id}/baz",
304 ),
305 ];
306
307 for (input, expected) in routes {
308 assert_eq!(super::transform_route(input), expected);
309 }
310 }
311}