near_async_derive/
lib.rs

1//! This crate provides a set of procedural macros for deriving traits for
2//! multi-sender types, which are structs that contain multiple `Sender` and
3//! `AsyncSender` fields. The structs can either have named fields or be a
4//! tuple struct.
5//!
6//! The derive macros provided by this crate allows multi-senders to be
7//! created from anything that behaves like all of the individual senders,
8//! and allows the multi-sender to be used like any of the individual senders.
9//! This can be very useful when one component needs to send multiple kinds of
10//! messages to another component; for example the networking layer needs to
11//! send multiple types of messages to the ClientActor, each expecting a
12//! different response; it would be very cumbersome to have to construct the
13//! PeerManagerActor by passing in 10 different sender objects, so instead we
14//! create a multi-sender interface of all the senders and pass that in instead.
15//!
16//! To better understand these macros,
17//!  - Look at the tests in this crate for examples of what the macros generate.
18//!  - Search for usages of the derive macros in the codebase.
19use proc_macro::TokenStream;
20use proc_macro2::Span;
21use quote::quote;
22use syn::Meta;
23
24/// Derives the ability to convert an object into this struct of Sender and
25/// AsyncSenders, as long as the object can be converted into each individual
26/// Sender or AsyncSender.
27/// The conversion is done by calling `.as_multi_sender()` or `.into_multi_sender()`.
28#[proc_macro_derive(MultiSenderFrom)]
29pub fn derive_multi_sender_from(input: TokenStream) -> TokenStream {
30    derive_multi_sender_from_impl(input.into()).into()
31}
32
33fn derive_multi_sender_from_impl(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
34    let ast: syn::DeriveInput = syn::parse2(input).unwrap();
35    let struct_name = ast.ident.clone();
36    let input = match ast.data {
37        syn::Data::Struct(input) => input,
38        _ => {
39            panic!("MultiSenderFrom can only be derived for structs");
40        }
41    };
42
43    let mut type_bounds = Vec::new();
44    let mut initializers = Vec::new();
45    let mut cfg_attrs = Vec::new();
46    let mut names = Vec::<syn::Ident>::new();
47    for (i, field) in input.fields.into_iter().enumerate() {
48        let field_name = field
49            .ident
50            .as_ref()
51            .map(|ident| ident.to_string())
52            .unwrap_or_else(|| format!("#{}", i));
53        cfg_attrs.push(extract_cfg_attributes(&field.attrs));
54        match &field.ty {
55            syn::Type::Path(path) => {
56                let last_segment = path.path.segments.last().unwrap();
57                let arguments = match last_segment.arguments.clone() {
58                    syn::PathArguments::AngleBracketed(arguments) => {
59                        arguments.args.into_iter().collect::<Vec<_>>()
60                    }
61                    _ => panic!("Field {} must be either a Sender or an AsyncSender", field_name),
62                };
63                if last_segment.ident == "Sender" {
64                    type_bounds.push(quote!(near_async::messaging::CanSend<#(#arguments),*>));
65                } else if last_segment.ident == "AsyncSender" {
66                    type_bounds.push(quote!(
67                            near_async::messaging::CanSend<near_async::messaging::MessageWithCallback<#(#arguments),*>>));
68                } else {
69                    panic!("Field {} must be either a Sender or an AsyncSender", field_name);
70                }
71                initializers.push(quote!(near_async::messaging::IntoSender::as_sender(&input)));
72                if let Some(name) = &field.ident {
73                    names.push(name.clone());
74                }
75            }
76            _ => panic!("Field {} must be either a Sender or an AsyncSender", field_name),
77        }
78    }
79
80    assert!(!type_bounds.is_empty(), "Must have at least one field");
81
82    let initializer = if names.is_empty() {
83        quote!(#struct_name(#(#(#cfg_attrs)* #initializers,)*))
84    } else {
85        quote!(#struct_name {
86            #(#(#cfg_attrs)* #names: #initializers,)*
87        })
88    };
89
90    quote! {
91        impl<A: #(#type_bounds)+*> near_async::messaging::MultiSenderFrom<A> for #struct_name {
92            fn multi_sender_from(input: std::sync::Arc<A>) -> Self {
93                #initializer
94            }
95        }
96    }
97}
98
99/// Derives the ability to use this struct of `Sender`s and `AsyncSender`s to
100/// call `.send` or `.send_async` directly as if using one of the included
101/// `Sender`s or `AsyncSender`s.
102#[proc_macro_derive(MultiSend)]
103pub fn derive_multi_send(input: TokenStream) -> TokenStream {
104    derive_multi_send_impl(input.into()).into()
105}
106
107fn derive_multi_send_impl(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
108    let ast: syn::DeriveInput = syn::parse2(input).unwrap();
109    let struct_name = ast.ident.clone();
110    let input = match ast.data {
111        syn::Data::Struct(input) => input,
112        _ => {
113            panic!("MultiSend can only be derived for structs");
114        }
115    };
116
117    let mut tokens = Vec::new();
118    for (i, field) in input.fields.into_iter().enumerate() {
119        let field_name = field.ident.as_ref().map(|ident| quote!(#ident)).unwrap_or_else(|| {
120            let index = syn::Index::from(i);
121            quote!(#index)
122        });
123        let cfg_attrs = extract_cfg_attributes(&field.attrs);
124        if let syn::Type::Path(path) = &field.ty {
125            let last_segment = path.path.segments.last().unwrap();
126            let arguments = match last_segment.arguments.clone() {
127                syn::PathArguments::AngleBracketed(arguments) => {
128                    arguments.args.into_iter().collect::<Vec<_>>()
129                }
130                _ => {
131                    continue;
132                }
133            };
134            if last_segment.ident == "Sender" {
135                let message_type = arguments[0].clone();
136                tokens.push(quote! {
137                    #(#cfg_attrs)*
138                    impl near_async::messaging::CanSend<#message_type> for #struct_name {
139                        fn send(&self, message: #message_type) {
140                            self.#field_name.send(message);
141                        }
142                    }
143                });
144            } else if last_segment.ident == "AsyncSender" {
145                let message_type = arguments[0].clone();
146                let result_type = arguments[1].clone();
147                let outer_msg_type =
148                    quote!(near_async::messaging::MessageWithCallback<#message_type, #result_type>);
149                tokens.push(quote! {
150                    #(#cfg_attrs)*
151                    impl near_async::messaging::CanSend<#outer_msg_type> for #struct_name {
152                        fn send(&self, message: #outer_msg_type) {
153                            self.#field_name.send(message);
154                        }
155                    }
156                });
157            }
158        }
159    }
160
161    quote! {#(#tokens)*}
162}
163
164fn extract_cfg_attributes(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
165    attrs.iter().filter(|attr| attr.path().is_ident("cfg")).cloned().collect()
166}
167
168/// Derives two enums, whose names are based on this struct by appending
169/// `Message` and `Input`. Each enum has a case for each `Sender` or
170/// `AsyncSender` in this struct. The `Message` enum contains the raw message
171/// being sent, which is `X` for `Sender<X>` and `MessageWithCallback<X, Y>`
172/// for `AsyncSender<X, Y>`. The `Input` enum contains the same for `Sender` but
173/// only the input, `X` for `AsyncSender<X, Y>`.
174///
175/// Additionally, this struct can then be used to `.send` using the derived
176/// `Message` enum. This is useful for packaging a multi-sender as a singular
177/// `Sender` that can then be embedded into another multi-sender. The `Input`
178/// enum is useful for capturing messages for testing purposes.
179#[proc_macro_derive(
180    MultiSendMessage,
181    attributes(multi_send_message_derive, multi_send_input_derive)
182)]
183pub fn derive_multi_send_message(input: TokenStream) -> TokenStream {
184    derive_multi_send_message_impl(input.into()).into()
185}
186
187fn derive_multi_send_message_impl(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
188    let ast: syn::DeriveInput = syn::parse2(input).unwrap();
189    let struct_name = ast.ident.clone();
190    let message_enum_name = syn::Ident::new(&format!("{}Message", struct_name), Span::call_site());
191    let input_enum_name = syn::Ident::new(&format!("{}Input", struct_name), Span::call_site());
192    let input = match ast.data {
193        syn::Data::Struct(input) => input,
194        _ => {
195            panic!("MultiSendMessage can only be derived for structs");
196        }
197    };
198
199    let mut field_names = Vec::new();
200    let mut message_types = Vec::new();
201    let mut input_types = Vec::new();
202    let mut discriminator_names = Vec::new();
203    let mut input_extractors = Vec::new();
204    for (i, field) in input.fields.into_iter().enumerate() {
205        let field_name = field.ident.as_ref().map(|ident| quote!(#ident)).unwrap_or_else(|| {
206            let index = syn::Index::from(i);
207            quote!(#index)
208        });
209        field_names.push(field_name.clone());
210        discriminator_names.push(syn::Ident::new(&format!("_{}", field_name), Span::call_site()));
211        if let syn::Type::Path(path) = &field.ty {
212            let last_segment = path.path.segments.last().unwrap();
213            let arguments = match last_segment.arguments.clone() {
214                syn::PathArguments::AngleBracketed(arguments) => {
215                    arguments.args.into_iter().collect::<Vec<_>>()
216                }
217                _ => {
218                    continue;
219                }
220            };
221            if last_segment.ident == "Sender" {
222                let message_type = arguments[0].clone();
223                message_types.push(quote!(#message_type));
224                input_types.push(quote!(#message_type));
225                input_extractors.push(quote!(msg));
226            } else if last_segment.ident == "AsyncSender" {
227                let message_type = arguments[0].clone();
228                let result_type = arguments[1].clone();
229                message_types.push(
230                    quote!(near_async::messaging::MessageWithCallback<#message_type, #result_type>),
231                );
232                input_types.push(quote!(#message_type));
233                input_extractors.push(quote!(msg.message));
234            }
235        }
236    }
237
238    let mut message_derives = proc_macro2::TokenStream::new();
239    let mut input_derives = proc_macro2::TokenStream::new();
240    for attr in ast.attrs {
241        if attr.path().is_ident("multi_send_message_derive") {
242            let Meta::List(metalist) = attr.meta else {
243                panic!("multi_send_message_derive must be a list");
244            };
245            message_derives = metalist.tokens;
246        } else if attr.path().is_ident("multi_send_input_derive") {
247            let Meta::List(metalist) = attr.meta else {
248                panic!("multi_send_input_derive must be a list");
249            };
250            input_derives = metalist.tokens;
251        }
252    }
253
254    quote! {
255        #[derive(#message_derives)]
256        pub enum #message_enum_name {
257            #(#discriminator_names(#message_types),)*
258        }
259
260        #[derive(#input_derives)]
261        pub enum #input_enum_name {
262            #(#discriminator_names(#input_types),)*
263        }
264
265        impl near_async::messaging::CanSend<#message_enum_name> for #struct_name {
266            fn send(&self, message: #message_enum_name) {
267                match message {
268                    #(#message_enum_name::#discriminator_names(message) => self.#field_names.send(message),)*
269                }
270            }
271        }
272
273        #(impl From<#message_types> for #message_enum_name {
274            fn from(message: #message_types) -> Self {
275                #message_enum_name::#discriminator_names(message)
276            }
277        })*
278
279        impl #message_enum_name {
280            pub fn into_input(self) -> #input_enum_name {
281                match self {
282                    #(Self::#discriminator_names(msg) => #input_enum_name::#discriminator_names(#input_extractors),)*
283                }
284            }
285        }
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use quote::quote;
292
293    #[test]
294    fn test_derive_into_multi_send() {
295        let input = quote! {
296            struct TestSenders {
297                sender: Sender<String>,
298                async_sender: AsyncSender<String, u32>,
299                qualified_sender: near_async::messaging::Sender<i32>,
300                qualified_async_sender: near_async::messaging::AsyncSender<i32, String>,
301            }
302        };
303        let expected = quote! {
304            impl<A:
305                near_async::messaging::CanSend<String>
306                + near_async::messaging::CanSend<near_async::messaging::MessageWithCallback<String, u32>>
307                + near_async::messaging::CanSend<i32>
308                + near_async::messaging::CanSend<near_async::messaging::MessageWithCallback<i32, String>>
309                > near_async::messaging::MultiSenderFrom<A> for TestSenders {
310                fn multi_sender_from(input: std::sync::Arc<A>) -> Self {
311                    TestSenders {
312                        sender: near_async::messaging::IntoSender::as_sender(&input),
313                        async_sender: near_async::messaging::IntoSender::as_sender(&input),
314                        qualified_sender: near_async::messaging::IntoSender::as_sender(&input),
315                        qualified_async_sender: near_async::messaging::IntoSender::as_sender(&input),
316                    }
317                }
318            }
319        };
320        let actual = super::derive_multi_sender_from_impl(input);
321        pretty_assertions::assert_str_eq!(actual.to_string(), expected.to_string());
322    }
323
324    #[test]
325    fn test_derive_multi_send() {
326        let input = quote! {
327            struct TestSenders {
328                sender: Sender<String>,
329                async_sender: AsyncSender<String, u32>,
330                qualified_sender: near_async::messaging::Sender<i32>,
331                qualified_async_sender: near_async::messaging::AsyncSender<i32, String>,
332            }
333        };
334        let expected = quote! {
335            impl near_async::messaging::CanSend<String> for TestSenders {
336                fn send(&self, message: String) {
337                    self.sender.send(message);
338                }
339            }
340            impl near_async::messaging::CanSend<near_async::messaging::MessageWithCallback<String, u32> > for TestSenders {
341                fn send(&self, message: near_async::messaging::MessageWithCallback<String, u32>) {
342                    self.async_sender.send(message);
343                }
344            }
345            impl near_async::messaging::CanSend<i32> for TestSenders {
346                fn send(&self, message: i32) {
347                    self.qualified_sender.send(message);
348                }
349            }
350            impl near_async::messaging::CanSend<near_async::messaging::MessageWithCallback<i32, String> > for TestSenders {
351                fn send(&self, message: near_async::messaging::MessageWithCallback<i32, String>) {
352                    self.qualified_async_sender.send(message);
353                }
354            }
355        };
356        let actual = super::derive_multi_send_impl(input);
357        pretty_assertions::assert_str_eq!(actual.to_string(), expected.to_string());
358    }
359
360    #[test]
361    fn test_derive_multi_send_message() {
362        let input = quote! {
363            #[multi_send_message_derive(X, Y)]
364            #[multi_send_input_derive(Z, W)]
365            struct TestSenders {
366                sender: Sender<A>,
367                async_sender: AsyncSender<B, C>,
368                qualified_sender: near_async::messaging::Sender<D>,
369                qualified_async_sender: near_async::messaging::AsyncSender<E, F>,
370            }
371        };
372
373        let expected = quote! {
374            #[derive(X, Y)]
375            pub enum TestSendersMessage {
376                _sender(A),
377                _async_sender(near_async::messaging::MessageWithCallback<B, C>),
378                _qualified_sender(D),
379                _qualified_async_sender(near_async::messaging::MessageWithCallback<E, F>),
380            }
381
382            #[derive(Z, W)]
383            pub enum TestSendersInput {
384                _sender(A),
385                _async_sender(B),
386                _qualified_sender(D),
387                _qualified_async_sender(E),
388            }
389
390            impl near_async::messaging::CanSend<TestSendersMessage> for TestSenders {
391                fn send(&self, message: TestSendersMessage) {
392                    match message {
393                        TestSendersMessage::_sender(message) => self.sender.send(message),
394                        TestSendersMessage::_async_sender(message) => self.async_sender.send(message),
395                        TestSendersMessage::_qualified_sender(message) => self.qualified_sender.send(message),
396                        TestSendersMessage::_qualified_async_sender(message) => self.qualified_async_sender.send(message),
397                    }
398                }
399            }
400
401            impl From<A> for TestSendersMessage {
402                fn from(message: A) -> Self {
403                    TestSendersMessage::_sender(message)
404                }
405            }
406
407            impl From<near_async::messaging::MessageWithCallback<B, C> > for TestSendersMessage {
408                fn from(message: near_async::messaging::MessageWithCallback<B, C>) -> Self {
409                    TestSendersMessage::_async_sender(message)
410                }
411            }
412
413            impl From<D> for TestSendersMessage {
414                fn from(message: D) -> Self {
415                    TestSendersMessage::_qualified_sender(message)
416                }
417            }
418
419            impl From<near_async::messaging::MessageWithCallback<E, F> > for TestSendersMessage {
420                fn from(message: near_async::messaging::MessageWithCallback<E, F>) -> Self {
421                    TestSendersMessage::_qualified_async_sender(message)
422                }
423            }
424
425            impl TestSendersMessage {
426                pub fn into_input(self) -> TestSendersInput {
427                    match self {
428                        Self::_sender(msg) => TestSendersInput::_sender(msg),
429                        Self::_async_sender(msg) => TestSendersInput::_async_sender(msg.message),
430                        Self::_qualified_sender(msg) => TestSendersInput::_qualified_sender(msg),
431                        Self::_qualified_async_sender(msg) => TestSendersInput::_qualified_async_sender(msg.message),
432                    }
433                }
434            }
435        };
436        let actual = super::derive_multi_send_message_impl(input);
437        pretty_assertions::assert_str_eq!(actual.to_string(), expected.to_string());
438    }
439}