use convert_case::{Case, Casing};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::parse::{Parse, ParseStream};
use syn::spanned::Spanned;
use syn::Token;
pub struct AbstractProcess {
args: Args,
item_impl: syn::ItemImpl,
arg_ty: syn::Type,
init: syn::ImplItemMethod,
terminate: Option<syn::ImplItemMethod>,
handle_link_trapped: Option<syn::ImplItemMethod>,
message_handlers: Vec<syn::ImplItemMethod>,
request_handlers: Vec<syn::ImplItemMethod>,
message_builder_ident: syn::Ident,
request_builder_ident: syn::Ident,
handler_trait_ident: syn::Ident,
}
impl AbstractProcess {
pub fn new(args: proc_macro::TokenStream, item: proc_macro::TokenStream) -> syn::Result<Self> {
let args: Args = syn::parse(args)?;
let mut item_impl: syn::ItemImpl = syn::parse(item)?;
let self_ident = match *item_impl.self_ty {
syn::Type::Path(ref ty_path) => match ty_path.path.segments.last() {
Some(segment) => segment.ident.clone(),
None => {
return Err(syn::Error::new(
item_impl.self_ty.span(),
"type is not supported",
))
}
},
_ => {
return Err(syn::Error::new(
item_impl.self_ty.span(),
"type is not supported",
))
}
};
let (init, terminate, handle_link_trapped, message_handlers, request_handlers) = item_impl
.items
.clone()
.into_iter()
.enumerate()
.filter_map(|(i, item)| match item {
syn::ImplItem::Method(impl_item_method) => Some((i, impl_item_method)),
_ => None,
})
.filter_map(|(i, mut impl_item_method)| {
let (j, item_attr) =
impl_item_method
.attrs
.iter()
.enumerate()
.find_map(|(i, attr)| {
attr.path
.get_ident()
.map(|ident| ident.to_string())
.and_then(|ident_string| ItemAttr::from_str(&ident_string))
.map(|item_attr| (i, item_attr))
})?;
impl_item_method.attrs.remove(j);
if let syn::ImplItem::Method(impl_item_method) = item_impl.items.get_mut(i).unwrap()
{
impl_item_method.attrs.remove(j);
}
Some((item_attr, impl_item_method))
})
.fold(
Ok((None, None, None, Vec::new(), Vec::new())),
|acc, (item_attr, impl_item_method)| {
let (
mut init,
mut terminate,
mut handle_link_trapped,
mut message_handlers,
mut request_handlers,
) = acc?;
match item_attr {
ItemAttr::Init => {
if init.is_some() {
return Err(syn::Error::new(
impl_item_method.sig.ident.span(),
"init method already defined",
));
}
init = Some(impl_item_method);
}
ItemAttr::Terminate => {
if terminate.is_some() {
return Err(syn::Error::new(
impl_item_method.sig.ident.span(),
"terminate method already defined",
));
}
terminate = Some(impl_item_method);
}
ItemAttr::HandleLinkTrapped => {
if handle_link_trapped.is_some() {
return Err(syn::Error::new(
impl_item_method.sig.ident.span(),
"handle_link_trapped method already defined",
));
}
handle_link_trapped = Some(impl_item_method);
}
ItemAttr::HandleMessage => {
message_handlers.push(impl_item_method);
}
ItemAttr::HandleRequest => {
request_handlers.push(impl_item_method);
}
}
Ok((
init,
terminate,
handle_link_trapped,
message_handlers,
request_handlers,
))
},
)?;
let init =
init.ok_or_else(|| syn::Error::new(item_impl.self_ty.span(), "missing init method"))?;
let arg_ty = match init
.sig
.inputs
.last()
.ok_or_else(|| syn::Error::new(init.sig.span(), "init must take 2 arguments"))?
{
syn::FnArg::Receiver(_) => {
return Err(syn::Error::new(init.sig.span(), "init cannot take `&self`"))
}
syn::FnArg::Typed(typed_arg) => *typed_arg.ty.clone(),
};
let message_builder_ident = format_ident!("{}MsgBuilder", self_ident);
let request_builder_ident = format_ident!("{}ReqBuilder", self_ident);
let handler_trait_ident = args
.trait_name
.as_ref()
.map(|trait_name| format_ident!("{}", trait_name.value()))
.unwrap_or_else(|| format_ident!("{}Handler", self_ident));
Ok(AbstractProcess {
args,
item_impl,
arg_ty,
init,
terminate,
handle_link_trapped,
message_handlers,
request_handlers,
message_builder_ident,
request_builder_ident,
handler_trait_ident,
})
}
pub fn expand(&self) -> TokenStream {
let handler_wrappers = self.expand_handler_wrappers();
let original_impl = self.expand_original_impl();
let impl_abstract_process = self.expand_impl_abstract_process();
let message_handler_impls = self.expand_message_handler_impls();
let request_handler_impls = self.expand_request_handler_impls();
let handler_trait = self.expand_handler_trait();
let impl_handler_trait = self.expand_impl_handler_trait();
let message_builders = self.expand_builders();
quote! {
#handler_wrappers
#original_impl
#impl_abstract_process
#message_handler_impls
#request_handler_impls
#handler_trait
#impl_handler_trait
#message_builders
}
}
fn expand_handler_wrappers(&self) -> TokenStream {
let wrappers = self
.message_handlers
.iter()
.chain(self.request_handlers.iter())
.map(|impl_item_method| self.expand_handler_wrapper(impl_item_method));
quote! {
#( #wrappers )*
}
}
fn expand_handler_wrapper(&self, impl_item_method: &syn::ImplItemMethod) -> TokenStream {
let ident = Self::handler_wrapper_ident(&impl_item_method.sig.ident);
let (_, ty_generics, _) = &self.item_impl.generics.split_for_impl();
let phantom_generics = &self.item_impl.generics.params;
let fields = filter_typed_args(impl_item_method.sig.inputs.iter()).map(|field| &*field.ty);
let phantom_field = if !self.item_impl.generics.params.is_empty() {
Some(quote! { std::marker::PhantomData <(#phantom_generics)>, })
} else {
None
};
quote! {
#[derive(serde::Serialize, serde::Deserialize)]
struct #ident #ty_generics (
#phantom_field
#( #fields ),*
);
}
}
fn expand_original_impl(&self) -> TokenStream {
let syn::ItemImpl {
attrs,
defaultness,
unsafety,
impl_token,
generics,
self_ty,
items,
..
} = &self.item_impl;
let (impl_generics, _ty_generics, where_clause) = generics.split_for_impl();
quote! {
#( #attrs )*
#defaultness #unsafety #impl_token #impl_generics #self_ty #where_clause {
#( #items )*
}
}
}
fn expand_impl_abstract_process(&self) -> TokenStream {
let syn::ItemImpl {
generics, self_ty, ..
} = &self.item_impl;
let (impl_generics, _ty_generics, where_clause) = generics.split_for_impl();
let arg_ty = &self.arg_ty;
let init_impl = self.expand_init_impl();
let terminate_impl = self.expand_terminate_impl();
let handle_link_trapped_impl = self.expand_handle_link_trapped_impl();
quote! {
impl #impl_generics lunatic::process::AbstractProcess for #self_ty #where_clause {
type State = #self_ty;
type Arg = #arg_ty;
#init_impl
#terminate_impl
#handle_link_trapped_impl
}
}
}
fn expand_init_impl(&self) -> TokenStream {
let ident = &self.init.sig.ident;
let arg_ty = &self.arg_ty;
quote! {
fn init(this: ProcessRef<Self>, arg: #arg_ty) -> Self::State {
Self::#ident(this, arg)
}
}
}
fn expand_terminate_impl(&self) -> TokenStream {
self.terminate
.as_ref()
.map(|terminate| {
let ident = &terminate.sig.ident;
quote! {
fn terminate(state: Self::State) {
state.#ident()
}
}
})
.unwrap_or_default()
}
fn expand_handle_link_trapped_impl(&self) -> TokenStream {
self.handle_link_trapped
.as_ref()
.map(|handle_link_trapped| {
let ident = &handle_link_trapped.sig.ident;
quote! {
fn handle_link_trapped(state: &mut Self::State, tag: lunatic::Tag) {
state.#ident(tag);
}
}
})
.unwrap_or_default()
}
fn expand_message_handler_impls(&self) -> TokenStream {
let message_handler_impls = self.message_handlers.iter().map(|message_handler| {
let syn::ImplItemMethod {
attrs,
sig,
..
} = message_handler;
let self_ty = &self.item_impl.self_ty;
let message_type = Self::handler_wrapper_ident(&sig.ident);
let fn_ident = &sig.ident;
let (impl_generics, ty_generics, where_clause) = self.item_impl.generics.split_for_impl();
let args = filter_typed_args(sig.inputs.iter());
let offset = if self.item_impl.generics.params.is_empty() { 0 } else { 1 };
let message_fields = (offset..args.count() + offset).map(|i| {
let i = proc_macro2::Literal::usize_unsuffixed(i);
quote! { message. #i }
});
quote! {
#( #attrs )*
impl #impl_generics lunatic::process::MessageHandler<#message_type #ty_generics> for #self_ty #where_clause {
fn handle(state: &mut Self::State, message: #message_type #ty_generics) {
state.#fn_ident(#( #message_fields ),*)
}
}
}
});
quote! {
#( #message_handler_impls )*
}
}
fn expand_request_handler_impls(&self) -> TokenStream {
let request_handler_impls = self.request_handlers.iter().map(|request_handler| {
let syn::ImplItemMethod {
attrs,
sig,
..
} = request_handler;
let self_ty = &self.item_impl.self_ty;
let request_type = Self::handler_wrapper_ident(&sig.ident);
let response_type = match &sig.output {
syn::ReturnType::Type(_, ty) => quote! { #ty },
syn::ReturnType::Default => {
quote! { () }
}
};
let fn_ident = &sig.ident;
let (impl_generics, ty_generics, where_clause) = self.item_impl.generics.split_for_impl();
let args = filter_typed_args(sig.inputs.iter());
let offset = if self.item_impl.generics.params.is_empty() { 0 } else { 1 };
let request_fields = (offset..args.count() + offset).map(|i| {
let i = proc_macro2::Literal::usize_unsuffixed(i);
quote! { request. #i }
});
quote! {
#( #attrs )*
impl #impl_generics lunatic::process::RequestHandler<#request_type #ty_generics> for #self_ty #where_clause {
type Response = #response_type;
fn handle(state: &mut Self::State, request: #request_type #ty_generics) -> Self::Response {
state.#fn_ident(#( #request_fields ),*)
}
}
}
});
quote! {
#( #request_handler_impls )*
}
}
fn expand_handler_trait(&self) -> TokenStream {
let Self {
args,
item_impl,
message_handlers,
request_handlers,
message_builder_ident,
request_builder_ident,
handler_trait_ident,
..
} = self;
let vis = &args.visibility;
let (_impl_generics, ty_generics, where_clause) = item_impl.generics.split_for_impl();
let message_handler_defs = message_handlers
.iter()
.map(HandlerStructure::from_handler)
.map(|handler| {
let HandlerStructure {
attrs,
ident,
generics,
args,
..
} = handler;
quote! {
#( #attrs )*
fn #ident #generics (&self #(, #args )*);
}
});
let request_handler_defs = request_handlers
.iter()
.map(HandlerStructure::from_handler)
.map(|handler| {
let HandlerStructure {
attrs,
ident,
generics,
args,
return_ty,
..
} = handler;
quote! {
#( #attrs )*
fn #ident #generics (&self #(, #args )*) -> #return_ty;
}
});
quote! {
#vis trait #handler_trait_ident #ty_generics #where_clause {
#( #message_handler_defs )*
#( #request_handler_defs )*
fn after(&self, duration: std::time::Duration) -> #message_builder_ident #ty_generics;
fn with_timeout(&self, duration: std::time::Duration) -> #request_builder_ident #ty_generics;
}
}
}
fn expand_impl_handler_trait(&self) -> TokenStream {
let Self {
item_impl,
message_handlers,
request_handlers,
message_builder_ident,
request_builder_ident,
handler_trait_ident,
..
} = self;
let self_ty = &item_impl.self_ty;
let (impl_generics, ty_generics, where_clause) = item_impl.generics.split_for_impl();
let arg_phantom = if !item_impl.generics.params.is_empty() {
Some(quote! { std::marker::PhantomData, })
} else {
None
};
let message_handler_impls = message_handlers
.iter()
.map(HandlerStructure::from_handler)
.map(|handler| {
let HandlerStructure {
attrs,
ident,
generics,
args,
message_type,
handler_args,
..
} = handler;
quote! {
#( #attrs )*
fn #ident #generics (&self #(, #args )*) {
use lunatic::process::Message;
let msg = #message_type(#arg_phantom #( #handler_args ),*);
self.send(msg);
}
}
});
let request_handler_impls = request_handlers
.iter()
.map(HandlerStructure::from_handler)
.map(|handler| {
let HandlerStructure {
attrs,
ident,
generics,
args,
return_ty,
message_type,
handler_args,
} = handler;
quote! {
#( #attrs )*
fn #ident #generics (&self #(, #args )*) -> #return_ty {
use lunatic::process::Request;
let req = #message_type(#arg_phantom #( #handler_args ),*);
self.request(req)
}
}
});
quote! {
impl #impl_generics #handler_trait_ident #ty_generics for lunatic::process::ProcessRef<#self_ty> #where_clause {
#( #message_handler_impls )*
#( #request_handler_impls )*
fn after(&self, duration: std::time::Duration) -> #message_builder_ident #ty_generics {
#message_builder_ident::new(duration, self.clone())
}
fn with_timeout(&self, duration: std::time::Duration) -> #request_builder_ident #ty_generics {
#request_builder_ident::new(duration, self.clone())
}
}
}
}
fn expand_builders(&self) -> TokenStream {
let Self {
args,
item_impl,
message_handlers,
request_handlers,
message_builder_ident,
request_builder_ident,
..
} = self;
let vis = &args.visibility;
let (impl_generics, ty_generics, where_clause) = item_impl.generics.split_for_impl();
let self_ty = &item_impl.self_ty;
let arg_phantom = if !item_impl.generics.params.is_empty() {
Some(quote! { std::marker::PhantomData, })
} else {
None
};
let message_builder_methods = message_handlers
.iter()
.map(HandlerStructure::from_handler)
.map(|handler| {
let HandlerStructure {
attrs,
ident,
args,
message_type,
handler_args,
..
} = handler;
quote! {
#( #attrs )*
#vis fn #ident(&self #(, #args )*) {
use lunatic::process::Message;
let msg = #message_type(#arg_phantom #( #handler_args ),*);
self.process_ref.send_after(msg, self.duration);
}
}
});
let request_builder_methods = request_handlers
.iter()
.map(HandlerStructure::from_handler)
.map(|handler| {
let HandlerStructure {
attrs,
ident,
args,
return_ty,
message_type,
handler_args,
..
} = handler;
quote! {
#( #attrs )*
#vis fn #ident(&self #(, #args )*) -> lunatic::MailboxResult<#return_ty> {
use lunatic::process::Request;
let req = #message_type(#arg_phantom #( #handler_args ),*);
self.process_ref.request_timeout(req, self.duration)
}
}
});
quote! {
#vis struct #message_builder_ident #ty_generics #where_clause {
duration: std::time::Duration,
process_ref: lunatic::process::ProcessRef<#self_ty>,
}
impl #impl_generics #message_builder_ident #ty_generics #where_clause {
fn new(duration: std::time::Duration, process_ref: ProcessRef<#self_ty>) -> Self {
Self {
duration,
process_ref,
}
}
#( #message_builder_methods )*
}
#vis struct #request_builder_ident #ty_generics #where_clause {
duration: std::time::Duration,
process_ref: lunatic::process::ProcessRef<#self_ty>,
}
impl #impl_generics #request_builder_ident #ty_generics #where_clause {
fn new(duration: std::time::Duration, process_ref: ProcessRef<#self_ty>) -> Self {
Self {
duration,
process_ref,
}
}
#( #request_builder_methods )*
}
}
}
fn handler_wrapper_ident(ident: impl ToString) -> syn::Ident {
format_ident!("__MsgWrap{}", ident.to_string().to_case(Case::Pascal))
}
}
#[derive(Default)]
pub struct Args {
trait_name: Option<syn::LitStr>,
visibility: Option<syn::Visibility>,
}
impl Args {
fn parse_arg(&mut self, input: ParseStream) -> syn::Result<()> {
if input.is_empty() {
return Ok(());
}
let ident: syn::Ident = input.parse()?;
let _: syn::Token![=] = input.parse()?;
if ident == "trait_name" {
if self.trait_name.is_some() {
return Err(syn::Error::new(
ident.span(),
"trait name already specified",
));
}
self.trait_name = Some(input.parse()?);
} else if ident == "visibility" {
if self.visibility.is_some() {
return Err(syn::Error::new(
ident.span(),
"visibility already specified",
));
}
self.visibility = Some(input.parse()?);
} else {
return Err(syn::Error::new(ident.span(), "unknown argument"));
}
Ok(())
}
}
impl Parse for Args {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut args = Args::default();
while !input.is_empty() {
args.parse_arg(input)?;
let _: Option<Token![,]> = input.parse()?;
}
Ok(args)
}
}
enum ItemAttr {
Init,
Terminate,
HandleLinkTrapped,
HandleMessage,
HandleRequest,
}
impl ItemAttr {
fn from_str(s: &str) -> Option<ItemAttr> {
match s {
"init" => Some(ItemAttr::Init),
"terminate" => Some(ItemAttr::Terminate),
"handle_link_trapped" => Some(ItemAttr::HandleLinkTrapped),
"handle_message" => Some(ItemAttr::HandleMessage),
"handle_request" => Some(ItemAttr::HandleRequest),
_ => None,
}
}
}
fn filter_typed_args<'a>(
args: impl Iterator<Item = &'a syn::FnArg>,
) -> impl Iterator<Item = &'a syn::PatType> {
args.filter_map(|input| match input {
syn::FnArg::Receiver(_) => None,
syn::FnArg::Typed(pat_type) => Some(pat_type),
})
}
fn filter_typed_arg_names<'a>(
args: impl Iterator<Item = &'a syn::FnArg> + 'a,
) -> impl Iterator<Item = (syn::Ident, &'a syn::Type)> + 'a {
filter_typed_args(args)
.enumerate()
.map(|(i, arg)| match &*arg.pat {
syn::Pat::Ident(pat_ident) => (pat_ident.ident.clone(), &*arg.ty),
_ => (format_ident!("__arg_{i}"), &*arg.ty),
})
}
struct HandlerStructure<'a> {
attrs: &'a Vec<syn::Attribute>,
ident: &'a syn::Ident,
generics: &'a syn::Generics,
args: Vec<TokenStream>,
return_ty: syn::Type,
message_type: syn::Ident,
handler_args: Vec<syn::Ident>,
}
impl<'a> HandlerStructure<'a> {
fn from_handler(handler: &'a syn::ImplItemMethod) -> Self {
let syn::ImplItemMethod { attrs, sig, .. } = handler;
let syn::Signature {
ident,
inputs,
generics,
output,
..
} = &sig;
let args = filter_typed_arg_names(inputs.iter())
.map(|(ident, ty)| quote! { #ident: #ty })
.collect();
let return_ty = match output {
syn::ReturnType::Default => syn::parse_quote! { () },
syn::ReturnType::Type(_, ty) => *ty.clone(),
};
let message_type = AbstractProcess::handler_wrapper_ident(ident);
let handler_args = filter_typed_arg_names(inputs.iter())
.map(|(ident, _ty)| ident)
.collect();
HandlerStructure {
attrs,
ident,
generics,
args,
return_ty,
message_type,
handler_args,
}
}
}