mod doc_attr;
mod route_attr;
use crate::get_add_operation_fn_name;
use darling::FromMeta;
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::ToTokens;
use rocket_http::Method;
use std::collections::BTreeMap;
use syn::{AttributeArgs, FnArg, Ident, ItemFn, ReturnType, Type, TypeTuple};
#[derive(Debug, Default, FromMeta)]
#[darling(default)]
struct OpenApiAttribute {
pub skip: bool,
#[darling(multiple, rename = "tag")]
pub tags: Vec<String>,
}
pub fn parse(args: TokenStream, input: TokenStream) -> TokenStream {
let attr_args = parse_macro_input!(args as AttributeArgs);
let input = parse_macro_input!(input as ItemFn);
let okapi_attr = match OpenApiAttribute::from_list(&attr_args) {
Ok(v) => v,
Err(e) => {
return e.write_errors().into();
}
};
if okapi_attr.skip {
return create_empty_route_operation_fn(input);
}
match route_attr::parse_attrs(&input.attrs) {
Ok(route) => create_route_operation_fn(input, route, okapi_attr.tags),
Err(e) => e,
}
}
fn create_empty_route_operation_fn(route_fn: ItemFn) -> TokenStream {
let fn_name = get_add_operation_fn_name(&route_fn.sig.ident);
TokenStream::from(quote! {
pub fn #fn_name(
_gen: &mut ::rocket_okapi::gen::OpenApiGenerator,
_op_id: String,
) -> ::rocket_okapi::Result<()> {
Ok(())
}
})
}
fn create_route_operation_fn(
route_fn: ItemFn,
route: route_attr::Route,
tags: Vec<String>,
) -> TokenStream {
let arg_types = get_arg_types(route_fn.sig.inputs.into_iter());
let return_type = match route_fn.sig.output {
ReturnType::Type(_, ty) => *ty,
ReturnType::Default => unit_type(),
};
let request_body = match &route.data_param {
Some(arg) => {
let ty = match arg_types.get(arg) {
Some(ty) => ty,
None => return quote! {
compile_error!(concat!("Could not find argument ", #arg, " matching data param."));
}.into()
};
quote! {
Some(<#ty as ::rocket_okapi::request::OpenApiFromData>::request_body(gen)?.into())
}
}
None => quote! { None },
};
let mut params = Vec::new();
for arg in route.path_params() {
let ty = match arg_types.get(arg) {
Some(ty) => ty,
None => return quote! {
compile_error!(concat!("Could not find argument ", #arg, " matching path param."));
}
.into(),
};
params.push(quote! {
<#ty as ::rocket_okapi::request::OpenApiFromParam>::path_parameter(gen, #arg.to_owned())?.into()
})
}
for arg in route.query_params() {
let ty = match arg_types.get(arg) {
Some(ty) => ty,
None => return quote! {
compile_error!(concat!("Could not find argument ", #arg, " matching query param."));
}
.into(),
};
params.push(quote! {
<#ty as ::rocket_okapi::request::OpenApiFromFormField>::form_parameter(gen, #arg.to_owned(), true)?.into()
})
}
let mut params_nested_list = Vec::new();
for arg in route.query_multi_params() {
let ty = match arg_types.get(arg) {
Some(ty) => ty,
None => return quote! {
compile_error!(concat!("Could not find argument ", #arg, " matching query multi param."));
}.into(),
};
params_nested_list.push(quote! {
<#ty as ::rocket_okapi::request::OpenApiFromForm>::form_multi_parameter(gen, #arg.to_owned(), true)?.into()
})
}
let mut responses = Vec::new();
responses.push(quote! {
<#return_type as ::rocket_okapi::response::OpenApiResponder>::responses(gen)?
});
let data_param_arg = route.data_param.clone().unwrap_or_else(|| String::new());
for arg_type in arg_types {
let ty = arg_type.1;
let arg = arg_type.0;
if route
.path_params()
.find(|item| arg == item.to_string())
.is_none()
&& route
.query_params()
.find(|item| arg == item.to_string())
.is_none()
&& route.query_multi_params()
.find(|item| arg == item.to_string() )
.is_none()
&& data_param_arg != arg
{
params.push(quote! {
<#ty as ::rocket_okapi::request::OpenApiFromRequest>::request_input(gen, #arg.to_owned())?.into()
});
}
}
let fn_name = get_add_operation_fn_name(&route_fn.sig.ident);
let path = route
.origin
.path()
.as_str()
.replace("<", "{")
.replace(">", "}");
let method = Ident::new(&to_pascal_case_string(route.method), Span::call_site());
let (title, desc) = doc_attr::get_title_and_desc_from_doc(&route_fn.attrs);
let title = match title {
Some(x) => quote!(Some(#x.to_owned())),
None => quote!(None),
};
let desc = match desc {
Some(x) => quote!(Some(#x.to_owned())),
None => quote!(None),
};
let tags = tags
.into_iter()
.map(|tag| quote!(#tag.to_owned()))
.collect::<Vec<_>>();
TokenStream::from(quote! {
pub fn #fn_name(
gen: &mut ::rocket_okapi::gen::OpenApiGenerator,
op_id: String,
) -> ::rocket_okapi::Result<()> {
let responses = <#return_type as ::rocket_okapi::response::OpenApiResponder>::responses(gen)?;
let request_body = #request_body;
use ::rocket_okapi::request::RequestHeaderInput;
use ::okapi::openapi3::Parameter;
use ::okapi::openapi3::RefOr;
let request_inputs: Vec<RequestHeaderInput> = vec![#(#params),*];
let mut parameters: Vec<::okapi::openapi3::RefOr<Parameter>> = Vec::new();
use std::collections::BTreeMap;
let mut security_schemes = BTreeMap::new();
for inp in request_inputs {
match inp {
RequestHeaderInput::Parameter(p) => {
parameters.push(p.into());
}
RequestHeaderInput::Security(s) => {
security_schemes.insert(s.0.scheme_identifier.clone(), Vec::new());
gen.add_security_scheme(s.0.scheme_identifier.clone(), s.0.clone());
}
_ => {
}
}
}
let security = if security_schemes.is_empty() {
None
} else {
Some(vec![security_schemes])
};
let parameters_nested_list: Vec<Vec<::okapi::openapi3::Parameter>> = vec![#(#params_nested_list),*];
for inner_list in parameters_nested_list{
for item in inner_list{
parameters.push(item.into());
}
}
gen.add_operation(::rocket_okapi::OperationInfo {
path: #path.to_owned(),
method: ::rocket::http::Method::#method,
operation: ::okapi::openapi3::Operation {
operation_id: Some(op_id),
responses,
request_body,
parameters,
summary: #title,
description: #desc,
security,
tags: vec![#(#tags),*],
..::okapi::openapi3::Operation::default()
},
});
Ok(())
}
})
}
fn unit_type() -> Type {
Type::Tuple(TypeTuple {
paren_token: syn::token::Paren::default(),
elems: syn::punctuated::Punctuated::default(),
})
}
fn to_pascal_case_string(method: Method) -> String {
let (first_char, rest) = method.as_str().split_at(1);
let first_char = first_char.to_ascii_uppercase();
let rest = rest.to_ascii_lowercase();
format!("{}{}", first_char, rest)
}
fn get_arg_types(args: impl Iterator<Item = FnArg>) -> BTreeMap<String, Type> {
let mut result = BTreeMap::new();
for arg in args {
if let syn::FnArg::Typed(arg) = arg {
if let syn::Pat::Ident(ident) = *arg.pat {
let name = ident.ident.into_token_stream().to_string();
let ty = *arg.ty;
result.insert(name, ty);
}
}
}
result
}