many-macros 0.1.0

Procedural macros to support create MANY modules.
Documentation
use inflections::Inflect;
use proc_macro2::{Ident, Span, TokenStream};
use quote::{quote, quote_spanned};
use serde::Deserialize;
use serde_tokenstream::from_tokenstream;
use syn::spanned::Spanned;
use syn::PathArguments::AngleBracketed;
use syn::{
    AngleBracketedGenericArguments, FnArg, GenericArgument, PatType, ReturnType, Signature,
    TraitItem, Type, TypePath,
};

#[derive(Deserialize)]
struct ManyModuleAttributes {
    pub id: Option<u32>,
    pub name: Option<String>,
    pub namespace: Option<String>,
    pub many_crate: Option<String>,
}

#[derive(Debug)]
struct Endpoint {
    pub name: String,
    pub func: Ident,
    pub span: Span,
    pub is_async: bool,
    pub is_mut: bool,
    pub has_sender: bool,
    pub arg_type: Option<Box<Type>>,
    #[allow(unused)]
    pub ret_type: Box<Type>,
}

impl Endpoint {
    pub fn new(signature: &Signature) -> Result<Self, (String, Span)> {
        let func = signature.ident.clone();
        let name = func.to_string();
        let is_async = signature.asyncness.is_some();

        let mut has_sender = false;
        let arg_type: Option<Box<Type>>;
        let mut ret_type: Option<Box<Type>> = None;

        let mut inputs = signature.inputs.iter();
        let receiver = inputs.next().ok_or_else(|| {
            (
                "Must have at least 1 argument".to_string(),
                signature.span(),
            )
        })?;
        let is_mut = if let FnArg::Receiver(r) = receiver {
            r.mutability.is_some()
        } else {
            return Err((
                "Function in trait must have a receiver".to_string(),
                receiver.span(),
            ));
        };

        let maybe_identity = inputs.next();
        let maybe_argument = inputs.next();

        match (maybe_identity, maybe_argument) {
            (_id, Some(FnArg::Typed(PatType { ty, .. }))) => {
                has_sender = true;
                arg_type = Some(ty.clone());
            }
            (Some(FnArg::Typed(PatType { ty, .. })), None) => {
                arg_type = Some(ty.clone());
            }
            (None, None) => {
                arg_type = None;
            }
            (_, _) => {
                return Err(("Must have 2 or 3 arguments".to_string(), signature.span()));
            }
        }

        if let ReturnType::Type(_, ty) = &signature.output {
            if let Type::Path(TypePath {
                path: syn::Path { segments, .. },
                ..
            }) = ty.as_ref()
            {
                if segments[0].ident == "Result"
                    || segments
                        .iter()
                        .map(|x| x.ident.to_string())
                        .collect::<Vec<String>>()
                        .join("::")
                        == "std::result::Result"
                {
                    if let AngleBracketed(AngleBracketedGenericArguments { ref args, .. }) =
                        segments[0].arguments
                    {
                        ret_type = Some(
                            args.iter()
                                .find_map(|x| match x {
                                    GenericArgument::Type(t) => Some(Box::new(t.clone())),
                                    _ => None,
                                })
                                .unwrap(),
                        );
                    }
                }
            }
        }

        if ret_type.is_none() {
            return Err((
                "Must have a result return type.".to_string(),
                signature.output.span(),
            ));
        }

        Ok(Self {
            name,
            func,
            span: signature.span(),
            is_async,
            is_mut,
            has_sender,
            arg_type,
            ret_type: ret_type.unwrap(),
        })
    }
}

