libp2p_scatter/
lib.rs

1use std::collections::VecDeque;
2use std::fmt;
3use std::task::{Context, Poll};
4
5use bytes::Bytes;
6use fnv::{FnvHashMap, FnvHashSet};
7use libp2p::swarm::derive_prelude::FromSwarm;
8use libp2p::swarm::{
9    CloseConnection, ConnectionHandler, ConnectionId, NetworkBehaviour, NotifyHandler,
10    OneShotHandler, ToSwarm,
11};
12use libp2p::{Multiaddr, PeerId};
13use prometheus_client::registry::Registry;
14
15use crate::protocol::Message;
16
17mod length_prefixed;
18mod metrics;
19mod protocol;
20
21pub use metrics::Metrics;
22pub use protocol::{Config, Topic};
23
24#[derive(Clone, Debug, Eq, PartialEq)]
25pub enum Event {
26    Subscribed(PeerId, Topic),
27    Unsubscribed(PeerId, Topic),
28    Received(PeerId, Topic, Bytes),
29}
30
31type Handler = OneShotHandler<Config, Message, HandlerEvent>;
32
33#[derive(Default)]
34pub struct Behaviour {
35    config: Config,
36    subscriptions: FnvHashSet<Topic>,
37    peers: FnvHashMap<PeerId, FnvHashSet<Topic>>,
38    topics: FnvHashMap<Topic, FnvHashSet<PeerId>>,
39    events: VecDeque<ToSwarm<Event, Message>>,
40    metrics: Option<Metrics>,
41}
42
43impl fmt::Debug for Behaviour {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        f.debug_struct("Behaviour")
46            .field("config", &self.config)
47            .field("subscriptions", &self.subscriptions)
48            .field("peers", &self.peers)
49            .field("topics", &self.topics)
50            .finish()
51    }
52}
53
54impl Behaviour {
55    pub fn new(config: Config) -> Self {
56        Self {
57            config,
58            ..Default::default()
59        }
60    }
61
62    pub fn new_with_metrics(config: Config, registry: &mut Registry) -> Self {
63        Self {
64            config,
65            metrics: Some(Metrics::new(registry)),
66            ..Default::default()
67        }
68    }
69
70    pub fn subscribed(&self) -> impl Iterator<Item = &Topic> + '_ {
71        self.subscriptions.iter()
72    }
73
74    pub fn peers(&self, topic: &Topic) -> Option<impl Iterator<Item = &PeerId> + '_> {
75        self.topics.get(topic).map(|peers| peers.iter())
76    }
77
78    pub fn topics(&self, peer: &PeerId) -> Option<impl Iterator<Item = &Topic> + '_> {
79        self.peers.get(peer).map(|topics| topics.iter())
80    }
81
82    pub fn subscribe(&mut self, topic: Topic) {
83        self.subscriptions.insert(topic);
84        let msg = Message::Subscribe(topic);
85        for peer in self.peers.keys() {
86            self.events.push_back(ToSwarm::NotifyHandler {
87                peer_id: *peer,
88                event: msg.clone(),
89                handler: NotifyHandler::Any,
90            });
91        }
92
93        if let Some(metrics) = &mut self.metrics {
94            metrics.subscribe(&topic);
95        }
96    }
97
98    pub fn unsubscribe(&mut self, topic: &Topic) {
99        self.subscriptions.remove(topic);
100        let msg = Message::Unsubscribe(*topic);
101        if let Some(peers) = self.topics.get(topic) {
102            for peer in peers {
103                self.events.push_back(ToSwarm::NotifyHandler {
104                    peer_id: *peer,
105                    event: msg.clone(),
106                    handler: NotifyHandler::Any,
107                });
108            }
109        }
110
111        if let Some(metrics) = &mut self.metrics {
112            metrics.unsubscribe(topic);
113        }
114    }
115
116    pub fn broadcast(&mut self, topic: &Topic, msg: Bytes) {
117        let msg = Message::Broadcast(*topic, msg);
118        if let Some(peers) = self.topics.get(topic) {
119            for peer in peers {
120                self.events.push_back(ToSwarm::NotifyHandler {
121                    peer_id: *peer,
122                    event: msg.clone(),
123                    handler: NotifyHandler::Any,
124                });
125            }
126        }
127
128        if let Some(metrics) = &mut self.metrics {
129            metrics.msg_sent(topic, msg.len());
130            metrics.register_published_message(topic);
131        }
132    }
133
134    fn inject_connected(&mut self, peer: &PeerId) {
135        self.peers.insert(*peer, FnvHashSet::default());
136        for topic in &self.subscriptions {
137            self.events.push_back(ToSwarm::NotifyHandler {
138                peer_id: *peer,
139                event: Message::Subscribe(*topic),
140                handler: NotifyHandler::Any,
141            });
142        }
143    }
144
145    fn inject_disconnected(&mut self, peer: &PeerId) {
146        if let Some(topics) = self.peers.remove(peer) {
147            for topic in topics {
148                if let Some(peers) = self.topics.get_mut(&topic) {
149                    peers.remove(peer);
150                }
151            }
152        }
153    }
154}
155
156impl NetworkBehaviour for Behaviour {
157    type ConnectionHandler = Handler;
158    type ToSwarm = Event;
159
160    fn handle_established_inbound_connection(
161        &mut self,
162        _connection_id: ConnectionId,
163        _peer: PeerId,
164        _local_addr: &Multiaddr,
165        _remote_addr: &Multiaddr,
166    ) -> Result<libp2p::swarm::THandler<Self>, libp2p::swarm::ConnectionDenied> {
167        Ok(Handler::default())
168    }
169
170    fn handle_established_outbound_connection(
171        &mut self,
172        _connection_id: ConnectionId,
173        _peer: PeerId,
174        _addr: &Multiaddr,
175        _role_override: libp2p::core::Endpoint,
176        _port_use: libp2p::core::transport::PortUse,
177    ) -> Result<libp2p::swarm::THandler<Self>, libp2p::swarm::ConnectionDenied> {
178        Ok(Handler::default())
179    }
180
181    fn on_swarm_event(&mut self, event: FromSwarm<'_>) {
182        match event {
183            FromSwarm::ConnectionEstablished(c) => {
184                if c.other_established == 0 {
185                    self.inject_connected(&c.peer_id);
186                }
187            }
188            FromSwarm::ConnectionClosed(c) => {
189                if c.remaining_established == 0 {
190                    self.inject_disconnected(&c.peer_id);
191                }
192            }
193            _ => {}
194        }
195    }
196
197    fn on_connection_handler_event(
198        &mut self,
199        peer: PeerId,
200        connection_id: ConnectionId,
201        msg: <Self::ConnectionHandler as ConnectionHandler>::ToBehaviour,
202    ) {
203        use HandlerEvent::*;
204        use Message::*;
205        let ev = match msg {
206            Ok(Rx(Subscribe(topic))) => {
207                let peers = self.topics.entry(topic).or_default();
208                self.peers.entry(peer).or_default().insert(topic);
209                peers.insert(peer);
210                if let Some(metrics) = self.metrics.as_mut() {
211                    metrics.inc_topic_peers(&topic);
212                }
213                Event::Subscribed(peer, topic)
214            }
215
216            Ok(Rx(Broadcast(topic, msg))) => {
217                if let Some(metrics) = self.metrics.as_mut() {
218                    metrics.msg_received(&topic, msg.len());
219                }
220                Event::Received(peer, topic, msg)
221            }
222
223            Ok(Rx(Unsubscribe(topic))) => {
224                self.peers.entry(peer).or_default().remove(&topic);
225                if let Some(peers) = self.topics.get_mut(&topic) {
226                    peers.remove(&peer);
227                }
228                if let Some(metrics) = self.metrics.as_mut() {
229                    metrics.dec_topic_peers(&topic);
230                }
231                Event::Unsubscribed(peer, topic)
232            }
233
234            Ok(Tx) => {
235                return;
236            }
237
238            Err(e) => {
239                tracing::debug!("Failed to broadcast message: {e}");
240
241                self.events.push_back(ToSwarm::CloseConnection {
242                    peer_id: peer,
243                    connection: CloseConnection::One(connection_id),
244                });
245
246                return;
247            }
248        };
249        self.events.push_back(ToSwarm::GenerateEvent(ev));
250    }
251
252    fn poll(&mut self, _: &mut Context) -> Poll<ToSwarm<Event, Message>> {
253        if let Some(event) = self.events.pop_front() {
254            Poll::Ready(event)
255        } else {
256            Poll::Pending
257        }
258    }
259}
260
261/// Transmission between the `OneShotHandler` and the `BroadcastHandler`.
262#[derive(Debug)]
263pub enum HandlerEvent {
264    /// We received a `Message` from a remote.
265    Rx(Message),
266    /// We successfully sent a `Message`.
267    Tx,
268}
269
270impl From<Message> for HandlerEvent {
271    fn from(message: Message) -> Self {
272        Self::Rx(message)
273    }
274}
275
276impl From<()> for HandlerEvent {
277    fn from(_: ()) -> Self {
278        Self::Tx
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    use std::sync::{Arc, Mutex};
287
288    struct DummySwarm {
289        peer_id: PeerId,
290        behaviour: Arc<Mutex<Behaviour>>,
291        connections: FnvHashMap<PeerId, Arc<Mutex<Behaviour>>>,
292    }
293
294    impl DummySwarm {
295        fn new() -> Self {
296            Self {
297                peer_id: PeerId::random(),
298                behaviour: Default::default(),
299                connections: Default::default(),
300            }
301        }
302
303        fn peer_id(&self) -> &PeerId {
304            &self.peer_id
305        }
306
307        fn dial(&mut self, other: &mut DummySwarm) {
308            self.behaviour
309                .lock()
310                .unwrap()
311                .inject_connected(other.peer_id());
312            self.connections
313                .insert(*other.peer_id(), other.behaviour.clone());
314            other
315                .behaviour
316                .lock()
317                .unwrap()
318                .inject_connected(self.peer_id());
319            other
320                .connections
321                .insert(*self.peer_id(), self.behaviour.clone());
322        }
323
324        fn next(&self) -> Option<Event> {
325            let waker = futures::task::noop_waker();
326            let mut ctx = Context::from_waker(&waker);
327            let mut me = self.behaviour.lock().unwrap();
328            loop {
329                match me.poll(&mut ctx) {
330                    Poll::Ready(ToSwarm::NotifyHandler { peer_id, event, .. }) => {
331                        if let Some(other) = self.connections.get(&peer_id) {
332                            let mut other = other.lock().unwrap();
333                            other.on_connection_handler_event(
334                                *self.peer_id(),
335                                ConnectionId::new_unchecked(0),
336                                Ok(HandlerEvent::Rx(event)),
337                            );
338                        }
339                    }
340                    Poll::Ready(ToSwarm::GenerateEvent(event)) => {
341                        return Some(event);
342                    }
343                    Poll::Ready(_) => panic!(),
344                    Poll::Pending => {
345                        return None;
346                    }
347                }
348            }
349        }
350
351        fn subscribe(&self, topic: Topic) {
352            let mut me = self.behaviour.lock().unwrap();
353            me.subscribe(topic);
354        }
355
356        fn unsubscribe(&self, topic: &Topic) {
357            let mut me = self.behaviour.lock().unwrap();
358            me.unsubscribe(topic);
359        }
360
361        fn broadcast(&self, topic: &Topic, msg: Bytes) {
362            let mut me = self.behaviour.lock().unwrap();
363            me.broadcast(topic, msg);
364        }
365    }
366
367    #[test]
368    fn test_broadcast() {
369        let topic = Topic::new(b"topic");
370        let msg = Bytes::from_static(b"msg");
371        let mut a = DummySwarm::new();
372        let mut b = DummySwarm::new();
373
374        a.subscribe(topic);
375        a.dial(&mut b);
376        assert!(a.next().is_none());
377        assert_eq!(b.next().unwrap(), Event::Subscribed(*a.peer_id(), topic));
378        b.subscribe(topic);
379        assert!(b.next().is_none());
380        assert_eq!(a.next().unwrap(), Event::Subscribed(*b.peer_id(), topic));
381        b.broadcast(&topic, msg.clone());
382        assert!(b.next().is_none());
383        assert_eq!(a.next().unwrap(), Event::Received(*b.peer_id(), topic, msg));
384        a.unsubscribe(&topic);
385        assert!(a.next().is_none());
386        assert_eq!(b.next().unwrap(), Event::Unsubscribed(*a.peer_id(), topic));
387    }
388}