Skip to main content

arc_handle/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse, FnArg, ItemTrait, Pat, Signature, TraitItem, TraitItemFn};
4
5/// Attribute macro that generates an Arc-based handle wrapper for a trait.
6///
7/// This macro renames the original trait to `TraitImpl` and creates a handle struct
8/// with the original trait name that wraps the trait in an `Arc<dyn TraitImpl + Send + Sync>`
9/// and provides methods that delegate to the inner trait implementation.
10///
11/// # Example
12///
13/// ```rust
14/// use arc_handle::arc_handle;
15///
16/// #[arc_handle]
17/// pub trait Greeter {
18///     fn greet(&self, name: &str) -> String;
19/// }
20///
21/// // The macro generates:
22/// //   - `GreeterImpl` trait (renamed from `Greeter`)
23/// //   - `Greeter` struct wrapping `Arc<dyn GreeterImpl + Send + Sync>`
24/// //   - Delegating methods on `Greeter` matching the trait
25///
26/// struct EnglishGreeter;
27///
28/// impl GreeterImpl for EnglishGreeter {
29///     fn greet(&self, name: &str) -> String {
30///         format!("Hello, {name}!")
31///     }
32/// }
33///
34/// let handle = Greeter::new(EnglishGreeter);
35/// assert_eq!(handle.greet("world"), "Hello, world!");
36/// ```
37///
38/// Async methods are also supported:
39///
40/// ```rust,ignore
41/// use arc_handle::arc_handle;
42/// use async_trait::async_trait;
43///
44/// #[arc_handle]
45/// #[async_trait]
46/// pub trait AsyncService {
47///     async fn fetch(&self, url: &str) -> String;
48///     fn name(&self) -> &str;
49/// }
50/// ```
51#[proc_macro_attribute]
52pub fn arc_handle(_args: TokenStream, input: TokenStream) -> TokenStream {
53    match arc_handle_inner(input) {
54        Ok(tokens) => tokens,
55        Err(e) => e.to_compile_error().into(),
56    }
57}
58
59fn arc_handle_inner(input: TokenStream) -> syn::Result<TokenStream> {
60    let mut input = parse::<ItemTrait>(input)?;
61
62    let original_trait_name = &input.ident;
63    let impl_trait_name = syn::Ident::new(
64        &format!("{}Impl", original_trait_name),
65        original_trait_name.span(),
66    );
67    let handle_name = original_trait_name.clone();
68    let vis = &input.vis;
69
70    // Rename the original trait to TraitImpl
71    input.ident = impl_trait_name.clone();
72
73    // Extract methods and validate trait items
74    let mut impl_methods = Vec::new();
75
76    for item in &input.items {
77        match item {
78            TraitItem::Fn(method) => {
79                // Reject default method bodies
80                if method.default.is_some() {
81                    return Err(syn::Error::new_spanned(
82                        method,
83                        "arc_handle does not support default method bodies",
84                    ));
85                }
86
87                // Validate receiver
88                validate_receiver(method)?;
89
90                let method_name = &method.sig.ident;
91                let inputs = &method.sig.inputs;
92                let output = &method.sig.output;
93                let is_async = is_async_method(&method.sig);
94
95                let param_names = extract_param_names(&method.sig)?;
96
97                if is_async {
98                    impl_methods.push(quote! {
99                        #[inline]
100                        #vis async fn #method_name(#inputs) #output {
101                            self.inner.#method_name(#(#param_names),*).await
102                        }
103                    });
104                } else {
105                    impl_methods.push(quote! {
106                        #[inline]
107                        #vis fn #method_name(#inputs) #output {
108                            self.inner.#method_name(#(#param_names),*)
109                        }
110                    });
111                }
112            }
113            TraitItem::Const(tc) => {
114                return Err(syn::Error::new_spanned(
115                    tc,
116                    "arc_handle does not support associated constants",
117                ));
118            }
119            TraitItem::Type(tt) => {
120                return Err(syn::Error::new_spanned(
121                    tt,
122                    "arc_handle does not support associated types",
123                ));
124            }
125            _ => {}
126        }
127    }
128
129    let expanded = quote! {
130        #input
131
132        #[doc = concat!("Arc-based handle wrapper for `", stringify!(#impl_trait_name), "`")]
133        #[derive(Clone)]
134        #vis struct #handle_name {
135            inner: std::sync::Arc<dyn #impl_trait_name + Send + Sync>,
136        }
137
138        impl #handle_name {
139            /// Create a new handle from a trait implementation
140            #[inline]
141            #vis fn new(inner: impl #impl_trait_name + Send + Sync + 'static) -> Self {
142                Self {
143                    inner: std::sync::Arc::new(inner),
144                }
145            }
146
147            /// Create a new handle from a boxed trait object
148            #[inline]
149            #vis fn from_boxed(inner: Box<dyn #impl_trait_name + Send + Sync>) -> Self {
150                Self {
151                    inner: std::sync::Arc::from(inner),
152                }
153            }
154
155            /// Create a new handle from an existing Arc
156            #[inline]
157            #vis fn from_arc(inner: std::sync::Arc<dyn #impl_trait_name + Send + Sync>) -> Self {
158                Self { inner }
159            }
160
161            /// Get a reference to the inner Arc
162            #[inline]
163            #vis fn inner(&self) -> &std::sync::Arc<dyn #impl_trait_name + Send + Sync> {
164                &self.inner
165            }
166
167            /// Unwrap the handle into the inner Arc
168            #[inline]
169            #vis fn into_inner(self) -> std::sync::Arc<dyn #impl_trait_name + Send + Sync> {
170                self.inner
171            }
172
173            #(#impl_methods)*
174        }
175    };
176
177    Ok(TokenStream::from(expanded))
178}
179
180fn extract_param_names(sig: &Signature) -> syn::Result<Vec<&syn::Ident>> {
181    sig.inputs
182        .iter()
183        .skip(1)
184        .map(|arg| {
185            if let FnArg::Typed(pat_type) = arg {
186                if let Pat::Ident(ident) = &*pat_type.pat {
187                    Ok(&ident.ident)
188                } else {
189                    Err(syn::Error::new_spanned(
190                        pat_type,
191                        "unsupported parameter pattern; expected a simple identifier",
192                    ))
193                }
194            } else {
195                Err(syn::Error::new_spanned(
196                    sig,
197                    "unexpected receiver in parameter list",
198                ))
199            }
200        })
201        .collect()
202}
203
204fn validate_receiver(method: &TraitItemFn) -> syn::Result<()> {
205    match method.sig.inputs.first() {
206        Some(FnArg::Receiver(r)) => {
207            // &self is fine; &mut self is not (can't get &mut through Arc<dyn>)
208            if r.mutability.is_some() {
209                return Err(syn::Error::new_spanned(
210                    r,
211                    "arc_handle does not support &mut self receivers; \
212                     the handle uses Arc which only provides shared access",
213                ));
214            }
215            Ok(())
216        }
217        Some(FnArg::Typed(pat_type)) => Err(syn::Error::new_spanned(
218            pat_type,
219            "arc_handle requires &self as the first parameter; \
220             by-value self is not supported",
221        )),
222        None => Err(syn::Error::new_spanned(
223            method,
224            "arc_handle requires methods to have a &self receiver",
225        )),
226    }
227}
228
229fn is_async_method(sig: &Signature) -> bool {
230    sig.asyncness.is_some()
231}