1use proc_macro::TokenStream;
20use quote::quote;
21
22#[proc_macro_derive(MultiSenderFrom)]
27pub fn derive_multi_sender_from(input: TokenStream) -> TokenStream {
28 derive_multi_sender_from_impl(input.into()).into()
29}
30
31fn derive_multi_sender_from_impl(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
32 let ast: syn::DeriveInput = syn::parse2(input).unwrap();
33 let struct_name = ast.ident.clone();
34 let input = match ast.data {
35 syn::Data::Struct(input) => input,
36 _ => {
37 panic!("MultiSenderFrom can only be derived for structs");
38 }
39 };
40
41 let mut type_bounds = Vec::new();
42 let mut initializers = Vec::new();
43 let mut cfg_attrs = Vec::new();
44 let mut names = Vec::<syn::Ident>::new();
45 for (i, field) in input.fields.into_iter().enumerate() {
46 let field_name = field
47 .ident
48 .as_ref()
49 .map(|ident| ident.to_string())
50 .unwrap_or_else(|| format!("#{}", i));
51 cfg_attrs.push(extract_cfg_attributes(&field.attrs));
52 match &field.ty {
53 syn::Type::Path(path) => {
54 let last_segment = path.path.segments.last().unwrap();
55 let arguments = match last_segment.arguments.clone() {
56 syn::PathArguments::AngleBracketed(arguments) => {
57 arguments.args.into_iter().collect::<Vec<_>>()
58 }
59 _ => panic!("Field {} must be either a Sender or an AsyncSender", field_name),
60 };
61 if last_segment.ident == "Sender" {
62 type_bounds.push(quote!(near_async::messaging::CanSend<#(#arguments),*>));
63 initializers.push(quote!(near_async::messaging::IntoSender::as_sender(&input)));
64 } else if last_segment.ident == "AsyncSender" {
65 type_bounds.push(quote!(
66 near_async::messaging::CanSendAsync<#(#arguments),*>
67 ));
68 initializers.push(quote!(
69 near_async::messaging::IntoAsyncSender::as_async_sender(&input)
70 ));
71 } else {
72 panic!("Field {} must be either a Sender or an AsyncSender", field_name);
73 }
74 if let Some(name) = &field.ident {
75 names.push(name.clone());
76 }
77 }
78 _ => panic!("Field {} must be either a Sender or an AsyncSender", field_name),
79 }
80 }
81
82 assert!(!type_bounds.is_empty(), "Must have at least one field");
83
84 let initializer = if names.is_empty() {
85 quote!(#struct_name(#(#(#cfg_attrs)* #initializers,)*))
86 } else {
87 quote!(#struct_name {
88 #(#(#cfg_attrs)* #names: #initializers,)*
89 })
90 };
91
92 quote! {
93 impl<A: #(#type_bounds)+*> near_async::messaging::MultiSenderFrom<A> for #struct_name {
94 fn multi_sender_from(input: std::sync::Arc<A>) -> Self {
95 #initializer
96 }
97 }
98 }
99}
100
101#[proc_macro_derive(MultiSend)]
105pub fn derive_multi_send(input: TokenStream) -> TokenStream {
106 derive_multi_send_impl(input.into()).into()
107}
108
109fn derive_multi_send_impl(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
110 let ast: syn::DeriveInput = syn::parse2(input).unwrap();
111 let struct_name = ast.ident.clone();
112 let input = match ast.data {
113 syn::Data::Struct(input) => input,
114 _ => {
115 panic!("MultiSend can only be derived for structs");
116 }
117 };
118
119 let mut tokens = Vec::new();
120 for (i, field) in input.fields.into_iter().enumerate() {
121 let field_name = field.ident.as_ref().map(|ident| quote!(#ident)).unwrap_or_else(|| {
122 let index = syn::Index::from(i);
123 quote!(#index)
124 });
125 let cfg_attrs = extract_cfg_attributes(&field.attrs);
126 if let syn::Type::Path(path) = &field.ty {
127 let last_segment = path.path.segments.last().unwrap();
128 let arguments = match last_segment.arguments.clone() {
129 syn::PathArguments::AngleBracketed(arguments) => {
130 arguments.args.into_iter().collect::<Vec<_>>()
131 }
132 _ => {
133 continue;
134 }
135 };
136 if last_segment.ident == "Sender" {
137 let message_type = arguments[0].clone();
138 tokens.push(quote! {
139 #(#cfg_attrs)*
140 impl near_async::messaging::CanSend<#message_type> for #struct_name {
141 fn send(&self, message: #message_type) {
142 self.#field_name.send(message);
143 }
144 }
145 });
146 } else if last_segment.ident == "AsyncSender" {
147 let message_type = arguments[0].clone();
148 let result_type = arguments[1].clone();
149 tokens.push(quote! {
150 #(#cfg_attrs)*
151 impl near_async::messaging::CanSendAsync<#message_type, #result_type> for #struct_name {
152 fn send_async(&self, message: #message_type)
153 -> near_async::futures::BoxFuture<'static, Result<#result_type, near_async::messaging::AsyncSendError>>
154 {
155 self.#field_name.send_async(message)
156 }
157 }
158 });
159 }
160 }
161 }
162
163 quote! {#(#tokens)*}
164}
165
166fn extract_cfg_attributes(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
167 attrs.iter().filter(|attr| attr.path().is_ident("cfg")).cloned().collect()
168}
169
170#[cfg(test)]
171mod tests {
172 use quote::quote;
173
174 #[test]
175 fn test_derive_into_multi_send() {
176 let input = quote! {
177 struct TestSenders {
178 sender: Sender<String>,
179 async_sender: AsyncSender<String, u32>,
180 qualified_sender: near_async::messaging::Sender<i32>,
181 qualified_async_sender: near_async::messaging::AsyncSender<i32, String>,
182 }
183 };
184 let expected = quote! {
185 impl<A:
186 near_async::messaging::CanSend<String>
187 + near_async::messaging::CanSendAsync<String, u32>
188 + near_async::messaging::CanSend<i32>
189 + near_async::messaging::CanSendAsync<i32, String>
190 > near_async::messaging::MultiSenderFrom<A> for TestSenders {
191 fn multi_sender_from(input: std::sync::Arc<A>) -> Self {
192 TestSenders {
193 sender: near_async::messaging::IntoSender::as_sender(&input),
194 async_sender: near_async::messaging::IntoAsyncSender::as_async_sender(&input),
195 qualified_sender: near_async::messaging::IntoSender::as_sender(&input),
196 qualified_async_sender: near_async::messaging::IntoAsyncSender::as_async_sender(&input),
197 }
198 }
199 }
200 };
201 let actual = super::derive_multi_sender_from_impl(input);
202 pretty_assertions::assert_str_eq!(actual.to_string(), expected.to_string());
203 }
204
205 #[test]
206 fn test_derive_multi_send() {
207 let input = quote! {
208 struct TestSenders {
209 sender: Sender<String>,
210 async_sender: AsyncSender<String, u32>,
211 qualified_sender: near_async::messaging::Sender<i32>,
212 qualified_async_sender: near_async::messaging::AsyncSender<i32, String>,
213 }
214 };
215 let expected = quote! {
216 impl near_async::messaging::CanSend<String> for TestSenders {
217 fn send(&self, message: String) {
218 self.sender.send(message);
219 }
220 }
221 impl near_async::messaging::CanSendAsync<String, u32> for TestSenders {
222 fn send_async(&self, message: String) -> near_async::futures::BoxFuture<'static, Result<u32, near_async::messaging::AsyncSendError>> {
223 self.async_sender.send_async(message)
224 }
225 }
226 impl near_async::messaging::CanSend<i32> for TestSenders {
227 fn send(&self, message: i32) {
228 self.qualified_sender.send(message);
229 }
230 }
231 impl near_async::messaging::CanSendAsync<i32, String> for TestSenders {
232 fn send_async(&self, message: i32) -> near_async::futures::BoxFuture<'static, Result<String, near_async::messaging::AsyncSendError>> {
233 self.qualified_async_sender.send_async(message)
234 }
235 }
236 };
237 let actual = super::derive_multi_send_impl(input);
238 pretty_assertions::assert_str_eq!(actual.to_string(), expected.to_string());
239 }
240}