use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{
Expr, ExprLit, FnArg, GenericArgument, Ident, ItemTrait, Lit, PathArguments, ReturnType,
TraitItem, Type, parse_macro_input,
};
#[proc_macro_attribute]
pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr = parse_macro_input!(attr as InterfaceAttr);
let item = parse_macro_input!(item as ItemTrait);
match service_impl(attr, item) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into(),
}
}
fn service_impl(attr: InterfaceAttr, item: ItemTrait) -> syn::Result<proc_macro2::TokenStream> {
let mut has_ref = false;
let mut has_mut_ref = false;
for trait_item in &item.items {
if let TraitItem::Fn(method) = trait_item
&& let Some(FnArg::Receiver(recv)) = method.sig.inputs.first()
{
if recv.mutability.is_some() {
has_mut_ref = true;
} else {
has_ref = true;
}
}
}
if has_ref && has_mut_ref {
return Err(syn::Error::new_spanned(
&item.ident,
"all methods must use `&self` receivers",
));
}
if has_mut_ref {
return Err(syn::Error::new_spanned(
&item.ident,
"`&mut self` methods (virtual actor mode) have been removed. Use `&self` for RPC services.",
));
}
interface_impl(attr, item)
}
struct MethodInfo {
index: u32,
name: Ident,
req_type: Type,
resp_type: Type,
}
fn interface_impl(attr: InterfaceAttr, item: ItemTrait) -> syn::Result<proc_macro2::TokenStream> {
let interface_id = attr.id;
let name = &item.ident;
let server_name = format_ident!("{}Server", name);
let client_name = format_ident!("{}Client", name);
let mut method_infos: Vec<MethodInfo> = Vec::new();
for (index, trait_item) in item.items.iter().enumerate() {
if let TraitItem::Fn(method) = trait_item {
let method_name = &method.sig.ident;
let (req_type, resp_type) = extract_method_types(&method.sig)?;
method_infos.push(MethodInfo {
index: (index + 1) as u32,
name: method_name.clone(),
req_type,
resp_type,
});
}
}
let method_count = method_infos.len() as u32;
let server_fields = method_infos.iter().map(|m| {
let name = &m.name;
let req_type = &m.req_type;
quote! { pub #name: moonpool_transport::RequestStream<#req_type, C> }
});
let server_inits: Vec<_> = method_infos
.iter()
.enumerate()
.map(|(i, m)| {
let name = &m.name;
let idx = m.index;
let is_last = i == method_infos.len() - 1;
if is_last {
quote! {
let (#name, _) = transport.register_handler_at(Self::INTERFACE_ID, #idx as u64, codec);
}
} else {
quote! {
let (#name, _) = transport.register_handler_at(Self::INTERFACE_ID, #idx as u64, codec.clone());
}
}
})
.collect();
let server_field_names: Vec<_> = method_infos.iter().map(|m| &m.name).collect();
let client_fields = method_infos.iter().map(|m| {
let name = &m.name;
let req_type = &m.req_type;
let resp_type = &m.resp_type;
quote! {
pub #name: moonpool_transport::ServiceEndpoint<#req_type, #resp_type, C>
}
});
let client_field_inits = method_infos.iter().map(|m| {
let name = &m.name;
let idx = m.index;
quote! {
#name: moonpool_transport::ServiceEndpoint::new(
moonpool_transport::Endpoint::new(
address.clone(),
moonpool_transport::UID::new(Self::INTERFACE_ID, #idx as u64),
),
codec.clone(),
)
}
});
let first_field_name = &method_infos[0].name;
let trait_vis = &item.vis;
let trait_items = &item.items;
let trait_name_snake = to_snake_case(&name.to_string());
let serve_close_handles: Vec<_> = method_infos
.iter()
.map(|m| {
let method_name = &m.name;
quote! {
let queue = self.#method_name.queue();
close_fns.push(Box::new(move || queue.close()));
}
})
.collect();
let serve_spawn_tasks: Vec<_> = method_infos
.iter()
.map(|m| {
let method_name = &m.name;
let resp_type = &m.resp_type;
let task_name = format!("{}_{}", trait_name_snake, m.name);
quote! {
{
let stream = self.#method_name;
let t = transport.clone();
let h = handler.clone();
providers.task().spawn_task(#task_name, async move {
while let Some((req, reply)) = stream
.recv_with_transport::<_, #resp_type>(&t)
.await
{
match h.#method_name(req).await {
Ok(resp) => reply.send(resp),
Err(e) => {
tracing::warn!(error = %e, method = #task_name, "handler error");
reply.send_error(moonpool_transport::ReplyError::BrokenPromise);
}
}
}
});
}
}
})
.collect();
let expanded = quote! {
#[async_trait::async_trait(?Send)]
#trait_vis trait #name {
#(#trait_items)*
}
pub struct #server_name<C: moonpool_transport::MessageCodec> {
#(#server_fields,)*
}
impl<C: moonpool_transport::MessageCodec + Clone> #server_name<C> {
pub const INTERFACE_ID: u64 = #interface_id;
pub const METHOD_COUNT: u32 = #method_count;
pub fn init<P>(transport: &std::rc::Rc<moonpool_transport::NetTransport<P>>, codec: C) -> Self
where
P: moonpool_transport::Providers,
{
#(#server_inits)*
Self { #(#server_field_names,)* }
}
pub fn serve<P, H>(
self,
transport: std::rc::Rc<moonpool_transport::NetTransport<P>>,
handler: std::rc::Rc<H>,
providers: &P,
) -> moonpool_transport::ServerHandle
where
P: moonpool_transport::Providers,
H: #name + 'static,
{
use moonpool_transport::TaskProvider as _;
let mut close_fns: Vec<Box<dyn Fn()>> = Vec::new();
#(#serve_close_handles)*
#(#serve_spawn_tasks)*
moonpool_transport::ServerHandle::new(close_fns)
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(bound(
serialize = "",
deserialize = "C: moonpool_transport::MessageCodec + Default",
))]
pub struct #client_name<C: moonpool_transport::MessageCodec> {
#(#client_fields,)*
}
impl<C: moonpool_transport::MessageCodec + Clone> #client_name<C> {
pub const INTERFACE_ID: u64 = #interface_id;
pub const METHOD_COUNT: u32 = #method_count;
pub fn new(address: moonpool_transport::NetworkAddress, codec: C) -> Self {
Self {
#(#client_field_inits,)*
}
}
pub fn address(&self) -> &moonpool_transport::NetworkAddress {
&self.#first_field_name.endpoint().address
}
}
};
Ok(expanded)
}
fn extract_method_types(sig: &syn::Signature) -> syn::Result<(Type, Type)> {
let mut inputs = sig.inputs.iter();
match inputs.next() {
Some(FnArg::Receiver(_)) => {}
_ => {
return Err(syn::Error::new_spanned(
sig,
"Interface method must have &self as first parameter",
));
}
}
let req_type = match inputs.next() {
Some(FnArg::Typed(pat_type)) => (*pat_type.ty).clone(),
_ => {
return Err(syn::Error::new_spanned(
sig,
"Interface method must have a request parameter: async fn name(&self, req: ReqType) -> Result<RespType, RpcError>",
));
}
};
let resp_type = match &sig.output {
ReturnType::Type(_, ty) => extract_result_ok_type(ty)?,
ReturnType::Default => {
return Err(syn::Error::new_spanned(
sig,
"Interface method must return Result<RespType, RpcError>",
));
}
};
Ok((req_type, resp_type))
}
fn extract_result_ok_type(ty: &Type) -> syn::Result<Type> {
if let Type::Path(type_path) = ty
&& let Some(segment) = type_path.path.segments.last()
&& segment.ident == "Result"
&& let PathArguments::AngleBracketed(args) = &segment.arguments
&& let Some(GenericArgument::Type(ok_type)) = args.args.first()
{
return Ok(ok_type.clone());
}
Err(syn::Error::new_spanned(
ty,
"Interface method must return Result<RespType, RpcError>",
))
}
fn to_snake_case(s: &str) -> String {
let mut result = String::new();
for (i, c) in s.chars().enumerate() {
if c.is_uppercase() {
if i > 0 {
result.push('_');
}
result.push(c.to_ascii_lowercase());
} else {
result.push(c);
}
}
result
}
struct InterfaceAttr {
id: u64,
}
impl syn::parse::Parse for InterfaceAttr {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let ident: Ident = input.parse()?;
if ident != "id" {
return Err(syn::Error::new_spanned(
ident,
"expected `id` in interface attribute",
));
}
let _eq: syn::Token![=] = input.parse()?;
let value: Expr = input.parse()?;
let id = match &value {
Expr::Lit(ExprLit {
lit: Lit::Int(lit_int),
..
}) => lit_int.base10_parse::<u64>()?,
_ => {
return Err(syn::Error::new_spanned(
value,
"expected integer literal for interface id",
));
}
};
Ok(InterfaceAttr { id })
}
}