near_async_derive/
lib.rs

1//! This crate provides a set of procedural macros for deriving traits for
2//! multi-sender types, which are structs that contain multiple `Sender` and
3//! `AsyncSender` fields. The structs can either have named fields or be a
4//! tuple struct.
5//!
6//! The derive macros provided by this crate allows multi-senders to be
7//! created from anything that behaves like all of the individual senders,
8//! and allows the multi-sender to be used like any of the individual senders.
9//! This can be very useful when one component needs to send multiple kinds of
10//! messages to another component; for example the networking layer needs to
11//! send multiple types of messages to the ClientActor, each expecting a
12//! different response; it would be very cumbersome to have to construct the
13//! PeerManagerActor by passing in 10 different sender objects, so instead we
14//! create a multi-sender interface of all the senders and pass that in instead.
15//!
16//! To better understand these macros,
17//!  - Look at the tests in this crate for examples of what the macros generate.
18//!  - Search for usages of the derive macros in the codebase.
19use proc_macro::TokenStream;
20use quote::quote;
21
22/// Derives the ability to convert an object into this struct of Sender and
23/// AsyncSenders, as long as the object can be converted into each individual
24/// Sender or AsyncSender.
25/// The conversion is done by calling `.as_multi_sender()` or `.into_multi_sender()`.
26#[proc_macro_derive(MultiSenderFrom)]
27pub fn derive_multi_sender_from(input: TokenStream) -> TokenStream {
28    derive_multi_sender_from_impl(input.into()).into()
29}
30
31fn derive_multi_sender_from_impl(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
32    let ast: syn::DeriveInput = syn::parse2(input).unwrap();
33    let struct_name = ast.ident.clone();
34    let input = match ast.data {
35        syn::Data::Struct(input) => input,
36        _ => {
37            panic!("MultiSenderFrom can only be derived for structs");
38        }
39    };
40
41    let mut type_bounds = Vec::new();
42    let mut initializers = Vec::new();
43    let mut cfg_attrs = Vec::new();
44    let mut names = Vec::<syn::Ident>::new();
45    for (i, field) in input.fields.into_iter().enumerate() {
46        let field_name = field
47            .ident
48            .as_ref()
49            .map(|ident| ident.to_string())
50            .unwrap_or_else(|| format!("#{}", i));
51        cfg_attrs.push(extract_cfg_attributes(&field.attrs));
52        match &field.ty {
53            syn::Type::Path(path) => {
54                let last_segment = path.path.segments.last().unwrap();
55                let arguments = match last_segment.arguments.clone() {
56                    syn::PathArguments::AngleBracketed(arguments) => {
57                        arguments.args.into_iter().collect::<Vec<_>>()
58                    }
59                    _ => panic!("Field {} must be either a Sender or an AsyncSender", field_name),
60                };
61                if last_segment.ident == "Sender" {
62                    type_bounds.push(quote!(near_async::messaging::CanSend<#(#arguments),*>));
63                    initializers.push(quote!(near_async::messaging::IntoSender::as_sender(&input)));
64                } else if last_segment.ident == "AsyncSender" {
65                    type_bounds.push(quote!(
66                        near_async::messaging::CanSendAsync<#(#arguments),*>
67                    ));
68                    initializers.push(quote!(
69                        near_async::messaging::IntoAsyncSender::as_async_sender(&input)
70                    ));
71                } else {
72                    panic!("Field {} must be either a Sender or an AsyncSender", field_name);
73                }
74                if let Some(name) = &field.ident {
75                    names.push(name.clone());
76                }
77            }
78            _ => panic!("Field {} must be either a Sender or an AsyncSender", field_name),
79        }
80    }
81
82    assert!(!type_bounds.is_empty(), "Must have at least one field");
83
84    let initializer = if names.is_empty() {
85        quote!(#struct_name(#(#(#cfg_attrs)* #initializers,)*))
86    } else {
87        quote!(#struct_name {
88            #(#(#cfg_attrs)* #names: #initializers,)*
89        })
90    };
91
92    quote! {
93        impl<A: #(#type_bounds)+*> near_async::messaging::MultiSenderFrom<A> for #struct_name {
94            fn multi_sender_from(input: std::sync::Arc<A>) -> Self {
95                #initializer
96            }
97        }
98    }
99}
100
101/// Derives the ability to use this struct of `Sender`s and `AsyncSender`s to
102/// call `.send` or `.send_async` directly as if using one of the included
103/// `Sender`s or `AsyncSender`s.
104#[proc_macro_derive(MultiSend)]
105pub fn derive_multi_send(input: TokenStream) -> TokenStream {
106    derive_multi_send_impl(input.into()).into()
107}
108
109fn derive_multi_send_impl(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
110    let ast: syn::DeriveInput = syn::parse2(input).unwrap();
111    let struct_name = ast.ident.clone();
112    let input = match ast.data {
113        syn::Data::Struct(input) => input,
114        _ => {
115            panic!("MultiSend can only be derived for structs");
116        }
117    };
118
119    let mut tokens = Vec::new();
120    for (i, field) in input.fields.into_iter().enumerate() {
121        let field_name = field.ident.as_ref().map(|ident| quote!(#ident)).unwrap_or_else(|| {
122            let index = syn::Index::from(i);
123            quote!(#index)
124        });
125        let cfg_attrs = extract_cfg_attributes(&field.attrs);
126        if let syn::Type::Path(path) = &field.ty {
127            let last_segment = path.path.segments.last().unwrap();
128            let arguments = match last_segment.arguments.clone() {
129                syn::PathArguments::AngleBracketed(arguments) => {
130                    arguments.args.into_iter().collect::<Vec<_>>()
131                }
132                _ => {
133                    continue;
134                }
135            };
136            if last_segment.ident == "Sender" {
137                let message_type = arguments[0].clone();
138                tokens.push(quote! {
139                    #(#cfg_attrs)*
140                    impl near_async::messaging::CanSend<#message_type> for #struct_name {
141                        fn send(&self, message: #message_type) {
142                            self.#field_name.send(message);
143                        }
144                    }
145                });
146            } else if last_segment.ident == "AsyncSender" {
147                let message_type = arguments[0].clone();
148                let result_type = arguments[1].clone();
149                tokens.push(quote! {
150                    #(#cfg_attrs)*
151                    impl near_async::messaging::CanSendAsync<#message_type, #result_type> for #struct_name {
152                        fn send_async(&self, message: #message_type)
153                            -> near_async::futures::BoxFuture<'static, Result<#result_type, near_async::messaging::AsyncSendError>>
154                        {
155                            self.#field_name.send_async(message)
156                        }
157                    }
158                });
159            }
160        }
161    }
162
163    quote! {#(#tokens)*}
164}
165
166fn extract_cfg_attributes(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
167    attrs.iter().filter(|attr| attr.path().is_ident("cfg")).cloned().collect()
168}
169
170#[cfg(test)]
171mod tests {
172    use quote::quote;
173
174    #[test]
175    fn test_derive_into_multi_send() {
176        let input = quote! {
177            struct TestSenders {
178                sender: Sender<String>,
179                async_sender: AsyncSender<String, u32>,
180                qualified_sender: near_async::messaging::Sender<i32>,
181                qualified_async_sender: near_async::messaging::AsyncSender<i32, String>,
182            }
183        };
184        let expected = quote! {
185            impl<A:
186                near_async::messaging::CanSend<String>
187                + near_async::messaging::CanSendAsync<String, u32>
188                + near_async::messaging::CanSend<i32>
189                + near_async::messaging::CanSendAsync<i32, String>
190                > near_async::messaging::MultiSenderFrom<A> for TestSenders {
191                fn multi_sender_from(input: std::sync::Arc<A>) -> Self {
192                    TestSenders {
193                        sender: near_async::messaging::IntoSender::as_sender(&input),
194                        async_sender: near_async::messaging::IntoAsyncSender::as_async_sender(&input),
195                        qualified_sender: near_async::messaging::IntoSender::as_sender(&input),
196                        qualified_async_sender: near_async::messaging::IntoAsyncSender::as_async_sender(&input),
197                    }
198                }
199            }
200        };
201        let actual = super::derive_multi_sender_from_impl(input);
202        pretty_assertions::assert_str_eq!(actual.to_string(), expected.to_string());
203    }
204
205    #[test]
206    fn test_derive_multi_send() {
207        let input = quote! {
208            struct TestSenders {
209                sender: Sender<String>,
210                async_sender: AsyncSender<String, u32>,
211                qualified_sender: near_async::messaging::Sender<i32>,
212                qualified_async_sender: near_async::messaging::AsyncSender<i32, String>,
213            }
214        };
215        let expected = quote! {
216            impl near_async::messaging::CanSend<String> for TestSenders {
217                fn send(&self, message: String) {
218                    self.sender.send(message);
219                }
220            }
221            impl near_async::messaging::CanSendAsync<String, u32> for TestSenders {
222                fn send_async(&self, message: String) -> near_async::futures::BoxFuture<'static, Result<u32, near_async::messaging::AsyncSendError>> {
223                    self.async_sender.send_async(message)
224                }
225            }
226            impl near_async::messaging::CanSend<i32> for TestSenders {
227                fn send(&self, message: i32) {
228                    self.qualified_sender.send(message);
229                }
230            }
231            impl near_async::messaging::CanSendAsync<i32, String> for TestSenders {
232                fn send_async(&self, message: i32) -> near_async::futures::BoxFuture<'static, Result<String, near_async::messaging::AsyncSendError>> {
233                    self.qualified_async_sender.send_async(message)
234                }
235            }
236        };
237        let actual = super::derive_multi_send_impl(input);
238        pretty_assertions::assert_str_eq!(actual.to_string(), expected.to_string());
239    }
240}