mwc_libp2p_swarm_derivee/
lib.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21#![recursion_limit = "256"]
22
23use quote::quote;
24use proc_macro::TokenStream;
25use syn::{parse_macro_input, DeriveInput, Data, DataStruct, Ident};
26
27/// Generates a delegating `NetworkBehaviour` implementation for the struct this is used for. See
28/// the trait documentation for better description.
29#[proc_macro_derive(NetworkBehaviour, attributes(behaviour))]
30pub fn hello_macro_derive(input: TokenStream) -> TokenStream {
31    let ast = parse_macro_input!(input as DeriveInput);
32    build(&ast)
33}
34
35/// The actual implementation.
36fn build(ast: &DeriveInput) -> TokenStream {
37    match ast.data {
38        Data::Struct(ref s) => build_struct(ast, s),
39        Data::Enum(_) => unimplemented!("Deriving NetworkBehaviour is not implemented for enums"),
40        Data::Union(_) => unimplemented!("Deriving NetworkBehaviour is not implemented for unions"),
41    }
42}
43
44/// The version for structs
45fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream {
46    let name = &ast.ident;
47    let (_, ty_generics, where_clause) = ast.generics.split_for_impl();
48    let multiaddr = quote!{::libp2p::core::Multiaddr};
49    let trait_to_impl = quote!{::libp2p::swarm::NetworkBehaviour};
50    let net_behv_event_proc = quote!{::libp2p::swarm::NetworkBehaviourEventProcess};
51    let either_ident = quote!{::libp2p::core::either::EitherOutput};
52    let network_behaviour_action = quote!{::libp2p::swarm::NetworkBehaviourAction};
53    let into_protocols_handler = quote!{::libp2p::swarm::IntoProtocolsHandler};
54    let protocols_handler = quote!{::libp2p::swarm::ProtocolsHandler};
55    let into_proto_select_ident = quote!{::libp2p::swarm::IntoProtocolsHandlerSelect};
56    let peer_id = quote!{::libp2p::core::PeerId};
57    let connection_id = quote!{::libp2p::core::connection::ConnectionId};
58    let connected_point = quote!{::libp2p::core::ConnectedPoint};
59    let listener_id = quote!{::libp2p::core::connection::ListenerId};
60
61    let poll_parameters = quote!{::libp2p::swarm::PollParameters};
62
63    // Build the generics.
64    let impl_generics = {
65        let tp = ast.generics.type_params();
66        let lf = ast.generics.lifetimes();
67        let cst = ast.generics.const_params();
68        quote!{<#(#lf,)* #(#tp,)* #(#cst,)*>}
69    };
70
71    // Whether or not we require the `NetworkBehaviourEventProcess` trait to be implemented.
72    let event_process = {
73        let mut event_process = true; // Default to true for backwards compatibility
74
75        for meta_items in ast.attrs.iter().filter_map(get_meta_items) {
76            for meta_item in meta_items {
77                match meta_item {
78                    syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) if m.path.is_ident("event_process") => {
79                        if let syn::Lit::Bool(ref b) = m.lit {
80                            event_process = b.value
81                        }
82                    }
83                    _ => ()
84                }
85            }
86        }
87
88        event_process
89    };
90
91    // The final out event.
92    // If we find a `#[behaviour(out_event = "Foo")]` attribute on the struct, we set `Foo` as
93    // the out event. Otherwise we use `()`.
94    let out_event = {
95        let mut out = quote!{()};
96        for meta_items in ast.attrs.iter().filter_map(get_meta_items) {
97            for meta_item in meta_items {
98                match meta_item {
99                    syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) if m.path.is_ident("out_event") => {
100                        if let syn::Lit::Str(ref s) = m.lit {
101                            let ident: syn::Type = syn::parse_str(&s.value()).unwrap();
102                            out = quote!{#ident};
103                        }
104                    }
105                    _ => ()
106                }
107            }
108        }
109        out
110    };
111
112    // Build the `where ...` clause of the trait implementation.
113    let where_clause = {
114        let additional = data_struct.fields.iter()
115            .filter(|x| !is_ignored(x))
116            .flat_map(|field| {
117                let ty = &field.ty;
118                vec![
119                    quote!{#ty: #trait_to_impl},
120                    if event_process {
121                        quote!{Self: #net_behv_event_proc<<#ty as #trait_to_impl>::OutEvent>}
122                    } else {
123                        quote!{#out_event: From< <#ty as #trait_to_impl>::OutEvent >}
124                    }
125                ]
126            })
127            .collect::<Vec<_>>();
128
129        if let Some(where_clause) = where_clause {
130            if where_clause.predicates.trailing_punct() {
131                Some(quote!{#where_clause #(#additional),*})
132            } else {
133                Some(quote!{#where_clause, #(#additional),*})
134            }
135        } else {
136            Some(quote!{where #(#additional),*})
137        }
138    };
139
140    // Build the list of statements to put in the body of `addresses_of_peer()`.
141    let addresses_of_peer_stmts = {
142        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
143            if is_ignored(&field) {
144                return None;
145            }
146
147            Some(match field.ident {
148                Some(ref i) => quote!{ out.extend(self.#i.addresses_of_peer(peer_id)); },
149                None => quote!{ out.extend(self.#field_n.addresses_of_peer(peer_id)); },
150            })
151        })
152    };
153
154    // Build the list of statements to put in the body of `inject_connected()`.
155    let inject_connected_stmts = {
156        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
157            if is_ignored(&field) {
158                return None;
159            }
160            Some(match field.ident {
161                Some(ref i) => quote!{ self.#i.inject_connected(peer_id); },
162                None => quote!{ self.#field_n.inject_connected(peer_id); },
163            })
164        })
165    };
166
167    // Build the list of statements to put in the body of `inject_disconnected()`.
168    let inject_disconnected_stmts = {
169        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
170            if is_ignored(&field) {
171                return None;
172            }
173            Some(match field.ident {
174                Some(ref i) => quote!{ self.#i.inject_disconnected(peer_id); },
175                None => quote!{ self.#field_n.inject_disconnected(peer_id); },
176            })
177        })
178    };
179
180    // Build the list of statements to put in the body of `inject_connection_established()`.
181    let inject_connection_established_stmts = {
182        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
183            if is_ignored(&field) {
184                return None;
185            }
186            Some(match field.ident {
187                Some(ref i) => quote!{ self.#i.inject_connection_established(peer_id, connection_id, endpoint); },
188                None => quote!{ self.#field_n.inject_connection_established(peer_id, connection_id, endpoint); },
189            })
190        })
191    };
192
193    // Build the list of statements to put in the body of `inject_address_change()`.
194    let inject_address_change_stmts = {
195        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
196            if is_ignored(&field) {
197                return None;
198            }
199            Some(match field.ident {
200                Some(ref i) => quote!{ self.#i.inject_address_change(peer_id, connection_id, old, new); },
201                None => quote!{ self.#field_n.inject_address_change(peer_id, connection_id, old, new); },
202            })
203        })
204    };
205
206    // Build the list of statements to put in the body of `inject_connection_closed()`.
207    let inject_connection_closed_stmts = {
208        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
209            if is_ignored(&field) {
210                return None;
211            }
212            Some(match field.ident {
213                Some(ref i) => quote!{ self.#i.inject_connection_closed(peer_id, connection_id, endpoint); },
214                None => quote!{ self.#field_n.inject_connection_closed(peer_id, connection_id, endpoint); },
215            })
216        })
217    };
218
219    // Build the list of statements to put in the body of `inject_addr_reach_failure()`.
220    let inject_addr_reach_failure_stmts = {
221        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
222            if is_ignored(&field) {
223                return None;
224            }
225
226            Some(match field.ident {
227                Some(ref i) => quote!{ self.#i.inject_addr_reach_failure(peer_id, addr, error); },
228                None => quote!{ self.#field_n.inject_addr_reach_failure(peer_id, addr, error); },
229            })
230        })
231    };
232
233    // Build the list of statements to put in the body of `inject_dial_failure()`.
234    let inject_dial_failure_stmts = {
235        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
236            if is_ignored(&field) {
237                return None;
238            }
239
240            Some(match field.ident {
241                Some(ref i) => quote!{ self.#i.inject_dial_failure(peer_id); },
242                None => quote!{ self.#field_n.inject_dial_failure(peer_id); },
243            })
244        })
245    };
246
247    // Build the list of statements to put in the body of `inject_new_listen_addr()`.
248    let inject_new_listen_addr_stmts = {
249        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
250            if is_ignored(&field) {
251                return None;
252            }
253
254            Some(match field.ident {
255                Some(ref i) => quote!{ self.#i.inject_new_listen_addr(addr); },
256                None => quote!{ self.#field_n.inject_new_listen_addr(addr); },
257            })
258        })
259    };
260
261    // Build the list of statements to put in the body of `inject_expired_listen_addr()`.
262    let inject_expired_listen_addr_stmts = {
263        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
264            if is_ignored(&field) {
265                return None;
266            }
267
268            Some(match field.ident {
269                Some(ref i) => quote!{ self.#i.inject_expired_listen_addr(addr); },
270                None => quote!{ self.#field_n.inject_expired_listen_addr(addr); },
271            })
272        })
273    };
274
275    // Build the list of statements to put in the body of `inject_new_external_addr()`.
276    let inject_new_external_addr_stmts = {
277        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
278            if is_ignored(&field) {
279                return None;
280            }
281
282            Some(match field.ident {
283                Some(ref i) => quote!{ self.#i.inject_new_external_addr(addr); },
284                None => quote!{ self.#field_n.inject_new_external_addr(addr); },
285            })
286        })
287    };
288
289    // Build the list of statements to put in the body of `inject_listener_error()`.
290    let inject_listener_error_stmts = {
291        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
292            if is_ignored(&field) {
293                return None
294            }
295            Some(match field.ident {
296                Some(ref i) => quote!(self.#i.inject_listener_error(id, err);),
297                None => quote!(self.#field_n.inject_listener_error(id, err);)
298            })
299        })
300    };
301
302    // Build the list of statements to put in the body of `inject_listener_closed()`.
303    let inject_listener_closed_stmts = {
304        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
305            if is_ignored(&field) {
306                return None
307            }
308            Some(match field.ident {
309                Some(ref i) => quote!(self.#i.inject_listener_closed(id, reason);),
310                None => quote!(self.#field_n.inject_listener_closed(id, reason);)
311            })
312        })
313    };
314
315    // Build the list of variants to put in the body of `inject_event()`.
316    //
317    // The event type is a construction of nested `#either_ident`s of the events of the children.
318    // We call `inject_event` on the corresponding child.
319    let inject_node_event_stmts = data_struct.fields.iter().enumerate().filter(|f| !is_ignored(&f.1)).enumerate().map(|(enum_n, (field_n, field))| {
320        let mut elem = if enum_n != 0 {
321            quote!{ #either_ident::Second(ev) }
322        } else {
323            quote!{ ev }
324        };
325
326        for _ in 0 .. data_struct.fields.iter().filter(|f| !is_ignored(f)).count() - 1 - enum_n {
327            elem = quote!{ #either_ident::First(#elem) };
328        }
329
330        Some(match field.ident {
331            Some(ref i) => quote!{ #elem => #trait_to_impl::inject_event(&mut self.#i, peer_id, connection_id, ev) },
332            None => quote!{ #elem => #trait_to_impl::inject_event(&mut self.#field_n, peer_id, connection_id, ev) },
333        })
334    });
335
336    // The `ProtocolsHandler` associated type.
337    let protocols_handler_ty = {
338        let mut ph_ty = None;
339        for field in data_struct.fields.iter() {
340            if is_ignored(&field) {
341                continue;
342            }
343            let ty = &field.ty;
344            let field_info = quote!{ <#ty as #trait_to_impl>::ProtocolsHandler };
345            match ph_ty {
346                Some(ev) => ph_ty = Some(quote!{ #into_proto_select_ident<#ev, #field_info> }),
347                ref mut ev @ None => *ev = Some(field_info),
348            }
349        }
350        ph_ty.unwrap_or(quote!{()})     // TODO: `!` instead
351    };
352
353    // The content of `new_handler()`.
354    // Example output: `self.field1.select(self.field2.select(self.field3))`.
355    let new_handler = {
356        let mut out_handler = None;
357
358        for (field_n, field) in data_struct.fields.iter().enumerate() {
359            if is_ignored(&field) {
360                continue;
361            }
362
363            let field_name = match field.ident {
364                Some(ref i) => quote!{ self.#i },
365                None => quote!{ self.#field_n },
366            };
367
368            let builder = quote! {
369                #field_name.new_handler()
370            };
371
372            match out_handler {
373                Some(h) => out_handler = Some(quote!{ #into_protocols_handler::select(#h, #builder) }),
374                ref mut h @ None => *h = Some(builder),
375            }
376        }
377
378        out_handler.unwrap_or(quote!{()})     // TODO: incorrect
379    };
380
381    // The method to use to poll.
382    // If we find a `#[behaviour(poll_method = "poll")]` attribute on the struct, we call
383    // `self.poll()` at the end of the polling.
384    let poll_method = {
385        let mut poll_method = quote!{std::task::Poll::Pending};
386        for meta_items in ast.attrs.iter().filter_map(get_meta_items) {
387            for meta_item in meta_items {
388                match meta_item {
389                    syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) if m.path.is_ident("poll_method") => {
390                        if let syn::Lit::Str(ref s) = m.lit {
391                            let ident: Ident = syn::parse_str(&s.value()).unwrap();
392                            poll_method = quote!{#name::#ident(self, cx, poll_params)};
393                        }
394                    }
395                    _ => ()
396                }
397            }
398        }
399        poll_method
400    };
401
402    // List of statements to put in `poll()`.
403    //
404    // We poll each child one by one and wrap around the output.
405    let poll_stmts = data_struct.fields.iter().enumerate().filter(|f| !is_ignored(&f.1)).enumerate().map(|(enum_n, (field_n, field))| {
406        let field_name = match field.ident {
407            Some(ref i) => quote!{ self.#i },
408            None => quote!{ self.#field_n },
409        };
410
411        let mut wrapped_event = if enum_n != 0 {
412            quote!{ #either_ident::Second(event) }
413        } else {
414            quote!{ event }
415        };
416        for _ in 0 .. data_struct.fields.iter().filter(|f| !is_ignored(f)).count() - 1 - enum_n {
417            wrapped_event = quote!{ #either_ident::First(#wrapped_event) };
418        }
419
420        let generate_event_match_arm = if event_process {
421            quote! {
422                std::task::Poll::Ready(#network_behaviour_action::GenerateEvent(event)) => {
423                    #net_behv_event_proc::inject_event(self, event)
424                }
425            }
426        } else {
427            quote! {
428                std::task::Poll::Ready(#network_behaviour_action::GenerateEvent(event)) => {
429                    return std::task::Poll::Ready(#network_behaviour_action::GenerateEvent(event.into()))
430                }
431            }
432        };
433
434        Some(quote!{
435            loop {
436                match #trait_to_impl::poll(&mut #field_name, cx, poll_params) {
437                    #generate_event_match_arm
438                    std::task::Poll::Ready(#network_behaviour_action::DialAddress { address }) => {
439                        return std::task::Poll::Ready(#network_behaviour_action::DialAddress { address });
440                    }
441                    std::task::Poll::Ready(#network_behaviour_action::DialPeer { peer_id, condition }) => {
442                        return std::task::Poll::Ready(#network_behaviour_action::DialPeer { peer_id, condition });
443                    }
444                    std::task::Poll::Ready(#network_behaviour_action::DisconnectPeer { peer_id }) => {
445                        return std::task::Poll::Ready(#network_behaviour_action::DisconnectPeer { peer_id });
446                    }
447                    std::task::Poll::Ready(#network_behaviour_action::NotifyHandler { peer_id, handler, event }) => {
448                        return std::task::Poll::Ready(#network_behaviour_action::NotifyHandler {
449                            peer_id,
450                            handler,
451                            event: #wrapped_event,
452                        });
453                    }
454                    std::task::Poll::Ready(#network_behaviour_action::ReportObservedAddr { address, score }) => {
455                        return std::task::Poll::Ready(#network_behaviour_action::ReportObservedAddr { address, score });
456                    }
457                    std::task::Poll::Pending => break,
458                }
459            }
460        })
461    });
462
463    // Now the magic happens.
464    let final_quote = quote!{
465        impl #impl_generics #trait_to_impl for #name #ty_generics
466        #where_clause
467        {
468            type ProtocolsHandler = #protocols_handler_ty;
469            type OutEvent = #out_event;
470
471            fn new_handler(&mut self) -> Self::ProtocolsHandler {
472                use #into_protocols_handler;
473                #new_handler
474            }
475
476            fn addresses_of_peer(&mut self, peer_id: &#peer_id) -> Vec<#multiaddr> {
477                let mut out = Vec::new();
478                #(#addresses_of_peer_stmts);*
479                out
480            }
481
482            fn inject_connected(&mut self, peer_id: &#peer_id) {
483                #(#inject_connected_stmts);*
484            }
485
486            fn inject_disconnected(&mut self, peer_id: &#peer_id) {
487                #(#inject_disconnected_stmts);*
488            }
489
490            fn inject_connection_established(&mut self, peer_id: &#peer_id, connection_id: &#connection_id, endpoint: &#connected_point) {
491                #(#inject_connection_established_stmts);*
492            }
493
494            fn inject_address_change(&mut self, peer_id: &#peer_id, connection_id: &#connection_id, old: &#connected_point, new: &#connected_point) {
495                #(#inject_address_change_stmts);*
496            }
497
498            fn inject_connection_closed(&mut self, peer_id: &#peer_id, connection_id: &#connection_id, endpoint: &#connected_point) {
499                #(#inject_connection_closed_stmts);*
500            }
501
502            fn inject_addr_reach_failure(&mut self, peer_id: Option<&#peer_id>, addr: &#multiaddr, error: &dyn std::error::Error) {
503                #(#inject_addr_reach_failure_stmts);*
504            }
505
506            fn inject_dial_failure(&mut self, peer_id: &#peer_id) {
507                #(#inject_dial_failure_stmts);*
508            }
509
510            fn inject_new_listen_addr(&mut self, addr: &#multiaddr) {
511                #(#inject_new_listen_addr_stmts);*
512            }
513
514            fn inject_expired_listen_addr(&mut self, addr: &#multiaddr) {
515                #(#inject_expired_listen_addr_stmts);*
516            }
517
518            fn inject_new_external_addr(&mut self, addr: &#multiaddr) {
519                #(#inject_new_external_addr_stmts);*
520            }
521
522            fn inject_listener_error(&mut self, id: #listener_id, err: &(dyn std::error::Error + 'static)) {
523                #(#inject_listener_error_stmts);*
524            }
525
526            fn inject_listener_closed(&mut self, id: #listener_id, reason: std::result::Result<(), &std::io::Error>) {
527                #(#inject_listener_closed_stmts);*
528            }
529
530            fn inject_event(
531                &mut self,
532                peer_id: #peer_id,
533                connection_id: #connection_id,
534                event: <<Self::ProtocolsHandler as #into_protocols_handler>::Handler as #protocols_handler>::OutEvent
535            ) {
536                match event {
537                    #(#inject_node_event_stmts),*
538                }
539            }
540
541            fn poll(&mut self, cx: &mut std::task::Context, poll_params: &mut impl #poll_parameters) -> std::task::Poll<#network_behaviour_action<<<Self::ProtocolsHandler as #into_protocols_handler>::Handler as #protocols_handler>::InEvent, Self::OutEvent>> {
542                use libp2p::futures::prelude::*;
543                #(#poll_stmts)*
544                let f: std::task::Poll<#network_behaviour_action<<<Self::ProtocolsHandler as #into_protocols_handler>::Handler as #protocols_handler>::InEvent, Self::OutEvent>> = #poll_method;
545                f
546            }
547        }
548    };
549
550    final_quote.into()
551}
552
553fn get_meta_items(attr: &syn::Attribute) -> Option<Vec<syn::NestedMeta>> {
554    if attr.path.segments.len() == 1 && attr.path.segments[0].ident == "behaviour" {
555        match attr.parse_meta() {
556            Ok(syn::Meta::List(ref meta)) => Some(meta.nested.iter().cloned().collect()),
557            Ok(_) => None,
558            Err(e) => {
559                eprintln!("error parsing attribute metadata: {}", e);
560                None
561            }
562        }
563    } else {
564        None
565    }
566}
567
568/// Returns true if a field is marked as ignored by the user.
569fn is_ignored(field: &syn::Field) -> bool {
570    for meta_items in field.attrs.iter().filter_map(get_meta_items) {
571        for meta_item in meta_items {
572            match meta_item {
573                syn::NestedMeta::Meta(syn::Meta::Path(ref m)) if m.is_ident("ignore") => {
574                    return true;
575                }
576                _ => ()
577            }
578        }
579    }
580
581    false
582}