arc_trait/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, ItemTrait, TraitItem, TraitItemFn};
4
5#[proc_macro_attribute]
6pub fn arc_trait(_attr: TokenStream, item: TokenStream) -> TokenStream {
7    // Parse the trait
8    let input = parse_macro_input!(item as ItemTrait);
9    let trait_name = &input.ident;
10
11    // Create the trait definition token to keep it intact
12    let trait_def = quote! {
13        #input
14    };
15
16    // Collect all methods in the trait
17    let methods: Vec<TraitItemFn> = input
18        .items
19        .iter()
20        .filter_map(|item| {
21            if let TraitItem::Fn(method) = item {
22                Some(method.clone())
23            } else {
24                None
25            }
26        })
27        .collect();
28
29    // Generate implementations for Arc<T>
30    let impls = methods.iter().map(|method| {
31        let name = &method.sig.ident;
32        let inputs = &method.sig.inputs;
33        let output = &method.sig.output;
34        let generics = &method.sig.generics;
35        let where_clause = &method.sig.generics.where_clause;
36        let attrs = &method.attrs;
37        let is_async = method.sig.asyncness.is_some();
38
39        let call_args = inputs.iter().skip(1).map(|arg| {
40            if let syn::FnArg::Typed(pat_type) = arg {
41                let pat = &pat_type.pat;
42                quote! { #pat }
43            } else {
44                quote! {}
45            }
46        });
47
48        if is_async {
49            quote! {
50                #(#attrs)*
51                async fn #name #generics (#inputs) #output #where_clause {
52                    self.as_ref().#name(#(#call_args),*).await
53                }
54            }
55        } else {
56            quote! {
57                #(#attrs)*
58                fn #name #generics (#inputs) #output #where_clause {
59                    self.as_ref().#name(#(#call_args),*)
60                }
61            }
62        }
63    });
64
65    let expanded = if methods.iter().any(|method| method.sig.asyncness.is_some()) {
66        quote! {
67            #trait_def
68
69            #[async_trait::async_trait]
70            impl<T: #trait_name + Send + Sync> #trait_name for std::sync::Arc<T> {
71                #(#impls)*
72            }
73        }
74    } else {
75        quote! {
76            #trait_def
77
78            impl<T: #trait_name> #trait_name for std::sync::Arc<T> {
79                #(#impls)*
80            }
81        }
82    };
83
84    // Return the generated implementation
85    TokenStream::from(expanded)
86}