1mod error;
22
23pub(crate) mod pool;
24mod supported_protocols;
25
26use std::{
27 collections::{HashMap, HashSet},
28 fmt,
29 fmt::{Display, Formatter},
30 future::Future,
31 io, mem,
32 pin::Pin,
33 sync::atomic::{AtomicUsize, Ordering},
34 task::{Context, Poll, Waker},
35 time::Duration,
36};
37
38pub use error::ConnectionError;
39pub(crate) use error::{
40 PendingConnectionError, PendingInboundConnectionError, PendingOutboundConnectionError,
41};
42use futures::{future::BoxFuture, stream, stream::FuturesUnordered, FutureExt, StreamExt};
43use futures_timer::Delay;
44use ant_libp2p_core::{
45 connection::ConnectedPoint,
46 multiaddr::Multiaddr,
47 muxing::{StreamMuxerBox, StreamMuxerEvent, StreamMuxerExt, SubstreamBox},
48 transport::PortUse,
49 upgrade,
50 upgrade::{NegotiationError, ProtocolError},
51 Endpoint,
52};
53use libp2p_identity::PeerId;
54pub use supported_protocols::SupportedProtocols;
55use web_time::Instant;
56
57use crate::{
58 handler::{
59 AddressChange, ConnectionEvent, ConnectionHandler, DialUpgradeError,
60 FullyNegotiatedInbound, FullyNegotiatedOutbound, ListenUpgradeError, ProtocolSupport,
61 ProtocolsChange, UpgradeInfoSend,
62 },
63 stream::ActiveStreamCounter,
64 upgrade::{InboundUpgradeSend, OutboundUpgradeSend},
65 ConnectionHandlerEvent, Stream, StreamProtocol, StreamUpgradeError, SubstreamProtocol,
66};
67
68static NEXT_CONNECTION_ID: AtomicUsize = AtomicUsize::new(1);
69
70#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
72pub struct ConnectionId(usize);
73
74impl ConnectionId {
75 pub fn new_unchecked(id: usize) -> Self {
83 Self(id)
84 }
85
86 pub(crate) fn next() -> Self {
88 Self(NEXT_CONNECTION_ID.fetch_add(1, Ordering::SeqCst))
89 }
90}
91
92impl Display for ConnectionId {
93 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
94 write!(f, "{}", self.0)
95 }
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
100pub(crate) struct Connected {
101 pub(crate) endpoint: ConnectedPoint,
103 pub(crate) peer_id: PeerId,
105}
106
107#[derive(Debug, Clone)]
109pub(crate) enum Event<T> {
110 Handler(T),
112 AddressChange(Multiaddr),
114}
115
116pub(crate) struct Connection<THandler>
118where
119 THandler: ConnectionHandler,
120{
121 muxing: StreamMuxerBox,
123 handler: THandler,
125 negotiating_in: FuturesUnordered<
127 StreamUpgrade<
128 THandler::InboundOpenInfo,
129 <THandler::InboundProtocol as InboundUpgradeSend>::Output,
130 <THandler::InboundProtocol as InboundUpgradeSend>::Error,
131 >,
132 >,
133 negotiating_out: FuturesUnordered<
135 StreamUpgrade<
136 THandler::OutboundOpenInfo,
137 <THandler::OutboundProtocol as OutboundUpgradeSend>::Output,
138 <THandler::OutboundProtocol as OutboundUpgradeSend>::Error,
139 >,
140 >,
141 shutdown: Shutdown,
143 substream_upgrade_protocol_override: Option<upgrade::Version>,
145 max_negotiating_inbound_streams: usize,
154 requested_substreams: FuturesUnordered<
159 SubstreamRequested<THandler::OutboundOpenInfo, THandler::OutboundProtocol>,
160 >,
161
162 local_supported_protocols:
163 HashMap<AsStrHashEq<<THandler::InboundProtocol as UpgradeInfoSend>::Info>, bool>,
164 remote_supported_protocols: HashSet<StreamProtocol>,
165 protocol_buffer: Vec<StreamProtocol>,
166
167 idle_timeout: Duration,
168 stream_counter: ActiveStreamCounter,
169}
170
171impl<THandler> fmt::Debug for Connection<THandler>
172where
173 THandler: ConnectionHandler + fmt::Debug,
174 THandler::OutboundOpenInfo: fmt::Debug,
175{
176 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177 f.debug_struct("Connection")
178 .field("handler", &self.handler)
179 .finish()
180 }
181}
182
183impl<THandler> Unpin for Connection<THandler> where THandler: ConnectionHandler {}
184
185impl<THandler> Connection<THandler>
186where
187 THandler: ConnectionHandler,
188{
189 pub(crate) fn new(
192 muxer: StreamMuxerBox,
193 mut handler: THandler,
194 substream_upgrade_protocol_override: Option<upgrade::Version>,
195 max_negotiating_inbound_streams: usize,
196 idle_timeout: Duration,
197 ) -> Self {
198 let initial_protocols = gather_supported_protocols(&handler);
199 let mut buffer = Vec::new();
200
201 if !initial_protocols.is_empty() {
202 handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(
203 ProtocolsChange::from_initial_protocols(
204 initial_protocols.keys().map(|e| &e.0),
205 &mut buffer,
206 ),
207 ));
208 }
209
210 Connection {
211 muxing: muxer,
212 handler,
213 negotiating_in: Default::default(),
214 negotiating_out: Default::default(),
215 shutdown: Shutdown::None,
216 substream_upgrade_protocol_override,
217 max_negotiating_inbound_streams,
218 requested_substreams: Default::default(),
219 local_supported_protocols: initial_protocols,
220 remote_supported_protocols: Default::default(),
221 protocol_buffer: buffer,
222 idle_timeout,
223 stream_counter: ActiveStreamCounter::default(),
224 }
225 }
226
227 pub(crate) fn on_behaviour_event(&mut self, event: THandler::FromBehaviour) {
229 self.handler.on_behaviour_event(event);
230 }
231
232 pub(crate) fn close(
235 self,
236 ) -> (
237 impl futures::Stream<Item = THandler::ToBehaviour>,
238 impl Future<Output = io::Result<()>>,
239 ) {
240 let Connection {
241 mut handler,
242 muxing,
243 ..
244 } = self;
245
246 (
247 stream::poll_fn(move |cx| handler.poll_close(cx)),
248 muxing.close(),
249 )
250 }
251
252 #[tracing::instrument(level = "debug", name = "Connection::poll", skip(self, cx))]
255 pub(crate) fn poll(
256 self: Pin<&mut Self>,
257 cx: &mut Context<'_>,
258 ) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError>> {
259 let Self {
260 requested_substreams,
261 muxing,
262 handler,
263 negotiating_out,
264 negotiating_in,
265 shutdown,
266 max_negotiating_inbound_streams,
267 substream_upgrade_protocol_override,
268 local_supported_protocols: supported_protocols,
269 remote_supported_protocols,
270 protocol_buffer,
271 idle_timeout,
272 stream_counter,
273 ..
274 } = self.get_mut();
275
276 loop {
277 match requested_substreams.poll_next_unpin(cx) {
278 Poll::Ready(Some(Ok(()))) => continue,
279 Poll::Ready(Some(Err(info))) => {
280 handler.on_connection_event(ConnectionEvent::DialUpgradeError(
281 DialUpgradeError {
282 info,
283 error: StreamUpgradeError::Timeout,
284 },
285 ));
286 continue;
287 }
288 Poll::Ready(None) | Poll::Pending => {}
289 }
290
291 match handler.poll(cx) {
293 Poll::Pending => {}
294 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
295 let timeout = *protocol.timeout();
296 let (upgrade, user_data) = protocol.into_upgrade();
297
298 requested_substreams.push(SubstreamRequested::new(user_data, timeout, upgrade));
299 continue; }
301 Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)) => {
302 return Poll::Ready(Ok(Event::Handler(event)));
303 }
304 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
305 ProtocolSupport::Added(protocols),
306 )) => {
307 if let Some(added) =
308 ProtocolsChange::add(remote_supported_protocols, protocols, protocol_buffer)
309 {
310 handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange(added));
311 remote_supported_protocols.extend(protocol_buffer.drain(..));
312 }
313 continue;
314 }
315 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
316 ProtocolSupport::Removed(protocols),
317 )) => {
318 if let Some(removed) = ProtocolsChange::remove(
319 remote_supported_protocols,
320 protocols,
321 protocol_buffer,
322 ) {
323 handler
324 .on_connection_event(ConnectionEvent::RemoteProtocolsChange(removed));
325 }
326 continue;
327 }
328 }
329
330 match negotiating_out.poll_next_unpin(cx) {
333 Poll::Pending | Poll::Ready(None) => {}
334 Poll::Ready(Some((info, Ok(protocol)))) => {
335 handler.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
336 FullyNegotiatedOutbound { protocol, info },
337 ));
338 continue;
339 }
340 Poll::Ready(Some((info, Err(error)))) => {
341 handler.on_connection_event(ConnectionEvent::DialUpgradeError(
342 DialUpgradeError { info, error },
343 ));
344 continue;
345 }
346 }
347
348 match negotiating_in.poll_next_unpin(cx) {
351 Poll::Pending | Poll::Ready(None) => {}
352 Poll::Ready(Some((info, Ok(protocol)))) => {
353 handler.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
354 FullyNegotiatedInbound { protocol, info },
355 ));
356 continue;
357 }
358 Poll::Ready(Some((info, Err(StreamUpgradeError::Apply(error))))) => {
359 handler.on_connection_event(ConnectionEvent::ListenUpgradeError(
360 ListenUpgradeError { info, error },
361 ));
362 continue;
363 }
364 Poll::Ready(Some((_, Err(StreamUpgradeError::Io(e))))) => {
365 tracing::debug!("failed to upgrade inbound stream: {e}");
366 continue;
367 }
368 Poll::Ready(Some((_, Err(StreamUpgradeError::NegotiationFailed)))) => {
369 tracing::debug!("no protocol could be agreed upon for inbound stream");
370 continue;
371 }
372 Poll::Ready(Some((_, Err(StreamUpgradeError::Timeout)))) => {
373 tracing::debug!("inbound stream upgrade timed out");
374 continue;
375 }
376 }
377
378 if negotiating_in.is_empty()
382 && negotiating_out.is_empty()
383 && requested_substreams.is_empty()
384 && stream_counter.has_no_active_streams()
385 {
386 if let Some(new_timeout) =
387 compute_new_shutdown(handler.connection_keep_alive(), shutdown, *idle_timeout)
388 {
389 *shutdown = new_timeout;
390 }
391
392 match shutdown {
393 Shutdown::None => {}
394 Shutdown::Asap => return Poll::Ready(Err(ConnectionError::KeepAliveTimeout)),
395 Shutdown::Later(delay) => match Future::poll(Pin::new(delay), cx) {
396 Poll::Ready(_) => {
397 return Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
398 }
399 Poll::Pending => {}
400 },
401 }
402 } else {
403 *shutdown = Shutdown::None;
404 }
405
406 match muxing.poll_unpin(cx)? {
407 Poll::Pending => {}
408 Poll::Ready(StreamMuxerEvent::AddressChange(address)) => {
409 handler.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
410 new_address: &address,
411 }));
412 return Poll::Ready(Ok(Event::AddressChange(address)));
413 }
414 }
415
416 if let Some(requested_substream) = requested_substreams.iter_mut().next() {
417 match muxing.poll_outbound_unpin(cx)? {
418 Poll::Pending => {}
419 Poll::Ready(substream) => {
420 let (user_data, timeout, upgrade) = requested_substream.extract();
421
422 negotiating_out.push(StreamUpgrade::new_outbound(
423 substream,
424 user_data,
425 timeout,
426 upgrade,
427 *substream_upgrade_protocol_override,
428 stream_counter.clone(),
429 ));
430
431 continue;
434 }
435 }
436 }
437
438 if negotiating_in.len() < *max_negotiating_inbound_streams {
439 match muxing.poll_inbound_unpin(cx)? {
440 Poll::Pending => {}
441 Poll::Ready(substream) => {
442 let protocol = handler.listen_protocol();
443
444 negotiating_in.push(StreamUpgrade::new_inbound(
445 substream,
446 protocol,
447 stream_counter.clone(),
448 ));
449
450 continue;
453 }
454 }
455 }
456
457 let changes = ProtocolsChange::from_full_sets(
458 supported_protocols,
459 handler.listen_protocol().upgrade().protocol_info(),
460 protocol_buffer,
461 );
462
463 if !changes.is_empty() {
464 for change in changes {
465 handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(change));
466 }
467 continue;
469 }
470
471 return Poll::Pending;
473 }
474 }
475
476 #[cfg(test)]
477 fn poll_noop_waker(&mut self) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError>> {
478 Pin::new(self).poll(&mut Context::from_waker(futures::task::noop_waker_ref()))
479 }
480}
481
482fn gather_supported_protocols<C: ConnectionHandler>(
483 handler: &C,
484) -> HashMap<AsStrHashEq<<C::InboundProtocol as UpgradeInfoSend>::Info>, bool> {
485 handler
486 .listen_protocol()
487 .upgrade()
488 .protocol_info()
489 .map(|info| (AsStrHashEq(info), true))
490 .collect()
491}
492
493fn compute_new_shutdown(
494 handler_keep_alive: bool,
495 current_shutdown: &Shutdown,
496 idle_timeout: Duration,
497) -> Option<Shutdown> {
498 match (current_shutdown, handler_keep_alive) {
499 (_, false) if idle_timeout == Duration::ZERO => Some(Shutdown::Asap),
500 (Shutdown::Later(_), false) => None,
502 (_, false) => {
503 let now = Instant::now();
504 let safe_keep_alive = checked_add_fraction(now, idle_timeout);
505
506 Some(Shutdown::Later(Delay::new(safe_keep_alive)))
507 }
508 (_, true) => Some(Shutdown::None),
509 }
510}
511
512fn checked_add_fraction(start: Instant, mut duration: Duration) -> Duration {
519 while start.checked_add(duration).is_none() {
520 tracing::debug!(start=?start, duration=?duration, "start + duration cannot be presented, halving duration");
521
522 duration /= 2;
523 }
524
525 duration
526}
527
528#[derive(Debug, Copy, Clone)]
530pub(crate) struct IncomingInfo<'a> {
531 pub(crate) local_addr: &'a Multiaddr,
533 pub(crate) send_back_addr: &'a Multiaddr,
535}
536
537impl IncomingInfo<'_> {
538 pub(crate) fn create_connected_point(&self) -> ConnectedPoint {
540 ConnectedPoint::Listener {
541 local_addr: self.local_addr.clone(),
542 send_back_addr: self.send_back_addr.clone(),
543 }
544 }
545}
546
547struct StreamUpgrade<UserData, TOk, TErr> {
548 user_data: Option<UserData>,
549 timeout: Delay,
550 upgrade: BoxFuture<'static, Result<TOk, StreamUpgradeError<TErr>>>,
551}
552
553impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
554 fn new_outbound<Upgrade>(
555 substream: SubstreamBox,
556 user_data: UserData,
557 timeout: Delay,
558 upgrade: Upgrade,
559 version_override: Option<upgrade::Version>,
560 counter: ActiveStreamCounter,
561 ) -> Self
562 where
563 Upgrade: OutboundUpgradeSend<Output = TOk, Error = TErr>,
564 {
565 let effective_version = match version_override {
566 Some(version_override) if version_override != upgrade::Version::default() => {
567 tracing::debug!(
568 "Substream upgrade protocol override: {:?} -> {:?}",
569 upgrade::Version::default(),
570 version_override
571 );
572
573 version_override
574 }
575 _ => upgrade::Version::default(),
576 };
577 let protocols = upgrade.protocol_info();
578
579 Self {
580 user_data: Some(user_data),
581 timeout,
582 upgrade: Box::pin(async move {
583 let (info, stream) = multistream_select::dialer_select_proto(
584 substream,
585 protocols,
586 effective_version,
587 )
588 .await
589 .map_err(to_stream_upgrade_error)?;
590
591 let output = upgrade
592 .upgrade_outbound(Stream::new(stream, counter), info)
593 .await
594 .map_err(StreamUpgradeError::Apply)?;
595
596 Ok(output)
597 }),
598 }
599 }
600}
601
602impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
603 fn new_inbound<Upgrade>(
604 substream: SubstreamBox,
605 protocol: SubstreamProtocol<Upgrade, UserData>,
606 counter: ActiveStreamCounter,
607 ) -> Self
608 where
609 Upgrade: InboundUpgradeSend<Output = TOk, Error = TErr>,
610 {
611 let timeout = *protocol.timeout();
612 let (upgrade, open_info) = protocol.into_upgrade();
613 let protocols = upgrade.protocol_info();
614
615 Self {
616 user_data: Some(open_info),
617 timeout: Delay::new(timeout),
618 upgrade: Box::pin(async move {
619 let (info, stream) =
620 multistream_select::listener_select_proto(substream, protocols)
621 .await
622 .map_err(to_stream_upgrade_error)?;
623
624 let output = upgrade
625 .upgrade_inbound(Stream::new(stream, counter), info)
626 .await
627 .map_err(StreamUpgradeError::Apply)?;
628
629 Ok(output)
630 }),
631 }
632 }
633}
634
635fn to_stream_upgrade_error<T>(e: NegotiationError) -> StreamUpgradeError<T> {
636 match e {
637 NegotiationError::Failed => StreamUpgradeError::NegotiationFailed,
638 NegotiationError::ProtocolError(ProtocolError::IoError(e)) => StreamUpgradeError::Io(e),
639 NegotiationError::ProtocolError(other) => {
640 StreamUpgradeError::Io(io::Error::new(io::ErrorKind::Other, other))
641 }
642 }
643}
644
645impl<UserData, TOk, TErr> Unpin for StreamUpgrade<UserData, TOk, TErr> {}
646
647impl<UserData, TOk, TErr> Future for StreamUpgrade<UserData, TOk, TErr> {
648 type Output = (UserData, Result<TOk, StreamUpgradeError<TErr>>);
649
650 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
651 match self.timeout.poll_unpin(cx) {
652 Poll::Ready(()) => {
653 return Poll::Ready((
654 self.user_data
655 .take()
656 .expect("Future not to be polled again once ready."),
657 Err(StreamUpgradeError::Timeout),
658 ))
659 }
660
661 Poll::Pending => {}
662 }
663
664 let result = futures::ready!(self.upgrade.poll_unpin(cx));
665 let user_data = self
666 .user_data
667 .take()
668 .expect("Future not to be polled again once ready.");
669
670 Poll::Ready((user_data, result))
671 }
672}
673
674enum SubstreamRequested<UserData, Upgrade> {
675 Waiting {
676 user_data: UserData,
677 timeout: Delay,
678 upgrade: Upgrade,
679 extracted_waker: Option<Waker>,
684 },
685 Done,
686}
687
688impl<UserData, Upgrade> SubstreamRequested<UserData, Upgrade> {
689 fn new(user_data: UserData, timeout: Duration, upgrade: Upgrade) -> Self {
690 Self::Waiting {
691 user_data,
692 timeout: Delay::new(timeout),
693 upgrade,
694 extracted_waker: None,
695 }
696 }
697
698 fn extract(&mut self) -> (UserData, Delay, Upgrade) {
699 match mem::replace(self, Self::Done) {
700 SubstreamRequested::Waiting {
701 user_data,
702 timeout,
703 upgrade,
704 extracted_waker: waker,
705 } => {
706 if let Some(waker) = waker {
707 waker.wake();
708 }
709
710 (user_data, timeout, upgrade)
711 }
712 SubstreamRequested::Done => panic!("cannot extract twice"),
713 }
714 }
715}
716
717impl<UserData, Upgrade> Unpin for SubstreamRequested<UserData, Upgrade> {}
718
719impl<UserData, Upgrade> Future for SubstreamRequested<UserData, Upgrade> {
720 type Output = Result<(), UserData>;
721
722 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
723 let this = self.get_mut();
724
725 match mem::replace(this, Self::Done) {
726 SubstreamRequested::Waiting {
727 user_data,
728 upgrade,
729 mut timeout,
730 ..
731 } => match timeout.poll_unpin(cx) {
732 Poll::Ready(()) => Poll::Ready(Err(user_data)),
733 Poll::Pending => {
734 *this = Self::Waiting {
735 user_data,
736 upgrade,
737 timeout,
738 extracted_waker: Some(cx.waker().clone()),
739 };
740 Poll::Pending
741 }
742 },
743 SubstreamRequested::Done => Poll::Ready(Ok(())),
744 }
745 }
746}
747
748#[derive(Debug)]
758enum Shutdown {
759 None,
761 Asap,
763 Later(Delay),
765}
766
767pub(crate) struct AsStrHashEq<T>(pub(crate) T);
771
772impl<T: AsRef<str>> Eq for AsStrHashEq<T> {}
773
774impl<T: AsRef<str>> PartialEq for AsStrHashEq<T> {
775 fn eq(&self, other: &Self) -> bool {
776 self.0.as_ref() == other.0.as_ref()
777 }
778}
779
780impl<T: AsRef<str>> std::hash::Hash for AsStrHashEq<T> {
781 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
782 self.0.as_ref().hash(state)
783 }
784}
785
786#[cfg(test)]
787mod tests {
788 use std::{
789 convert::Infallible,
790 sync::{Arc, Weak},
791 time::Instant,
792 };
793
794 use futures::{future, AsyncRead, AsyncWrite};
795 use ant_libp2p_core::{
796 upgrade::{DeniedUpgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo},
797 StreamMuxer,
798 };
799 use quickcheck::*;
800 use tracing_subscriber::EnvFilter;
801
802 use super::*;
803 use crate::dummy;
804
805 #[test]
806 fn max_negotiating_inbound_streams() {
807 let _ = tracing_subscriber::fmt()
808 .with_env_filter(EnvFilter::from_default_env())
809 .try_init();
810
811 fn prop(max_negotiating_inbound_streams: u8) {
812 let max_negotiating_inbound_streams: usize = max_negotiating_inbound_streams.into();
813
814 let alive_substream_counter = Arc::new(());
815 let mut connection = Connection::new(
816 StreamMuxerBox::new(DummyStreamMuxer {
817 counter: alive_substream_counter.clone(),
818 }),
819 MockConnectionHandler::new(Duration::from_secs(10)),
820 None,
821 max_negotiating_inbound_streams,
822 Duration::ZERO,
823 );
824
825 let result = connection.poll_noop_waker();
826
827 assert!(result.is_pending());
828 assert_eq!(
829 Arc::weak_count(&alive_substream_counter),
830 max_negotiating_inbound_streams,
831 "Expect no more than the maximum number of allowed streams"
832 );
833 }
834
835 QuickCheck::new().quickcheck(prop as fn(_));
836 }
837
838 #[test]
839 fn outbound_stream_timeout_starts_on_request() {
840 let upgrade_timeout = Duration::from_secs(1);
841 let mut connection = Connection::new(
842 StreamMuxerBox::new(PendingStreamMuxer),
843 MockConnectionHandler::new(upgrade_timeout),
844 None,
845 2,
846 Duration::ZERO,
847 );
848
849 connection.handler.open_new_outbound();
850 let _ = connection.poll_noop_waker();
851
852 std::thread::sleep(upgrade_timeout + Duration::from_secs(1));
853
854 let _ = connection.poll_noop_waker();
855
856 assert!(matches!(
857 connection.handler.error.unwrap(),
858 StreamUpgradeError::Timeout
859 ))
860 }
861
862 #[test]
863 fn propagates_changes_to_supported_inbound_protocols() {
864 let mut connection = Connection::new(
865 StreamMuxerBox::new(PendingStreamMuxer),
866 ConfigurableProtocolConnectionHandler::default(),
867 None,
868 0,
869 Duration::ZERO,
870 );
871
872 connection.handler.listen_on(&["/foo"]);
874 let _ = connection.poll_noop_waker();
875
876 assert_eq!(connection.handler.local_added, vec![vec!["/foo"]]);
877 assert!(connection.handler.local_removed.is_empty());
878
879 connection.handler.listen_on(&["/foo", "/bar"]);
881 let _ = connection.poll_noop_waker();
882
883 assert_eq!(
884 connection.handler.local_added,
885 vec![vec!["/foo"], vec!["/bar"]],
886 "expect to only receive an event for the newly added protocols"
887 );
888 assert!(connection.handler.local_removed.is_empty());
889
890 connection.handler.listen_on(&["/bar"]);
892 let _ = connection.poll_noop_waker();
893
894 assert_eq!(
895 connection.handler.local_added,
896 vec![vec!["/foo"], vec!["/bar"]]
897 );
898 assert_eq!(connection.handler.local_removed, vec![vec!["/foo"]]);
899 }
900
901 #[test]
902 fn only_propagtes_actual_changes_to_remote_protocols_to_handler() {
903 let mut connection = Connection::new(
904 StreamMuxerBox::new(PendingStreamMuxer),
905 ConfigurableProtocolConnectionHandler::default(),
906 None,
907 0,
908 Duration::ZERO,
909 );
910
911 connection.handler.remote_adds_support_for(&["/foo"]);
913 let _ = connection.poll_noop_waker();
914
915 assert_eq!(connection.handler.remote_added, vec![vec!["/foo"]]);
916 assert!(connection.handler.remote_removed.is_empty());
917
918 connection
920 .handler
921 .remote_adds_support_for(&["/foo", "/bar"]);
922 let _ = connection.poll_noop_waker();
923
924 assert_eq!(
925 connection.handler.remote_added,
926 vec![vec!["/foo"], vec!["/bar"]],
927 "expect to only receive an event for the newly added protocol"
928 );
929 assert!(connection.handler.remote_removed.is_empty());
930
931 connection.handler.remote_removes_support_for(&["/baz"]);
934 let _ = connection.poll_noop_waker();
935
936 assert_eq!(
937 connection.handler.remote_added,
938 vec![vec!["/foo"], vec!["/bar"]]
939 );
940 assert!(&connection.handler.remote_removed.is_empty());
941
942 connection.handler.remote_removes_support_for(&["/bar"]);
944 let _ = connection.poll_noop_waker();
945
946 assert_eq!(
947 connection.handler.remote_added,
948 vec![vec!["/foo"], vec!["/bar"]]
949 );
950 assert_eq!(connection.handler.remote_removed, vec![vec!["/bar"]]);
951 }
952
953 #[tokio::test]
954 async fn idle_timeout_with_keep_alive_no() {
955 let idle_timeout = Duration::from_millis(100);
956
957 let mut connection = Connection::new(
958 StreamMuxerBox::new(PendingStreamMuxer),
959 dummy::ConnectionHandler,
960 None,
961 0,
962 idle_timeout,
963 );
964
965 assert!(connection.poll_noop_waker().is_pending());
966
967 tokio::time::sleep(idle_timeout).await;
968
969 assert!(matches!(
970 connection.poll_noop_waker(),
971 Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
972 ));
973 }
974
975 #[test]
976 fn checked_add_fraction_can_add_u64_max() {
977 let _ = tracing_subscriber::fmt()
978 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
979 .try_init();
980 let start = Instant::now();
981
982 let duration = checked_add_fraction(start, Duration::from_secs(u64::MAX));
983
984 assert!(start.checked_add(duration).is_some())
985 }
986
987 #[test]
988 fn compute_new_shutdown_does_not_panic() {
989 let _ = tracing_subscriber::fmt()
990 .with_env_filter(EnvFilter::from_default_env())
991 .try_init();
992
993 #[derive(Debug)]
994 struct ArbitraryShutdown(Shutdown);
995
996 impl Clone for ArbitraryShutdown {
997 fn clone(&self) -> Self {
998 let shutdown = match self.0 {
999 Shutdown::None => Shutdown::None,
1000 Shutdown::Asap => Shutdown::Asap,
1001 Shutdown::Later(_) => Shutdown::Later(
1002 Delay::new(Duration::from_secs(1)),
1005 ),
1006 };
1007
1008 ArbitraryShutdown(shutdown)
1009 }
1010 }
1011
1012 impl Arbitrary for ArbitraryShutdown {
1013 fn arbitrary(g: &mut Gen) -> Self {
1014 let shutdown = match g.gen_range(1u8..4) {
1015 1 => Shutdown::None,
1016 2 => Shutdown::Asap,
1017 3 => Shutdown::Later(Delay::new(Duration::from_secs(u32::arbitrary(g) as u64))),
1018 _ => unreachable!(),
1019 };
1020
1021 Self(shutdown)
1022 }
1023 }
1024
1025 fn prop(
1026 handler_keep_alive: bool,
1027 current_shutdown: ArbitraryShutdown,
1028 idle_timeout: Duration,
1029 ) {
1030 compute_new_shutdown(handler_keep_alive, ¤t_shutdown.0, idle_timeout);
1031 }
1032
1033 QuickCheck::new().quickcheck(prop as fn(_, _, _));
1034 }
1035
1036 struct DummyStreamMuxer {
1037 counter: Arc<()>,
1038 }
1039
1040 impl StreamMuxer for DummyStreamMuxer {
1041 type Substream = PendingSubstream;
1042 type Error = Infallible;
1043
1044 fn poll_inbound(
1045 self: Pin<&mut Self>,
1046 _: &mut Context<'_>,
1047 ) -> Poll<Result<Self::Substream, Self::Error>> {
1048 Poll::Ready(Ok(PendingSubstream {
1049 _weak: Arc::downgrade(&self.counter),
1050 }))
1051 }
1052
1053 fn poll_outbound(
1054 self: Pin<&mut Self>,
1055 _: &mut Context<'_>,
1056 ) -> Poll<Result<Self::Substream, Self::Error>> {
1057 Poll::Pending
1058 }
1059
1060 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1061 Poll::Ready(Ok(()))
1062 }
1063
1064 fn poll(
1065 self: Pin<&mut Self>,
1066 _: &mut Context<'_>,
1067 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
1068 Poll::Pending
1069 }
1070 }
1071
1072 struct PendingStreamMuxer;
1074
1075 impl StreamMuxer for PendingStreamMuxer {
1076 type Substream = PendingSubstream;
1077 type Error = Infallible;
1078
1079 fn poll_inbound(
1080 self: Pin<&mut Self>,
1081 _: &mut Context<'_>,
1082 ) -> Poll<Result<Self::Substream, Self::Error>> {
1083 Poll::Pending
1084 }
1085
1086 fn poll_outbound(
1087 self: Pin<&mut Self>,
1088 _: &mut Context<'_>,
1089 ) -> Poll<Result<Self::Substream, Self::Error>> {
1090 Poll::Pending
1091 }
1092
1093 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1094 Poll::Pending
1095 }
1096
1097 fn poll(
1098 self: Pin<&mut Self>,
1099 _: &mut Context<'_>,
1100 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
1101 Poll::Pending
1102 }
1103 }
1104
1105 struct PendingSubstream {
1106 _weak: Weak<()>,
1107 }
1108
1109 impl AsyncRead for PendingSubstream {
1110 fn poll_read(
1111 self: Pin<&mut Self>,
1112 _cx: &mut Context<'_>,
1113 _buf: &mut [u8],
1114 ) -> Poll<std::io::Result<usize>> {
1115 Poll::Pending
1116 }
1117 }
1118
1119 impl AsyncWrite for PendingSubstream {
1120 fn poll_write(
1121 self: Pin<&mut Self>,
1122 _cx: &mut Context<'_>,
1123 _buf: &[u8],
1124 ) -> Poll<std::io::Result<usize>> {
1125 Poll::Pending
1126 }
1127
1128 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1129 Poll::Pending
1130 }
1131
1132 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1133 Poll::Pending
1134 }
1135 }
1136
1137 struct MockConnectionHandler {
1138 outbound_requested: bool,
1139 error: Option<StreamUpgradeError<Infallible>>,
1140 upgrade_timeout: Duration,
1141 }
1142
1143 impl MockConnectionHandler {
1144 fn new(upgrade_timeout: Duration) -> Self {
1145 Self {
1146 outbound_requested: false,
1147 error: None,
1148 upgrade_timeout,
1149 }
1150 }
1151
1152 fn open_new_outbound(&mut self) {
1153 self.outbound_requested = true;
1154 }
1155 }
1156
1157 #[derive(Default)]
1158 struct ConfigurableProtocolConnectionHandler {
1159 events: Vec<ConnectionHandlerEvent<DeniedUpgrade, (), Infallible>>,
1160 active_protocols: HashSet<StreamProtocol>,
1161 local_added: Vec<Vec<StreamProtocol>>,
1162 local_removed: Vec<Vec<StreamProtocol>>,
1163 remote_added: Vec<Vec<StreamProtocol>>,
1164 remote_removed: Vec<Vec<StreamProtocol>>,
1165 }
1166
1167 impl ConfigurableProtocolConnectionHandler {
1168 fn listen_on(&mut self, protocols: &[&'static str]) {
1169 self.active_protocols = protocols.iter().copied().map(StreamProtocol::new).collect();
1170 }
1171
1172 fn remote_adds_support_for(&mut self, protocols: &[&'static str]) {
1173 self.events
1174 .push(ConnectionHandlerEvent::ReportRemoteProtocols(
1175 ProtocolSupport::Added(
1176 protocols.iter().copied().map(StreamProtocol::new).collect(),
1177 ),
1178 ));
1179 }
1180
1181 fn remote_removes_support_for(&mut self, protocols: &[&'static str]) {
1182 self.events
1183 .push(ConnectionHandlerEvent::ReportRemoteProtocols(
1184 ProtocolSupport::Removed(
1185 protocols.iter().copied().map(StreamProtocol::new).collect(),
1186 ),
1187 ));
1188 }
1189 }
1190
1191 impl ConnectionHandler for MockConnectionHandler {
1192 type FromBehaviour = Infallible;
1193 type ToBehaviour = Infallible;
1194 type InboundProtocol = DeniedUpgrade;
1195 type OutboundProtocol = DeniedUpgrade;
1196 type InboundOpenInfo = ();
1197 type OutboundOpenInfo = ();
1198
1199 fn listen_protocol(
1200 &self,
1201 ) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
1202 SubstreamProtocol::new(DeniedUpgrade, ()).with_timeout(self.upgrade_timeout)
1203 }
1204
1205 fn on_connection_event(
1206 &mut self,
1207 event: ConnectionEvent<
1208 Self::InboundProtocol,
1209 Self::OutboundProtocol,
1210 Self::InboundOpenInfo,
1211 Self::OutboundOpenInfo,
1212 >,
1213 ) {
1214 match event {
1215 #[allow(unreachable_patterns)]
1217 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
1218 protocol,
1219 ..
1220 }) => ant_libp2p_core::util::unreachable(protocol),
1221 #[allow(unreachable_patterns)]
1223 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
1224 protocol,
1225 ..
1226 }) => ant_libp2p_core::util::unreachable(protocol),
1227 ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
1228 self.error = Some(error)
1229 }
1230 #[allow(unreachable_patterns)]
1232 ConnectionEvent::AddressChange(_)
1233 | ConnectionEvent::ListenUpgradeError(_)
1234 | ConnectionEvent::LocalProtocolsChange(_)
1235 | ConnectionEvent::RemoteProtocolsChange(_) => {}
1236 }
1237 }
1238
1239 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
1240 #[allow(unreachable_patterns)]
1242 ant_libp2p_core::util::unreachable(event)
1243 }
1244
1245 fn connection_keep_alive(&self) -> bool {
1246 true
1247 }
1248
1249 fn poll(
1250 &mut self,
1251 _: &mut Context<'_>,
1252 ) -> Poll<
1253 ConnectionHandlerEvent<
1254 Self::OutboundProtocol,
1255 Self::OutboundOpenInfo,
1256 Self::ToBehaviour,
1257 >,
1258 > {
1259 if self.outbound_requested {
1260 self.outbound_requested = false;
1261 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
1262 protocol: SubstreamProtocol::new(DeniedUpgrade, ())
1263 .with_timeout(self.upgrade_timeout),
1264 });
1265 }
1266
1267 Poll::Pending
1268 }
1269 }
1270
1271 impl ConnectionHandler for ConfigurableProtocolConnectionHandler {
1272 type FromBehaviour = Infallible;
1273 type ToBehaviour = Infallible;
1274 type InboundProtocol = ManyProtocolsUpgrade;
1275 type OutboundProtocol = DeniedUpgrade;
1276 type InboundOpenInfo = ();
1277 type OutboundOpenInfo = ();
1278
1279 fn listen_protocol(
1280 &self,
1281 ) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
1282 SubstreamProtocol::new(
1283 ManyProtocolsUpgrade {
1284 protocols: Vec::from_iter(self.active_protocols.clone()),
1285 },
1286 (),
1287 )
1288 }
1289
1290 fn on_connection_event(
1291 &mut self,
1292 event: ConnectionEvent<
1293 Self::InboundProtocol,
1294 Self::OutboundProtocol,
1295 Self::InboundOpenInfo,
1296 Self::OutboundOpenInfo,
1297 >,
1298 ) {
1299 match event {
1300 ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Added(added)) => {
1301 self.local_added.push(added.cloned().collect())
1302 }
1303 ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Removed(removed)) => {
1304 self.local_removed.push(removed.cloned().collect())
1305 }
1306 ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Added(added)) => {
1307 self.remote_added.push(added.cloned().collect())
1308 }
1309 ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Removed(removed)) => {
1310 self.remote_removed.push(removed.cloned().collect())
1311 }
1312 _ => {}
1313 }
1314 }
1315
1316 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
1317 #[allow(unreachable_patterns)]
1319 ant_libp2p_core::util::unreachable(event)
1320 }
1321
1322 fn connection_keep_alive(&self) -> bool {
1323 true
1324 }
1325
1326 fn poll(
1327 &mut self,
1328 _: &mut Context<'_>,
1329 ) -> Poll<
1330 ConnectionHandlerEvent<
1331 Self::OutboundProtocol,
1332 Self::OutboundOpenInfo,
1333 Self::ToBehaviour,
1334 >,
1335 > {
1336 if let Some(event) = self.events.pop() {
1337 return Poll::Ready(event);
1338 }
1339
1340 Poll::Pending
1341 }
1342 }
1343
1344 struct ManyProtocolsUpgrade {
1345 protocols: Vec<StreamProtocol>,
1346 }
1347
1348 impl UpgradeInfo for ManyProtocolsUpgrade {
1349 type Info = StreamProtocol;
1350 type InfoIter = std::vec::IntoIter<Self::Info>;
1351
1352 fn protocol_info(&self) -> Self::InfoIter {
1353 self.protocols.clone().into_iter()
1354 }
1355 }
1356
1357 impl<C> InboundUpgrade<C> for ManyProtocolsUpgrade {
1358 type Output = C;
1359 type Error = Infallible;
1360 type Future = future::Ready<Result<Self::Output, Self::Error>>;
1361
1362 fn upgrade_inbound(self, stream: C, _: Self::Info) -> Self::Future {
1363 future::ready(Ok(stream))
1364 }
1365 }
1366
1367 impl<C> OutboundUpgrade<C> for ManyProtocolsUpgrade {
1368 type Output = C;
1369 type Error = Infallible;
1370 type Future = future::Ready<Result<Self::Output, Self::Error>>;
1371
1372 fn upgrade_outbound(self, stream: C, _: Self::Info) -> Self::Future {
1373 future::ready(Ok(stream))
1374 }
1375 }
1376}
1377
1378#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1380enum PendingPoint {
1381 Dialer {
1387 role_override: Endpoint,
1389 port_use: PortUse,
1390 },
1391 Listener {
1393 local_addr: Multiaddr,
1395 send_back_addr: Multiaddr,
1397 },
1398}
1399
1400impl From<ConnectedPoint> for PendingPoint {
1401 fn from(endpoint: ConnectedPoint) -> Self {
1402 match endpoint {
1403 ConnectedPoint::Dialer {
1404 role_override,
1405 port_use,
1406 ..
1407 } => PendingPoint::Dialer {
1408 role_override,
1409 port_use,
1410 },
1411 ConnectedPoint::Listener {
1412 local_addr,
1413 send_back_addr,
1414 } => PendingPoint::Listener {
1415 local_addr,
1416 send_back_addr,
1417 },
1418 }
1419 }
1420}