arc-handle 1.1.0

Proc macro for generating Arc-based handle wrappers for traits
Documentation
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse, FnArg, ItemTrait, Pat, Signature, TraitItem, TraitItemFn};

/// Attribute macro that generates an Arc-based handle wrapper for a trait.
///
/// This macro renames the original trait to `TraitImpl` and creates a handle struct
/// with the original trait name that wraps the trait in an `Arc<dyn TraitImpl + Send + Sync>`
/// and provides methods that delegate to the inner trait implementation.
///
/// # Example
///
/// ```rust
/// use arc_handle::arc_handle;
///
/// #[arc_handle]
/// pub trait Greeter {
///     fn greet(&self, name: &str) -> String;
/// }
///
/// // The macro generates:
/// //   - `GreeterImpl` trait (renamed from `Greeter`)
/// //   - `Greeter` struct wrapping `Arc<dyn GreeterImpl + Send + Sync>`
/// //   - Delegating methods on `Greeter` matching the trait
///
/// struct EnglishGreeter;
///
/// impl GreeterImpl for EnglishGreeter {
///     fn greet(&self, name: &str) -> String {
///         format!("Hello, {name}!")
///     }
/// }
///
/// let handle = Greeter::new(EnglishGreeter);
/// assert_eq!(handle.greet("world"), "Hello, world!");
/// ```
///
/// Async methods are also supported:
///
/// ```rust,ignore
/// use arc_handle::arc_handle;
/// use async_trait::async_trait;
///
/// #[arc_handle]
/// #[async_trait]
/// pub trait AsyncService {
///     async fn fetch(&self, url: &str) -> String;
///     fn name(&self) -> &str;
/// }
/// ```
#[proc_macro_attribute]
pub fn arc_handle(_args: TokenStream, input: TokenStream) -> TokenStream {
    match arc_handle_inner(input) {
        Ok(tokens) => tokens,
        Err(e) => e.to_compile_error().into(),
    }
}

fn arc_handle_inner(input: TokenStream) -> syn::Result<TokenStream> {
    let mut input = parse::<ItemTrait>(input)?;

    let original_trait_name = &input.ident;
    let impl_trait_name = syn::Ident::new(
        &format!("{}Impl", original_trait_name),
        original_trait_name.span(),
    );
    let handle_name = original_trait_name.clone();
    let vis = &input.vis;

    // Rename the original trait to TraitImpl
    input.ident = impl_trait_name.clone();

    // Extract methods and validate trait items
    let mut impl_methods = Vec::new();

    for item in &input.items {
        match item {
            TraitItem::Fn(method) => {
                // Reject default method bodies
                if method.default.is_some() {
                    return Err(syn::Error::new_spanned(
                        method,
                        "arc_handle does not support default method bodies",
                    ));
                }

                // Validate receiver
                validate_receiver(method)?;

                let method_name = &method.sig.ident;
                let inputs = &method.sig.inputs;
                let output = &method.sig.output;
                let is_async = is_async_method(&method.sig);

                let param_names = extract_param_names(&method.sig)?;

                if is_async {
                    impl_methods.push(quote! {
                        #[inline]
                        #vis async fn #method_name(#inputs) #output {
                            self.inner.#method_name(#(#param_names),*).await
                        }
                    });
                } else {
                    impl_methods.push(quote! {
                        #[inline]
                        #vis fn #method_name(#inputs) #output {
                            self.inner.#method_name(#(#param_names),*)
                        }
                    });
                }
            }
            TraitItem::Const(tc) => {
                return Err(syn::Error::new_spanned(
                    tc,
                    "arc_handle does not support associated constants",
                ));
            }
            TraitItem::Type(tt) => {
                return Err(syn::Error::new_spanned(
                    tt,
                    "arc_handle does not support associated types",
                ));
            }
            _ => {}
        }
    }

    let expanded = quote! {
        #input

        #[doc = concat!("Arc-based handle wrapper for `", stringify!(#impl_trait_name), "`")]
        #[derive(Clone)]
        #vis struct #handle_name {
            inner: std::sync::Arc<dyn #impl_trait_name + Send + Sync>,
        }

        impl #handle_name {
            /// Create a new handle from a trait implementation
            #[inline]
            #vis fn new(inner: impl #impl_trait_name + Send + Sync + 'static) -> Self {
                Self {
                    inner: std::sync::Arc::new(inner),
                }
            }

            /// Create a new handle from a boxed trait object
            #[inline]
            #vis fn from_boxed(inner: Box<dyn #impl_trait_name + Send + Sync>) -> Self {
                Self {
                    inner: std::sync::Arc::from(inner),
                }
            }

            /// Create a new handle from an existing Arc
            #[inline]
            #vis fn from_arc(inner: std::sync::Arc<dyn #impl_trait_name + Send + Sync>) -> Self {
                Self { inner }
            }

            /// Get a reference to the inner Arc
            #[inline]
            #vis fn inner(&self) -> &std::sync::Arc<dyn #impl_trait_name + Send + Sync> {
                &self.inner
            }

            /// Unwrap the handle into the inner Arc
            #[inline]
            #vis fn into_inner(self) -> std::sync::Arc<dyn #impl_trait_name + Send + Sync> {
                self.inner
            }

            #(#impl_methods)*
        }
    };

    Ok(TokenStream::from(expanded))
}

fn extract_param_names(sig: &Signature) -> syn::Result<Vec<&syn::Ident>> {
    sig.inputs
        .iter()
        .skip(1)
        .map(|arg| {
            if let FnArg::Typed(pat_type) = arg {
                if let Pat::Ident(ident) = &*pat_type.pat {
                    Ok(&ident.ident)
                } else {
                    Err(syn::Error::new_spanned(
                        pat_type,
                        "unsupported parameter pattern; expected a simple identifier",
                    ))
                }
            } else {
                Err(syn::Error::new_spanned(
                    sig,
                    "unexpected receiver in parameter list",
                ))
            }
        })
        .collect()
}

fn validate_receiver(method: &TraitItemFn) -> syn::Result<()> {
    match method.sig.inputs.first() {
        Some(FnArg::Receiver(r)) => {
            // &self is fine; &mut self is not (can't get &mut through Arc<dyn>)
            if r.mutability.is_some() {
                return Err(syn::Error::new_spanned(
                    r,
                    "arc_handle does not support &mut self receivers; \
                     the handle uses Arc which only provides shared access",
                ));
            }
            Ok(())
        }
        Some(FnArg::Typed(pat_type)) => Err(syn::Error::new_spanned(
            pat_type,
            "arc_handle requires &self as the first parameter; \
             by-value self is not supported",
        )),
        None => Err(syn::Error::new_spanned(
            method,
            "arc_handle requires methods to have a &self receiver",
        )),
    }
}

fn is_async_method(sig: &Signature) -> bool {
    sig.asyncness.is_some()
}