1use self::mdns_wrapper::MdnsWrapper;
2use futures::FutureExt;
3use libp2p::{
4 Multiaddr,
5 PeerId,
6 core::{
7 Endpoint,
8 transport::PortUse,
9 },
10 kad::{
11 self,
12 store::MemoryStore,
13 },
14 mdns,
15 swarm::{
16 ConnectionDenied,
17 ConnectionId,
18 NetworkBehaviour,
19 THandler,
20 derive_prelude::{
21 ConnectionClosed,
22 ConnectionEstablished,
23 FromSwarm,
24 },
25 },
26};
27
28use libp2p::swarm::{
29 THandlerInEvent,
30 THandlerOutEvent,
31 ToSwarm,
32};
33use std::{
34 collections::HashSet,
35 pin::Pin,
36 task::{
37 Context,
38 Poll,
39 },
40 time::Duration,
41};
42use tracing::trace;
43mod discovery_config;
44mod mdns_wrapper;
45
46pub use discovery_config::Config;
47
48const SIXTY_SECONDS: Duration = Duration::from_secs(60);
49
50pub type Event = kad::Event;
51
52pub struct Behaviour {
54 connected_peers: HashSet<PeerId>,
56
57 mdns: MdnsWrapper,
59
60 kademlia: kad::Behaviour<MemoryStore>,
62
63 next_kad_random_walk: Option<Pin<Box<tokio::time::Sleep>>>,
66
67 duration_to_next_kad: Duration,
69
70 max_peers_connected: usize,
72}
73
74impl Behaviour {
75 pub fn add_address(&mut self, peer_id: &PeerId, address: Multiaddr) {
77 self.kademlia.add_address(peer_id, address);
78 }
79}
80
81impl NetworkBehaviour for Behaviour {
82 type ConnectionHandler =
83 <kad::Behaviour<MemoryStore> as NetworkBehaviour>::ConnectionHandler;
84 type ToSwarm = kad::Event;
85
86 fn handle_established_inbound_connection(
87 &mut self,
88 connection_id: ConnectionId,
89 peer: PeerId,
90 local_addr: &Multiaddr,
91 remote_addr: &Multiaddr,
92 ) -> Result<THandler<Self>, ConnectionDenied> {
93 self.kademlia.handle_established_inbound_connection(
94 connection_id,
95 peer,
96 local_addr,
97 remote_addr,
98 )
99 }
100
101 fn handle_pending_outbound_connection(
103 &mut self,
104 connection_id: ConnectionId,
105 maybe_peer: Option<PeerId>,
106 addresses: &[Multiaddr],
107 effective_role: Endpoint,
108 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
109 let mut kademlia_addrs = self.kademlia.handle_pending_outbound_connection(
110 connection_id,
111 maybe_peer,
112 addresses,
113 effective_role,
114 )?;
115 let mdns_addrs = self.mdns.handle_pending_outbound_connection(
116 connection_id,
117 maybe_peer,
118 addresses,
119 effective_role,
120 )?;
121 kademlia_addrs.extend(mdns_addrs);
122 Ok(kademlia_addrs)
123 }
124
125 fn handle_established_outbound_connection(
126 &mut self,
127 connection_id: ConnectionId,
128 peer: PeerId,
129 addr: &Multiaddr,
130 role_override: Endpoint,
131 port_use: PortUse,
132 ) -> Result<THandler<Self>, ConnectionDenied> {
133 self.kademlia.handle_established_outbound_connection(
134 connection_id,
135 peer,
136 addr,
137 role_override,
138 port_use,
139 )
140 }
141
142 fn on_swarm_event(&mut self, event: FromSwarm) {
143 match &event {
144 FromSwarm::ConnectionEstablished(ConnectionEstablished {
145 peer_id,
146 other_established,
147 ..
148 }) => {
149 if *other_established == 0 {
150 self.connected_peers.insert(*peer_id);
151
152 trace!("Connected to a peer {:?}", peer_id);
153 }
154 }
155 FromSwarm::ConnectionClosed(ConnectionClosed {
156 peer_id,
157 remaining_established,
158 ..
159 }) => {
160 if *remaining_established == 0 {
161 self.connected_peers.remove(peer_id);
162 trace!("Disconnected from {:?}", peer_id);
163 }
164 }
165 _ => (),
166 }
167 self.mdns.on_swarm_event(&event);
168 self.kademlia.on_swarm_event(event);
169 }
170
171 fn on_connection_handler_event(
173 &mut self,
174 peer_id: PeerId,
175 connection: ConnectionId,
176 event: THandlerOutEvent<Self>,
177 ) {
178 self.kademlia
179 .on_connection_handler_event(peer_id, connection, event);
180 }
181
182 fn poll(
184 &mut self,
185 cx: &mut Context<'_>,
186 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
187 if let Some(next_kad_random_query) = self.next_kad_random_walk.as_mut() {
189 while next_kad_random_query.poll_unpin(cx).is_ready() {
190 if self.connected_peers.len() < self.max_peers_connected {
191 let random_peer_id = PeerId::random();
192 self.kademlia.get_closest_peers(random_peer_id);
193 }
194
195 *next_kad_random_query =
196 Box::pin(tokio::time::sleep(self.duration_to_next_kad));
197 self.duration_to_next_kad = std::cmp::min(
200 self.duration_to_next_kad.saturating_mul(2),
201 SIXTY_SECONDS,
202 );
203 }
204 }
205
206 if let Poll::Ready(kad_action) = self.kademlia.poll(cx) {
208 return Poll::Ready(kad_action)
209 };
210
211 while let Poll::Ready(mdns_event) = self.mdns.poll(cx) {
212 match mdns_event {
213 ToSwarm::GenerateEvent(mdns::Event::Discovered(list)) => {
214 for (peer_id, multiaddr) in list {
215 self.kademlia.add_address(&peer_id, multiaddr);
216 }
217 }
218 ToSwarm::CloseConnection {
219 peer_id,
220 connection,
221 } => {
222 return Poll::Ready(ToSwarm::CloseConnection {
223 peer_id,
224 connection,
225 })
226 }
227 _ => {}
228 }
229 }
230 Poll::Pending
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::{
237 Behaviour,
238 Config,
239 Event,
240 };
241 use futures::{
242 StreamExt,
243 future::poll_fn,
244 };
245 use libp2p::{
246 Multiaddr,
247 PeerId,
248 Swarm,
249 identity::Keypair,
250 multiaddr::Protocol,
251 swarm::SwarmEvent,
252 };
253 use std::{
254 collections::HashSet,
255 task::Poll,
256 time::Duration,
257 };
258
259 use libp2p_swarm_test::SwarmExt;
260
261 const MAX_PEERS: usize = 50;
262
263 fn build_behavior_fn(
264 bootstrap_nodes: Vec<Multiaddr>,
265 ) -> impl FnOnce(Keypair) -> Behaviour {
266 |keypair| {
267 let mut config =
268 Config::new(keypair.public().to_peer_id(), "test_network".into());
269 config
270 .max_peers_connected(MAX_PEERS)
271 .with_bootstrap_nodes(bootstrap_nodes)
272 .with_random_walk(Duration::from_millis(500));
273
274 config.finish().expect("Config should be valid")
275 }
276 }
277
278 fn build_fuel_discovery(
280 bootstrap_nodes: Vec<Multiaddr>,
281 ) -> (Swarm<Behaviour>, Multiaddr, PeerId) {
282 let behaviour_fn = build_behavior_fn(bootstrap_nodes);
283
284 let listen_addr: Multiaddr = Protocol::Memory(rand::random::<u64>()).into();
285
286 let mut swarm = Swarm::new_ephemeral(behaviour_fn);
287
288 swarm
289 .listen_on(listen_addr.clone())
290 .expect("swarm should start listening");
291
292 let peer_id = swarm.local_peer_id().to_owned();
293
294 (swarm, listen_addr, peer_id)
295 }
296
297 #[tokio::test]
302 async fn discovery_works() {
303 let num_of_swarms = 25;
305 let (first_swarm, first_peer_addr, first_peer_id) = build_fuel_discovery(vec![]);
306 let bootstrap_addr: Multiaddr =
307 format!("{}/p2p/{}", first_peer_addr.clone(), first_peer_id)
308 .parse()
309 .unwrap();
310
311 let mut discovery_swarms = Vec::new();
312 discovery_swarms.push((first_swarm, first_peer_addr, first_peer_id));
313
314 for _ in 1..num_of_swarms {
315 let (swarm, peer_addr, peer_id) =
316 build_fuel_discovery(vec![bootstrap_addr.clone()]);
317 discovery_swarms.push((swarm, peer_addr, peer_id));
318 }
319
320 let mut left_to_discover = (0..discovery_swarms.len())
322 .map(|current_index| {
323 (0..discovery_swarms.len())
324 .skip(1) .filter_map(|swarm_index| {
326 if swarm_index != current_index {
328 Some(*Swarm::local_peer_id(&discovery_swarms[swarm_index].0))
330 } else {
331 None
332 }
333 })
334 .collect::<HashSet<_>>()
335 })
336 .collect::<Vec<_>>();
337
338 let test_future = poll_fn(move |cx| {
339 'polling: loop {
340 for swarm_index in 0..discovery_swarms.len() {
341 if let Poll::Ready(Some(event)) =
342 discovery_swarms[swarm_index].0.poll_next_unpin(cx)
343 {
344 match event {
345 SwarmEvent::ConnectionEstablished { peer_id, .. } => {
346 left_to_discover[swarm_index].remove(&peer_id);
348 }
349 SwarmEvent::Behaviour(Event::UnroutablePeer {
350 peer: peer_id,
351 }) => {
352 let unroutable_peer_addr = discovery_swarms
355 .iter()
356 .find_map(|(_, next_addr, next_peer_id)| {
357 if next_peer_id == &peer_id {
359 Some(next_addr.clone())
361 } else {
362 None
363 }
364 })
365 .unwrap();
366
367 discovery_swarms[swarm_index]
370 .0
371 .behaviour_mut()
372 .kademlia
373 .add_address(&peer_id, unroutable_peer_addr.clone());
374 }
375 SwarmEvent::ConnectionClosed { peer_id, .. } => {
376 dbg!(peer_id);
377 }
378 _ => {}
379 }
380 continue 'polling
381 }
382 }
383 break
384 }
385
386 if left_to_discover.iter().all(|l| l.is_empty()) {
388 Poll::Ready(())
390 } else {
391 Poll::Pending
393 }
394 });
395
396 test_future.await;
397 }
398}