rust_ipfs/p2p/
addressbook.rs

1mod handler;
2
3use std::{
4    collections::{hash_map::Entry, HashMap, HashSet, VecDeque},
5    task::{Context, Poll},
6};
7
8use crate::AddPeerOpt;
9use libp2p::core::transport::PortUse;
10use libp2p::swarm::ConnectionClosed;
11use libp2p::{
12    core::{ConnectedPoint, Endpoint},
13    multiaddr::Protocol,
14    swarm::{
15        self, behaviour::ConnectionEstablished, AddressChange, ConnectionDenied, ConnectionId,
16        FromSwarm, NetworkBehaviour, THandler, THandlerInEvent, ToSwarm,
17    },
18    Multiaddr, PeerId,
19};
20
21#[derive(Default, Debug, Copy, Clone)]
22pub struct Config {
23    /// Store peer address on an established connection
24    pub store_on_connection: bool,
25    /// Keep connection alive automatically if peer is added through `Behaviour::add_address`
26    pub keep_connection_alive: bool,
27}
28
29#[derive(Default, Debug)]
30pub struct Behaviour {
31    events: VecDeque<ToSwarm<<Self as NetworkBehaviour>::ToSwarm, THandlerInEvent<Self>>>,
32    connections: HashMap<PeerId, HashSet<ConnectionId>>,
33    peer_addresses: HashMap<PeerId, HashSet<Multiaddr>>,
34    peer_keepalive: HashSet<PeerId>,
35    config: Config,
36}
37
38impl Behaviour {
39    pub fn with_config(config: Config) -> Self {
40        Self {
41            config,
42            ..Default::default()
43        }
44    }
45    pub fn add_address<I: Into<AddPeerOpt>>(&mut self, opt: I) -> bool {
46        let opt = opt.into();
47
48        let peer_id = opt.peer_id();
49        let addresses = opt.addresses();
50
51        if !addresses.is_empty() {
52            let addrs = self.peer_addresses.entry(*peer_id).or_default();
53
54            for addr in addresses {
55                addrs.insert(addr.clone());
56            }
57
58            if let Some(opts) = opt.to_dial_opts() {
59                self.events.push_back(ToSwarm::Dial { opts });
60            }
61        }
62
63        if (opt.can_keep_alive() || self.config.keep_connection_alive)
64            && self.peer_addresses.contains_key(peer_id)
65        {
66            self.keep_peer_alive(peer_id);
67        }
68
69        true
70    }
71
72    pub fn remove_address(&mut self, peer_id: &PeerId, addr: &Multiaddr) -> bool {
73        if let Entry::Occupied(mut e) = self.peer_addresses.entry(*peer_id) {
74            let entry = e.get_mut();
75
76            if !entry.remove(addr) {
77                return false;
78            }
79
80            if entry.is_empty() {
81                e.remove();
82                self.dont_keep_peer_alive(peer_id);
83            }
84        }
85        true
86    }
87
88    pub fn remove_peer(&mut self, peer_id: &PeerId) -> bool {
89        let removed = self.peer_addresses.remove(peer_id).is_some();
90        if removed {
91            self.dont_keep_peer_alive(peer_id);
92        }
93        removed
94    }
95
96    pub fn contains(&self, peer_id: &PeerId, addr: &Multiaddr) -> bool {
97        self.peer_addresses
98            .get(peer_id)
99            .map(|list| list.contains(addr))
100            .unwrap_or_default()
101    }
102
103    pub fn get_peer_addresses(&self, peer_id: &PeerId) -> Option<Vec<Multiaddr>> {
104        self.peer_addresses
105            .get(peer_id)
106            .cloned()
107            .map(Vec::from_iter)
108    }
109
110    pub fn iter(&self) -> impl Iterator<Item = (&PeerId, &HashSet<Multiaddr>)> {
111        self.peer_addresses.iter()
112    }
113
114    fn keep_peer_alive(&mut self, peer_id: &PeerId) {
115        self.peer_keepalive.insert(*peer_id);
116        if let Some(conns) = self.connections.get(peer_id) {
117            self.events.extend(
118                conns
119                    .iter()
120                    .copied()
121                    .map(|connection_id| ToSwarm::NotifyHandler {
122                        peer_id: *peer_id,
123                        handler: swarm::NotifyHandler::One(connection_id),
124                        event: handler::In::Protect,
125                    }),
126            )
127        }
128    }
129
130    fn dont_keep_peer_alive(&mut self, peer_id: &PeerId) {
131        self.peer_keepalive.remove(peer_id);
132        if let Some(conns) = self.connections.get(peer_id) {
133            self.events.extend(
134                conns
135                    .iter()
136                    .copied()
137                    .map(|connection_id| ToSwarm::NotifyHandler {
138                        peer_id: *peer_id,
139                        handler: swarm::NotifyHandler::One(connection_id),
140                        event: handler::In::Unprotect,
141                    }),
142            )
143        }
144    }
145
146    fn on_connection_established(
147        &mut self,
148        ConnectionEstablished {
149            peer_id,
150            connection_id,
151            endpoint,
152            ..
153        }: ConnectionEstablished,
154    ) {
155        self.connections
156            .entry(peer_id)
157            .or_default()
158            .insert(connection_id);
159
160        if !self.config.store_on_connection {
161            return;
162        }
163
164        let mut addr = match endpoint {
165            ConnectedPoint::Dialer { address, .. } => address.clone(),
166            ConnectedPoint::Listener { local_addr, .. } if endpoint.is_relayed() => {
167                local_addr.clone()
168            }
169            ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr.clone(),
170        };
171
172        if matches!(addr.iter().last(), Some(Protocol::P2p(_))) {
173            addr.pop();
174        }
175
176        self.peer_addresses.entry(peer_id).or_default().insert(addr);
177    }
178
179    fn on_connection_closed(
180        &mut self,
181        ConnectionClosed {
182            peer_id,
183            connection_id,
184            remaining_established,
185            ..
186        }: ConnectionClosed,
187    ) {
188        if let Entry::Occupied(mut entry) = self.connections.entry(peer_id) {
189            let list = entry.get_mut();
190            list.remove(&connection_id);
191            if list.is_empty() && remaining_established == 0 {
192                entry.remove();
193            }
194        }
195    }
196}
197
198impl NetworkBehaviour for Behaviour {
199    type ConnectionHandler = handler::Handler;
200    type ToSwarm = void::Void;
201
202    fn handle_pending_outbound_connection(
203        &mut self,
204        _: ConnectionId,
205        peer_id: Option<PeerId>,
206        _: &[Multiaddr],
207        _: Endpoint,
208    ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
209        let Some(peer_id) = peer_id else {
210            return Ok(vec![]);
211        };
212
213        let list = self
214            .peer_addresses
215            .get(&peer_id)
216            .cloned()
217            .map(Vec::from_iter)
218            .unwrap_or_default();
219
220        Ok(list)
221    }
222
223    fn handle_established_inbound_connection(
224        &mut self,
225        _: ConnectionId,
226        peer_id: PeerId,
227        _: &Multiaddr,
228        _: &Multiaddr,
229    ) -> Result<THandler<Self>, ConnectionDenied> {
230        let keepalive = self.peer_keepalive.contains(&peer_id);
231        Ok(handler::Handler::new(keepalive))
232    }
233
234    fn handle_established_outbound_connection(
235        &mut self,
236        _: ConnectionId,
237        peer_id: PeerId,
238        _: &Multiaddr,
239        _: Endpoint,
240        _: PortUse,
241    ) -> Result<THandler<Self>, ConnectionDenied> {
242        let keepalive = self.peer_keepalive.contains(&peer_id);
243        Ok(handler::Handler::new(keepalive))
244    }
245
246    fn on_connection_handler_event(
247        &mut self,
248        _: PeerId,
249        _: ConnectionId,
250        _: swarm::THandlerOutEvent<Self>,
251    ) {
252    }
253
254    fn on_swarm_event(&mut self, event: FromSwarm) {
255        match event {
256            FromSwarm::AddressChange(AddressChange {
257                peer_id, old, new, ..
258            }) => {
259                let mut old = match old {
260                    ConnectedPoint::Dialer { address, .. } => address.clone(),
261                    ConnectedPoint::Listener { local_addr, .. } if old.is_relayed() => {
262                        local_addr.clone()
263                    }
264                    ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr.clone(),
265                };
266
267                if matches!(old.iter().last(), Some(Protocol::P2p(_))) {
268                    old.pop();
269                }
270
271                let mut new = match new {
272                    ConnectedPoint::Dialer { address, .. } => address.clone(),
273                    ConnectedPoint::Listener { local_addr, .. } if new.is_relayed() => {
274                        local_addr.clone()
275                    }
276                    ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr.clone(),
277                };
278
279                if matches!(new.iter().last(), Some(Protocol::P2p(_))) {
280                    new.pop();
281                }
282
283                if let Entry::Occupied(mut e) = self.peer_addresses.entry(peer_id) {
284                    let entry = e.get_mut();
285                    entry.insert(new);
286                    entry.remove(&old);
287                }
288            }
289            FromSwarm::ConnectionEstablished(ev) => self.on_connection_established(ev),
290            FromSwarm::ConnectionClosed(ev) => self.on_connection_closed(ev),
291            FromSwarm::DialFailure(_) => {}
292            FromSwarm::ListenFailure(_) => {}
293            FromSwarm::NewListener(_) => {}
294            FromSwarm::NewListenAddr(_) => {}
295            FromSwarm::ExpiredListenAddr(_) => {}
296            FromSwarm::ListenerError(_) => {}
297            FromSwarm::ListenerClosed(_) => {}
298            FromSwarm::NewExternalAddrCandidate(_) => {}
299            FromSwarm::ExternalAddrConfirmed(_) => {}
300            FromSwarm::ExternalAddrExpired(_) => {}
301            _ => {}
302        }
303    }
304
305    fn poll(&mut self, _: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
306        if let Some(event) = self.events.pop_front() {
307            return Poll::Ready(event);
308        }
309        Poll::Pending
310    }
311}
312
313#[cfg(test)]
314mod test {
315    use std::time::Duration;
316
317    use futures::{FutureExt, StreamExt};
318    use libp2p::{
319        swarm::{dial_opts::DialOpts, SwarmEvent},
320        Multiaddr, PeerId, Swarm, SwarmBuilder,
321    };
322
323    use crate::AddPeerOpt;
324
325    #[tokio::test]
326    async fn dial_with_peer_id() -> anyhow::Result<()> {
327        let (_, _, mut swarm1) = build_swarm(false).await;
328        let (peer2, addr2, mut swarm2) = build_swarm(false).await;
329
330        let opts = AddPeerOpt::with_peer_id(peer2).add_address(addr2);
331
332        swarm1.behaviour_mut().add_address(opts);
333
334        swarm1.dial(peer2)?;
335
336        loop {
337            futures::select! {
338                event = swarm1.select_next_some() => {
339                    if let SwarmEvent::ConnectionEstablished { peer_id, .. } = event {
340                        assert_eq!(peer_id, peer2);
341                        break;
342                    }
343                }
344                _ = swarm2.next() => {}
345            }
346        }
347        Ok(())
348    }
349
350    #[tokio::test]
351    async fn remove_peer_address() -> anyhow::Result<()> {
352        let (_, _, mut swarm1) = build_swarm(false).await;
353        let (peer2, addr2, mut swarm2) = build_swarm(false).await;
354        let opts = AddPeerOpt::with_peer_id(peer2).add_address(addr2);
355        swarm1.behaviour_mut().add_address(opts);
356
357        swarm1.dial(peer2)?;
358
359        loop {
360            futures::select! {
361                event = swarm1.select_next_some() => {
362                    if let SwarmEvent::ConnectionEstablished { peer_id, .. } = event {
363                        assert_eq!(peer_id, peer2);
364                        break;
365                    }
366                }
367                _ = swarm2.next() => {}
368            }
369        }
370
371        swarm1.disconnect_peer_id(peer2).expect("Shouldnt fail");
372
373        loop {
374            futures::select! {
375                event = swarm1.select_next_some() => {
376                    if let SwarmEvent::ConnectionClosed { peer_id, .. } = event {
377                        assert_eq!(peer_id, peer2);
378                        break;
379                    }
380                }
381                _ = swarm2.next() => {}
382            }
383        }
384
385        swarm1.behaviour_mut().remove_peer(&peer2);
386
387        assert!(swarm1.dial(peer2).is_err());
388
389        Ok(())
390    }
391
392    #[tokio::test]
393    async fn dial_and_keepalive() -> anyhow::Result<()> {
394        let (peer1, addr1, mut swarm1) = build_swarm(false).await;
395        let (peer2, addr2, mut swarm2) = build_swarm(false).await;
396        let opts_1 = AddPeerOpt::with_peer_id(peer2)
397            .add_address(addr2)
398            .keepalive();
399        swarm1.behaviour_mut().add_address(opts_1);
400
401        let opts_2 = AddPeerOpt::with_peer_id(peer1)
402            .add_address(addr1)
403            .keepalive();
404        swarm2.behaviour_mut().add_address(opts_2);
405
406        swarm1.dial(peer2)?;
407
408        let mut peer_a_connected = false;
409        let mut peer_b_connected = false;
410
411        loop {
412            futures::select! {
413                event = swarm1.select_next_some() => {
414                    if let SwarmEvent::ConnectionEstablished { peer_id, .. } = event {
415                        assert_eq!(peer_id, peer2);
416                        peer_b_connected = true;
417                    }
418                }
419                event = swarm2.select_next_some() => {
420                     if let SwarmEvent::ConnectionEstablished { peer_id, .. } = event {
421                        assert_eq!(peer_id, peer1);
422                        peer_a_connected = true;
423                    }
424                }
425            }
426
427            if peer_a_connected && peer_b_connected {
428                break;
429            }
430        }
431
432        let mut timer = futures_timer::Delay::new(Duration::from_secs(4)).fuse();
433
434        loop {
435            futures::select! {
436                _ = &mut timer => {
437                    break;
438                }
439                event = swarm1.select_next_some() => {
440                    if let SwarmEvent::ConnectionClosed { peer_id, .. } = event {
441                        assert_eq!(peer_id, peer2);
442                        unreachable!("connection shouldnt have closed")
443                    }
444                }
445                event = swarm2.select_next_some() => {
446                    if let SwarmEvent::ConnectionClosed { peer_id, .. } = event {
447                        assert_eq!(peer_id, peer1);
448                        unreachable!("connection shouldnt have closed")
449                    }
450                }
451            }
452        }
453
454        Ok(())
455    }
456
457    #[tokio::test]
458    async fn store_address() -> anyhow::Result<()> {
459        let (_, _, mut swarm1) = build_swarm(true).await;
460        let (peer2, addr2, mut swarm2) = build_swarm(true).await;
461
462        let opt = DialOpts::peer_id(peer2)
463            .addresses(vec![addr2.clone()])
464            .build();
465
466        swarm1.dial(opt)?;
467
468        loop {
469            futures::select! {
470                event = swarm1.select_next_some() => {
471                    if let SwarmEvent::ConnectionEstablished { peer_id, .. } = event {
472                        assert_eq!(peer_id, peer2);
473                        break;
474                    }
475                }
476                _ = swarm2.next() => {}
477            }
478        }
479
480        let addrs = swarm1
481            .behaviour()
482            .get_peer_addresses(&peer2)
483            .expect("Exist");
484
485        for addr in addrs {
486            assert_eq!(addr, addr2);
487        }
488        Ok(())
489    }
490
491    async fn build_swarm(
492        store_on_connection: bool,
493    ) -> (PeerId, Multiaddr, Swarm<super::Behaviour>) {
494        let mut swarm = SwarmBuilder::with_new_identity()
495            .with_tokio()
496            .with_tcp(
497                libp2p::tcp::Config::default(),
498                libp2p::noise::Config::new,
499                libp2p::yamux::Config::default,
500            )
501            .expect("")
502            .with_behaviour(|_| {
503                super::Behaviour::with_config(super::Config {
504                    store_on_connection,
505                    ..Default::default()
506                })
507            })
508            .expect("")
509            .with_swarm_config(|c| c.with_idle_connection_timeout(Duration::from_secs(3)))
510            .build();
511
512        Swarm::listen_on(&mut swarm, "/ip4/127.0.0.1/tcp/0".parse().unwrap()).unwrap();
513
514        if let Some(SwarmEvent::NewListenAddr { address, .. }) = swarm.next().await {
515            let peer_id = swarm.local_peer_id();
516            return (*peer_id, address, swarm);
517        }
518
519        panic!("no new addrs")
520    }
521}