axum_openapi3_derive/
lib.rs1use handler_signature::{parse_handler_arguments, parse_handler_ret_type, HandlerArgument};
5use macro_arguments::MacroArgs;
6use quote::quote;
7use std::fmt::Write;
8use syn::{parse_macro_input, spanned::Spanned, ItemFn};
9
10mod handler_signature;
11mod macro_arguments;
12mod util;
13
14#[proc_macro_attribute]
16pub fn endpoint(
17 args: proc_macro::TokenStream,
18 input: proc_macro::TokenStream,
19) -> proc_macro::TokenStream {
20 let macro_args = parse_macro_input!(args as MacroArgs);
21 let input_fn = parse_macro_input!(input as ItemFn);
22
23 let fn_args = match parse_handler_arguments(&input_fn.sig) {
24 Ok(args) => args,
25 Err(err) => return err.to_compile_error().into(),
26 };
27
28 let ret_type = match parse_handler_ret_type(&input_fn.sig) {
29 Ok(ty) => ty,
30 Err(err) => return err.to_compile_error().into(),
31 };
32
33 let fn_name = input_fn.sig.ident.clone();
34 let fn_name_str = fn_name.to_string();
35
36 let path = macro_args.path;
37 let method = macro_args.method.to_string();
38 let description = macro_args.description.unwrap_or_default();
39
40 let method: http::Method = method.parse().unwrap(); let (utoipa_method_name, axum_method) = match get_method_tokens(method) {
43 Ok(value) => value,
44 Err(_) => {
45 return syn::Error::new(input_fn.sig.span(), "Unsupported HTTP method")
46 .to_compile_error()
47 .into()
48 }
49 };
50
51 let ret_type = get_ret_type_token(ret_type);
52
53 let request_body = get_request_body_token(&fn_args);
54
55 let query_params = get_query_params_token(&fn_args);
56
57 let path_param_names = extract_params(&path);
58 let path_params = get_path_params_token(&fn_args, path_param_names);
59
60 let state = get_state_token(fn_args);
61
62 let path_for_openapi = transform_route(&path);
63
64 let public = get_public_token(&input_fn.vis);
65
66 let output = quote! {
67 #public fn #fn_name() -> (&'static str, axum::routing::MethodRouter < #state , std::convert::Infallible >)
68 {
69 #input_fn
70
71 let handler = axum::routing:: #axum_method (#fn_name);
72
73 let op_builder = axum_openapi3::utoipa::openapi::path::OperationBuilder::new()
74 .description(Some(#description));
75
76 #ret_type
77
78 #request_body
79
80 #query_params
81
82 #path_params
83
84 let op_builder = op_builder.operation_id(Some(#fn_name_str));
85
86 let paths = axum_openapi3::utoipa::openapi::PathsBuilder::new()
87 .path(#path_for_openapi, axum_openapi3::utoipa::openapi::path::PathItemBuilder::new()
88 .operation(
89 axum_openapi3::utoipa::openapi::HttpMethod:: #utoipa_method_name,
90 op_builder.build()
91 )
92 .build())
93 .build();
94
95 axum_openapi3::ENDPOINTS.lock().unwrap().push(paths);
96
97 (#path, handler)
98 }
99
100 };
101
102 output.into()
103}
104
105fn get_ret_type_token(ret_type: Option<String>) -> proc_macro2::TokenStream {
106 let ret_type = if let Some(ret_type) = ret_type {
107 format!(
108 r#"
109let response_schema = < {ret_type} as axum_openapi3::utoipa::PartialSchema > :: schema();
110let op_builder = op_builder.response(
111 "200",
112 axum_openapi3::utoipa::openapi::ResponseBuilder::new()
113 .content(
114 "application/json",
115 axum_openapi3::utoipa::openapi::ContentBuilder::new()
116 .schema(Some(response_schema))
117 .build()
118 )
119 .build()
120);
121 "#
122 )
123 } else {
124 "let op_builder = op_builder;".to_string()
125 };
126 let ret_type: proc_macro2::TokenStream = ret_type.parse().unwrap();
127 ret_type
128}
129
130fn get_request_body_token(fn_args: &[HandlerArgument]) -> proc_macro2::TokenStream {
131 let request_body = fn_args.iter().find_map(|arg| match arg {
132 HandlerArgument::RequestBody(ty) => Some(format!(
133 r#"
134let request_body = < {ty} as axum_openapi3::utoipa::PartialSchema > :: schema();
135let op_builder = op_builder
136 .request_body(Some(
137 axum_openapi3::utoipa::openapi::request_body::RequestBodyBuilder::new()
138 .content(
139 "application/json",
140 axum_openapi3::utoipa::openapi::ContentBuilder::new()
141 .schema(Some(request_body))
142 .build()
143 )
144 .build()
145 ));
146 "#
147 )),
148 _ => None,
149 });
150 let request_body: proc_macro2::TokenStream = if let Some(request_body) = request_body {
151 request_body.parse().unwrap()
152 } else {
153 "let op_builder = op_builder;".parse().unwrap()
154 };
155 request_body
156}
157
158fn get_state_token(fn_args: Vec<HandlerArgument>) -> proc_macro2::TokenStream {
159 let state = fn_args
160 .iter()
161 .find_map(|arg| match arg {
162 HandlerArgument::State(ty) => Some(ty.as_str()),
163 _ => None,
164 })
165 .unwrap_or("()");
166 let state: proc_macro2::TokenStream = state.parse().unwrap();
167 state
168}
169
170fn get_public_token(public: &syn::Visibility) -> proc_macro2::TokenStream {
171 let public: proc_macro2::TokenStream = match public {
172 syn::Visibility::Public(_) => "pub ".parse().unwrap(),
173 _ => "".parse().unwrap(),
174 };
175 public
176}
177
178fn get_path_params_token(
179 fn_args: &[HandlerArgument],
180 path_param_names: Vec<String>,
181) -> proc_macro2::TokenStream {
182 let path_params: String = fn_args
183 .iter()
184 .filter_map(|arg| match arg {
185 HandlerArgument::Path(ty) => Some(ty),
186 _ => None,
187 })
188 .zip(path_param_names.iter())
189 .fold(String::new(), |mut acc, (ty, name)| {
190 let _ = write!(
191 acc,
192 r#"
193let schema = < {ty} as axum_openapi3::utoipa::PartialSchema > :: schema();
194let path_param = axum_openapi3::utoipa::openapi::path::ParameterBuilder::new()
195 .parameter_in(axum_openapi3::utoipa::openapi::path::ParameterIn::Path)
196 .name("{name}")
197 .required(axum_openapi3::utoipa::openapi::Required::True)
198 .schema(Some(schema))
199 .build();
200
201let op_builder = op_builder
202 .parameter(path_param);
203"#
204 );
205 acc
206 });
207 let path_params: proc_macro2::TokenStream = if path_params.is_empty() {
208 "let op_builder = op_builder;".parse().unwrap()
209 } else {
210 path_params.parse().unwrap()
211 };
212 path_params
213}
214
215fn get_query_params_token(fn_args: &[HandlerArgument]) -> proc_macro2::TokenStream {
216 let query_params = fn_args.iter().find_map(|arg| {
217 match arg {
218 HandlerArgument::Query(ty) => Some(format!(r#"
219let query_params = < {ty} as axum_openapi3::utoipa::IntoParams > :: into_params(|| Some(axum_openapi3::utoipa::openapi::path::ParameterIn::Query));
220let op_builder = op_builder
221 .parameters(Some(query_params));
222 "#)),
223 _ => None,
224 }
225 });
226 let query_params: proc_macro2::TokenStream = if let Some(query_params) = query_params {
227 query_params.parse().unwrap()
228 } else {
229 "let op_builder = op_builder;".parse().unwrap()
230 };
231 query_params
232}
233
234fn get_method_tokens(
235 method: http::Method,
236) -> Result<(proc_macro2::TokenStream, proc_macro2::TokenStream), ()> {
237 let (utoipa_method_name, axum_method): (proc_macro2::TokenStream, proc_macro2::TokenStream) =
238 match method {
239 http::Method::GET => ("Get".parse().unwrap(), "get".parse().unwrap()),
240 http::Method::POST => ("Post".parse().unwrap(), "post".parse().unwrap()),
241 http::Method::PUT => ("Put".parse().unwrap(), "put".parse().unwrap()),
242 http::Method::DELETE => ("Delete".parse().unwrap(), "delete".parse().unwrap()),
243 http::Method::HEAD => ("Head".parse().unwrap(), "head".parse().unwrap()),
244 http::Method::OPTIONS => ("Options".parse().unwrap(), "options".parse().unwrap()),
245 http::Method::CONNECT => ("Connect".parse().unwrap(), "connect".parse().unwrap()),
246 http::Method::PATCH => ("Patch".parse().unwrap(), "patch".parse().unwrap()),
247 _ => return Err(()),
249 };
250 Ok((utoipa_method_name, axum_method))
251}
252
253fn extract_params(input: &str) -> Vec<String> {
254 input
255 .split('/')
256 .filter_map(|segment| {
257 if segment.starts_with('{') && segment.ends_with("}") {
258 Some(
259 segment
260 .trim_start_matches('{')
261 .trim_end_matches("}")
262 .to_string(),
263 )
264 } else {
265 None
266 }
267 })
268 .collect()
269}
270
271fn transform_route(route: &str) -> String {
272 route
273 .split('/') .map(|segment| {
275 if let Some(stripped) = segment.strip_prefix(':') {
276 format!("{{{}}}", stripped) } else {
278 segment.to_string() }
280 })
281 .collect::<Vec<_>>() .join("/") }
284
285#[cfg(test)]
286mod tests {
287 #[test]
288 fn test_extract_params() {
289 assert_eq!(super::extract_params("/foo/{id}/bar"), vec!["id"]);
290 assert_eq!(
291 super::extract_params("/foo/{id}/bar/{baz}"),
292 vec!["id", "baz"]
293 );
294 assert_eq!(
295 super::extract_params("/foo/{id}/bar/{baz}/"),
296 vec!["id", "baz"]
297 );
298 assert_eq!(
299 super::extract_params("/foo/{id}/bar/{baz}/{qux}"),
300 vec!["id", "baz", "qux"]
301 );
302 }
303
304 #[test]
305 fn test_transform_route() {
306 let routes = vec![
307 ("/todos", "/todos"),
308 ("/todos/:id", "/todos/{id}"),
309 ("/todos/:id/foo", "/todos/{id}/foo"),
310 ("/bar/:bar_id/foo/:foo_id", "/bar/{bar_id}/foo/{foo_id}"),
311 (
312 "/bar/{bar_id}/foo/{foo_id}/baz",
313 "/bar/{bar_id}/foo/{foo_id}/baz",
314 ),
315 ];
316
317 for (input, expected) in routes {
318 assert_eq!(super::transform_route(input), expected);
319 }
320 }
321}