tractor-macros 0.0.6

Proc macros for `tractor`
Documentation
use heck::CamelCase;
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, ToTokens};
use syn::{
    parse::{Error, Parse, ParseStream, Result},
    parse_macro_input, ItemImpl, Token,
};

// The `impl MyActor` the `#[actor]` or `#[async_actor]` attribute was specified on.
struct ActorDef(ItemImpl);

impl Parse for ActorDef {
    fn parse(input: ParseStream) -> Result<Self> {
        let lookahead = input.lookahead1();
        if lookahead.peek(Token![impl]) {
            let item: ItemImpl = input.parse()?;
            if item.trait_.is_some() {
                return Err(Error::new(Span::call_site(), "expected non-trait impl"));
            }
            Ok(Self(item))
        } else {
            Err(lookahead.error())
        }
    }
}

impl ToTokens for ActorDef {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        self.0.to_tokens(tokens);
    }
}

struct MsgType<'a>(&'a ItemImpl);

impl<'a> ToTokens for MsgType<'a> {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        let actor_ty_ident = actor_type_ident(&self.0.self_ty);
        let actor_msg_ident = format_ident!("{}Msg", actor_ty_ident);

        let msg_cases = self.0.items.iter().map(|item| match item {
            syn::ImplItem::Method(method) => method_to_enum_case(method),
            _ => panic!("Method definition expected"),
        });

        let expanded = quote! {
            pub enum #actor_msg_ident {
                #(#msg_cases),*
            }
        };

        expanded.to_tokens(tokens);
    }
}

// This will generate the `impl Actor for MyActor`
struct ImplActor<'a>(&'a ItemImpl);

impl<'a> ToTokens for ImplActor<'a> {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        let actor_ty_ident = actor_type_ident(&self.0.self_ty);
        let actor_msg_ident = format_ident!("{}Msg", actor_ty_ident);

        let pat_cases = self.0.items.iter().map(|item| match item {
            syn::ImplItem::Method(method) => {
                method_to_message_pattern_match(&actor_msg_ident, method, false)
            }
            _ => panic!("Only methods supported in #[actor]"),
        });

        let expanded = quote! {
            impl ::tractor::Actor for #actor_ty_ident {
                type Msg = #actor_msg_ident;
            }

            impl ::tractor::ActorBehavior for #actor_ty_ident {
                fn handle(&mut self, msg: #actor_msg_ident, ctx: &::tractor::Context<Self>) {
                    match msg {
                        #(#pat_cases),*
                    }
                }
            }
        };

        expanded.to_tokens(tokens);
    }
}

// This will generate the `impl AsyncActor for MyActor`
struct ImplAsyncActor<'a>(&'a ItemImpl);

impl<'a> ToTokens for ImplAsyncActor<'a> {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        let actor_ty_ident = actor_type_ident(&self.0.self_ty);
        let actor_msg_ident = format_ident!("{}Msg", actor_ty_ident);

        let pat_cases = self.0.items.iter().map(|item| match item {
            syn::ImplItem::Method(method) => {
                method_to_message_pattern_match(&actor_msg_ident, method, true)
            }
            _ => panic!("Only methods supported in #[actor]"),
        });

        let expanded = quote! {
            impl ::tractor::Actor for #actor_ty_ident {
                type Msg = #actor_msg_ident;
            }

            #[::async_trait::async_trait]
            impl ::tractor::ActorBehaviorAsync for #actor_ty_ident {
                async fn handle(&mut self, msg: #actor_msg_ident, ctx: &::tractor::Context<Self>) {
                    match msg {
                        #(#pat_cases),*
                    }
                }
            }
        };

        expanded.to_tokens(tokens);
    }
}

// This will generate the empty `impl ActorHooks for MyActor {}`
struct ImplActorHooks<'a>(&'a ItemImpl);

impl<'a> ToTokens for ImplActorHooks<'a> {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        let actor_ty_ident = actor_type_ident(&self.0.self_ty);

