1mod handler;
2
3use std::{
4 collections::{hash_map::Entry, HashMap, HashSet, VecDeque},
5 task::{Context, Poll},
6};
7
8use crate::AddPeerOpt;
9use libp2p::core::transport::PortUse;
10use libp2p::swarm::ConnectionClosed;
11use libp2p::{
12 core::{ConnectedPoint, Endpoint},
13 multiaddr::Protocol,
14 swarm::{
15 self, behaviour::ConnectionEstablished, AddressChange, ConnectionDenied, ConnectionId,
16 FromSwarm, NetworkBehaviour, THandler, THandlerInEvent, ToSwarm,
17 },
18 Multiaddr, PeerId,
19};
20
21#[derive(Default, Debug, Copy, Clone)]
22pub struct Config {
23 pub store_on_connection: bool,
25 pub keep_connection_alive: bool,
27}
28
29#[derive(Default, Debug)]
30pub struct Behaviour {
31 events: VecDeque<ToSwarm<<Self as NetworkBehaviour>::ToSwarm, THandlerInEvent<Self>>>,
32 connections: HashMap<PeerId, HashSet<ConnectionId>>,
33 peer_addresses: HashMap<PeerId, HashSet<Multiaddr>>,
34 peer_keepalive: HashSet<PeerId>,
35 config: Config,
36}
37
38impl Behaviour {
39 pub fn with_config(config: Config) -> Self {
40 Self {
41 config,
42 ..Default::default()
43 }
44 }
45 pub fn add_address<I: Into<AddPeerOpt>>(&mut self, opt: I) -> bool {
46 let opt = opt.into();
47
48 let peer_id = opt.peer_id();
49 let addresses = opt.addresses();
50
51 if !addresses.is_empty() {
52 let addrs = self.peer_addresses.entry(*peer_id).or_default();
53
54 for addr in addresses {
55 addrs.insert(addr.clone());
56 }
57
58 if let Some(opts) = opt.to_dial_opts() {
59 self.events.push_back(ToSwarm::Dial { opts });
60 }
61 }
62
63 if (opt.can_keep_alive() || self.config.keep_connection_alive)
64 && self.peer_addresses.contains_key(peer_id)
65 {
66 self.keep_peer_alive(peer_id);
67 }
68
69 true
70 }
71
72 pub fn remove_address(&mut self, peer_id: &PeerId, addr: &Multiaddr) -> bool {
73 if let Entry::Occupied(mut e) = self.peer_addresses.entry(*peer_id) {
74 let entry = e.get_mut();
75
76 if !entry.remove(addr) {
77 return false;
78 }
79
80 if entry.is_empty() {
81 e.remove();
82 self.dont_keep_peer_alive(peer_id);
83 }
84 }
85 true
86 }
87
88 pub fn remove_peer(&mut self, peer_id: &PeerId) -> bool {
89 let removed = self.peer_addresses.remove(peer_id).is_some();
90 if removed {
91 self.dont_keep_peer_alive(peer_id);
92 }
93 removed
94 }
95
96 pub fn contains(&self, peer_id: &PeerId, addr: &Multiaddr) -> bool {
97 self.peer_addresses
98 .get(peer_id)
99 .map(|list| list.contains(addr))
100 .unwrap_or_default()
101 }
102
103 pub fn get_peer_addresses(&self, peer_id: &PeerId) -> Option<Vec<Multiaddr>> {
104 self.peer_addresses
105 .get(peer_id)
106 .cloned()
107 .map(Vec::from_iter)
108 }
109
110 pub fn iter(&self) -> impl Iterator<Item = (&PeerId, &HashSet<Multiaddr>)> {
111 self.peer_addresses.iter()
112 }
113
114 fn keep_peer_alive(&mut self, peer_id: &PeerId) {
115 self.peer_keepalive.insert(*peer_id);
116 if let Some(conns) = self.connections.get(peer_id) {
117 self.events.extend(
118 conns
119 .iter()
120 .copied()
121 .map(|connection_id| ToSwarm::NotifyHandler {
122 peer_id: *peer_id,
123 handler: swarm::NotifyHandler::One(connection_id),
124 event: handler::In::Protect,
125 }),
126 )
127 }
128 }
129
130 fn dont_keep_peer_alive(&mut self, peer_id: &PeerId) {
131 self.peer_keepalive.remove(peer_id);
132 if let Some(conns) = self.connections.get(peer_id) {
133 self.events.extend(
134 conns
135 .iter()
136 .copied()
137 .map(|connection_id| ToSwarm::NotifyHandler {
138 peer_id: *peer_id,
139 handler: swarm::NotifyHandler::One(connection_id),
140 event: handler::In::Unprotect,
141 }),
142 )
143 }
144 }
145
146 fn on_connection_established(
147 &mut self,
148 ConnectionEstablished {
149 peer_id,
150 connection_id,
151 endpoint,
152 ..
153 }: ConnectionEstablished,
154 ) {
155 self.connections
156 .entry(peer_id)
157 .or_default()
158 .insert(connection_id);
159
160 if !self.config.store_on_connection {
161 return;
162 }
163
164 let mut addr = match endpoint {
165 ConnectedPoint::Dialer { address, .. } => address.clone(),
166 ConnectedPoint::Listener { local_addr, .. } if endpoint.is_relayed() => {
167 local_addr.clone()
168 }
169 ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr.clone(),
170 };
171
172 if matches!(addr.iter().last(), Some(Protocol::P2p(_))) {
173 addr.pop();
174 }
175
176 self.peer_addresses.entry(peer_id).or_default().insert(addr);
177 }
178
179 fn on_connection_closed(
180 &mut self,
181 ConnectionClosed {
182 peer_id,
183 connection_id,
184 remaining_established,
185 ..
186 }: ConnectionClosed,
187 ) {
188 if let Entry::Occupied(mut entry) = self.connections.entry(peer_id) {
189 let list = entry.get_mut();
190 list.remove(&connection_id);
191 if list.is_empty() && remaining_established == 0 {
192 entry.remove();
193 }
194 }
195 }
196}
197
198impl NetworkBehaviour for Behaviour {
199 type ConnectionHandler = handler::Handler;
200 type ToSwarm = void::Void;
201
202 fn handle_pending_outbound_connection(
203 &mut self,
204 _: ConnectionId,
205 peer_id: Option<PeerId>,
206 _: &[Multiaddr],
207 _: Endpoint,
208 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
209 let Some(peer_id) = peer_id else {
210 return Ok(vec![]);
211 };
212
213 let list = self
214 .peer_addresses
215 .get(&peer_id)
216 .cloned()
217 .map(Vec::from_iter)
218 .unwrap_or_default();
219
220 Ok(list)
221 }
222
223 fn handle_established_inbound_connection(
224 &mut self,
225 _: ConnectionId,
226 peer_id: PeerId,
227 _: &Multiaddr,
228 _: &Multiaddr,
229 ) -> Result<THandler<Self>, ConnectionDenied> {
230 let keepalive = self.peer_keepalive.contains(&peer_id);
231 Ok(handler::Handler::new(keepalive))
232 }
233
234 fn handle_established_outbound_connection(
235 &mut self,
236 _: ConnectionId,
237 peer_id: PeerId,
238 _: &Multiaddr,
239 _: Endpoint,
240 _: PortUse,
241 ) -> Result<THandler<Self>, ConnectionDenied> {
242 let keepalive = self.peer_keepalive.contains(&peer_id);
243 Ok(handler::Handler::new(keepalive))
244 }
245
246 fn on_connection_handler_event(
247 &mut self,
248 _: PeerId,
249 _: ConnectionId,
250 _: swarm::THandlerOutEvent<Self>,
251 ) {
252 }
253
254 fn on_swarm_event(&mut self, event: FromSwarm) {
255 match event {
256 FromSwarm::AddressChange(AddressChange {
257 peer_id, old, new, ..
258 }) => {
259 let mut old = match old {
260 ConnectedPoint::Dialer { address, .. } => address.clone(),
261 ConnectedPoint::Listener { local_addr, .. } if old.is_relayed() => {
262 local_addr.clone()
263 }
264 ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr.clone(),
265 };
266
267 if matches!(old.iter().last(), Some(Protocol::P2p(_))) {
268 old.pop();
269 }
270
271 let mut new = match new {
272 ConnectedPoint::Dialer { address, .. } => address.clone(),
273 ConnectedPoint::Listener { local_addr, .. } if new.is_relayed() => {
274 local_addr.clone()
275 }
276 ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr.clone(),
277 };
278
279 if matches!(new.iter().last(), Some(Protocol::P2p(_))) {
280 new.pop();
281 }
282
283 if let Entry::Occupied(mut e) = self.peer_addresses.entry(peer_id) {
284 let entry = e.get_mut();
285 entry.insert(new);
286 entry.remove(&old);
287 }
288 }
289 FromSwarm::ConnectionEstablished(ev) => self.on_connection_established(ev),
290 FromSwarm::ConnectionClosed(ev) => self.on_connection_closed(ev),
291 FromSwarm::DialFailure(_) => {}
292 FromSwarm::ListenFailure(_) => {}
293 FromSwarm::NewListener(_) => {}
294 FromSwarm::NewListenAddr(_) => {}
295 FromSwarm::ExpiredListenAddr(_) => {}
296 FromSwarm::ListenerError(_) => {}
297 FromSwarm::ListenerClosed(_) => {}
298 FromSwarm::NewExternalAddrCandidate(_) => {}
299 FromSwarm::ExternalAddrConfirmed(_) => {}
300 FromSwarm::ExternalAddrExpired(_) => {}
301 _ => {}
302 }
303 }
304
305 fn poll(&mut self, _: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
306 if let Some(event) = self.events.pop_front() {
307 return Poll::Ready(event);
308 }
309 Poll::Pending
310 }
311}
312
313#[cfg(test)]
314mod test {
315 use std::time::Duration;
316
317 use futures::{FutureExt, StreamExt};
318 use libp2p::{
319 swarm::{dial_opts::DialOpts, SwarmEvent},
320 Multiaddr, PeerId, Swarm, SwarmBuilder,
321 };
322
323 use crate::AddPeerOpt;
324
325 #[tokio::test]
326 async fn dial_with_peer_id() -> anyhow::Result<()> {
327 let (_, _, mut swarm1) = build_swarm(false).await;
328 let (peer2, addr2, mut swarm2) = build_swarm(false).await;
329
330 let opts = AddPeerOpt::with_peer_id(peer2).add_address(addr2);
331
332 swarm1.behaviour_mut().add_address(opts);
333
334 swarm1.dial(peer2)?;
335
336 loop {
337 futures::select! {
338 event = swarm1.select_next_some() => {
339 if let SwarmEvent::ConnectionEstablished { peer_id, .. } = event {
340 assert_eq!(peer_id, peer2);
341 break;
342 }
343 }
344 _ = swarm2.next() => {}
345 }
346 }
347 Ok(())
348 }
349
350 #[tokio::test]
351 async fn remove_peer_address() -> anyhow::Result<()> {
352 let (_, _, mut swarm1) = build_swarm(false).await;
353 let (peer2, addr2, mut swarm2) = build_swarm(false).await;
354 let opts = AddPeerOpt::with_peer_id(peer2).add_address(addr2);
355 swarm1.behaviour_mut().add_address(opts);
356
357 swarm1.dial(peer2)?;
358
359 loop {
360 futures::select! {
361 event = swarm1.select_next_some() => {
362 if let SwarmEvent::ConnectionEstablished { peer_id, .. } = event {
363 assert_eq!(peer_id, peer2);
364 break;
365 }
366 }
367 _ = swarm2.next() => {}
368 }
369 }
370
371 swarm1.disconnect_peer_id(peer2).expect("Shouldnt fail");
372
373 loop {
374 futures::select! {
375 event = swarm1.select_next_some() => {
376 if let SwarmEvent::ConnectionClosed { peer_id, .. } = event {
377 assert_eq!(peer_id, peer2);
378 break;
379 }
380 }
381 _ = swarm2.next() => {}
382 }
383 }
384
385 swarm1.behaviour_mut().remove_peer(&peer2);
386
387 assert!(swarm1.dial(peer2).is_err());
388
389 Ok(())
390 }
391
392 #[tokio::test]
393 async fn dial_and_keepalive() -> anyhow::Result<()> {
394 let (peer1, addr1, mut swarm1) = build_swarm(false).await;
395 let (peer2, addr2, mut swarm2) = build_swarm(false).await;
396 let opts_1 = AddPeerOpt::with_peer_id(peer2)
397 .add_address(addr2)
398 .keepalive();
399 swarm1.behaviour_mut().add_address(opts_1);
400
401 let opts_2 = AddPeerOpt::with_peer_id(peer1)
402 .add_address(addr1)
403 .keepalive();
404 swarm2.behaviour_mut().add_address(opts_2);
405
406 swarm1.dial(peer2)?;
407
408 let mut peer_a_connected = false;
409 let mut peer_b_connected = false;
410
411 loop {
412 futures::select! {
413 event = swarm1.select_next_some() => {
414 if let SwarmEvent::ConnectionEstablished { peer_id, .. } = event {
415 assert_eq!(peer_id, peer2);
416 peer_b_connected = true;
417 }
418 }
419 event = swarm2.select_next_some() => {
420 if let SwarmEvent::ConnectionEstablished { peer_id, .. } = event {
421 assert_eq!(peer_id, peer1);
422 peer_a_connected = true;
423 }
424 }
425 }
426
427 if peer_a_connected && peer_b_connected {
428 break;
429 }
430 }
431
432 let mut timer = futures_timer::Delay::new(Duration::from_secs(4)).fuse();
433
434 loop {
435 futures::select! {
436 _ = &mut timer => {
437 break;
438 }
439 event = swarm1.select_next_some() => {
440 if let SwarmEvent::ConnectionClosed { peer_id, .. } = event {
441 assert_eq!(peer_id, peer2);
442 unreachable!("connection shouldnt have closed")
443 }
444 }
445 event = swarm2.select_next_some() => {
446 if let SwarmEvent::ConnectionClosed { peer_id, .. } = event {
447 assert_eq!(peer_id, peer1);
448 unreachable!("connection shouldnt have closed")
449 }
450 }
451 }
452 }
453
454 Ok(())
455 }
456
457 #[tokio::test]
458 async fn store_address() -> anyhow::Result<()> {
459 let (_, _, mut swarm1) = build_swarm(true).await;
460 let (peer2, addr2, mut swarm2) = build_swarm(true).await;
461
462 let opt = DialOpts::peer_id(peer2)
463 .addresses(vec![addr2.clone()])
464 .build();
465
466 swarm1.dial(opt)?;
467
468 loop {
469 futures::select! {
470 event = swarm1.select_next_some() => {
471 if let SwarmEvent::ConnectionEstablished { peer_id, .. } = event {
472 assert_eq!(peer_id, peer2);
473 break;
474 }
475 }
476 _ = swarm2.next() => {}
477 }
478 }
479
480 let addrs = swarm1
481 .behaviour()
482 .get_peer_addresses(&peer2)
483 .expect("Exist");
484
485 for addr in addrs {
486 assert_eq!(addr, addr2);
487 }
488 Ok(())
489 }
490
491 async fn build_swarm(
492 store_on_connection: bool,
493 ) -> (PeerId, Multiaddr, Swarm<super::Behaviour>) {
494 let mut swarm = SwarmBuilder::with_new_identity()
495 .with_tokio()
496 .with_tcp(
497 libp2p::tcp::Config::default(),
498 libp2p::noise::Config::new,
499 libp2p::yamux::Config::default,
500 )
501 .expect("")
502 .with_behaviour(|_| {
503 super::Behaviour::with_config(super::Config {
504 store_on_connection,
505 ..Default::default()
506 })
507 })
508 .expect("")
509 .with_swarm_config(|c| c.with_idle_connection_timeout(Duration::from_secs(3)))
510 .build();
511
512 Swarm::listen_on(&mut swarm, "/ip4/127.0.0.1/tcp/0".parse().unwrap()).unwrap();
513
514 if let Some(SwarmEvent::NewListenAddr { address, .. }) = swarm.next().await {
515 let peer_id = swarm.local_peer_id();
516 return (*peer_id, address, swarm);
517 }
518
519 panic!("no new addrs")
520 }
521}