1use std::collections::VecDeque;
2use std::fmt;
3use std::task::{Context, Poll};
4
5use bytes::Bytes;
6use fnv::{FnvHashMap, FnvHashSet};
7use libp2p::swarm::derive_prelude::FromSwarm;
8use libp2p::swarm::{
9 CloseConnection, ConnectionHandler, ConnectionId, NetworkBehaviour, NotifyHandler,
10 OneShotHandler, ToSwarm,
11};
12use libp2p::{Multiaddr, PeerId};
13use prometheus_client::registry::Registry;
14
15use crate::protocol::Message;
16
17mod length_prefixed;
18mod metrics;
19mod protocol;
20
21pub use metrics::Metrics;
22pub use protocol::{Config, Topic};
23
24#[derive(Clone, Debug, Eq, PartialEq)]
25pub enum Event {
26 Subscribed(PeerId, Topic),
27 Unsubscribed(PeerId, Topic),
28 Received(PeerId, Topic, Bytes),
29}
30
31type Handler = OneShotHandler<Config, Message, HandlerEvent>;
32
33#[derive(Default)]
34pub struct Behaviour {
35 config: Config,
36 subscriptions: FnvHashSet<Topic>,
37 peers: FnvHashMap<PeerId, FnvHashSet<Topic>>,
38 topics: FnvHashMap<Topic, FnvHashSet<PeerId>>,
39 events: VecDeque<ToSwarm<Event, Message>>,
40 metrics: Option<Metrics>,
41}
42
43impl fmt::Debug for Behaviour {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 f.debug_struct("Behaviour")
46 .field("config", &self.config)
47 .field("subscriptions", &self.subscriptions)
48 .field("peers", &self.peers)
49 .field("topics", &self.topics)
50 .finish()
51 }
52}
53
54impl Behaviour {
55 pub fn new(config: Config) -> Self {
56 Self {
57 config,
58 ..Default::default()
59 }
60 }
61
62 pub fn new_with_metrics(config: Config, registry: &mut Registry) -> Self {
63 Self {
64 config,
65 metrics: Some(Metrics::new(registry)),
66 ..Default::default()
67 }
68 }
69
70 pub fn subscribed(&self) -> impl Iterator<Item = &Topic> + '_ {
71 self.subscriptions.iter()
72 }
73
74 pub fn peers(&self, topic: &Topic) -> Option<impl Iterator<Item = &PeerId> + '_> {
75 self.topics.get(topic).map(|peers| peers.iter())
76 }
77
78 pub fn topics(&self, peer: &PeerId) -> Option<impl Iterator<Item = &Topic> + '_> {
79 self.peers.get(peer).map(|topics| topics.iter())
80 }
81
82 pub fn subscribe(&mut self, topic: Topic) {
83 self.subscriptions.insert(topic);
84 let msg = Message::Subscribe(topic);
85 for peer in self.peers.keys() {
86 self.events.push_back(ToSwarm::NotifyHandler {
87 peer_id: *peer,
88 event: msg.clone(),
89 handler: NotifyHandler::Any,
90 });
91 }
92
93 if let Some(metrics) = &mut self.metrics {
94 metrics.subscribe(&topic);
95 }
96 }
97
98 pub fn unsubscribe(&mut self, topic: &Topic) {
99 self.subscriptions.remove(topic);
100 let msg = Message::Unsubscribe(*topic);
101 if let Some(peers) = self.topics.get(topic) {
102 for peer in peers {
103 self.events.push_back(ToSwarm::NotifyHandler {
104 peer_id: *peer,
105 event: msg.clone(),
106 handler: NotifyHandler::Any,
107 });
108 }
109 }
110
111 if let Some(metrics) = &mut self.metrics {
112 metrics.unsubscribe(topic);
113 }
114 }
115
116 pub fn broadcast(&mut self, topic: &Topic, msg: Bytes) {
117 let msg = Message::Broadcast(*topic, msg);
118 if let Some(peers) = self.topics.get(topic) {
119 for peer in peers {
120 self.events.push_back(ToSwarm::NotifyHandler {
121 peer_id: *peer,
122 event: msg.clone(),
123 handler: NotifyHandler::Any,
124 });
125 }
126 }
127
128 if let Some(metrics) = &mut self.metrics {
129 metrics.msg_sent(topic, msg.len());
130 metrics.register_published_message(topic);
131 }
132 }
133
134 fn inject_connected(&mut self, peer: &PeerId) {
135 self.peers.insert(*peer, FnvHashSet::default());
136 for topic in &self.subscriptions {
137 self.events.push_back(ToSwarm::NotifyHandler {
138 peer_id: *peer,
139 event: Message::Subscribe(*topic),
140 handler: NotifyHandler::Any,
141 });
142 }
143 }
144
145 fn inject_disconnected(&mut self, peer: &PeerId) {
146 if let Some(topics) = self.peers.remove(peer) {
147 for topic in topics {
148 if let Some(peers) = self.topics.get_mut(&topic) {
149 peers.remove(peer);
150 }
151 }
152 }
153 }
154}
155
156impl NetworkBehaviour for Behaviour {
157 type ConnectionHandler = Handler;
158 type ToSwarm = Event;
159
160 fn handle_established_inbound_connection(
161 &mut self,
162 _connection_id: ConnectionId,
163 _peer: PeerId,
164 _local_addr: &Multiaddr,
165 _remote_addr: &Multiaddr,
166 ) -> Result<libp2p::swarm::THandler<Self>, libp2p::swarm::ConnectionDenied> {
167 Ok(Handler::default())
168 }
169
170 fn handle_established_outbound_connection(
171 &mut self,
172 _connection_id: ConnectionId,
173 _peer: PeerId,
174 _addr: &Multiaddr,
175 _role_override: libp2p::core::Endpoint,
176 _port_use: libp2p::core::transport::PortUse,
177 ) -> Result<libp2p::swarm::THandler<Self>, libp2p::swarm::ConnectionDenied> {
178 Ok(Handler::default())
179 }
180
181 fn on_swarm_event(&mut self, event: FromSwarm<'_>) {
182 match event {
183 FromSwarm::ConnectionEstablished(c) => {
184 if c.other_established == 0 {
185 self.inject_connected(&c.peer_id);
186 }
187 }
188 FromSwarm::ConnectionClosed(c) => {
189 if c.remaining_established == 0 {
190 self.inject_disconnected(&c.peer_id);
191 }
192 }
193 _ => {}
194 }
195 }
196
197 fn on_connection_handler_event(
198 &mut self,
199 peer: PeerId,
200 connection_id: ConnectionId,
201 msg: <Self::ConnectionHandler as ConnectionHandler>::ToBehaviour,
202 ) {
203 use HandlerEvent::*;
204 use Message::*;
205 let ev = match msg {
206 Ok(Rx(Subscribe(topic))) => {
207 let peers = self.topics.entry(topic).or_default();
208 self.peers.entry(peer).or_default().insert(topic);
209 peers.insert(peer);
210 if let Some(metrics) = self.metrics.as_mut() {
211 metrics.inc_topic_peers(&topic);
212 }
213 Event::Subscribed(peer, topic)
214 }
215
216 Ok(Rx(Broadcast(topic, msg))) => {
217 if let Some(metrics) = self.metrics.as_mut() {
218 metrics.msg_received(&topic, msg.len());
219 }
220 Event::Received(peer, topic, msg)
221 }
222
223 Ok(Rx(Unsubscribe(topic))) => {
224 self.peers.entry(peer).or_default().remove(&topic);
225 if let Some(peers) = self.topics.get_mut(&topic) {
226 peers.remove(&peer);
227 }
228 if let Some(metrics) = self.metrics.as_mut() {
229 metrics.dec_topic_peers(&topic);
230 }
231 Event::Unsubscribed(peer, topic)
232 }
233
234 Ok(Tx) => {
235 return;
236 }
237
238 Err(e) => {
239 tracing::debug!("Failed to broadcast message: {e}");
240
241 self.events.push_back(ToSwarm::CloseConnection {
242 peer_id: peer,
243 connection: CloseConnection::One(connection_id),
244 });
245
246 return;
247 }
248 };
249 self.events.push_back(ToSwarm::GenerateEvent(ev));
250 }
251
252 fn poll(&mut self, _: &mut Context) -> Poll<ToSwarm<Event, Message>> {
253 if let Some(event) = self.events.pop_front() {
254 Poll::Ready(event)
255 } else {
256 Poll::Pending
257 }
258 }
259}
260
261#[derive(Debug)]
263pub enum HandlerEvent {
264 Rx(Message),
266 Tx,
268}
269
270impl From<Message> for HandlerEvent {
271 fn from(message: Message) -> Self {
272 Self::Rx(message)
273 }
274}
275
276impl From<()> for HandlerEvent {
277 fn from(_: ()) -> Self {
278 Self::Tx
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 use std::sync::{Arc, Mutex};
287
288 struct DummySwarm {
289 peer_id: PeerId,
290 behaviour: Arc<Mutex<Behaviour>>,
291 connections: FnvHashMap<PeerId, Arc<Mutex<Behaviour>>>,
292 }
293
294 impl DummySwarm {
295 fn new() -> Self {
296 Self {
297 peer_id: PeerId::random(),
298 behaviour: Default::default(),
299 connections: Default::default(),
300 }
301 }
302
303 fn peer_id(&self) -> &PeerId {
304 &self.peer_id
305 }
306
307 fn dial(&mut self, other: &mut DummySwarm) {
308 self.behaviour
309 .lock()
310 .unwrap()
311 .inject_connected(other.peer_id());
312 self.connections
313 .insert(*other.peer_id(), other.behaviour.clone());
314 other
315 .behaviour
316 .lock()
317 .unwrap()
318 .inject_connected(self.peer_id());
319 other
320 .connections
321 .insert(*self.peer_id(), self.behaviour.clone());
322 }
323
324 fn next(&self) -> Option<Event> {
325 let waker = futures::task::noop_waker();
326 let mut ctx = Context::from_waker(&waker);
327 let mut me = self.behaviour.lock().unwrap();
328 loop {
329 match me.poll(&mut ctx) {
330 Poll::Ready(ToSwarm::NotifyHandler { peer_id, event, .. }) => {
331 if let Some(other) = self.connections.get(&peer_id) {
332 let mut other = other.lock().unwrap();
333 other.on_connection_handler_event(
334 *self.peer_id(),
335 ConnectionId::new_unchecked(0),
336 Ok(HandlerEvent::Rx(event)),
337 );
338 }
339 }
340 Poll::Ready(ToSwarm::GenerateEvent(event)) => {
341 return Some(event);
342 }
343 Poll::Ready(_) => panic!(),
344 Poll::Pending => {
345 return None;
346 }
347 }
348 }
349 }
350
351 fn subscribe(&self, topic: Topic) {
352 let mut me = self.behaviour.lock().unwrap();
353 me.subscribe(topic);
354 }
355
356 fn unsubscribe(&self, topic: &Topic) {
357 let mut me = self.behaviour.lock().unwrap();
358 me.unsubscribe(topic);
359 }
360
361 fn broadcast(&self, topic: &Topic, msg: Bytes) {
362 let mut me = self.behaviour.lock().unwrap();
363 me.broadcast(topic, msg);
364 }
365 }
366
367 #[test]
368 fn test_broadcast() {
369 let topic = Topic::new(b"topic");
370 let msg = Bytes::from_static(b"msg");
371 let mut a = DummySwarm::new();
372 let mut b = DummySwarm::new();
373
374 a.subscribe(topic);
375 a.dial(&mut b);
376 assert!(a.next().is_none());
377 assert_eq!(b.next().unwrap(), Event::Subscribed(*a.peer_id(), topic));
378 b.subscribe(topic);
379 assert!(b.next().is_none());
380 assert_eq!(a.next().unwrap(), Event::Subscribed(*b.peer_id(), topic));
381 b.broadcast(&topic, msg.clone());
382 assert!(b.next().is_none());
383 assert_eq!(a.next().unwrap(), Event::Received(*b.peer_id(), topic, msg));
384 a.unsubscribe(&topic);
385 assert!(a.next().is_none());
386 assert_eq!(b.next().unwrap(), Event::Unsubscribed(*a.peer_id(), topic));
387 }
388}