libp2p_broadcast/
lib.rs

1use crate::protocol::Message;
2use fnv::{FnvHashMap, FnvHashSet};
3use libp2p::core::connection::ConnectionId;
4use libp2p::swarm::{
5    NetworkBehaviour, NetworkBehaviourAction, NotifyHandler, OneShotHandler, PollParameters,
6};
7use libp2p::{Multiaddr, PeerId};
8use std::collections::VecDeque;
9use std::fmt;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12
13mod protocol;
14
15use libp2p::swarm::derive_prelude::FromSwarm;
16pub use protocol::{BroadcastConfig, Topic};
17
18#[derive(Clone, Debug, Eq, PartialEq)]
19pub enum BroadcastEvent {
20    Subscribed(PeerId, Topic),
21    Unsubscribed(PeerId, Topic),
22    Received(PeerId, Topic, Arc<[u8]>),
23}
24type Handler = OneShotHandler<BroadcastConfig, Message, HandlerEvent>;
25
26#[derive(Default)]
27pub struct Broadcast {
28    config: BroadcastConfig,
29    subscriptions: FnvHashSet<Topic>,
30    peers: FnvHashMap<PeerId, FnvHashSet<Topic>>,
31    topics: FnvHashMap<Topic, FnvHashSet<PeerId>>,
32    events: VecDeque<NetworkBehaviourAction<BroadcastEvent, Handler>>,
33}
34
35impl fmt::Debug for Broadcast {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        f.debug_struct("Broadcast")
38            .field("config", &self.config)
39            .field("subscriptions", &self.subscriptions)
40            .field("peers", &self.peers)
41            .field("topics", &self.topics)
42            .finish()
43    }
44}
45
46impl Broadcast {
47    pub fn new(config: BroadcastConfig) -> Self {
48        Self {
49            config,
50            ..Default::default()
51        }
52    }
53
54    pub fn subscribed(&self) -> impl Iterator<Item = &Topic> + '_ {
55        self.subscriptions.iter()
56    }
57
58    pub fn peers(&self, topic: &Topic) -> Option<impl Iterator<Item = &PeerId> + '_> {
59        self.topics.get(topic).map(|peers| peers.iter())
60    }
61
62    pub fn topics(&self, peer: &PeerId) -> Option<impl Iterator<Item = &Topic> + '_> {
63        self.peers.get(peer).map(|topics| topics.iter())
64    }
65
66    pub fn subscribe(&mut self, topic: Topic) {
67        self.subscriptions.insert(topic);
68        let msg = Message::Subscribe(topic);
69        for peer in self.peers.keys() {
70            self.events
71                .push_back(NetworkBehaviourAction::NotifyHandler {
72                    peer_id: *peer,
73                    event: msg.clone(),
74                    handler: NotifyHandler::Any,
75                });
76        }
77    }
78
79    pub fn unsubscribe(&mut self, topic: &Topic) {
80        self.subscriptions.remove(topic);
81        let msg = Message::Unsubscribe(*topic);
82        if let Some(peers) = self.topics.get(topic) {
83            for peer in peers {
84                self.events
85                    .push_back(NetworkBehaviourAction::NotifyHandler {
86                        peer_id: *peer,
87                        event: msg.clone(),
88                        handler: NotifyHandler::Any,
89                    });
90            }
91        }
92    }
93
94    pub fn broadcast(&mut self, topic: &Topic, msg: Arc<[u8]>) {
95        let msg = Message::Broadcast(*topic, msg);
96        if let Some(peers) = self.topics.get(topic) {
97            for peer in peers {
98                self.events
99                    .push_back(NetworkBehaviourAction::NotifyHandler {
100                        peer_id: *peer,
101                        event: msg.clone(),
102                        handler: NotifyHandler::Any,
103                    });
104            }
105        }
106    }
107
108    fn inject_connected(&mut self, peer: &PeerId) {
109        self.peers.insert(*peer, FnvHashSet::default());
110        for topic in &self.subscriptions {
111            self.events
112                .push_back(NetworkBehaviourAction::NotifyHandler {
113                    peer_id: *peer,
114                    event: Message::Subscribe(*topic),
115                    handler: NotifyHandler::Any,
116                });
117        }
118    }
119
120    fn inject_disconnected(&mut self, peer: &PeerId) {
121        if let Some(topics) = self.peers.remove(peer) {
122            for topic in topics {
123                if let Some(peers) = self.topics.get_mut(&topic) {
124                    peers.remove(peer);
125                }
126            }
127        }
128    }
129}
130
131impl NetworkBehaviour for Broadcast {
132    type ConnectionHandler = OneShotHandler<BroadcastConfig, Message, HandlerEvent>;
133    type OutEvent = BroadcastEvent;
134
135    fn new_handler(&mut self) -> Self::ConnectionHandler {
136        Default::default()
137    }
138
139    fn addresses_of_peer(&mut self, _peer: &PeerId) -> Vec<Multiaddr> {
140        Vec::new()
141    }
142
143    fn on_swarm_event(&mut self, event: FromSwarm<Self::ConnectionHandler>) {
144        match event {
145            FromSwarm::ConnectionEstablished(c) => {
146                if c.other_established == 0 {
147                    self.inject_connected(&c.peer_id);
148                }
149            }
150            FromSwarm::ConnectionClosed(c) => {
151                if c.remaining_established == 0 {
152                    self.inject_disconnected(&c.peer_id);
153                }
154            }
155            _ => {}
156        }
157    }
158
159    fn on_connection_handler_event(&mut self, peer: PeerId, _: ConnectionId, msg: HandlerEvent) {
160        use HandlerEvent::*;
161        use Message::*;
162        let ev = match msg {
163            Rx(Subscribe(topic)) => {
164                let peers = self.topics.entry(topic).or_default();
165                self.peers.get_mut(&peer).unwrap().insert(topic);
166                peers.insert(peer);
167                BroadcastEvent::Subscribed(peer, topic)
168            }
169            Rx(Broadcast(topic, msg)) => BroadcastEvent::Received(peer, topic, msg),
170            Rx(Unsubscribe(topic)) => {
171                self.peers.get_mut(&peer).unwrap().remove(&topic);
172                if let Some(peers) = self.topics.get_mut(&topic) {
173                    peers.remove(&peer);
174                }
175                BroadcastEvent::Unsubscribed(peer, topic)
176            }
177            Tx => {
178                return;
179            }
180        };
181        self.events
182            .push_back(NetworkBehaviourAction::GenerateEvent(ev));
183    }
184
185    fn poll(
186        &mut self,
187        _: &mut Context,
188        _: &mut impl PollParameters,
189    ) -> Poll<NetworkBehaviourAction<BroadcastEvent, Handler>> {
190        if let Some(event) = self.events.pop_front() {
191            Poll::Ready(event)
192        } else {
193            Poll::Pending
194        }
195    }
196}
197
198/// Transmission between the `OneShotHandler` and the `BroadcastHandler`.
199#[derive(Debug)]
200pub enum HandlerEvent {
201    /// We received a `Message` from a remote.
202    Rx(Message),
203    /// We successfully sent a `Message`.
204    Tx,
205}
206
207impl From<Message> for HandlerEvent {
208    fn from(message: Message) -> Self {
209        Self::Rx(message)
210    }
211}
212
213impl From<()> for HandlerEvent {
214    fn from(_: ()) -> Self {
215        Self::Tx
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use libp2p::swarm::AddressRecord;
223    use std::sync::{Arc, Mutex};
224
225    struct DummySwarm {
226        peer_id: PeerId,
227        behaviour: Arc<Mutex<Broadcast>>,
228        connections: FnvHashMap<PeerId, Arc<Mutex<Broadcast>>>,
229    }
230
231    impl DummySwarm {
232        fn new() -> Self {
233            Self {
234                peer_id: PeerId::random(),
235                behaviour: Default::default(),
236                connections: Default::default(),
237            }
238        }
239
240        fn peer_id(&self) -> &PeerId {
241            &self.peer_id
242        }
243
244        fn dial(&mut self, other: &mut DummySwarm) {
245            self.behaviour
246                .lock()
247                .unwrap()
248                .inject_connected(other.peer_id());
249            self.connections
250                .insert(*other.peer_id(), other.behaviour.clone());
251            other
252                .behaviour
253                .lock()
254                .unwrap()
255                .inject_connected(self.peer_id());
256            other
257                .connections
258                .insert(*self.peer_id(), self.behaviour.clone());
259        }
260
261        fn next(&self) -> Option<BroadcastEvent> {
262            let waker = futures::task::noop_waker();
263            let mut ctx = Context::from_waker(&waker);
264            let mut me = self.behaviour.lock().unwrap();
265            loop {
266                match me.poll(&mut ctx, &mut DummyPollParameters) {
267                    Poll::Ready(NetworkBehaviourAction::NotifyHandler {
268                        peer_id, event, ..
269                    }) => {
270                        if let Some(other) = self.connections.get(&peer_id) {
271                            let mut other = other.lock().unwrap();
272                            other.on_connection_handler_event(
273                                *self.peer_id(),
274                                ConnectionId::new(0),
275                                HandlerEvent::Rx(event),
276                            );
277                        }
278                    }
279                    Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)) => {
280                        return Some(event);
281                    }
282                    Poll::Ready(_) => panic!(),
283                    Poll::Pending => {
284                        return None;
285                    }
286                }
287            }
288        }
289
290        fn subscribe(&self, topic: Topic) {
291            let mut me = self.behaviour.lock().unwrap();
292            me.subscribe(topic);
293        }
294
295        fn unsubscribe(&self, topic: &Topic) {
296            let mut me = self.behaviour.lock().unwrap();
297            me.unsubscribe(topic);
298        }
299
300        fn broadcast(&self, topic: &Topic, msg: Arc<[u8]>) {
301            let mut me = self.behaviour.lock().unwrap();
302            me.broadcast(topic, msg);
303        }
304    }
305
306    struct DummyPollParameters;
307
308    impl PollParameters for DummyPollParameters {
309        type SupportedProtocolsIter = std::iter::Empty<Vec<u8>>;
310        type ListenedAddressesIter = std::iter::Empty<Multiaddr>;
311        type ExternalAddressesIter = std::iter::Empty<AddressRecord>;
312
313        fn supported_protocols(&self) -> Self::SupportedProtocolsIter {
314            unimplemented!()
315        }
316
317        fn listened_addresses(&self) -> Self::ListenedAddressesIter {
318            unimplemented!()
319        }
320
321        fn external_addresses(&self) -> Self::ExternalAddressesIter {
322            unimplemented!()
323        }
324
325        fn local_peer_id(&self) -> &PeerId {
326            unimplemented!()
327        }
328    }
329
330    #[test]
331    fn test_broadcast() {
332        let topic = Topic::new(b"topic");
333        let msg = Arc::new(*b"msg");
334        let mut a = DummySwarm::new();
335        let mut b = DummySwarm::new();
336
337        a.subscribe(topic);
338        a.dial(&mut b);
339        assert!(a.next().is_none());
340        assert_eq!(
341            b.next().unwrap(),
342            BroadcastEvent::Subscribed(*a.peer_id(), topic)
343        );
344        b.subscribe(topic);
345        assert!(b.next().is_none());
346        assert_eq!(
347            a.next().unwrap(),
348            BroadcastEvent::Subscribed(*b.peer_id(), topic)
349        );
350        b.broadcast(&topic, msg.clone());
351        assert!(b.next().is_none());
352        assert_eq!(
353            a.next().unwrap(),
354            BroadcastEvent::Received(*b.peer_id(), topic, msg)
355        );
356        a.unsubscribe(&topic);
357        assert!(a.next().is_none());
358        assert_eq!(
359            b.next().unwrap(),
360            BroadcastEvent::Unsubscribed(*a.peer_id(), topic)
361        );
362    }
363}