use proc_macro2::{Ident, Span, TokenStream};
use syn::{Error, FnArg, ImplItem, Result, Signature, Type, TypeParamBound, ItemImpl};
use quote::quote;
pub fn expand_middleware(mut mid_impl: ItemImpl) -> Result<TokenStream> {
let middleware_fn = remove_middleware_fn(&mut mid_impl)?;
let fn_def = Middleware::new(middleware_fn)?;
mid_impl.items.push(fn_def.def);
let middleware_ident = crate::util::parse_item_impl_ident(&mid_impl)?;
let middleware_name = middleware_ident.to_string();
let mod_ident = Ident::new(&format!("macro_{}", &middleware_name), Span::call_site());
let fn_ident = fn_def.fn_ident;
Ok(quote! {
#mid_impl
mod #mod_ident {
use super::*;
use athene::prelude::*;
#[athene::async_trait::async_trait]
impl Middleware for #middleware_ident {
async fn next(&'static self, ctx: Context, chain: &'static dyn Next) -> Result<Context, Error>{
self.#fn_ident(ctx, chain).await
}
}
}
})
}
fn remove_middleware_fn(input: &mut ItemImpl) -> Result<ImplItem> {
let middleware_fn_pos = input.items.iter().position(|item| {
if let ImplItem::Method(method) = item {
return method.sig.ident == "next";
}
false
}).ok_or_else(|| Error::new_spanned(&input, "No method `next` found in the impl section of the middleware.\nMake sure the impl block contains a fn with the following signature:\n `async fn next(&self, _: HttpContext, _: &dyn MiddlewareChain) -> Result<HttpContext, SaphirError>`"))?;
let mid_fn = input.items.remove(middleware_fn_pos);
Ok(mid_fn)
}
pub struct Middleware {
pub def: ImplItem,
pub fn_ident: Ident,
}
impl Middleware {
pub fn new(middleware_fn: ImplItem) -> Result<Self> {
let mut m = if let ImplItem::Method(m) = middleware_fn {
m
} else {
return Err(Error::new_spanned(middleware_fn, "The token named next is not method"));
};
check_signature(&m.sig)?;
let fn_ident = Ident::new(&format!("{}_wrapped", m.sig.ident), Span::call_site());
m.sig.ident = fn_ident.clone();
Ok(Self {
def: ImplItem::Method(m),
fn_ident,
})
}
}
fn check_signature(m: &Signature) -> Result<()> {
if m.asyncness.is_none() {
return Err(Error::new_spanned(m, "Invalid function signature, the middleware function should be async"));
}
if m.inputs.len() != 3 {
return Err(Error::new_spanned(
m,
"Invalid middleware function input parameters.\nExpected the following parameters:\n (&self, _: Context, _: &dyn Next)",
));
}
let mut input_args = m.inputs.iter();
match input_args.next().expect("len was checked above") {
FnArg::Receiver(_) => {}
arg => {
return Err(Error::new_spanned(arg, "Invalid 1st parameter, expected `&self`"));
}
}
let arg2 = input_args.next().expect("len was checked above");
let passed = match arg2 {
FnArg::Typed(t) => {
if let Type::Path(pt) = &*t.ty {
pt.path
.segments
.first()
.ok_or_else(|| Error::new_spanned(&t.ty, "Unexpected type"))?
.ident
.to_string()
.eq("Context")
} else {
false
}
}
_ => false,
};
if !passed {
return Err(Error::new_spanned(arg2, "Invalid parameter, expected `Context`"));
}
let arg3 = input_args.next().expect("len was checked above");
let passed = match arg3 {
FnArg::Typed(t) => {
if let Type::Reference(tr) = &*t.ty {
if let Type::TraitObject(to) = &*tr.elem {
if let TypeParamBound::Trait(bo) = to.bounds.first().ok_or_else(|| Error::new_spanned(&t.ty, "Unexpected type"))? {
bo.path
.segments
.first()
.ok_or_else(|| Error::new_spanned(&t.ty, "Unexpected type"))?
.ident
.to_string()
.eq("Next")
} else {
false
}
} else if let Type::Path(pt) = &*tr.elem {
pt.path
.segments
.first()
.ok_or_else(|| Error::new_spanned(&t.ty, "Unexpected type"))?
.ident
.to_string()
.eq("Next")
} else {
false
}
} else {
false
}
}
_ => false,
};
if !passed {
return Err(Error::new_spanned(arg3, "Invalid parameter, expected `&dyn Next`"));
}
Ok(())
}