        let expanded = quote! {
            impl ::tractor::ActorHooks for #actor_ty_ident {
            }
        };

        expanded.to_tokens(tokens);
    }
}
//
// This will generate the empty `impl ActorHooksAsync for MyActor {}`
struct ImplActorHooksAsync<'a>(&'a ItemImpl);

impl<'a> ToTokens for ImplActorHooksAsync<'a> {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        let actor_ty_ident = actor_type_ident(&self.0.self_ty);

        let expanded = quote! {
            impl ::tractor::ActorHooksAsync for #actor_ty_ident {
            }
        };

        expanded.to_tokens(tokens);
    }
}

// This will generate the code for `trait MyActorRef : Channel<MyActorMsg> { .. }`
struct ActorRefTrait<'a>(&'a ItemImpl);

impl<'a> ToTokens for ActorRefTrait<'a> {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        let actor_ty_ident = actor_type_ident(&self.0.self_ty);
        let actor_msg_ident = format_ident!("{}Msg", actor_ty_ident);
        let actor_ref_trait = format_ident!("{}Ref", actor_ty_ident);

        let trait_methods = self.0.items.iter().map(|item| match item {
            syn::ImplItem::Method(method) => method_to_trait_method(&actor_msg_ident, method),
            _ => panic!("Only methods supported in #[actor]"),
        });

        let expanded = quote! {
            pub trait #actor_ref_trait : ::tractor::Address<#actor_ty_ident> {
                #(#trait_methods)*
            }
            impl #actor_ref_trait for Addr<#actor_ty_ident> {}
        };

        expanded.to_tokens(tokens);
    }
}

fn should_derive_hooks(args: proc_macro::TokenStream) -> bool {
    args.into_iter().any(|arg| match arg {
        proc_macro::TokenTree::Ident(ident) => ident.to_string() == "derive_hooks",
        _ => false,
    })
}

#[proc_macro_attribute]
pub fn actor(
    args: proc_macro::TokenStream,
    input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
    let mut tokens = TokenStream::new();

    let actor_def = parse_macro_input!(input as ActorDef);
    actor_def.to_tokens(&mut tokens);

    MsgType(&actor_def.0).to_tokens(&mut tokens);
    ImplActor(&actor_def.0).to_tokens(&mut tokens);
    ActorRefTrait(&actor_def.0).to_tokens(&mut tokens);

    if should_derive_hooks(args) {
        ImplActorHooks(&actor_def.0).to_tokens(&mut tokens);
    }

    proc_macro::TokenStream::from(tokens)
}

#[proc_macro_attribute]
pub fn async_actor(
    args: proc_macro::TokenStream,
    input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
    let mut tokens = TokenStream::new();

    let actor_def = parse_macro_input!(input as ActorDef);
    actor_def.to_tokens(&mut tokens);

    MsgType(&actor_def.0).to_tokens(&mut tokens);
    ImplAsyncActor(&actor_def.0).to_tokens(&mut tokens);
    ActorRefTrait(&actor_def.0).to_tokens(&mut tokens);

    if should_derive_hooks(args) {
        ImplActorHooks(&actor_def.0).to_tokens(&mut tokens);
    }

    proc_macro::TokenStream::from(tokens)
}

fn actor_type_ident(ty: &syn::Type) -> syn::Ident {
    match ty {
        syn::Type::Path(type_path) => type_path.path.get_ident().cloned().unwrap(),
        _ => panic!(),
    }
}

fn method_to_enum_case(method: &syn::ImplItemMethod) -> proc_macro2::TokenStream {
    match method.sig.output {
        syn::ReturnType::Default => {}
        _ => panic!("Behaviors have to return ()"),
    }

    let method_name = method.sig.ident.to_string();
    let enum_case_ident = format_ident!("{}", method_name.as_str().to_camel_case());

    let mut inputs_iter = method.sig.inputs.iter();
    match inputs_iter.next() {
        Some(syn::FnArg::Receiver(recv)) if recv.reference.is_some() => {}
        _ => panic!("Behaviors require &self or &mut self receiver"),
    }

    let args = inputs_iter.map(|arg| match arg {
        syn::FnArg::Typed(pat_type) => pat_type,
        _ => panic!(""),
    });

    quote! {
        #enum_case_ident { #(#args),* }
    }
}

