use heck::{ToPascalCase, ToSnakeCase};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use openapi_trait_shared::codegen::operations::{OperationInfo, ParamInfo, ResponseStatus};
pub fn generate_router(mod_ident: &syn::Ident, ops: &[OperationInfo]) -> TokenStream {
let trait_name = format_ident!("{}Api", mod_ident.to_string().to_pascal_case());
let into_response_impls: Vec<TokenStream> =
ops.iter().map(generate_into_response_impl).collect();
let (query_structs, route_calls): (Vec<_>, Vec<_>) = ops.iter().map(generate_route).unzip();
quote! {
#(#into_response_impls)*
fn make_router<T, S>(__api: ::std::sync::Arc<T>) -> ::axum::Router<S>
where
T: #trait_name<S> + ::core::marker::Send + ::core::marker::Sync + 'static,
S: ::core::clone::Clone + ::core::marker::Send + ::core::marker::Sync + 'static,
{
#(#query_structs)*
::axum::Router::new()
#(#route_calls)*
}
}
}
fn generate_into_response_impl(op: &OperationInfo) -> TokenStream {
let resp_ident = format_ident!("{}Response", op.operation_id.to_pascal_case());
let arms: Vec<TokenStream> = op
.responses
.iter()
.map(|r| match &r.status {
ResponseStatus::Code(n) => {
let variant_ident = format_ident!("Status{}", n);
let status_ident = status_code_ident(*n);
if r.rust_type.is_some() {
quote! {
Self::#variant_ident(body) => (
::axum::http::StatusCode::#status_ident,
::axum::Json(body),
).into_response(),
}
} else {
quote! {
Self::#variant_ident => {
::axum::http::StatusCode::#status_ident
.into_response()
},
}
}
}
ResponseStatus::Default => {
quote! {
Self::Default(msg) => (
::axum::http::StatusCode::INTERNAL_SERVER_ERROR,
msg,
).into_response(),
}
}
})
.collect();
quote! {
impl ::axum::response::IntoResponse for #resp_ident {
fn into_response(self) -> ::axum::response::Response {
use ::axum::response::IntoResponse as _;
match self {
#(#arms)*
}
}
}
}
}
fn status_code_ident(n: u16) -> proc_macro2::Ident {
let name = match n {
200 => "OK",
201 => "CREATED",
202 => "ACCEPTED",
204 => "NO_CONTENT",
301 => "MOVED_PERMANENTLY",
302 => "FOUND",
304 => "NOT_MODIFIED",
400 => "BAD_REQUEST",
401 => "UNAUTHORIZED",
403 => "FORBIDDEN",
404 => "NOT_FOUND",
405 => "METHOD_NOT_ALLOWED",
409 => "CONFLICT",
410 => "GONE",
422 => "UNPROCESSABLE_ENTITY",
429 => "TOO_MANY_REQUESTS",
501 => "NOT_IMPLEMENTED",
502 => "BAD_GATEWAY",
503 => "SERVICE_UNAVAILABLE",
_ => "INTERNAL_SERVER_ERROR",
};
format_ident!("{}", name)
}
fn generate_route(op: &OperationInfo) -> (TokenStream, TokenStream) {
let method_ident = format_ident!("{}", op.operation_id.to_snake_case());
let req_ident = format_ident!("{}Request", op.operation_id.to_pascal_case());
let path = &op.path;
let routing_method = format_ident!("{}", op.method);
let (path_extractor, path_fields) = build_path_extractor(&op.path_params);
let (query_struct, query_extractor, query_fields) = build_query_extractor(op);
let (body_extractor, body_field) = build_body_extractor(op);
let header_fields: Vec<TokenStream> = op
.header_params
.iter()
.map(|p| {
let field_ident = format_ident!("{}", p.name.to_snake_case());
let header_name = &p.name;
quote! {
#field_ident: headers
.get(#header_name)
.and_then(|v| v.to_str().ok())
.map(::std::string::String::from),
}
})
.collect();
let mut closure_params: Vec<TokenStream> = vec![
quote! { state: ::axum::extract::State<S> },
quote! { headers: ::axum::http::HeaderMap },
];
if let Some(p) = path_extractor {
closure_params.push(p);
}
if let Some(p) = query_extractor {
closure_params.push(p);
}
if let Some(p) = body_extractor {
closure_params.push(p);
}
let mut req_fields: Vec<TokenStream> = path_fields;
req_fields.extend(query_fields);
req_fields.extend(header_fields);
if let Some(f) = body_field {
req_fields.push(f);
}
let route_call = quote! {
.route(#path, ::axum::routing::#routing_method({
let __api = __api.clone();
move |#(#closure_params),*| {
let __api = __api.clone();
async move {
use ::axum::response::IntoResponse as _;
let req = #req_ident { #(#req_fields)* };
match __api.#method_ident(req, state, headers).await {
::core::result::Result::Ok(r) => r.into_response(),
::core::result::Result::Err(e) => e.into_response(),
}
}
}
}))
};
(query_struct, route_call)
}
fn build_path_extractor(params: &[ParamInfo]) -> (Option<TokenStream>, Vec<TokenStream>) {
if params.is_empty() {
return (None, vec![]);
}
let types: Vec<TokenStream> = params.iter().map(|p| p.rust_type.clone()).collect();
let var_idents: Vec<proc_macro2::Ident> = params
.iter()
.map(|p| format_ident!("path_{}", p.name.to_snake_case()))
.collect();
let field_idents: Vec<proc_macro2::Ident> = params
.iter()
.map(|p| format_ident!("{}", p.name.to_snake_case()))
.collect();
let extractor = if params.len() == 1 {
let v = &var_idents[0];
let t = &types[0];
quote! {
::axum::extract::Path(#v):
::axum::extract::Path<#t>
}
} else {
quote! {
::axum::extract::Path((#(#var_idents),*)):
::axum::extract::Path<(#(#types),*)>
}
};
let inits: Vec<TokenStream> = field_idents
.iter()
.zip(var_idents.iter())
.map(|(f, v)| quote! { #f: #v, })
.collect();
(Some(extractor), inits)
}
fn build_query_extractor(
op: &OperationInfo,
) -> (TokenStream, Option<TokenStream>, Vec<TokenStream>) {
if op.query_params.is_empty() {
return (quote! {}, None, vec![]);
}
let struct_ident = format_ident!("{}QueryParams", op.operation_id.to_pascal_case());
let struct_fields: Vec<TokenStream> = op
.query_params
.iter()
.map(|p| {
let field_ident = format_ident!("{}", p.name.to_snake_case());
let rename_attr = if field_ident == p.name.as_str() {
quote! {}
} else {
let n = &p.name;
quote! { #[serde(rename = #n)] }
};
let inner = &p.rust_type;
let ftype = if p.required {
quote! { #inner }
} else {
quote! { ::core::option::Option<#inner> }
};
quote! {
#rename_attr
pub #field_ident: #ftype,
}
})
.collect();
let query_struct = quote! {
#[derive(::serde::Deserialize)]
struct #struct_ident {
#(#struct_fields)*
}
};
let extractor = quote! {
::axum::extract::Query(query_params):
::axum::extract::Query<#struct_ident>
};
let inits: Vec<TokenStream> = op
.query_params
.iter()
.map(|p| {
let field_ident = format_ident!("{}", p.name.to_snake_case());
quote! { #field_ident: query_params.#field_ident, }
})
.collect();
(query_struct, Some(extractor), inits)
}
fn build_body_extractor(op: &OperationInfo) -> (Option<TokenStream>, Option<TokenStream>) {
op.body.as_ref().map_or_else(
|| (None, None),
|body| {
let ty = &body.rust_type;
let extractor = quote! {
::axum::extract::Json(body):
::axum::extract::Json<#ty>
};
let field_init = if body.required {
quote! { body, }
} else {
quote! { body: ::core::option::Option::Some(body), }
};
(Some(extractor), Some(field_init))
},
)
}