ant_libp2p_connection_limits/
lib.rs

1// Copyright 2023 Protocol Labs.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    collections::{HashMap, HashSet},
23    convert::Infallible,
24    fmt,
25    task::{Context, Poll},
26};
27
28use ant_libp2p_core::{transport::PortUse, ConnectedPoint, Endpoint, Multiaddr};
29use ant_libp2p_swarm::{
30    behaviour::{ConnectionEstablished, DialFailure, ListenFailure},
31    dummy, ConnectionClosed, ConnectionDenied, ConnectionId, FromSwarm, NetworkBehaviour, THandler,
32    THandlerInEvent, THandlerOutEvent, ToSwarm,
33};
34use libp2p_identity::PeerId;
35
36/// A [`NetworkBehaviour`] that enforces a set of [`ConnectionLimits`].
37///
38/// For these limits to take effect, this needs to be composed
39/// into the behaviour tree of your application.
40///
41/// If a connection is denied due to a limit, either a
42/// [`SwarmEvent::IncomingConnectionError`](libp2p_swarm::SwarmEvent::IncomingConnectionError)
43/// or [`SwarmEvent::OutgoingConnectionError`](libp2p_swarm::SwarmEvent::OutgoingConnectionError)
44/// will be emitted. The [`ListenError::Denied`](libp2p_swarm::ListenError::Denied) and respectively
45/// the [`DialError::Denied`](libp2p_swarm::DialError::Denied) variant
46/// contain a [`ConnectionDenied`] type that can be downcast to [`Exceeded`] error if (and only if)
47/// **this** behaviour denied the connection.
48///
49/// If you employ multiple [`NetworkBehaviour`]s that manage connections,
50/// it may also be a different error.
51///
52/// # Example
53///
54/// ```rust
55/// # use libp2p_identify as identify;
56/// # use libp2p_ping as ping;
57/// # use libp2p_swarm_derive::NetworkBehaviour;
58/// # use libp2p_connection_limits as connection_limits;
59///
60/// #[derive(NetworkBehaviour)]
61/// # #[behaviour(prelude = "libp2p_swarm::derive_prelude")]
62/// struct MyBehaviour {
63///     identify: identify::Behaviour,
64///     ping: ping::Behaviour,
65///     limits: connection_limits::Behaviour,
66/// }
67/// ```
68pub struct Behaviour {
69    limits: ConnectionLimits,
70
71    pending_inbound_connections: HashSet<ConnectionId>,
72    pending_outbound_connections: HashSet<ConnectionId>,
73    established_inbound_connections: HashSet<ConnectionId>,
74    established_outbound_connections: HashSet<ConnectionId>,
75    established_per_peer: HashMap<PeerId, HashSet<ConnectionId>>,
76}
77
78impl Behaviour {
79    pub fn new(limits: ConnectionLimits) -> Self {
80        Self {
81            limits,
82            pending_inbound_connections: Default::default(),
83            pending_outbound_connections: Default::default(),
84            established_inbound_connections: Default::default(),
85            established_outbound_connections: Default::default(),
86            established_per_peer: Default::default(),
87        }
88    }
89
90    /// Returns a mutable reference to [`ConnectionLimits`].
91    /// > **Note**: A new limit will not be enforced against existing connections.
92    pub fn limits_mut(&mut self) -> &mut ConnectionLimits {
93        &mut self.limits
94    }
95}
96
97fn check_limit(limit: Option<u32>, current: usize, kind: Kind) -> Result<(), ConnectionDenied> {
98    let limit = limit.unwrap_or(u32::MAX);
99    let current = current as u32;
100
101    if current >= limit {
102        return Err(ConnectionDenied::new(Exceeded { limit, kind }));
103    }
104
105    Ok(())
106}
107
108/// A connection limit has been exceeded.
109#[derive(Debug, Clone, Copy)]
110pub struct Exceeded {
111    limit: u32,
112    kind: Kind,
113}
114
115impl Exceeded {
116    pub fn limit(&self) -> u32 {
117        self.limit
118    }
119}
120
121impl fmt::Display for Exceeded {
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        write!(
124            f,
125            "connection limit exceeded: at most {} {} are allowed",
126            self.limit, self.kind
127        )
128    }
129}
130
131#[derive(Debug, Clone, Copy)]
132enum Kind {
133    PendingIncoming,
134    PendingOutgoing,
135    EstablishedIncoming,
136    EstablishedOutgoing,
137    EstablishedPerPeer,
138    EstablishedTotal,
139}
140
141impl fmt::Display for Kind {
142    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143        match self {
144            Kind::PendingIncoming => write!(f, "pending incoming connections"),
145            Kind::PendingOutgoing => write!(f, "pending outgoing connections"),
146            Kind::EstablishedIncoming => write!(f, "established incoming connections"),
147            Kind::EstablishedOutgoing => write!(f, "established outgoing connections"),
148            Kind::EstablishedPerPeer => write!(f, "established connections per peer"),
149            Kind::EstablishedTotal => write!(f, "established connections"),
150        }
151    }
152}
153
154impl std::error::Error for Exceeded {}
155
156/// The configurable connection limits.
157#[derive(Debug, Clone, Default)]
158pub struct ConnectionLimits {
159    max_pending_incoming: Option<u32>,
160    max_pending_outgoing: Option<u32>,
161    max_established_incoming: Option<u32>,
162    max_established_outgoing: Option<u32>,
163    max_established_per_peer: Option<u32>,
164    max_established_total: Option<u32>,
165}
166
167impl ConnectionLimits {
168    /// Configures the maximum number of concurrently incoming connections being established.
169    pub fn with_max_pending_incoming(mut self, limit: Option<u32>) -> Self {
170        self.max_pending_incoming = limit;
171        self
172    }
173
174    /// Configures the maximum number of concurrently outgoing connections being established.
175    pub fn with_max_pending_outgoing(mut self, limit: Option<u32>) -> Self {
176        self.max_pending_outgoing = limit;
177        self
178    }
179
180    /// Configures the maximum number of concurrent established inbound connections.
181    pub fn with_max_established_incoming(mut self, limit: Option<u32>) -> Self {
182        self.max_established_incoming = limit;
183        self
184    }
185
186    /// Configures the maximum number of concurrent established outbound connections.
187    pub fn with_max_established_outgoing(mut self, limit: Option<u32>) -> Self {
188        self.max_established_outgoing = limit;
189        self
190    }
191
192    /// Configures the maximum number of concurrent established connections (both
193    /// inbound and outbound).
194    ///
195    /// Note: This should be used in conjunction with
196    /// [`ConnectionLimits::with_max_established_incoming`] to prevent possible
197    /// eclipse attacks (all connections being inbound).
198    pub fn with_max_established(mut self, limit: Option<u32>) -> Self {
199        self.max_established_total = limit;
200        self
201    }
202
203    /// Configures the maximum number of concurrent established connections per peer,
204    /// regardless of direction (incoming or outgoing).
205    pub fn with_max_established_per_peer(mut self, limit: Option<u32>) -> Self {
206        self.max_established_per_peer = limit;
207        self
208    }
209}
210
211impl NetworkBehaviour for Behaviour {
212    type ConnectionHandler = dummy::ConnectionHandler;
213    type ToSwarm = Infallible;
214
215    fn handle_pending_inbound_connection(
216        &mut self,
217        connection_id: ConnectionId,
218        _: &Multiaddr,
219        _: &Multiaddr,
220    ) -> Result<(), ConnectionDenied> {
221        check_limit(
222            self.limits.max_pending_incoming,
223            self.pending_inbound_connections.len(),
224            Kind::PendingIncoming,
225        )?;
226
227        self.pending_inbound_connections.insert(connection_id);
228
229        Ok(())
230    }
231
232    fn handle_established_inbound_connection(
233        &mut self,
234        connection_id: ConnectionId,
235        peer: PeerId,
236        _: &Multiaddr,
237        _: &Multiaddr,
238    ) -> Result<THandler<Self>, ConnectionDenied> {
239        self.pending_inbound_connections.remove(&connection_id);
240
241        check_limit(
242            self.limits.max_established_incoming,
243            self.established_inbound_connections.len(),
244            Kind::EstablishedIncoming,
245        )?;
246        check_limit(
247            self.limits.max_established_per_peer,
248            self.established_per_peer
249                .get(&peer)
250                .map(|connections| connections.len())
251                .unwrap_or(0),
252            Kind::EstablishedPerPeer,
253        )?;
254        check_limit(
255            self.limits.max_established_total,
256            self.established_inbound_connections.len()
257                + self.established_outbound_connections.len(),
258            Kind::EstablishedTotal,
259        )?;
260
261        Ok(dummy::ConnectionHandler)
262    }
263
264    fn handle_pending_outbound_connection(
265        &mut self,
266        connection_id: ConnectionId,
267        _: Option<PeerId>,
268        _: &[Multiaddr],
269        _: Endpoint,
270    ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
271        check_limit(
272            self.limits.max_pending_outgoing,
273            self.pending_outbound_connections.len(),
274            Kind::PendingOutgoing,
275        )?;
276
277        self.pending_outbound_connections.insert(connection_id);
278
279        Ok(vec![])
280    }
281
282    fn handle_established_outbound_connection(
283        &mut self,
284        connection_id: ConnectionId,
285        peer: PeerId,
286        _: &Multiaddr,
287        _: Endpoint,
288        _: PortUse,
289    ) -> Result<THandler<Self>, ConnectionDenied> {
290        self.pending_outbound_connections.remove(&connection_id);
291
292        check_limit(
293            self.limits.max_established_outgoing,
294            self.established_outbound_connections.len(),
295            Kind::EstablishedOutgoing,
296        )?;
297        check_limit(
298            self.limits.max_established_per_peer,
299            self.established_per_peer
300                .get(&peer)
301                .map(|connections| connections.len())
302                .unwrap_or(0),
303            Kind::EstablishedPerPeer,
304        )?;
305        check_limit(
306            self.limits.max_established_total,
307            self.established_inbound_connections.len()
308                + self.established_outbound_connections.len(),
309            Kind::EstablishedTotal,
310        )?;
311
312        Ok(dummy::ConnectionHandler)
313    }
314
315    fn on_swarm_event(&mut self, event: FromSwarm) {
316        match event {
317            FromSwarm::ConnectionClosed(ConnectionClosed {
318                peer_id,
319                connection_id,
320                ..
321            }) => {
322                self.established_inbound_connections.remove(&connection_id);
323                self.established_outbound_connections.remove(&connection_id);
324                self.established_per_peer
325                    .entry(peer_id)
326                    .or_default()
327                    .remove(&connection_id);
328            }
329            FromSwarm::ConnectionEstablished(ConnectionEstablished {
330                peer_id,
331                endpoint,
332                connection_id,
333                ..
334            }) => {
335                match endpoint {
336                    ConnectedPoint::Listener { .. } => {
337                        self.established_inbound_connections.insert(connection_id);
338                    }
339                    ConnectedPoint::Dialer { .. } => {
340                        self.established_outbound_connections.insert(connection_id);
341                    }
342                }
343
344                self.established_per_peer
345                    .entry(peer_id)
346                    .or_default()
347                    .insert(connection_id);
348            }
349            FromSwarm::DialFailure(DialFailure { connection_id, .. }) => {
350                self.pending_outbound_connections.remove(&connection_id);
351            }
352            FromSwarm::ListenFailure(ListenFailure { connection_id, .. }) => {
353                self.pending_inbound_connections.remove(&connection_id);
354            }
355            _ => {}
356        }
357    }
358
359    fn on_connection_handler_event(
360        &mut self,
361        _id: PeerId,
362        _: ConnectionId,
363        event: THandlerOutEvent<Self>,
364    ) {
365        // TODO: remove when Rust 1.82 is MSRV
366        #[allow(unreachable_patterns)]
367        ant_libp2p_core::util::unreachable(event)
368    }
369
370    fn poll(&mut self, _: &mut Context<'_>) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
371        Poll::Pending
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use libp2p_swarm::{
378        behaviour::toggle::Toggle,
379        dial_opts::{DialOpts, PeerCondition},
380        DialError, ListenError, Swarm, SwarmEvent,
381    };
382    use libp2p_swarm_test::SwarmExt;
383    use quickcheck::*;
384
385    use super::*;
386
387    #[test]
388    fn max_outgoing() {
389        use rand::Rng;
390
391        let outgoing_limit = rand::thread_rng().gen_range(1..10);
392
393        let mut network = Swarm::new_ephemeral(|_| {
394            Behaviour::new(
395                ConnectionLimits::default().with_max_pending_outgoing(Some(outgoing_limit)),
396            )
397        });
398
399        let addr: Multiaddr = "/memory/1234".parse().unwrap();
400        let target = PeerId::random();
401
402        for _ in 0..outgoing_limit {
403            network
404                .dial(
405                    DialOpts::peer_id(target)
406                        // Dial always, even if already dialing or connected.
407                        .condition(PeerCondition::Always)
408                        .addresses(vec![addr.clone()])
409                        .build(),
410                )
411                .expect("Unexpected connection limit.");
412        }
413
414        match network
415            .dial(
416                DialOpts::peer_id(target)
417                    .condition(PeerCondition::Always)
418                    .addresses(vec![addr])
419                    .build(),
420            )
421            .expect_err("Unexpected dialing success.")
422        {
423            DialError::Denied { cause } => {
424                let exceeded = cause
425                    .downcast::<Exceeded>()
426                    .expect("connection denied because of limit");
427
428                assert_eq!(exceeded.limit(), outgoing_limit);
429            }
430            e => panic!("Unexpected error: {e:?}"),
431        }
432
433        let info = network.network_info();
434        assert_eq!(info.num_peers(), 0);
435        assert_eq!(
436            info.connection_counters().num_pending_outgoing(),
437            outgoing_limit
438        );
439    }
440
441    #[test]
442    fn max_established_incoming() {
443        fn prop(Limit(limit): Limit) {
444            let mut swarm1 = Swarm::new_ephemeral(|_| {
445                Behaviour::new(
446                    ConnectionLimits::default().with_max_established_incoming(Some(limit)),
447                )
448            });
449            let mut swarm2 = Swarm::new_ephemeral(|_| {
450                Behaviour::new(
451                    ConnectionLimits::default().with_max_established_incoming(Some(limit)),
452                )
453            });
454
455            async_std::task::block_on(async {
456                let (listen_addr, _) = swarm1.listen().with_memory_addr_external().await;
457
458                for _ in 0..limit {
459                    swarm2.connect(&mut swarm1).await;
460                }
461
462                swarm2.dial(listen_addr).unwrap();
463
464                async_std::task::spawn(swarm2.loop_on_next());
465
466                let cause = swarm1
467                    .wait(|event| match event {
468                        SwarmEvent::IncomingConnectionError {
469                            error: ListenError::Denied { cause },
470                            ..
471                        } => Some(cause),
472                        _ => None,
473                    })
474                    .await;
475
476                assert_eq!(cause.downcast::<Exceeded>().unwrap().limit, limit);
477            });
478        }
479
480        #[derive(Debug, Clone)]
481        struct Limit(u32);
482
483        impl Arbitrary for Limit {
484            fn arbitrary(g: &mut Gen) -> Self {
485                Self(g.gen_range(1..10))
486            }
487        }
488
489        quickcheck(prop as fn(_));
490    }
491
492    /// Another sibling [`NetworkBehaviour`] implementation might deny established connections in
493    /// [`handle_established_outbound_connection`] or [`handle_established_inbound_connection`].
494    /// [`Behaviour`] must not increase the established counters in
495    /// [`handle_established_outbound_connection`] or [`handle_established_inbound_connection`], but
496    /// in [`SwarmEvent::ConnectionEstablished`] as the connection might still be denied by a
497    /// sibling [`NetworkBehaviour`] in the former case. Only in the latter case
498    /// ([`SwarmEvent::ConnectionEstablished`]) can the connection be seen as established.
499    #[test]
500    fn support_other_behaviour_denying_connection() {
501        let mut swarm1 = Swarm::new_ephemeral(|_| {
502            Behaviour::new_with_connection_denier(ConnectionLimits::default())
503        });
504        let mut swarm2 = Swarm::new_ephemeral(|_| Behaviour::new(ConnectionLimits::default()));
505
506        async_std::task::block_on(async {
507            // Have swarm2 dial swarm1.
508            let (listen_addr, _) = swarm1.listen().await;
509            swarm2.dial(listen_addr).unwrap();
510            async_std::task::spawn(swarm2.loop_on_next());
511
512            // Wait for the ConnectionDenier of swarm1 to deny the established connection.
513            let cause = swarm1
514                .wait(|event| match event {
515                    SwarmEvent::IncomingConnectionError {
516                        error: ListenError::Denied { cause },
517                        ..
518                    } => Some(cause),
519                    _ => None,
520                })
521                .await;
522
523            cause.downcast::<std::io::Error>().unwrap();
524
525            assert_eq!(
526                0,
527                swarm1
528                    .behaviour_mut()
529                    .limits
530                    .established_inbound_connections
531                    .len(),
532                "swarm1 connection limit behaviour to not count denied established connection as established connection"
533            )
534        });
535    }
536
537    #[derive(libp2p_swarm_derive::NetworkBehaviour)]
538    #[behaviour(prelude = "libp2p_swarm::derive_prelude")]
539    struct Behaviour {
540        limits: super::Behaviour,
541        connection_denier: Toggle<ConnectionDenier>,
542    }
543
544    impl Behaviour {
545        fn new(limits: ConnectionLimits) -> Self {
546            Self {
547                limits: super::Behaviour::new(limits),
548                connection_denier: None.into(),
549            }
550        }
551        fn new_with_connection_denier(limits: ConnectionLimits) -> Self {
552            Self {
553                limits: super::Behaviour::new(limits),
554                connection_denier: Some(ConnectionDenier {}).into(),
555            }
556        }
557    }
558
559    struct ConnectionDenier {}
560
561    impl NetworkBehaviour for ConnectionDenier {
562        type ConnectionHandler = dummy::ConnectionHandler;
563        type ToSwarm = Infallible;
564
565        fn handle_established_inbound_connection(
566            &mut self,
567            _connection_id: ConnectionId,
568            _peer: PeerId,
569            _local_addr: &Multiaddr,
570            _remote_addr: &Multiaddr,
571        ) -> Result<THandler<Self>, ConnectionDenied> {
572            Err(ConnectionDenied::new(std::io::Error::new(
573                std::io::ErrorKind::Other,
574                "ConnectionDenier",
575            )))
576        }
577
578        fn handle_established_outbound_connection(
579            &mut self,
580            _connection_id: ConnectionId,
581            _peer: PeerId,
582            _addr: &Multiaddr,
583            _role_override: Endpoint,
584            _port_use: PortUse,
585        ) -> Result<THandler<Self>, ConnectionDenied> {
586            Err(ConnectionDenied::new(std::io::Error::new(
587                std::io::ErrorKind::Other,
588                "ConnectionDenier",
589            )))
590        }
591
592        fn on_swarm_event(&mut self, _event: FromSwarm) {}
593
594        fn on_connection_handler_event(
595            &mut self,
596            _peer_id: PeerId,
597            _connection_id: ConnectionId,
598            event: THandlerOutEvent<Self>,
599        ) {
600            // TODO: remove when Rust 1.82 is MSRV
601            #[allow(unreachable_patterns)]
602            ant_libp2p_core::util::unreachable(event)
603        }
604
605        fn poll(
606            &mut self,
607            _: &mut Context<'_>,
608        ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
609            Poll::Pending
610        }
611    }
612}