use proc_macro::TokenStream;
use quote::quote;
use syn::{parse, FnArg, ItemTrait, Pat, Signature, TraitItem, TraitItemFn};
#[proc_macro_attribute]
pub fn arc_handle(_args: TokenStream, input: TokenStream) -> TokenStream {
match arc_handle_inner(input) {
Ok(tokens) => tokens,
Err(e) => e.to_compile_error().into(),
}
}
fn arc_handle_inner(input: TokenStream) -> syn::Result<TokenStream> {
let mut input = parse::<ItemTrait>(input)?;
let original_trait_name = &input.ident;
let impl_trait_name = syn::Ident::new(
&format!("{}Impl", original_trait_name),
original_trait_name.span(),
);
let handle_name = original_trait_name.clone();
let vis = &input.vis;
input.ident = impl_trait_name.clone();
let mut impl_methods = Vec::new();
for item in &input.items {
match item {
TraitItem::Fn(method) => {
if method.default.is_some() {
return Err(syn::Error::new_spanned(
method,
"arc_handle does not support default method bodies",
));
}
validate_receiver(method)?;
let method_name = &method.sig.ident;
let inputs = &method.sig.inputs;
let output = &method.sig.output;
let is_async = is_async_method(&method.sig);
let param_names = extract_param_names(&method.sig)?;
if is_async {
impl_methods.push(quote! {
#[inline]
#vis async fn #method_name(#inputs) #output {
self.inner.#method_name(#(#param_names),*).await
}
});
} else {
impl_methods.push(quote! {
#[inline]
#vis fn #method_name(#inputs) #output {
self.inner.#method_name(#(#param_names),*)
}
});
}
}
TraitItem::Const(tc) => {
return Err(syn::Error::new_spanned(
tc,
"arc_handle does not support associated constants",
));
}
TraitItem::Type(tt) => {
return Err(syn::Error::new_spanned(
tt,
"arc_handle does not support associated types",
));
}
_ => {}
}
}
let expanded = quote! {
#input
#[doc = concat!("Arc-based handle wrapper for `", stringify!(#impl_trait_name), "`")]
#[derive(Clone)]
#vis struct #handle_name {
inner: std::sync::Arc<dyn #impl_trait_name + Send + Sync>,
}
impl #handle_name {
#[inline]
#vis fn new(inner: impl #impl_trait_name + Send + Sync + 'static) -> Self {
Self {
inner: std::sync::Arc::new(inner),
}
}
#[inline]
#vis fn from_boxed(inner: Box<dyn #impl_trait_name + Send + Sync>) -> Self {
Self {
inner: std::sync::Arc::from(inner),
}
}
#[inline]
#vis fn from_arc(inner: std::sync::Arc<dyn #impl_trait_name + Send + Sync>) -> Self {
Self { inner }
}
#[inline]
#vis fn inner(&self) -> &std::sync::Arc<dyn #impl_trait_name + Send + Sync> {
&self.inner
}
#[inline]
#vis fn into_inner(self) -> std::sync::Arc<dyn #impl_trait_name + Send + Sync> {
self.inner
}
#(#impl_methods)*
}
};
Ok(TokenStream::from(expanded))
}
fn extract_param_names(sig: &Signature) -> syn::Result<Vec<&syn::Ident>> {
sig.inputs
.iter()
.skip(1)
.map(|arg| {
if let FnArg::Typed(pat_type) = arg {
if let Pat::Ident(ident) = &*pat_type.pat {
Ok(&ident.ident)
} else {
Err(syn::Error::new_spanned(
pat_type,
"unsupported parameter pattern; expected a simple identifier",
))
}
} else {
Err(syn::Error::new_spanned(
sig,
"unexpected receiver in parameter list",
))
}
})
.collect()
}
fn validate_receiver(method: &TraitItemFn) -> syn::Result<()> {
match method.sig.inputs.first() {
Some(FnArg::Receiver(r)) => {
if r.mutability.is_some() {
return Err(syn::Error::new_spanned(
r,
"arc_handle does not support &mut self receivers; \
the handle uses Arc which only provides shared access",
));
}
Ok(())
}
Some(FnArg::Typed(pat_type)) => Err(syn::Error::new_spanned(
pat_type,
"arc_handle requires &self as the first parameter; \
by-value self is not supported",
)),
None => Err(syn::Error::new_spanned(
method,
"arc_handle requires methods to have a &self receiver",
)),
}
}
fn is_async_method(sig: &Signature) -> bool {
sig.asyncness.is_some()
}