use std::collections::HashSet;
use proc_macro::TokenStream;
use proc_macro_error::{abort, proc_macro_error};
use quote::{quote, quote_spanned};
use syn::{parse_macro_input, spanned::Spanned, FnArg, Item, LitStr, Type};
#[proc_macro_error]
#[proc_macro_attribute]
pub fn route(args: TokenStream, input: TokenStream) -> TokenStream {
let endpoint = parse_macro_input!(args as LitStr);
let endpoint_val = endpoint.value();
let mut param_names = HashSet::new();
for segment in endpoint_val.split("/") {
if segment.starts_with(":") || segment.starts_with("*") {
if segment == "*" {
continue;
}
if !(param_names.insert(segment[1..].to_owned())) {
abort!(
endpoint.span(),
"Cannot have multiple named parameters with the same name";
help = "Rename or remove one of the parameters named `{}`", &segment[1..]
);
}
}
}
let input = parse_macro_input!(input as Item);
let func = match &input {
Item::Fn(f) => f,
_ => {
abort!(input.span(), "You can only use route on functions");
}
};
let name = &func.sig.ident;
let return_ty = &func.sig.output;
let block = &func.block;
let mut request_arg = None;
let mut params = Vec::new();
for arg in &func.sig.inputs {
if let FnArg::Typed(arg) = arg {
if let syn::Pat::Ident(ident) = arg.pat.as_ref() {
let mut arg_name = ident.ident.to_string();
if arg_name.starts_with("_") {
arg_name.remove(0);
}
if arg_name == "request" {
request_arg = Some(arg);
} else {
if !param_names.contains(&arg_name) {
abort!(
arg.span(), "Parameter `{}` not in endpoint", arg_name;
note = endpoint.span() => "Add `{}` to the endpoint", arg_name
);
}
let ty = &arg.ty;
let param_lit = LitStr::new(&arg_name, ident.ident.span());
let get_param = quote! {
gemfra::error::ToGemError::into_gem(params.find(#param_lit))?
};
if let Type::Reference(r) = ty.as_ref() {
if let Type::Path(path) = r.elem.as_ref() {
if let Some(segment) = path.path.segments.first() {
if segment.ident.to_string() == "str" {
params.push(quote_spanned! {arg.span()=>
let #ident: #ty = #get_param;
});
continue;
}
}
}
}
params.push(quote_spanned! {arg.span()=>
let #ident: #ty = gemfra::error::ToGemError::into_gem_type(
#get_param.parse(),
gemfra::error::GemErrorType::NotFound
)?;
});
}
}
}
}
let request_arg = match request_arg {
Some(v) => v,
None => {
abort!(func.sig.span(), "input `request` is a required parameter");
}
};
TokenStream::from(quote! {
#[allow(non_camel_case_types)]
struct #name;
#[async_trait::async_trait]
impl gemfra::routed::Route for #name {
fn endpoint(&self) -> &str {
#endpoint
}
async fn handle(&self, params: &gemfra::routed::Params, #request_arg) #return_ty {
#(#params)*
#block
}
}
})
}