1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, ItemTrait, TraitItem, TraitItemFn};
4
5#[proc_macro_attribute]
6pub fn arc_trait(_attr: TokenStream, item: TokenStream) -> TokenStream {
7 let input = parse_macro_input!(item as ItemTrait);
9 let trait_name = &input.ident;
10
11 let trait_def = quote! {
13 #input
14 };
15
16 let methods: Vec<TraitItemFn> = input
18 .items
19 .iter()
20 .filter_map(|item| {
21 if let TraitItem::Fn(method) = item {
22 Some(method.clone())
23 } else {
24 None
25 }
26 })
27 .collect();
28
29 let impls = methods.iter().map(|method| {
31 let name = &method.sig.ident;
32 let inputs = &method.sig.inputs;
33 let output = &method.sig.output;
34 let generics = &method.sig.generics;
35 let where_clause = &method.sig.generics.where_clause;
36 let attrs = &method.attrs;
37 let is_async = method.sig.asyncness.is_some();
38
39 let call_args = inputs.iter().skip(1).map(|arg| {
40 if let syn::FnArg::Typed(pat_type) = arg {
41 let pat = &pat_type.pat;
42 quote! { #pat }
43 } else {
44 quote! {}
45 }
46 });
47
48 if is_async {
49 quote! {
50 #(#attrs)*
51 async fn #name #generics (#inputs) #output #where_clause {
52 self.as_ref().#name(#(#call_args),*).await
53 }
54 }
55 } else {
56 quote! {
57 #(#attrs)*
58 fn #name #generics (#inputs) #output #where_clause {
59 self.as_ref().#name(#(#call_args),*)
60 }
61 }
62 }
63 });
64
65 let expanded = if methods.iter().any(|method| method.sig.asyncness.is_some()) {
66 quote! {
67 #trait_def
68
69 #[async_trait::async_trait]
70 impl<T: #trait_name + Send + Sync> #trait_name for std::sync::Arc<T> {
71 #(#impls)*
72 }
73 }
74 } else {
75 quote! {
76 #trait_def
77
78 impl<T: #trait_name> #trait_name for std::sync::Arc<T> {
79 #(#impls)*
80 }
81 }
82 };
83
84 TokenStream::from(expanded)
86}