athene_macro 0.3.5

Macro generation for athene
Documentation
use proc_macro2::{Ident, Span, TokenStream};
use syn::{Error, FnArg, ImplItem, Result, Signature, Type, TypeParamBound, ItemImpl};
use quote::quote;

pub fn expand_middleware(mut mid_impl: ItemImpl) -> Result<TokenStream> {
    let middleware_fn = remove_middleware_fn(&mut mid_impl)?;

    let fn_def = Middleware::new(middleware_fn)?;
    mid_impl.items.push(fn_def.def);

    let middleware_ident = crate::util::parse_item_impl_ident(&mid_impl)?;
    let middleware_name = middleware_ident.to_string();

    let mod_ident = Ident::new(&format!("macro_{}", &middleware_name), Span::call_site());
    let fn_ident = fn_def.fn_ident;

    Ok(quote! {
        #mid_impl

        mod #mod_ident {
            use super::*;
            use athene::prelude::*;
            
            #[athene::async_trait::async_trait]
            impl Middleware for #middleware_ident {
                async fn next(&'static self, ctx: Context, chain: &'static dyn Next) -> Result<Context, Error>{
                    self.#fn_ident(ctx, chain).await
                }
            }
        }
    })
}

fn remove_middleware_fn(input: &mut ItemImpl) -> Result<ImplItem> {
    let middleware_fn_pos = input.items.iter().position(|item| {
        if let ImplItem::Method(method) = item {
            return method.sig.ident == "next";
        }

        false
    }).ok_or_else(|| Error::new_spanned(&input, "No method `next` found in the impl section of the middleware.\nMake sure the impl block contains a fn with the following signature:\n `async fn next(&self, _: HttpContext, _: &dyn MiddlewareChain) -> Result<HttpContext, SaphirError>`"))?;

    let mid_fn = input.items.remove(middleware_fn_pos);

    Ok(mid_fn)
}

pub struct Middleware {
    pub def: ImplItem,
    pub fn_ident: Ident,
}

impl Middleware {
    pub fn new(middleware_fn: ImplItem) -> Result<Self> {
        let mut m = if let ImplItem::Method(m) = middleware_fn {
            m
        } else {
            return Err(Error::new_spanned(middleware_fn, "The token named next is not method"));
        };

        check_signature(&m.sig)?;

        let fn_ident = Ident::new(&format!("{}_wrapped", m.sig.ident), Span::call_site());
        m.sig.ident = fn_ident.clone();

        Ok(Self {
            def: ImplItem::Method(m),
            fn_ident,
        })
    }
}

fn check_signature(m: &Signature) -> Result<()> {
    if m.asyncness.is_none() {
        return Err(Error::new_spanned(m, "Invalid function signature, the middleware function should be async"));
    }

    if m.inputs.len() != 3 {
        return Err(Error::new_spanned(
            m,
            "Invalid middleware function input parameters.\nExpected the following parameters:\n (&self, _: Context, _: &dyn Next)",
        ));
    }

    let mut input_args = m.inputs.iter();

    match input_args.next().expect("len was checked above") {
        FnArg::Receiver(_) => {}
        arg => {
            return Err(Error::new_spanned(arg, "Invalid 1st parameter, expected `&self`"));
        }
    }

    let arg2 = input_args.next().expect("len was checked above");
    let passed = match arg2 {
        FnArg::Typed(t) => {
            if let Type::Path(pt) = &*t.ty {
                pt.path
                    .segments
                    .first()
                    .ok_or_else(|| Error::new_spanned(&t.ty, "Unexpected type"))?
                    .ident
                    .to_string()
                    .eq("Context")
            } else {
                false
            }
        }
        _ => false,
    };

    if !passed {
        return Err(Error::new_spanned(arg2, "Invalid  parameter, expected `Context`"));
    }

    let arg3 = input_args.next().expect("len was checked above");
    let passed = match arg3 {
        FnArg::Typed(t) => {
            if let Type::Reference(tr) = &*t.ty {
                if let Type::TraitObject(to) = &*tr.elem {
                    if let TypeParamBound::Trait(bo) = to.bounds.first().ok_or_else(|| Error::new_spanned(&t.ty, "Unexpected type"))? {
                        bo.path
                            .segments
                            .first()
                            .ok_or_else(|| Error::new_spanned(&t.ty, "Unexpected type"))?
                            .ident
                            .to_string()
                            .eq("Next")
                    } else {
                        false
                    }
                } else if let Type::Path(pt) = &*tr.elem {
                    pt.path
                        .segments
                        .first()
                        .ok_or_else(|| Error::new_spanned(&t.ty, "Unexpected type"))?
                        .ident
                        .to_string()
                        .eq("Next")
                } else {
                    false
                }
            } else {
                false
            }
        }
        _ => false,
    };

    if !passed {
        return Err(Error::new_spanned(arg3, "Invalid parameter, expected `&dyn Next`"));
    }

    Ok(())
}