1use std::{
22 collections::{HashMap, HashSet},
23 convert::Infallible,
24 fmt,
25 task::{Context, Poll},
26};
27
28use ant_libp2p_core::{transport::PortUse, ConnectedPoint, Endpoint, Multiaddr};
29use ant_libp2p_swarm::{
30 behaviour::{ConnectionEstablished, DialFailure, ListenFailure},
31 dummy, ConnectionClosed, ConnectionDenied, ConnectionId, FromSwarm, NetworkBehaviour, THandler,
32 THandlerInEvent, THandlerOutEvent, ToSwarm,
33};
34use libp2p_identity::PeerId;
35
36pub struct Behaviour {
69 limits: ConnectionLimits,
70
71 pending_inbound_connections: HashSet<ConnectionId>,
72 pending_outbound_connections: HashSet<ConnectionId>,
73 established_inbound_connections: HashSet<ConnectionId>,
74 established_outbound_connections: HashSet<ConnectionId>,
75 established_per_peer: HashMap<PeerId, HashSet<ConnectionId>>,
76}
77
78impl Behaviour {
79 pub fn new(limits: ConnectionLimits) -> Self {
80 Self {
81 limits,
82 pending_inbound_connections: Default::default(),
83 pending_outbound_connections: Default::default(),
84 established_inbound_connections: Default::default(),
85 established_outbound_connections: Default::default(),
86 established_per_peer: Default::default(),
87 }
88 }
89
90 pub fn limits_mut(&mut self) -> &mut ConnectionLimits {
93 &mut self.limits
94 }
95}
96
97fn check_limit(limit: Option<u32>, current: usize, kind: Kind) -> Result<(), ConnectionDenied> {
98 let limit = limit.unwrap_or(u32::MAX);
99 let current = current as u32;
100
101 if current >= limit {
102 return Err(ConnectionDenied::new(Exceeded { limit, kind }));
103 }
104
105 Ok(())
106}
107
108#[derive(Debug, Clone, Copy)]
110pub struct Exceeded {
111 limit: u32,
112 kind: Kind,
113}
114
115impl Exceeded {
116 pub fn limit(&self) -> u32 {
117 self.limit
118 }
119}
120
121impl fmt::Display for Exceeded {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 write!(
124 f,
125 "connection limit exceeded: at most {} {} are allowed",
126 self.limit, self.kind
127 )
128 }
129}
130
131#[derive(Debug, Clone, Copy)]
132enum Kind {
133 PendingIncoming,
134 PendingOutgoing,
135 EstablishedIncoming,
136 EstablishedOutgoing,
137 EstablishedPerPeer,
138 EstablishedTotal,
139}
140
141impl fmt::Display for Kind {
142 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143 match self {
144 Kind::PendingIncoming => write!(f, "pending incoming connections"),
145 Kind::PendingOutgoing => write!(f, "pending outgoing connections"),
146 Kind::EstablishedIncoming => write!(f, "established incoming connections"),
147 Kind::EstablishedOutgoing => write!(f, "established outgoing connections"),
148 Kind::EstablishedPerPeer => write!(f, "established connections per peer"),
149 Kind::EstablishedTotal => write!(f, "established connections"),
150 }
151 }
152}
153
154impl std::error::Error for Exceeded {}
155
156#[derive(Debug, Clone, Default)]
158pub struct ConnectionLimits {
159 max_pending_incoming: Option<u32>,
160 max_pending_outgoing: Option<u32>,
161 max_established_incoming: Option<u32>,
162 max_established_outgoing: Option<u32>,
163 max_established_per_peer: Option<u32>,
164 max_established_total: Option<u32>,
165}
166
167impl ConnectionLimits {
168 pub fn with_max_pending_incoming(mut self, limit: Option<u32>) -> Self {
170 self.max_pending_incoming = limit;
171 self
172 }
173
174 pub fn with_max_pending_outgoing(mut self, limit: Option<u32>) -> Self {
176 self.max_pending_outgoing = limit;
177 self
178 }
179
180 pub fn with_max_established_incoming(mut self, limit: Option<u32>) -> Self {
182 self.max_established_incoming = limit;
183 self
184 }
185
186 pub fn with_max_established_outgoing(mut self, limit: Option<u32>) -> Self {
188 self.max_established_outgoing = limit;
189 self
190 }
191
192 pub fn with_max_established(mut self, limit: Option<u32>) -> Self {
199 self.max_established_total = limit;
200 self
201 }
202
203 pub fn with_max_established_per_peer(mut self, limit: Option<u32>) -> Self {
206 self.max_established_per_peer = limit;
207 self
208 }
209}
210
211impl NetworkBehaviour for Behaviour {
212 type ConnectionHandler = dummy::ConnectionHandler;
213 type ToSwarm = Infallible;
214
215 fn handle_pending_inbound_connection(
216 &mut self,
217 connection_id: ConnectionId,
218 _: &Multiaddr,
219 _: &Multiaddr,
220 ) -> Result<(), ConnectionDenied> {
221 check_limit(
222 self.limits.max_pending_incoming,
223 self.pending_inbound_connections.len(),
224 Kind::PendingIncoming,
225 )?;
226
227 self.pending_inbound_connections.insert(connection_id);
228
229 Ok(())
230 }
231
232 fn handle_established_inbound_connection(
233 &mut self,
234 connection_id: ConnectionId,
235 peer: PeerId,
236 _: &Multiaddr,
237 _: &Multiaddr,
238 ) -> Result<THandler<Self>, ConnectionDenied> {
239 self.pending_inbound_connections.remove(&connection_id);
240
241 check_limit(
242 self.limits.max_established_incoming,
243 self.established_inbound_connections.len(),
244 Kind::EstablishedIncoming,
245 )?;
246 check_limit(
247 self.limits.max_established_per_peer,
248 self.established_per_peer
249 .get(&peer)
250 .map(|connections| connections.len())
251 .unwrap_or(0),
252 Kind::EstablishedPerPeer,
253 )?;
254 check_limit(
255 self.limits.max_established_total,
256 self.established_inbound_connections.len()
257 + self.established_outbound_connections.len(),
258 Kind::EstablishedTotal,
259 )?;
260
261 Ok(dummy::ConnectionHandler)
262 }
263
264 fn handle_pending_outbound_connection(
265 &mut self,
266 connection_id: ConnectionId,
267 _: Option<PeerId>,
268 _: &[Multiaddr],
269 _: Endpoint,
270 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
271 check_limit(
272 self.limits.max_pending_outgoing,
273 self.pending_outbound_connections.len(),
274 Kind::PendingOutgoing,
275 )?;
276
277 self.pending_outbound_connections.insert(connection_id);
278
279 Ok(vec![])
280 }
281
282 fn handle_established_outbound_connection(
283 &mut self,
284 connection_id: ConnectionId,
285 peer: PeerId,
286 _: &Multiaddr,
287 _: Endpoint,
288 _: PortUse,
289 ) -> Result<THandler<Self>, ConnectionDenied> {
290 self.pending_outbound_connections.remove(&connection_id);
291
292 check_limit(
293 self.limits.max_established_outgoing,
294 self.established_outbound_connections.len(),
295 Kind::EstablishedOutgoing,
296 )?;
297 check_limit(
298 self.limits.max_established_per_peer,
299 self.established_per_peer
300 .get(&peer)
301 .map(|connections| connections.len())
302 .unwrap_or(0),
303 Kind::EstablishedPerPeer,
304 )?;
305 check_limit(
306 self.limits.max_established_total,
307 self.established_inbound_connections.len()
308 + self.established_outbound_connections.len(),
309 Kind::EstablishedTotal,
310 )?;
311
312 Ok(dummy::ConnectionHandler)
313 }
314
315 fn on_swarm_event(&mut self, event: FromSwarm) {
316 match event {
317 FromSwarm::ConnectionClosed(ConnectionClosed {
318 peer_id,
319 connection_id,
320 ..
321 }) => {
322 self.established_inbound_connections.remove(&connection_id);
323 self.established_outbound_connections.remove(&connection_id);
324 self.established_per_peer
325 .entry(peer_id)
326 .or_default()
327 .remove(&connection_id);
328 }
329 FromSwarm::ConnectionEstablished(ConnectionEstablished {
330 peer_id,
331 endpoint,
332 connection_id,
333 ..
334 }) => {
335 match endpoint {
336 ConnectedPoint::Listener { .. } => {
337 self.established_inbound_connections.insert(connection_id);
338 }
339 ConnectedPoint::Dialer { .. } => {
340 self.established_outbound_connections.insert(connection_id);
341 }
342 }
343
344 self.established_per_peer
345 .entry(peer_id)
346 .or_default()
347 .insert(connection_id);
348 }
349 FromSwarm::DialFailure(DialFailure { connection_id, .. }) => {
350 self.pending_outbound_connections.remove(&connection_id);
351 }
352 FromSwarm::ListenFailure(ListenFailure { connection_id, .. }) => {
353 self.pending_inbound_connections.remove(&connection_id);
354 }
355 _ => {}
356 }
357 }
358
359 fn on_connection_handler_event(
360 &mut self,
361 _id: PeerId,
362 _: ConnectionId,
363 event: THandlerOutEvent<Self>,
364 ) {
365 #[allow(unreachable_patterns)]
367 ant_libp2p_core::util::unreachable(event)
368 }
369
370 fn poll(&mut self, _: &mut Context<'_>) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
371 Poll::Pending
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use libp2p_swarm::{
378 behaviour::toggle::Toggle,
379 dial_opts::{DialOpts, PeerCondition},
380 DialError, ListenError, Swarm, SwarmEvent,
381 };
382 use libp2p_swarm_test::SwarmExt;
383 use quickcheck::*;
384
385 use super::*;
386
387 #[test]
388 fn max_outgoing() {
389 use rand::Rng;
390
391 let outgoing_limit = rand::thread_rng().gen_range(1..10);
392
393 let mut network = Swarm::new_ephemeral(|_| {
394 Behaviour::new(
395 ConnectionLimits::default().with_max_pending_outgoing(Some(outgoing_limit)),
396 )
397 });
398
399 let addr: Multiaddr = "/memory/1234".parse().unwrap();
400 let target = PeerId::random();
401
402 for _ in 0..outgoing_limit {
403 network
404 .dial(
405 DialOpts::peer_id(target)
406 .condition(PeerCondition::Always)
408 .addresses(vec![addr.clone()])
409 .build(),
410 )
411 .expect("Unexpected connection limit.");
412 }
413
414 match network
415 .dial(
416 DialOpts::peer_id(target)
417 .condition(PeerCondition::Always)
418 .addresses(vec![addr])
419 .build(),
420 )
421 .expect_err("Unexpected dialing success.")
422 {
423 DialError::Denied { cause } => {
424 let exceeded = cause
425 .downcast::<Exceeded>()
426 .expect("connection denied because of limit");
427
428 assert_eq!(exceeded.limit(), outgoing_limit);
429 }
430 e => panic!("Unexpected error: {e:?}"),
431 }
432
433 let info = network.network_info();
434 assert_eq!(info.num_peers(), 0);
435 assert_eq!(
436 info.connection_counters().num_pending_outgoing(),
437 outgoing_limit
438 );
439 }
440
441 #[test]
442 fn max_established_incoming() {
443 fn prop(Limit(limit): Limit) {
444 let mut swarm1 = Swarm::new_ephemeral(|_| {
445 Behaviour::new(
446 ConnectionLimits::default().with_max_established_incoming(Some(limit)),
447 )
448 });
449 let mut swarm2 = Swarm::new_ephemeral(|_| {
450 Behaviour::new(
451 ConnectionLimits::default().with_max_established_incoming(Some(limit)),
452 )
453 });
454
455 async_std::task::block_on(async {
456 let (listen_addr, _) = swarm1.listen().with_memory_addr_external().await;
457
458 for _ in 0..limit {
459 swarm2.connect(&mut swarm1).await;
460 }
461
462 swarm2.dial(listen_addr).unwrap();
463
464 async_std::task::spawn(swarm2.loop_on_next());
465
466 let cause = swarm1
467 .wait(|event| match event {
468 SwarmEvent::IncomingConnectionError {
469 error: ListenError::Denied { cause },
470 ..
471 } => Some(cause),
472 _ => None,
473 })
474 .await;
475
476 assert_eq!(cause.downcast::<Exceeded>().unwrap().limit, limit);
477 });
478 }
479
480 #[derive(Debug, Clone)]
481 struct Limit(u32);
482
483 impl Arbitrary for Limit {
484 fn arbitrary(g: &mut Gen) -> Self {
485 Self(g.gen_range(1..10))
486 }
487 }
488
489 quickcheck(prop as fn(_));
490 }
491
492 #[test]
500 fn support_other_behaviour_denying_connection() {
501 let mut swarm1 = Swarm::new_ephemeral(|_| {
502 Behaviour::new_with_connection_denier(ConnectionLimits::default())
503 });
504 let mut swarm2 = Swarm::new_ephemeral(|_| Behaviour::new(ConnectionLimits::default()));
505
506 async_std::task::block_on(async {
507 let (listen_addr, _) = swarm1.listen().await;
509 swarm2.dial(listen_addr).unwrap();
510 async_std::task::spawn(swarm2.loop_on_next());
511
512 let cause = swarm1
514 .wait(|event| match event {
515 SwarmEvent::IncomingConnectionError {
516 error: ListenError::Denied { cause },
517 ..
518 } => Some(cause),
519 _ => None,
520 })
521 .await;
522
523 cause.downcast::<std::io::Error>().unwrap();
524
525 assert_eq!(
526 0,
527 swarm1
528 .behaviour_mut()
529 .limits
530 .established_inbound_connections
531 .len(),
532 "swarm1 connection limit behaviour to not count denied established connection as established connection"
533 )
534 });
535 }
536
537 #[derive(libp2p_swarm_derive::NetworkBehaviour)]
538 #[behaviour(prelude = "libp2p_swarm::derive_prelude")]
539 struct Behaviour {
540 limits: super::Behaviour,
541 connection_denier: Toggle<ConnectionDenier>,
542 }
543
544 impl Behaviour {
545 fn new(limits: ConnectionLimits) -> Self {
546 Self {
547 limits: super::Behaviour::new(limits),
548 connection_denier: None.into(),
549 }
550 }
551 fn new_with_connection_denier(limits: ConnectionLimits) -> Self {
552 Self {
553 limits: super::Behaviour::new(limits),
554 connection_denier: Some(ConnectionDenier {}).into(),
555 }
556 }
557 }
558
559 struct ConnectionDenier {}
560
561 impl NetworkBehaviour for ConnectionDenier {
562 type ConnectionHandler = dummy::ConnectionHandler;
563 type ToSwarm = Infallible;
564
565 fn handle_established_inbound_connection(
566 &mut self,
567 _connection_id: ConnectionId,
568 _peer: PeerId,
569 _local_addr: &Multiaddr,
570 _remote_addr: &Multiaddr,
571 ) -> Result<THandler<Self>, ConnectionDenied> {
572 Err(ConnectionDenied::new(std::io::Error::new(
573 std::io::ErrorKind::Other,
574 "ConnectionDenier",
575 )))
576 }
577
578 fn handle_established_outbound_connection(
579 &mut self,
580 _connection_id: ConnectionId,
581 _peer: PeerId,
582 _addr: &Multiaddr,
583 _role_override: Endpoint,
584 _port_use: PortUse,
585 ) -> Result<THandler<Self>, ConnectionDenied> {
586 Err(ConnectionDenied::new(std::io::Error::new(
587 std::io::ErrorKind::Other,
588 "ConnectionDenier",
589 )))
590 }
591
592 fn on_swarm_event(&mut self, _event: FromSwarm) {}
593
594 fn on_connection_handler_event(
595 &mut self,
596 _peer_id: PeerId,
597 _connection_id: ConnectionId,
598 event: THandlerOutEvent<Self>,
599 ) {
600 #[allow(unreachable_patterns)]
602 ant_libp2p_core::util::unreachable(event)
603 }
604
605 fn poll(
606 &mut self,
607 _: &mut Context<'_>,
608 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
609 Poll::Pending
610 }
611 }
612}