1use crate::protocol::Message;
2use fnv::{FnvHashMap, FnvHashSet};
3use libp2p::core::connection::ConnectionId;
4use libp2p::swarm::{
5 NetworkBehaviour, NetworkBehaviourAction, NotifyHandler, OneShotHandler, PollParameters,
6};
7use libp2p::{Multiaddr, PeerId};
8use std::collections::VecDeque;
9use std::fmt;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12
13mod protocol;
14
15use libp2p::swarm::derive_prelude::FromSwarm;
16pub use protocol::{BroadcastConfig, Topic};
17
18#[derive(Clone, Debug, Eq, PartialEq)]
19pub enum BroadcastEvent {
20 Subscribed(PeerId, Topic),
21 Unsubscribed(PeerId, Topic),
22 Received(PeerId, Topic, Arc<[u8]>),
23}
24type Handler = OneShotHandler<BroadcastConfig, Message, HandlerEvent>;
25
26#[derive(Default)]
27pub struct Broadcast {
28 config: BroadcastConfig,
29 subscriptions: FnvHashSet<Topic>,
30 peers: FnvHashMap<PeerId, FnvHashSet<Topic>>,
31 topics: FnvHashMap<Topic, FnvHashSet<PeerId>>,
32 events: VecDeque<NetworkBehaviourAction<BroadcastEvent, Handler>>,
33}
34
35impl fmt::Debug for Broadcast {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 f.debug_struct("Broadcast")
38 .field("config", &self.config)
39 .field("subscriptions", &self.subscriptions)
40 .field("peers", &self.peers)
41 .field("topics", &self.topics)
42 .finish()
43 }
44}
45
46impl Broadcast {
47 pub fn new(config: BroadcastConfig) -> Self {
48 Self {
49 config,
50 ..Default::default()
51 }
52 }
53
54 pub fn subscribed(&self) -> impl Iterator<Item = &Topic> + '_ {
55 self.subscriptions.iter()
56 }
57
58 pub fn peers(&self, topic: &Topic) -> Option<impl Iterator<Item = &PeerId> + '_> {
59 self.topics.get(topic).map(|peers| peers.iter())
60 }
61
62 pub fn topics(&self, peer: &PeerId) -> Option<impl Iterator<Item = &Topic> + '_> {
63 self.peers.get(peer).map(|topics| topics.iter())
64 }
65
66 pub fn subscribe(&mut self, topic: Topic) {
67 self.subscriptions.insert(topic);
68 let msg = Message::Subscribe(topic);
69 for peer in self.peers.keys() {
70 self.events
71 .push_back(NetworkBehaviourAction::NotifyHandler {
72 peer_id: *peer,
73 event: msg.clone(),
74 handler: NotifyHandler::Any,
75 });
76 }
77 }
78
79 pub fn unsubscribe(&mut self, topic: &Topic) {
80 self.subscriptions.remove(topic);
81 let msg = Message::Unsubscribe(*topic);
82 if let Some(peers) = self.topics.get(topic) {
83 for peer in peers {
84 self.events
85 .push_back(NetworkBehaviourAction::NotifyHandler {
86 peer_id: *peer,
87 event: msg.clone(),
88 handler: NotifyHandler::Any,
89 });
90 }
91 }
92 }
93
94 pub fn broadcast(&mut self, topic: &Topic, msg: Arc<[u8]>) {
95 let msg = Message::Broadcast(*topic, msg);
96 if let Some(peers) = self.topics.get(topic) {
97 for peer in peers {
98 self.events
99 .push_back(NetworkBehaviourAction::NotifyHandler {
100 peer_id: *peer,
101 event: msg.clone(),
102 handler: NotifyHandler::Any,
103 });
104 }
105 }
106 }
107
108 fn inject_connected(&mut self, peer: &PeerId) {
109 self.peers.insert(*peer, FnvHashSet::default());
110 for topic in &self.subscriptions {
111 self.events
112 .push_back(NetworkBehaviourAction::NotifyHandler {
113 peer_id: *peer,
114 event: Message::Subscribe(*topic),
115 handler: NotifyHandler::Any,
116 });
117 }
118 }
119
120 fn inject_disconnected(&mut self, peer: &PeerId) {
121 if let Some(topics) = self.peers.remove(peer) {
122 for topic in topics {
123 if let Some(peers) = self.topics.get_mut(&topic) {
124 peers.remove(peer);
125 }
126 }
127 }
128 }
129}
130
131impl NetworkBehaviour for Broadcast {
132 type ConnectionHandler = OneShotHandler<BroadcastConfig, Message, HandlerEvent>;
133 type OutEvent = BroadcastEvent;
134
135 fn new_handler(&mut self) -> Self::ConnectionHandler {
136 Default::default()
137 }
138
139 fn addresses_of_peer(&mut self, _peer: &PeerId) -> Vec<Multiaddr> {
140 Vec::new()
141 }
142
143 fn on_swarm_event(&mut self, event: FromSwarm<Self::ConnectionHandler>) {
144 match event {
145 FromSwarm::ConnectionEstablished(c) => {
146 if c.other_established == 0 {
147 self.inject_connected(&c.peer_id);
148 }
149 }
150 FromSwarm::ConnectionClosed(c) => {
151 if c.remaining_established == 0 {
152 self.inject_disconnected(&c.peer_id);
153 }
154 }
155 _ => {}
156 }
157 }
158
159 fn on_connection_handler_event(&mut self, peer: PeerId, _: ConnectionId, msg: HandlerEvent) {
160 use HandlerEvent::*;
161 use Message::*;
162 let ev = match msg {
163 Rx(Subscribe(topic)) => {
164 let peers = self.topics.entry(topic).or_default();
165 self.peers.get_mut(&peer).unwrap().insert(topic);
166 peers.insert(peer);
167 BroadcastEvent::Subscribed(peer, topic)
168 }
169 Rx(Broadcast(topic, msg)) => BroadcastEvent::Received(peer, topic, msg),
170 Rx(Unsubscribe(topic)) => {
171 self.peers.get_mut(&peer).unwrap().remove(&topic);
172 if let Some(peers) = self.topics.get_mut(&topic) {
173 peers.remove(&peer);
174 }
175 BroadcastEvent::Unsubscribed(peer, topic)
176 }
177 Tx => {
178 return;
179 }
180 };
181 self.events
182 .push_back(NetworkBehaviourAction::GenerateEvent(ev));
183 }
184
185 fn poll(
186 &mut self,
187 _: &mut Context,
188 _: &mut impl PollParameters,
189 ) -> Poll<NetworkBehaviourAction<BroadcastEvent, Handler>> {
190 if let Some(event) = self.events.pop_front() {
191 Poll::Ready(event)
192 } else {
193 Poll::Pending
194 }
195 }
196}
197
198#[derive(Debug)]
200pub enum HandlerEvent {
201 Rx(Message),
203 Tx,
205}
206
207impl From<Message> for HandlerEvent {
208 fn from(message: Message) -> Self {
209 Self::Rx(message)
210 }
211}
212
213impl From<()> for HandlerEvent {
214 fn from(_: ()) -> Self {
215 Self::Tx
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use libp2p::swarm::AddressRecord;
223 use std::sync::{Arc, Mutex};
224
225 struct DummySwarm {
226 peer_id: PeerId,
227 behaviour: Arc<Mutex<Broadcast>>,
228 connections: FnvHashMap<PeerId, Arc<Mutex<Broadcast>>>,
229 }
230
231 impl DummySwarm {
232 fn new() -> Self {
233 Self {
234 peer_id: PeerId::random(),
235 behaviour: Default::default(),
236 connections: Default::default(),
237 }
238 }
239
240 fn peer_id(&self) -> &PeerId {
241 &self.peer_id
242 }
243
244 fn dial(&mut self, other: &mut DummySwarm) {
245 self.behaviour
246 .lock()
247 .unwrap()
248 .inject_connected(other.peer_id());
249 self.connections
250 .insert(*other.peer_id(), other.behaviour.clone());
251 other
252 .behaviour
253 .lock()
254 .unwrap()
255 .inject_connected(self.peer_id());
256 other
257 .connections
258 .insert(*self.peer_id(), self.behaviour.clone());
259 }
260
261 fn next(&self) -> Option<BroadcastEvent> {
262 let waker = futures::task::noop_waker();
263 let mut ctx = Context::from_waker(&waker);
264 let mut me = self.behaviour.lock().unwrap();
265 loop {
266 match me.poll(&mut ctx, &mut DummyPollParameters) {
267 Poll::Ready(NetworkBehaviourAction::NotifyHandler {
268 peer_id, event, ..
269 }) => {
270 if let Some(other) = self.connections.get(&peer_id) {
271 let mut other = other.lock().unwrap();
272 other.on_connection_handler_event(
273 *self.peer_id(),
274 ConnectionId::new(0),
275 HandlerEvent::Rx(event),
276 );
277 }
278 }
279 Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)) => {
280 return Some(event);
281 }
282 Poll::Ready(_) => panic!(),
283 Poll::Pending => {
284 return None;
285 }
286 }
287 }
288 }
289
290 fn subscribe(&self, topic: Topic) {
291 let mut me = self.behaviour.lock().unwrap();
292 me.subscribe(topic);
293 }
294
295 fn unsubscribe(&self, topic: &Topic) {
296 let mut me = self.behaviour.lock().unwrap();
297 me.unsubscribe(topic);
298 }
299
300 fn broadcast(&self, topic: &Topic, msg: Arc<[u8]>) {
301 let mut me = self.behaviour.lock().unwrap();
302 me.broadcast(topic, msg);
303 }
304 }
305
306 struct DummyPollParameters;
307
308 impl PollParameters for DummyPollParameters {
309 type SupportedProtocolsIter = std::iter::Empty<Vec<u8>>;
310 type ListenedAddressesIter = std::iter::Empty<Multiaddr>;
311 type ExternalAddressesIter = std::iter::Empty<AddressRecord>;
312
313 fn supported_protocols(&self) -> Self::SupportedProtocolsIter {
314 unimplemented!()
315 }
316
317 fn listened_addresses(&self) -> Self::ListenedAddressesIter {
318 unimplemented!()
319 }
320
321 fn external_addresses(&self) -> Self::ExternalAddressesIter {
322 unimplemented!()
323 }
324
325 fn local_peer_id(&self) -> &PeerId {
326 unimplemented!()
327 }
328 }
329
330 #[test]
331 fn test_broadcast() {
332 let topic = Topic::new(b"topic");
333 let msg = Arc::new(*b"msg");
334 let mut a = DummySwarm::new();
335 let mut b = DummySwarm::new();
336
337 a.subscribe(topic);
338 a.dial(&mut b);
339 assert!(a.next().is_none());
340 assert_eq!(
341 b.next().unwrap(),
342 BroadcastEvent::Subscribed(*a.peer_id(), topic)
343 );
344 b.subscribe(topic);
345 assert!(b.next().is_none());
346 assert_eq!(
347 a.next().unwrap(),
348 BroadcastEvent::Subscribed(*b.peer_id(), topic)
349 );
350 b.broadcast(&topic, msg.clone());
351 assert!(b.next().is_none());
352 assert_eq!(
353 a.next().unwrap(),
354 BroadcastEvent::Received(*b.peer_id(), topic, msg)
355 );
356 a.unsubscribe(&topic);
357 assert!(a.next().is_none());
358 assert_eq!(
359 b.next().unwrap(),
360 BroadcastEvent::Unsubscribed(*a.peer_id(), topic)
361 );
362 }
363}