1use proc_macro::TokenStream;
20use proc_macro2::Span;
21use quote::quote;
22use syn::Meta;
23
24#[proc_macro_derive(MultiSenderFrom)]
29pub fn derive_multi_sender_from(input: TokenStream) -> TokenStream {
30 derive_multi_sender_from_impl(input.into()).into()
31}
32
33fn derive_multi_sender_from_impl(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
34 let ast: syn::DeriveInput = syn::parse2(input).unwrap();
35 let struct_name = ast.ident.clone();
36 let input = match ast.data {
37 syn::Data::Struct(input) => input,
38 _ => {
39 panic!("MultiSenderFrom can only be derived for structs");
40 }
41 };
42
43 let mut type_bounds = Vec::new();
44 let mut initializers = Vec::new();
45 let mut cfg_attrs = Vec::new();
46 let mut names = Vec::<syn::Ident>::new();
47 for (i, field) in input.fields.into_iter().enumerate() {
48 let field_name = field
49 .ident
50 .as_ref()
51 .map(|ident| ident.to_string())
52 .unwrap_or_else(|| format!("#{}", i));
53 cfg_attrs.push(extract_cfg_attributes(&field.attrs));
54 match &field.ty {
55 syn::Type::Path(path) => {
56 let last_segment = path.path.segments.last().unwrap();
57 let arguments = match last_segment.arguments.clone() {
58 syn::PathArguments::AngleBracketed(arguments) => {
59 arguments.args.into_iter().collect::<Vec<_>>()
60 }
61 _ => panic!("Field {} must be either a Sender or an AsyncSender", field_name),
62 };
63 if last_segment.ident == "Sender" {
64 type_bounds.push(quote!(near_async::messaging::CanSend<#(#arguments),*>));
65 } else if last_segment.ident == "AsyncSender" {
66 type_bounds.push(quote!(
67 near_async::messaging::CanSend<near_async::messaging::MessageWithCallback<#(#arguments),*>>));
68 } else {
69 panic!("Field {} must be either a Sender or an AsyncSender", field_name);
70 }
71 initializers.push(quote!(near_async::messaging::IntoSender::as_sender(&input)));
72 if let Some(name) = &field.ident {
73 names.push(name.clone());
74 }
75 }
76 _ => panic!("Field {} must be either a Sender or an AsyncSender", field_name),
77 }
78 }
79
80 assert!(!type_bounds.is_empty(), "Must have at least one field");
81
82 let initializer = if names.is_empty() {
83 quote!(#struct_name(#(#(#cfg_attrs)* #initializers,)*))
84 } else {
85 quote!(#struct_name {
86 #(#(#cfg_attrs)* #names: #initializers,)*
87 })
88 };
89
90 quote! {
91 impl<A: #(#type_bounds)+*> near_async::messaging::MultiSenderFrom<A> for #struct_name {
92 fn multi_sender_from(input: std::sync::Arc<A>) -> Self {
93 #initializer
94 }
95 }
96 }
97}
98
99#[proc_macro_derive(MultiSend)]
103pub fn derive_multi_send(input: TokenStream) -> TokenStream {
104 derive_multi_send_impl(input.into()).into()
105}
106
107fn derive_multi_send_impl(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
108 let ast: syn::DeriveInput = syn::parse2(input).unwrap();
109 let struct_name = ast.ident.clone();
110 let input = match ast.data {
111 syn::Data::Struct(input) => input,
112 _ => {
113 panic!("MultiSend can only be derived for structs");
114 }
115 };
116
117 let mut tokens = Vec::new();
118 for (i, field) in input.fields.into_iter().enumerate() {
119 let field_name = field.ident.as_ref().map(|ident| quote!(#ident)).unwrap_or_else(|| {
120 let index = syn::Index::from(i);
121 quote!(#index)
122 });
123 let cfg_attrs = extract_cfg_attributes(&field.attrs);
124 if let syn::Type::Path(path) = &field.ty {
125 let last_segment = path.path.segments.last().unwrap();
126 let arguments = match last_segment.arguments.clone() {
127 syn::PathArguments::AngleBracketed(arguments) => {
128 arguments.args.into_iter().collect::<Vec<_>>()
129 }
130 _ => {
131 continue;
132 }
133 };
134 if last_segment.ident == "Sender" {
135 let message_type = arguments[0].clone();
136 tokens.push(quote! {
137 #(#cfg_attrs)*
138 impl near_async::messaging::CanSend<#message_type> for #struct_name {
139 fn send(&self, message: #message_type) {
140 self.#field_name.send(message);
141 }
142 }
143 });
144 } else if last_segment.ident == "AsyncSender" {
145 let message_type = arguments[0].clone();
146 let result_type = arguments[1].clone();
147 let outer_msg_type =
148 quote!(near_async::messaging::MessageWithCallback<#message_type, #result_type>);
149 tokens.push(quote! {
150 #(#cfg_attrs)*
151 impl near_async::messaging::CanSend<#outer_msg_type> for #struct_name {
152 fn send(&self, message: #outer_msg_type) {
153 self.#field_name.send(message);
154 }
155 }
156 });
157 }
158 }
159 }
160
161 quote! {#(#tokens)*}
162}
163
164fn extract_cfg_attributes(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
165 attrs.iter().filter(|attr| attr.path().is_ident("cfg")).cloned().collect()
166}
167
168#[proc_macro_derive(
180 MultiSendMessage,
181 attributes(multi_send_message_derive, multi_send_input_derive)
182)]
183pub fn derive_multi_send_message(input: TokenStream) -> TokenStream {
184 derive_multi_send_message_impl(input.into()).into()
185}
186
187fn derive_multi_send_message_impl(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
188 let ast: syn::DeriveInput = syn::parse2(input).unwrap();
189 let struct_name = ast.ident.clone();
190 let message_enum_name = syn::Ident::new(&format!("{}Message", struct_name), Span::call_site());
191 let input_enum_name = syn::Ident::new(&format!("{}Input", struct_name), Span::call_site());
192 let input = match ast.data {
193 syn::Data::Struct(input) => input,
194 _ => {
195 panic!("MultiSendMessage can only be derived for structs");
196 }
197 };
198
199 let mut field_names = Vec::new();
200 let mut message_types = Vec::new();
201 let mut input_types = Vec::new();
202 let mut discriminator_names = Vec::new();
203 let mut input_extractors = Vec::new();
204 for (i, field) in input.fields.into_iter().enumerate() {
205 let field_name = field.ident.as_ref().map(|ident| quote!(#ident)).unwrap_or_else(|| {
206 let index = syn::Index::from(i);
207 quote!(#index)
208 });
209 field_names.push(field_name.clone());
210 discriminator_names.push(syn::Ident::new(&format!("_{}", field_name), Span::call_site()));
211 if let syn::Type::Path(path) = &field.ty {
212 let last_segment = path.path.segments.last().unwrap();
213 let arguments = match last_segment.arguments.clone() {
214 syn::PathArguments::AngleBracketed(arguments) => {
215 arguments.args.into_iter().collect::<Vec<_>>()
216 }
217 _ => {
218 continue;
219 }
220 };
221 if last_segment.ident == "Sender" {
222 let message_type = arguments[0].clone();
223 message_types.push(quote!(#message_type));
224 input_types.push(quote!(#message_type));
225 input_extractors.push(quote!(msg));
226 } else if last_segment.ident == "AsyncSender" {
227 let message_type = arguments[0].clone();
228 let result_type = arguments[1].clone();
229 message_types.push(
230 quote!(near_async::messaging::MessageWithCallback<#message_type, #result_type>),
231 );
232 input_types.push(quote!(#message_type));
233 input_extractors.push(quote!(msg.message));
234 }
235 }
236 }
237
238 let mut message_derives = proc_macro2::TokenStream::new();
239 let mut input_derives = proc_macro2::TokenStream::new();
240 for attr in ast.attrs {
241 if attr.path().is_ident("multi_send_message_derive") {
242 let Meta::List(metalist) = attr.meta else {
243 panic!("multi_send_message_derive must be a list");
244 };
245 message_derives = metalist.tokens;
246 } else if attr.path().is_ident("multi_send_input_derive") {
247 let Meta::List(metalist) = attr.meta else {
248 panic!("multi_send_input_derive must be a list");
249 };
250 input_derives = metalist.tokens;
251 }
252 }
253
254 quote! {
255 #[derive(#message_derives)]
256 pub enum #message_enum_name {
257 #(#discriminator_names(#message_types),)*
258 }
259
260 #[derive(#input_derives)]
261 pub enum #input_enum_name {
262 #(#discriminator_names(#input_types),)*
263 }
264
265 impl near_async::messaging::CanSend<#message_enum_name> for #struct_name {
266 fn send(&self, message: #message_enum_name) {
267 match message {
268 #(#message_enum_name::#discriminator_names(message) => self.#field_names.send(message),)*
269 }
270 }
271 }
272
273 #(impl From<#message_types> for #message_enum_name {
274 fn from(message: #message_types) -> Self {
275 #message_enum_name::#discriminator_names(message)
276 }
277 })*
278
279 impl #message_enum_name {
280 pub fn into_input(self) -> #input_enum_name {
281 match self {
282 #(Self::#discriminator_names(msg) => #input_enum_name::#discriminator_names(#input_extractors),)*
283 }
284 }
285 }
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use quote::quote;
292
293 #[test]
294 fn test_derive_into_multi_send() {
295 let input = quote! {
296 struct TestSenders {
297 sender: Sender<String>,
298 async_sender: AsyncSender<String, u32>,
299 qualified_sender: near_async::messaging::Sender<i32>,
300 qualified_async_sender: near_async::messaging::AsyncSender<i32, String>,
301 }
302 };
303 let expected = quote! {
304 impl<A:
305 near_async::messaging::CanSend<String>
306 + near_async::messaging::CanSend<near_async::messaging::MessageWithCallback<String, u32>>
307 + near_async::messaging::CanSend<i32>
308 + near_async::messaging::CanSend<near_async::messaging::MessageWithCallback<i32, String>>
309 > near_async::messaging::MultiSenderFrom<A> for TestSenders {
310 fn multi_sender_from(input: std::sync::Arc<A>) -> Self {
311 TestSenders {
312 sender: near_async::messaging::IntoSender::as_sender(&input),
313 async_sender: near_async::messaging::IntoSender::as_sender(&input),
314 qualified_sender: near_async::messaging::IntoSender::as_sender(&input),
315 qualified_async_sender: near_async::messaging::IntoSender::as_sender(&input),
316 }
317 }
318 }
319 };
320 let actual = super::derive_multi_sender_from_impl(input);
321 pretty_assertions::assert_str_eq!(actual.to_string(), expected.to_string());
322 }
323
324 #[test]
325 fn test_derive_multi_send() {
326 let input = quote! {
327 struct TestSenders {
328 sender: Sender<String>,
329 async_sender: AsyncSender<String, u32>,
330 qualified_sender: near_async::messaging::Sender<i32>,
331 qualified_async_sender: near_async::messaging::AsyncSender<i32, String>,
332 }
333 };
334 let expected = quote! {
335 impl near_async::messaging::CanSend<String> for TestSenders {
336 fn send(&self, message: String) {
337 self.sender.send(message);
338 }
339 }
340 impl near_async::messaging::CanSend<near_async::messaging::MessageWithCallback<String, u32> > for TestSenders {
341 fn send(&self, message: near_async::messaging::MessageWithCallback<String, u32>) {
342 self.async_sender.send(message);
343 }
344 }
345 impl near_async::messaging::CanSend<i32> for TestSenders {
346 fn send(&self, message: i32) {
347 self.qualified_sender.send(message);
348 }
349 }
350 impl near_async::messaging::CanSend<near_async::messaging::MessageWithCallback<i32, String> > for TestSenders {
351 fn send(&self, message: near_async::messaging::MessageWithCallback<i32, String>) {
352 self.qualified_async_sender.send(message);
353 }
354 }
355 };
356 let actual = super::derive_multi_send_impl(input);
357 pretty_assertions::assert_str_eq!(actual.to_string(), expected.to_string());
358 }
359
360 #[test]
361 fn test_derive_multi_send_message() {
362 let input = quote! {
363 #[multi_send_message_derive(X, Y)]
364 #[multi_send_input_derive(Z, W)]
365 struct TestSenders {
366 sender: Sender<A>,
367 async_sender: AsyncSender<B, C>,
368 qualified_sender: near_async::messaging::Sender<D>,
369 qualified_async_sender: near_async::messaging::AsyncSender<E, F>,
370 }
371 };
372
373 let expected = quote! {
374 #[derive(X, Y)]
375 pub enum TestSendersMessage {
376 _sender(A),
377 _async_sender(near_async::messaging::MessageWithCallback<B, C>),
378 _qualified_sender(D),
379 _qualified_async_sender(near_async::messaging::MessageWithCallback<E, F>),
380 }
381
382 #[derive(Z, W)]
383 pub enum TestSendersInput {
384 _sender(A),
385 _async_sender(B),
386 _qualified_sender(D),
387 _qualified_async_sender(E),
388 }
389
390 impl near_async::messaging::CanSend<TestSendersMessage> for TestSenders {
391 fn send(&self, message: TestSendersMessage) {
392 match message {
393 TestSendersMessage::_sender(message) => self.sender.send(message),
394 TestSendersMessage::_async_sender(message) => self.async_sender.send(message),
395 TestSendersMessage::_qualified_sender(message) => self.qualified_sender.send(message),
396 TestSendersMessage::_qualified_async_sender(message) => self.qualified_async_sender.send(message),
397 }
398 }
399 }
400
401 impl From<A> for TestSendersMessage {
402 fn from(message: A) -> Self {
403 TestSendersMessage::_sender(message)
404 }
405 }
406
407 impl From<near_async::messaging::MessageWithCallback<B, C> > for TestSendersMessage {
408 fn from(message: near_async::messaging::MessageWithCallback<B, C>) -> Self {
409 TestSendersMessage::_async_sender(message)
410 }
411 }
412
413 impl From<D> for TestSendersMessage {
414 fn from(message: D) -> Self {
415 TestSendersMessage::_qualified_sender(message)
416 }
417 }
418
419 impl From<near_async::messaging::MessageWithCallback<E, F> > for TestSendersMessage {
420 fn from(message: near_async::messaging::MessageWithCallback<E, F>) -> Self {
421 TestSendersMessage::_qualified_async_sender(message)
422 }
423 }
424
425 impl TestSendersMessage {
426 pub fn into_input(self) -> TestSendersInput {
427 match self {
428 Self::_sender(msg) => TestSendersInput::_sender(msg),
429 Self::_async_sender(msg) => TestSendersInput::_async_sender(msg.message),
430 Self::_qualified_sender(msg) => TestSendersInput::_qualified_sender(msg),
431 Self::_qualified_async_sender(msg) => TestSendersInput::_qualified_async_sender(msg.message),
432 }
433 }
434 }
435 };
436 let actual = super::derive_multi_send_message_impl(input);
437 pretty_assertions::assert_str_eq!(actual.to_string(), expected.to_string());
438 }
439}