use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{ToTokens, format_ident, quote};
use syn::{
Attribute, FnArg, GenericArgument, ItemTrait, Pat, PatType, PathArguments, ReturnType,
TraitItem, TraitItemFn, Type, Visibility, parse_macro_input,
};
pub fn hosted_rpc(attr: TokenStream, item: TokenStream) -> TokenStream {
if !attr.is_empty() {
let attr2: TokenStream2 = attr.into();
return syn::Error::new_spanned(
attr2,
"`#[hosted_rpc]` does not take any attribute arguments in the MVP",
)
.to_compile_error()
.into();
}
let item_trait = parse_macro_input!(item as ItemTrait);
expand(item_trait).into()
}
fn expand(item_trait: ItemTrait) -> TokenStream2 {
if item_trait.unsafety.is_some() {
return syn::Error::new_spanned(
item_trait.unsafety,
"`#[hosted_rpc]` traits must not be `unsafe` in the MVP",
)
.to_compile_error();
}
if let Some(g) = item_trait.generics.params.first() {
return syn::Error::new_spanned(g, "`#[hosted_rpc]` traits must be non-generic in the MVP")
.to_compile_error();
}
if !item_trait.supertraits.is_empty() {
return syn::Error::new_spanned(
&item_trait.supertraits,
"`#[hosted_rpc]` traits must not have supertraits in the MVP",
)
.to_compile_error();
}
if let Some(attr) = item_trait.attrs.iter().find(|a| is_cfg_attr(a)) {
return syn::Error::new_spanned(
attr,
"`#[hosted_rpc]` does not support `#[cfg(...)]` / `#[cfg_attr(...)]` on the trait in the MVP \
(the generated sibling items would not be cfg-propagated)",
)
.to_compile_error();
}
if let Some(item) = item_trait
.items
.iter()
.find(|it| !matches!(it, TraitItem::Fn(_)))
{
return syn::Error::new_spanned(
item,
"`#[hosted_rpc]` traits must only declare methods (no consts, types, etc.)",
)
.to_compile_error();
}
let methods: Vec<&TraitItemFn> = item_trait
.items
.iter()
.filter_map(|it| match it {
TraitItem::Fn(f) => Some(f),
_ => None,
})
.collect();
let async_mode = methods.iter().any(|m| m.sig.asyncness.is_some());
if async_mode && let Some(sync_method) = methods.iter().find(|m| m.sig.asyncness.is_none()) {
return syn::Error::new_spanned(
&sync_method.sig,
"`#[hosted_rpc]` requires methods to be either all `async fn` or all sync \
(mixed sync/async traits are rejected so the generated stub trait surface stays consistent)",
)
.to_compile_error();
}
for m in &methods {
if let Some(g) = m.sig.generics.params.first() {
return syn::Error::new_spanned(
g,
"`#[hosted_rpc]` methods must be non-generic in the MVP",
)
.to_compile_error();
}
if m.sig.unsafety.is_some() {
return syn::Error::new_spanned(
m.sig.unsafety,
"`#[hosted_rpc]` methods must not be `unsafe` in the MVP",
)
.to_compile_error();
}
if let Some(abi) = &m.sig.abi {
return syn::Error::new_spanned(
abi,
"`#[hosted_rpc]` methods must use the default Rust ABI (no `extern ...`) in the MVP",
)
.to_compile_error();
}
if let Some(variadic) = &m.sig.variadic {
return syn::Error::new_spanned(
variadic,
"`#[hosted_rpc]` methods must not be variadic in the MVP",
)
.to_compile_error();
}
if m.default.is_some() {
return syn::Error::new_spanned(
m,
"`#[hosted_rpc]` methods must not have a default body in the MVP",
)
.to_compile_error();
}
if let Some(attr) = m.attrs.iter().find(|a| is_cfg_attr(a)) {
return syn::Error::new_spanned(
attr,
"`#[hosted_rpc]` does not support `#[cfg(...)]` / `#[cfg_attr(...)]` on trait methods in the MVP \
(the generated dispatch arms would not be cfg-propagated)",
)
.to_compile_error();
}
let Some(first) = m.sig.inputs.first() else {
return syn::Error::new_spanned(&m.sig, "`#[hosted_rpc]` methods must take `&self`")
.to_compile_error();
};
let FnArg::Receiver(receiver) = first else {
return syn::Error::new_spanned(first, "`#[hosted_rpc]` methods must take `&self`")
.to_compile_error();
};
if receiver.reference.is_none() {
return syn::Error::new_spanned(
receiver,
"`#[hosted_rpc]` methods must take `&self` (no by-value `self`)",
)
.to_compile_error();
}
if receiver.colon_token.is_some() {
return syn::Error::new_spanned(
receiver,
"`#[hosted_rpc]` methods must take `&self` (no explicit `self: T` type)",
)
.to_compile_error();
}
if receiver.mutability.is_some() {
return syn::Error::new_spanned(
receiver,
"`#[hosted_rpc]` methods must take `&self` (test-r test deps are injected as `&Stub`; \
`&mut self` stub methods would be uncallable from injected test parameters)",
)
.to_compile_error();
}
for input in m.sig.inputs.iter() {
if let FnArg::Typed(t) = input
&& contains_impl_trait(&t.ty)
{
return syn::Error::new_spanned(
&t.ty,
"`#[hosted_rpc]` does not support `impl Trait` in argument position in the MVP",
)
.to_compile_error();
}
}
if let ReturnType::Type(_, ty) = &m.sig.output
&& contains_impl_trait(ty)
{
return syn::Error::new_spanned(
ty,
"`#[hosted_rpc]` does not support `impl Trait` in return position in the MVP",
)
.to_compile_error();
}
for input in m.sig.inputs.iter() {
if let FnArg::Typed(t) = input
&& !matches!(&*t.pat, Pat::Ident(_))
{
return syn::Error::new_spanned(
&t.pat,
"`#[hosted_rpc]` requires plain identifier argument patterns (no `_`, no destructuring) in the MVP",
)
.to_compile_error();
}
}
}
let trait_vis = &item_trait.vis;
let trait_ident = &item_trait.ident;
let stub_ident = format_ident!("{}Stub", trait_ident);
let dispatch_ident = format_ident!("{}Dispatch", trait_ident);
let dispatch_method_ident =
format_ident!("dispatch_{}", to_snake_case(&trait_ident.to_string()));
let mut stub_impl_arms: Vec<TokenStream2> = Vec::new();
let mut dispatch_arms: Vec<TokenStream2> = Vec::new();
for (idx, m) in methods.iter().enumerate() {
let method_idx = idx as u32;
let sig = &m.sig;
let method_ident = &sig.ident;
let asyncness = &sig.asyncness;
let await_token = if asyncness.is_some() {
quote!(.await)
} else {
quote!()
};
let (receiver, typed_args): (TokenStream2, Vec<&PatType>) = {
let mut recv = quote!();
let mut others: Vec<&PatType> = Vec::new();
for input in sig.inputs.iter() {
match input {
FnArg::Receiver(r) => recv = r.to_token_stream(),
FnArg::Typed(t) => others.push(t),
}
}
(recv, others)
};
let arg_idents: Vec<TokenStream2> = typed_args
.iter()
.map(|t| match &*t.pat {
Pat::Ident(p) => {
let i = &p.ident;
quote!(#i)
}
other => other.to_token_stream(),
})
.collect();
let arg_types: Vec<TokenStream2> =
typed_args.iter().map(|t| t.ty.to_token_stream()).collect();
let ret_ty: TokenStream2 = match &sig.output {
ReturnType::Default => quote!(()),
ReturnType::Type(_, t) => t.to_token_stream(),
};
let args_pack: TokenStream2 = if arg_idents.is_empty() {
quote!(())
} else if arg_idents.len() == 1 {
let id = &arg_idents[0];
quote!(#id)
} else {
quote!((#(#arg_idents),*))
};
let args_tuple_ty: TokenStream2 = if arg_types.is_empty() {
quote!(())
} else if arg_types.len() == 1 {
let t = &arg_types[0];
quote!(#t)
} else {
quote!((#(#arg_types),*))
};
let arg_unpack: Vec<TokenStream2> = if arg_idents.is_empty() {
Vec::new()
} else if arg_idents.len() == 1 {
let id = &arg_idents[0];
vec![quote!(let #id = __args;)]
} else {
vec![quote!(let (#(#arg_idents),*) = __args;)]
};
let attrs = &m.attrs;
let stub_label = format!("{}::{}", trait_ident, method_ident);
let stub_encode_msg = format!("hosted_rpc({stub_label}): encode args");
let stub_call_msg = format!("hosted_rpc({stub_label}): rpc call failed");
let stub_decode_msg = format!("hosted_rpc({stub_label}): decode reply");
let dispatch_decode_args_fmt = format!(
"hosted_rpc dispatch ({stub_label}, method_idx={method_idx}): decode args: {{:?}}"
);
let dispatch_encode_reply_fmt = format!(
"hosted_rpc dispatch ({stub_label}, method_idx={method_idx}): encode reply: {{:?}}"
);
stub_impl_arms.push(quote! {
#(#attrs)*
#asyncness fn #method_ident(#receiver, #(#typed_args),*) -> #ret_ty {
let __args: #args_tuple_ty = #args_pack;
let __args_bytes: ::std::vec::Vec<u8> =
::test_r::core::desert_rust::serialize_to_byte_vec(&__args)
.expect(#stub_encode_msg);
let __reply: ::std::vec::Vec<u8> = self
.channel
.call(#method_idx, __args_bytes)
.expect(#stub_call_msg);
::test_r::core::desert_rust::deserialize::<#ret_ty>(&__reply)
.expect(#stub_decode_msg)
}
});
dispatch_arms.push(quote! {
#method_idx => {
let __args: #args_tuple_ty =
::test_r::core::desert_rust::deserialize(args)
.map_err(|e| ::std::format!(#dispatch_decode_args_fmt, e))?;
#(#arg_unpack)*
let __result: #ret_ty = self.#method_ident(#(#arg_idents),*) #await_token;
::test_r::core::desert_rust::serialize_to_byte_vec(&__result)
.map_err(|e| ::std::format!(#dispatch_encode_reply_fmt, e))
}
});
}
let stub_vis: &Visibility = trait_vis;
let trait_decl = &item_trait;
let dispatch_unknown_method_text = format!("{}: unknown method_idx {{}}", trait_ident);
let stub_struct_name_text = stub_ident.to_string();
let dispatch_asyncness = if async_mode { quote!(async) } else { quote!() };
let blanket_bound = if async_mode {
quote!(#trait_ident + ::core::marker::Send + ::core::marker::Sync + ?Sized)
} else {
quote!(#trait_ident + ?Sized)
};
quote! {
#trait_decl
#stub_vis struct #stub_ident {
channel: ::test_r::core::HostedRpcChannel,
}
impl #stub_ident {
pub fn new(channel: ::test_r::core::HostedRpcChannel) -> Self {
Self { channel }
}
}
impl ::core::fmt::Debug for #stub_ident {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
f.debug_struct(#stub_struct_name_text)
.field("dep_id", &self.channel.dep_id())
.finish()
}
}
impl #trait_ident for #stub_ident {
#(#stub_impl_arms)*
}
#trait_vis trait #dispatch_ident {
#dispatch_asyncness fn #dispatch_method_ident(
&mut self,
method_idx: u32,
args: &[u8],
) -> ::std::result::Result<::std::vec::Vec<u8>, ::std::string::String>;
}
impl<__T: #blanket_bound> #dispatch_ident for __T {
#dispatch_asyncness fn #dispatch_method_ident(
&mut self,
method_idx: u32,
args: &[u8],
) -> ::std::result::Result<::std::vec::Vec<u8>, ::std::string::String> {
match method_idx {
#(#dispatch_arms)*
other => ::std::result::Result::Err(
::std::format!(#dispatch_unknown_method_text, other),
),
}
}
}
}
}
fn is_cfg_attr(attr: &Attribute) -> bool {
attr.path().is_ident("cfg") || attr.path().is_ident("cfg_attr")
}
fn contains_impl_trait(ty: &Type) -> bool {
match ty {
Type::ImplTrait(_) => true,
Type::Reference(r) => contains_impl_trait(&r.elem),
Type::Paren(p) => contains_impl_trait(&p.elem),
Type::Group(g) => contains_impl_trait(&g.elem),
Type::Slice(s) => contains_impl_trait(&s.elem),
Type::Array(a) => contains_impl_trait(&a.elem),
Type::Ptr(p) => contains_impl_trait(&p.elem),
Type::Tuple(t) => t.elems.iter().any(contains_impl_trait),
Type::Path(p) => {
p.path.segments.iter().any(|seg| match &seg.arguments {
PathArguments::None => false,
PathArguments::AngleBracketed(args) => args.args.iter().any(|a| match a {
GenericArgument::Type(t) => contains_impl_trait(t),
_ => false,
}),
PathArguments::Parenthesized(args) => {
args.inputs.iter().any(contains_impl_trait)
|| matches!(&args.output, ReturnType::Type(_, t) if contains_impl_trait(t))
}
})
}
_ => false,
}
}
fn to_snake_case(s: &str) -> String {
let mut out = String::with_capacity(s.len() + 4);
for (i, ch) in s.chars().enumerate() {
if ch.is_ascii_uppercase() {
if i > 0 {
out.push('_');
}
out.push(ch.to_ascii_lowercase());
} else {
out.push(ch);
}
}
out
}
#[cfg(test)]
mod tests {
use super::expand;
use syn::parse_quote;
fn expand_to_string(item: syn::ItemTrait) -> String {
expand(item).to_string()
}
#[test]
fn rejects_generic_trait() {
let s = expand_to_string(parse_quote! {
trait Foo<T> {
fn one(&self) -> T;
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for generic traits, got: {s}"
);
assert!(
s.contains("non-generic"),
"expected the rejection to mention non-generic, got: {s}"
);
}
#[test]
fn rejects_generic_method() {
let s = expand_to_string(parse_quote! {
trait Foo {
fn one<T>(&self, x: T);
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for generic methods, got: {s}"
);
}
#[test]
fn rejects_mixed_sync_and_async_methods() {
let s = expand_to_string(parse_quote! {
trait Foo {
async fn async_one(&self);
fn sync_two(&self);
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for mixed sync/async methods, got: {s}"
);
assert!(
s.contains("all `async fn`") || s.contains("all sync"),
"expected the rejection to mention all-or-nothing async, got: {s}"
);
}
#[test]
fn accepts_all_async_methods_and_emits_async_dispatch() {
let s = expand_to_string(parse_quote! {
trait Counter {
async fn next(&self) -> u64;
async fn reserve(&self, count: u32) -> u64;
}
});
assert!(
!s.contains("compile_error"),
"valid async trait must not emit compile_error!, got: {s}"
);
let normalized: String = s.split_whitespace().collect::<Vec<_>>().join(" ");
assert!(
normalized.contains("async fn dispatch_counter"),
"async-mode must produce an async dispatch helper method, got: {normalized}"
);
assert!(
normalized.contains("async fn next") && normalized.contains("async fn reserve"),
"async-mode stub impl methods must stay `async fn`, got: {normalized}"
);
}
#[test]
fn rejects_default_body_method() {
let s = expand_to_string(parse_quote! {
trait Foo {
fn one(&self) { let _ = 1; }
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for default-body methods, got: {s}"
);
assert!(
s.contains("default body"),
"expected the rejection to mention default body, got: {s}"
);
}
#[test]
fn rejects_associated_type() {
let s = expand_to_string(parse_quote! {
trait Foo {
type Item;
fn one(&self);
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for associated types, got: {s}"
);
}
#[test]
fn rejects_method_without_self() {
let s = expand_to_string(parse_quote! {
trait Foo {
fn one(x: u32);
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for non-`self` first argument, got: {s}"
);
}
#[test]
fn rejects_supertraits() {
let s = expand_to_string(parse_quote! {
trait Foo: Send {
fn one(&self);
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for supertraits, got: {s}"
);
assert!(
s.contains("supertraits"),
"expected the rejection to mention supertraits, got: {s}"
);
}
#[test]
fn rejects_self_by_value_receiver() {
let s = expand_to_string(parse_quote! {
trait Foo {
fn one(self);
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for by-value self, got: {s}"
);
assert!(
s.contains("by-value"),
"expected the rejection to mention by-value, got: {s}"
);
}
#[test]
fn rejects_explicit_self_type() {
let s = expand_to_string(parse_quote! {
trait Foo {
fn one(self: Box<Self>);
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for explicit `self: T`, got: {s}"
);
assert!(
s.contains("&self") || s.contains("by-value") || s.contains("self: T"),
"expected the rejection to mention &self / by-value / self: T, got: {s}"
);
}
#[test]
fn rejects_unsafe_method() {
let s = expand_to_string(parse_quote! {
trait Foo {
unsafe fn one(&self);
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for unsafe methods, got: {s}"
);
assert!(
s.contains("unsafe"),
"expected the rejection to mention unsafe, got: {s}"
);
}
#[test]
fn rejects_unsafe_trait() {
let s = expand_to_string(parse_quote! {
unsafe trait Foo {
fn one(&self);
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for unsafe traits, got: {s}"
);
assert!(
s.contains("unsafe"),
"expected the rejection to mention unsafe, got: {s}"
);
}
#[test]
fn rejects_extern_abi_method() {
let s = expand_to_string(parse_quote! {
trait Foo {
extern "C" fn one(&self);
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for non-default ABI, got: {s}"
);
assert!(
s.contains("Rust ABI") || s.contains("extern"),
"expected the rejection to mention ABI/extern, got: {s}"
);
}
#[test]
fn rejects_impl_trait_in_argument() {
let s = expand_to_string(parse_quote! {
trait Foo {
fn one(&self, x: impl ::std::fmt::Display);
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for impl Trait in argument, got: {s}"
);
assert!(
s.contains("argument position"),
"expected the rejection to mention argument position, got: {s}"
);
}
#[test]
fn rejects_impl_trait_in_return() {
let s = expand_to_string(parse_quote! {
trait Foo {
fn one(&self) -> impl ::std::fmt::Display;
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for impl Trait in return, got: {s}"
);
assert!(
s.contains("return position"),
"expected the rejection to mention return position, got: {s}"
);
}
#[test]
fn rejects_wildcard_arg_pattern() {
let s = expand_to_string(parse_quote! {
trait Foo {
fn one(&self, _: u32);
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for wildcard arg pattern, got: {s}"
);
assert!(
s.contains("identifier"),
"expected the rejection to mention identifier patterns, got: {s}"
);
}
#[test]
fn rejects_destructured_arg_pattern() {
let s = expand_to_string(parse_quote! {
trait Foo {
fn one(&self, (a, b): (u32, u32));
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for destructured arg pattern, got: {s}"
);
}
#[test]
fn rejects_cfg_on_trait() {
let s = expand_to_string(parse_quote! {
#[cfg(unix)]
trait Foo {
fn one(&self);
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for cfg on trait, got: {s}"
);
assert!(
s.contains("cfg"),
"expected the rejection to mention cfg, got: {s}"
);
}
#[test]
fn rejects_cfg_on_method() {
let s = expand_to_string(parse_quote! {
trait Foo {
#[cfg(unix)]
fn one(&self);
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for cfg on method, got: {s}"
);
assert!(
s.contains("cfg"),
"expected the rejection to mention cfg, got: {s}"
);
}
#[test]
fn rejects_mut_self_receiver() {
let s = expand_to_string(parse_quote! {
trait Foo {
fn one(&mut self);
}
});
assert!(
s.contains("compile_error"),
"expected a compile_error! for `&mut self`, got: {s}"
);
assert!(
s.contains("&Stub") && s.contains("uncallable"),
"expected the rejection to mention the immutable `&Stub` injection rationale, got: {s}"
);
}
#[test]
fn accepts_two_arg_method() {
let s = expand_to_string(parse_quote! {
trait Foo {
fn add(&self, a: u32, b: u32) -> u32;
}
});
assert!(
!s.contains("compile_error"),
"valid `&self` + 2-arg trait must not emit compile_error!, got: {s}"
);
assert!(s.contains("struct FooStub"));
assert!(s.contains("trait FooDispatch"));
}
#[test]
fn accepts_unit_return_method() {
let s = expand_to_string(parse_quote! {
trait Foo {
fn ping(&self);
}
});
assert!(
!s.contains("compile_error"),
"valid `&self` + unit-return trait must not emit compile_error!, got: {s}"
);
assert!(s.contains("struct FooStub"));
}
#[test]
fn accepts_simple_trait_and_emits_stub_and_dispatch() {
let s = expand_to_string(parse_quote! {
trait Counter {
fn next(&self) -> u64;
fn reserve(&self, count: u32) -> u64;
}
});
assert!(
!s.contains("compile_error"),
"valid trait must not emit compile_error!, got: {s}"
);
assert!(s.contains("struct CounterStub"));
assert!(s.contains("trait CounterDispatch"));
assert!(s.contains("dispatch_counter"));
}
#[test]
fn emits_debug_impl_on_stub() {
let s = expand_to_string(parse_quote! {
trait Counter {
fn next(&self) -> u64;
}
});
assert!(
!s.contains("compile_error"),
"valid trait must not emit compile_error!, got: {s}"
);
let normalized: String = s.split_whitespace().collect::<Vec<_>>().join(" ");
assert!(
normalized.contains("impl :: core :: fmt :: Debug for CounterStub")
|| normalized.contains("impl ::core::fmt::Debug for CounterStub")
|| normalized.contains("impl core :: fmt :: Debug for CounterStub")
|| normalized.contains("impl Debug for CounterStub"),
"must emit Debug impl for CounterStub, got: {normalized}"
);
assert!(
normalized.contains("\"CounterStub\""),
"Debug fmt must include the stub type name, got: {normalized}"
);
assert!(
normalized.contains("dep_id"),
"Debug fmt must include the dep_id field, got: {normalized}"
);
}
}