#[allow(clippy::too_many_lines)]
fn many_module_impl(attr: &TokenStream, item: TokenStream) -> Result<TokenStream, syn::Error> {
    let attrs: ManyModuleAttributes = from_tokenstream(attr)?;
    let many = Ident::new(
        attrs.many_crate.as_ref().map_or("many", String::as_str),
        attr.span(),
    );

    let namespace = attrs.namespace;
    let span = item.span();
    let tr: syn::ItemTrait = syn::parse2(item)
        .map_err(|_| syn::Error::new(span, "`many_module` only applies to traits.".to_string()))?;

    let struct_name = attrs.name.clone().unwrap_or_else(|| tr.ident.to_string());
    let struct_ident = Ident::new(
        struct_name.as_str(),
        attrs
            .name
            .as_ref()
            .map_or_else(|| attr.span(), |_| tr.ident.span()),
    );

    let mut trait_ = tr.clone();

    if attrs.name.is_none() {
        trait_.ident = Ident::new(&format!("{}Backend", struct_name), tr.ident.span());
    }
    let trait_ident = trait_.ident.clone();

    let vis = trait_.vis.clone();

    let attr_id = attrs.id.iter();
    let attr_name =
        inflections::Inflect::to_constant_case(format!("{}Attribute", struct_name).as_str());
    let attr_ident = Ident::new(&attr_name, attr.span());

    let info_name = format!("{}Info", struct_name);
    let info_ident = Ident::new(&info_name, attr.span());

    let endpoints: Result<Vec<_>, (String, Span)> = trait_
        .items
        .iter()
        .filter_map(|item| match item {
            TraitItem::Method(m) => Some(m),
            _ => None,
        })
        .map(|item| Endpoint::new(&item.sig))
        .collect();
    let endpoints = endpoints.map_err(|(msg, span)| syn::Error::new(span, msg))?;
    let ns = namespace.clone();
    let endpoint_strings: Vec<String> = endpoints
        .iter()
        .map(move |e| {
            let name = e.name.as_str().to_camel_case();
            match ns {
                Some(ref namespace) => format!("{}.{}", namespace, name),
                None => name,
            }
        })
        .collect();

    let ns = namespace.clone();
    let validate_endpoint_pat = endpoints.iter().map(|e| {
        let span = e.span;
        let name = e.name.as_str().to_camel_case();
        let ep = match ns {
            Some(ref namespace) => format!("{}.{}", namespace, name),
            None => name,
        };

        if let Some(ty) = &e.arg_type {
            quote_spanned! { span =>
                #ep => {
                    minicbor::decode::<'_, #ty>(data)
                        .map_err(|e| ManyError::deserialization_error(e.to_string()))?;
                }
            }
        } else {
            quote! {
                #ep => {}
            }
        }
    });
    let validate = quote! {
        fn validate(&self, message: & #many ::message::RequestMessage) -> Result<(),  #many ::ManyError> {
            let method = message.method.as_str();
            let data = message.data.as_slice();
            match method {
                #(#validate_endpoint_pat)*

                _ => return Err( #many ::ManyError::invalid_method_name(method.to_string())),
            };
            Ok(())
        }
    };

    let ns = namespace;
    let execute_endpoint_pat = endpoints.iter().map(|e| {
        let span = e.span;
        let name = e.name.as_str().to_camel_case();
        let ep = match ns {
            Some(ref namespace) => format!("{}.{}", namespace, name),
            None => name,
        };
        let ep_ident = &e.func;

        let backend_decl = if e.is_mut {
            quote! { let mut backend = self.backend.lock().unwrap(); }
        } else {
            quote! { let backend = self.backend.lock().unwrap(); }
        };

        let call = match (e.has_sender, e.arg_type.is_some(), e.is_async) {
            (false, true, false) => quote_spanned! { span => encode( backend . #ep_ident ( decode( data )? ) ) },
            (false, true, true) => quote_spanned! { span => encode( backend . #ep_ident ( decode( data )? ).await ) },
            (true, true, false) => quote_spanned! { span => encode( backend . #ep_ident ( &message.from.unwrap_or_default(), decode( data )? ) ) },
            (true, true, true) => quote_spanned! { span => encode( backend . #ep_ident ( &message.from.unwrap_or_default(), decode( data )? ).await ) },
            (false, false, false) => quote_spanned! { span => encode( backend . #ep_ident ( ) ) },
            (false, false, true) => quote_spanned! { span => encode( backend . #ep_ident ( ).await ) },
            (true, false, false) => quote_spanned! { span => encode( backend . #ep_ident ( &message.from.unwrap_or_default() ) ) },
            (true, false, true) => quote_spanned! { span => encode( backend . #ep_ident ( &message.from.unwrap_or_default() ).await ) },
        };

        quote_spanned! { span =>
            #ep => {
                #backend_decl
                #call
            }
        }
    });
    let execute = quote! {
        async fn execute(
            &self,
            message:  #many ::message::RequestMessage,
        ) -> Result< #many ::message::ResponseMessage,  #many ::ManyError> {
            use  #many ::ManyError;
            fn decode<'a, T: minicbor::Decode<'a>>(data: &'a [u8]) -> Result<T, ManyError> {
                minicbor::decode(data).map_err(|e| ManyError::deserialization_error(e.to_string()))
            }
            fn encode<T: minicbor::Encode>(result: Result<T, ManyError>) -> Result<Vec<u8>, ManyError> {
                minicbor::to_vec(result?).map_err(|e| ManyError::serialization_error(e.to_string()))
            }

            let data = message.data.as_slice();
            let result = match message.method.as_str() {
                #( #execute_endpoint_pat )*

                _ => Err(ManyError::internal_server_error()),
            }?;

            Ok( #many ::message::ResponseMessage::from_request(
                &message,
                &message.to,
                Ok(result),
            ))
        }
    };

    let attribute = if attrs.id.is_some() {
        quote! { Some(#attr_ident) }
    } else {
        quote! { None }
    };

    Ok(quote! {
        #( #vis const #attr_ident:  #many ::protocol::Attribute =  #many ::protocol::Attribute::id(#attr_id); )*

        #vis struct #info_ident;
        impl std::ops::Deref for #info_ident {
            type Target =  #many ::server::module::ManyModuleInfo;

            fn deref(&self) -> & #many ::server::module::ManyModuleInfo {
                use  #many ::server::module::ManyModuleInfo;
                static ONCE: std::sync::Once = std::sync::Once::new();
                static mut VALUE: *mut ManyModuleInfo = 0 as *mut ManyModuleInfo;

                unsafe {
                    ONCE.call_once(|| VALUE = Box::into_raw(Box::new(ManyModuleInfo {
                        name: #struct_name .to_string(),
                        attribute: #attribute,
                        endpoints: vec![ #( #endpoint_strings .to_string() ),* ],
                    })));
                    &*VALUE
                }
            }
        }

        #[async_trait::async_trait]
        #trait_

        #vis struct #struct_ident<T: #trait_ident> {
            backend: std::sync::Arc<std::sync::Mutex<T>>
        }

        impl<T: #trait_ident> std::fmt::Debug for #struct_ident<T> {
            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
                f.debug_struct(#struct_name).finish()
            }
        }

        impl<T: #trait_ident> #struct_ident<T> {
            pub fn new(backend: std::sync::Arc<std::sync::Mutex<T>>) -> Self {
                Self { backend }
            }
        }

        #[async_trait::async_trait]
        impl<T: #trait_ident>  #many ::ManyModule for #struct_ident<T> {
            fn info(&self) -> & #many ::server::module::ManyModuleInfo {
                & #info_ident
            }

            #validate

            #execute
        }
    })
}

#[proc_macro_attribute]
pub fn many_module(
    attr: proc_macro::TokenStream,
    item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
    many_module_impl(&attr.into(), item.into())
        .unwrap_or_else(|e| e.to_compile_error())
        .into()
}