// This generates something like:
//
//     AdderMsg::Add {num: m_num} => self.add(m_num)
//
fn method_to_message_pattern_match(
    msg_type: &syn::Ident,
    method: &syn::ImplItemMethod,
    use_await: bool,
) -> proc_macro2::TokenStream {
    match method.sig.output {
        syn::ReturnType::Default => {}
        _ => panic!("Behaviors have to return ()"),
    }

    let method_ident = &method.sig.ident;
    let method_name = method_ident.to_string();
    let enum_case = format_ident!("{}", method_name.as_str().to_camel_case());

    let mut inputs_iter = method.sig.inputs.iter();
    match inputs_iter.next() {
        Some(syn::FnArg::Receiver(recv)) if recv.reference.is_some() => {}
        _ => panic!("Behaviors require &self or &mut self receiver"),
    }

    let args = inputs_iter
        .map(|arg| match arg {
            syn::FnArg::Typed(pat_type) => match pat_type.pat.as_ref() {
                syn::Pat::Ident(pat_ident) => (
                    pat_ident.ident.clone(),
                    format_ident!("m_{}", pat_ident.ident),
                ),
                _ => panic!("Ident required"),
            },
            _ => panic!("Requires syn::FnArg::Typed"),
        })
        .collect::<Vec<_>>();

    let patterns = args.iter().map(|(a, b)| quote! { #a : #b });
    let vars = args.iter().map(|(_a, b)| b);

    if use_await {
        quote! {
            #msg_type :: #enum_case { #(#patterns),* } => self . #method_ident (#(#vars),*) . await
        }
    } else {
        quote! {
            #msg_type :: #enum_case { #(#patterns),* } => self . #method_ident (#(#vars),*)
        }
    }
}

// This generates something like:
//
//     fn add(&self, num: usize) { self.send(AdderMsg::Add {num}) };
//
fn method_to_trait_method(
    msg_type: &syn::Ident,
    method: &syn::ImplItemMethod,
) -> proc_macro2::TokenStream {
    match method.sig.output {
        syn::ReturnType::Default => {}
        _ => panic!("Behaviors have to return ()"),
    }

    let method_ident = &method.sig.ident;
    let method_name = method_ident.to_string();
    let enum_case = format_ident!("{}", method_name.as_str().to_camel_case());

    let args_with_type: Vec<TokenStream> = {
        let mut inputs_iter = method.sig.inputs.iter();
        match inputs_iter.next() {
            Some(syn::FnArg::Receiver(recv)) if recv.reference.is_some() => {}
            _ => panic!("Behaviors require &self or &mut self receiver"),
        }

        inputs_iter
            .map(|arg| match arg {
                syn::FnArg::Typed(pat_type) => quote! { #pat_type },
                _ => panic!("Requires syn::FnArg::Typed"),
            })
            .collect()
    };

    let args: Vec<syn::Ident> = {
        let mut inputs_iter = method.sig.inputs.iter();
        match inputs_iter.next() {
            Some(syn::FnArg::Receiver(recv)) if recv.reference.is_some() => {}
            _ => panic!("Behaviors require &self or &mut self receiver"),
        }

        inputs_iter
            .map(|arg| match arg {
                syn::FnArg::Typed(pat_type) => match pat_type.pat.as_ref() {
                    syn::Pat::Ident(pat_ident) => pat_ident.ident.clone(),
                    _ => panic!("Ident required"),
                },
                _ => panic!("Requires syn::FnArg::Typed"),
            })
            .collect()
    };

    quote! {
        fn #method_ident (&self, #(#args_with_type),*) {
            self.send(#msg_type :: #enum_case { #(#args),* });
        }
    }
}