1use std::{
17 collections::HashSet,
18 fmt,
19 io,
20 net::{IpAddr, SocketAddr},
21 ops::Deref,
22 sync::{
23 Arc,
24 atomic::{AtomicUsize, Ordering::*},
25 },
26 time::{Duration, Instant},
27};
28
29#[cfg(feature = "locktick")]
30use locktick::parking_lot::Mutex;
31use once_cell::sync::OnceCell;
32#[cfg(not(feature = "locktick"))]
33use parking_lot::Mutex;
34use tokio::{
35 io::split,
36 net::{TcpListener, TcpSocket, TcpStream},
37 sync::oneshot,
38 task::JoinHandle,
39 time::timeout,
40};
41use tracing::*;
42
43use crate::{
44 BannedPeers,
45 Config,
46 KnownPeers,
47 Stats,
48 connections::{Connection, ConnectionSide, Connections},
49 protocols::{Protocol, Protocols},
50};
51
52static SEQUENTIAL_NODE_ID: AtomicUsize = AtomicUsize::new(0);
54
55#[derive(Clone)]
57pub struct Tcp(Arc<InnerTcp>);
58
59impl Deref for Tcp {
60 type Target = Arc<InnerTcp>;
61
62 fn deref(&self) -> &Self::Target {
63 &self.0
64 }
65}
66
67#[allow(missing_docs)]
69#[derive(thiserror::Error, Debug)]
70pub enum ConnectError {
71 #[error("already reached the maximum number of {limit} connections")]
72 MaximumConnectionsReached { limit: u16 },
73 #[error("already connecting to node at {address:?}")]
74 AlreadyConnecting { address: SocketAddr },
75 #[error("already connected to node at {address:?}")]
76 AlreadyConnected { address: SocketAddr },
77 #[error("attempt to self-connect (at address {address:?}")]
78 SelfConnect { address: SocketAddr },
79 #[error("I/O error: {0}")]
80 IoError(std::io::Error),
81}
82
83impl From<std::io::Error> for ConnectError {
84 fn from(inner: std::io::Error) -> Self {
85 Self::IoError(inner)
86 }
87}
88
89#[doc(hidden)]
90pub struct InnerTcp {
91 span: Span,
93 config: Config,
95 listening_addr: OnceCell<SocketAddr>,
97 pub(crate) protocols: Protocols,
99 connecting: Mutex<HashSet<SocketAddr>>,
101 connections: Connections,
103 known_peers: KnownPeers,
105 banned_peers: BannedPeers,
107 stats: Stats,
109 pub(crate) tasks: Mutex<Vec<JoinHandle<()>>>,
111}
112
113impl Tcp {
114 pub fn new(mut config: Config) -> Self {
116 if config.name.is_none() {
118 config.name = Some(SEQUENTIAL_NODE_ID.fetch_add(1, Relaxed).to_string());
119 }
120
121 let span = crate::helpers::create_span(config.name.as_deref().unwrap());
123
124 let tcp = Tcp(Arc::new(InnerTcp {
126 span,
127 config,
128 listening_addr: Default::default(),
129 protocols: Default::default(),
130 connecting: Default::default(),
131 connections: Default::default(),
132 known_peers: Default::default(),
133 banned_peers: Default::default(),
134 stats: Stats::new(Instant::now()),
135 tasks: Default::default(),
136 }));
137
138 debug!(parent: tcp.span(), "The node is ready");
139
140 tcp
141 }
142
143 #[inline]
145 pub fn name(&self) -> &str {
146 self.config.name.as_deref().unwrap()
148 }
149
150 #[inline]
152 pub fn config(&self) -> &Config {
153 &self.config
154 }
155
156 pub fn listening_addr(&self) -> io::Result<SocketAddr> {
159 self.listening_addr.get().copied().ok_or_else(|| io::ErrorKind::AddrNotAvailable.into())
160 }
161
162 pub fn is_connected(&self, addr: SocketAddr) -> bool {
164 self.connections.is_connected(addr)
165 }
166
167 pub fn is_connecting(&self, addr: SocketAddr) -> bool {
169 self.connecting.lock().contains(&addr)
170 }
171
172 pub fn num_connected(&self) -> usize {
174 self.connections.num_connected()
175 }
176
177 pub fn num_connecting(&self) -> usize {
179 self.connecting.lock().len()
180 }
181
182 pub fn connected_addrs(&self) -> Vec<SocketAddr> {
184 self.connections.addrs()
185 }
186
187 pub fn connecting_addrs(&self) -> Vec<SocketAddr> {
189 self.connecting.lock().iter().copied().collect()
190 }
191
192 #[inline]
194 pub fn known_peers(&self) -> &KnownPeers {
195 &self.known_peers
196 }
197
198 #[inline]
200 pub fn banned_peers(&self) -> &BannedPeers {
201 &self.banned_peers
202 }
203
204 #[inline]
206 pub fn stats(&self) -> &Stats {
207 &self.stats
208 }
209
210 #[inline]
212 pub fn span(&self) -> &Span {
213 &self.span
214 }
215
216 pub async fn shut_down(&self) {
218 debug!(parent: self.span(), "Shutting down the TCP stack");
219
220 let mut tasks = std::mem::take(&mut *self.tasks.lock()).into_iter();
222
223 if let Some(listening_task) = tasks.next() {
225 listening_task.abort(); }
227 for addr in self.connected_addrs() {
229 self.disconnect(addr).await;
230 }
231 for handle in tasks {
233 handle.abort();
234 }
235 }
236}
237
238impl Tcp {
239 pub async fn connect(&self, addr: SocketAddr) -> Result<(), ConnectError> {
241 if let Ok(listening_addr) = self.listening_addr() {
242 if addr == listening_addr || self.is_self_connect(addr) {
244 error!(parent: self.span(), "Attempted to self-connect ({addr})");
245 return Err(ConnectError::SelfConnect { address: addr });
246 }
247 }
248
249 if !self.can_add_connection() {
250 error!(parent: self.span(), "Too many connections; refusing to connect to {addr}");
251 return Err(ConnectError::MaximumConnectionsReached { limit: self.config.max_connections });
252 }
253
254 if self.is_connected(addr) {
255 warn!(parent: self.span(), "Already connected to {addr}");
256 return Err(ConnectError::AlreadyConnected { address: addr });
257 }
258
259 if !self.connecting.lock().insert(addr) {
260 warn!(parent: self.span(), "Already connecting to {addr}");
261 return Err(ConnectError::AlreadyConnecting { address: addr });
262 }
263
264 let timeout_duration = Duration::from_millis(self.config().connection_timeout_ms.into());
265
266 let res = if let Some(listen_ip) = self.config().listener_ip {
269 timeout(timeout_duration, self.connect_with_specific_interface(listen_ip, addr)).await
270 } else {
271 timeout(timeout_duration, TcpStream::connect(addr)).await
272 };
273
274 let stream = match res {
275 Ok(Ok(stream)) => Ok(stream),
276 Ok(err) => {
277 self.connecting.lock().remove(&addr);
278 err
279 }
280 Err(err) => {
281 self.connecting.lock().remove(&addr);
282 error!("connection timeout error: {}", err);
283 Err(io::ErrorKind::TimedOut.into())
284 }
285 }?;
286
287 let ret = self.adapt_stream(stream, addr, ConnectionSide::Initiator).await;
288
289 if let Err(ref e) = ret {
290 self.connecting.lock().remove(&addr);
291 self.known_peers().register_failure(addr.ip());
292 error!(parent: self.span(), "Unable to initiate a connection with {addr}: {e}");
293 }
294
295 ret.map_err(|err| err.into())
296 }
297
298 async fn connect_with_specific_interface(&self, listen_ip: IpAddr, addr: SocketAddr) -> io::Result<TcpStream> {
299 let sock = if listen_ip.is_ipv4() { TcpSocket::new_v4()? } else { TcpSocket::new_v6()? };
300 sock.bind(SocketAddr::new(listen_ip, 0))?;
302 sock.connect(addr).await
303 }
304
305 pub async fn disconnect(&self, addr: SocketAddr) -> bool {
309 if let Some(conn) = self.connections.0.read().get(&addr) {
311 if conn.disconnecting.swap(true, Relaxed) {
312 return false;
314 }
315 } else {
316 return false;
318 };
319
320 if let Some(handler) = self.protocols.disconnect.get() {
321 let (sender, receiver) = oneshot::channel();
322 handler.trigger((addr, sender));
323 let _ = receiver.await; }
325
326 let conn = self.connections.remove(addr);
327
328 if let Some(ref conn) = conn {
329 debug!(parent: self.span(), "Disconnecting from {}", conn.addr());
330
331 for task in conn.tasks.iter().rev() {
333 task.abort();
334 }
335
336 debug!(parent: self.span(), "Disconnected from {}", conn.addr());
337 } else {
338 warn!(parent: self.span(), "Failed to disconnect, was not connected to {addr}");
339 }
340
341 conn.is_some()
342 }
343}
344
345impl Tcp {
346 pub async fn enable_listener(&self) -> io::Result<SocketAddr> {
348 let listener_ip =
350 self.config().listener_ip.expect("Tcp::enable_listener was called, but Config::listener_ip is not set");
351
352 let listener = self.create_listener(listener_ip).await?;
354
355 let port = listener.local_addr()?.port();
357
358 let listening_addr = (listener_ip, port).into();
360 self.listening_addr.set(listening_addr).expect("The node's listener was started more than once");
361
362 let (tx, rx) = oneshot::channel();
364
365 let tcp = self.clone();
366 let listening_task = tokio::spawn(async move {
367 trace!(parent: tcp.span(), "Spawned the listening task");
368 tx.send(()).unwrap(); loop {
371 match listener.accept().await {
373 Ok((stream, addr)) => tcp.handle_connection(stream, addr),
374 Err(e) => error!(parent: tcp.span(), "Failed to accept a connection: {e}"),
375 }
376 }
377 });
378 self.tasks.lock().push(listening_task);
379 let _ = rx.await;
380 debug!(parent: self.span(), "Listening on {listening_addr}");
381
382 Ok(listening_addr)
383 }
384
385 async fn create_listener(&self, listener_ip: IpAddr) -> io::Result<TcpListener> {
387 debug!("Creating a TCP listener on {listener_ip}...");
388 let listener = if let Some(port) = self.config().desired_listening_port {
389 let desired_listening_addr = SocketAddr::new(listener_ip, port);
391 match TcpListener::bind(desired_listening_addr).await {
393 Ok(listener) => listener,
394 Err(e) => {
395 if self.config().allow_random_port {
396 warn!(
397 parent: self.span(),
398 "Trying any listening port, as the desired port is unavailable: {e}"
399 );
400 let random_available_addr = SocketAddr::new(listener_ip, 0);
401 TcpListener::bind(random_available_addr).await?
402 } else {
403 error!(parent: self.span(), "The desired listening port is unavailable: {e}");
404 return Err(e);
405 }
406 }
407 }
408 } else if self.config().allow_random_port {
409 let random_available_addr = SocketAddr::new(listener_ip, 0);
410 TcpListener::bind(random_available_addr).await?
411 } else {
412 panic!("As 'listener_ip' is set, either 'desired_listening_port' or 'allow_random_port' must be set");
413 };
414
415 Ok(listener)
416 }
417
418 fn handle_connection(&self, stream: TcpStream, addr: SocketAddr) {
420 debug!(parent: self.span(), "Received a connection from {addr}");
421
422 if !self.can_add_connection() || self.is_self_connect(addr) {
423 debug!(parent: self.span(), "Rejecting the connection from {addr}");
424 return;
425 }
426
427 self.connecting.lock().insert(addr);
428
429 let tcp = self.clone();
430 tokio::spawn(async move {
431 if let Err(e) = tcp.adapt_stream(stream, addr, ConnectionSide::Responder).await {
432 tcp.connecting.lock().remove(&addr);
433 tcp.known_peers().register_failure(addr.ip());
434 error!(parent: tcp.span(), "Failed to connect with {addr}: {e}");
435 }
436 });
437 }
438
439 fn is_self_connect(&self, addr: SocketAddr) -> bool {
441 let listening_addr = self.listening_addr().unwrap();
443
444 match listening_addr.ip().is_loopback() {
445 true => listening_addr.port() == addr.port(),
448 false => listening_addr.ip() == addr.ip(),
450 }
451 }
452
453 fn can_add_connection(&self) -> bool {
455 let num_connected = self.num_connected();
457 let limit = self.config.max_connections as usize;
459
460 if num_connected >= limit {
461 warn!(parent: self.span(), "Maximum number of active connections ({limit}) reached");
462 false
463 } else if num_connected + self.num_connecting() >= limit {
464 warn!(parent: self.span(), "Maximum number of active & pending connections ({limit}) reached");
465 false
466 } else {
467 true
468 }
469 }
470
471 async fn adapt_stream(&self, stream: TcpStream, peer_addr: SocketAddr, own_side: ConnectionSide) -> io::Result<()> {
473 self.known_peers.add(peer_addr.ip());
474
475 if own_side == ConnectionSide::Initiator {
477 if let Ok(addr) = stream.local_addr() {
478 debug!(
479 parent: self.span(), "establishing connection with {}; the peer is connected on port {}",
480 peer_addr, addr.port()
481 );
482 } else {
483 warn!(parent: self.span(), "couldn't determine the peer's port");
484 }
485 }
486
487 let connection = Connection::new(peer_addr, stream, !own_side);
488
489 let mut connection = self.enable_protocols(connection).await?;
491
492 let conn_ready_tx = connection.readiness_notifier.take();
494
495 self.connections.add(connection);
496 self.connecting.lock().remove(&peer_addr);
497
498 if let Some(tx) = conn_ready_tx {
500 let _ = tx.send(());
501 }
502
503 if let Some(handler) = self.protocols.on_connect.get() {
505 let (sender, receiver) = oneshot::channel();
506 handler.trigger((peer_addr, sender));
507 let _ = receiver.await; }
509
510 Ok(())
511 }
512
513 async fn enable_protocols(&self, conn: Connection) -> io::Result<Connection> {
515 macro_rules! enable_protocol {
517 ($handler_type: ident, $node:expr, $conn: expr) => {
518 if let Some(handler) = $node.protocols.$handler_type.get() {
519 let (conn_returner, conn_retriever) = oneshot::channel();
520
521 handler.trigger(($conn, conn_returner));
522
523 match conn_retriever.await {
524 Ok(Ok(conn)) => conn,
525 Err(_) => return Err(io::ErrorKind::BrokenPipe.into()),
526 Ok(e) => return e,
527 }
528 } else {
529 $conn
530 }
531 };
532 }
533
534 let mut conn = enable_protocol!(handshake, self, conn);
535
536 if let Some(stream) = conn.stream.take() {
538 let (reader, writer) = split(stream);
539 conn.reader = Some(Box::new(reader));
540 conn.writer = Some(Box::new(writer));
541 }
542
543 let conn = enable_protocol!(reading, self, conn);
544 let conn = enable_protocol!(writing, self, conn);
545
546 Ok(conn)
547 }
548}
549
550impl fmt::Debug for Tcp {
551 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
552 write!(f, "The TCP stack config: {:?}", self.config)
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559
560 use std::{
561 net::{IpAddr, Ipv4Addr},
562 str::FromStr,
563 };
564
565 #[tokio::test]
566 async fn test_new() {
567 let tcp = Tcp::new(Config {
568 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
569 max_connections: 200,
570 ..Default::default()
571 });
572
573 assert_eq!(tcp.config.max_connections, 200);
574 assert_eq!(tcp.config.listener_ip, Some(IpAddr::V4(Ipv4Addr::LOCALHOST)));
575 assert_eq!(tcp.enable_listener().await.unwrap().ip(), IpAddr::V4(Ipv4Addr::LOCALHOST));
576
577 assert_eq!(tcp.num_connected(), 0);
578 assert_eq!(tcp.num_connecting(), 0);
579 }
580
581 #[tokio::test]
582 async fn test_connect() {
583 let tcp = Tcp::new(Config::default());
584 let node_ip = tcp.enable_listener().await.unwrap();
585
586 let result = tcp.connect(node_ip).await;
588 assert!(matches!(result, Err(ConnectError::SelfConnect { .. })));
589
590 assert_eq!(tcp.num_connected(), 0);
591 assert_eq!(tcp.num_connecting(), 0);
592 assert!(!tcp.is_connected(node_ip));
593 assert!(!tcp.is_connecting(node_ip));
594
595 let peer = Tcp::new(Config {
597 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
598 desired_listening_port: Some(0),
599 max_connections: 1,
600 ..Default::default()
601 });
602 let peer_ip = peer.enable_listener().await.unwrap();
603
604 tcp.connect(peer_ip).await.unwrap();
606 assert_eq!(tcp.num_connected(), 1);
607 assert_eq!(tcp.num_connecting(), 0);
608 assert!(tcp.is_connected(peer_ip));
609 assert!(!tcp.is_connecting(peer_ip));
610 }
611
612 #[tokio::test]
613 async fn test_disconnect() {
614 let tcp = Tcp::new(Config::default());
615 let _node_ip = tcp.enable_listener().await.unwrap();
616
617 let peer = Tcp::new(Config {
619 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
620 desired_listening_port: Some(0),
621 max_connections: 1,
622 ..Default::default()
623 });
624 let peer_ip = peer.enable_listener().await.unwrap();
625
626 tcp.connect(peer_ip).await.unwrap();
628 assert_eq!(tcp.num_connected(), 1);
629 assert_eq!(tcp.num_connecting(), 0);
630 assert!(tcp.is_connected(peer_ip));
631 assert!(!tcp.is_connecting(peer_ip));
632
633 let has_disconnected = tcp.disconnect(peer_ip).await;
635 assert!(has_disconnected);
636 assert_eq!(tcp.num_connected(), 0);
637 assert_eq!(tcp.num_connecting(), 0);
638 assert!(!tcp.is_connected(peer_ip));
639 assert!(!tcp.is_connecting(peer_ip));
640
641 let has_disconnected = tcp.disconnect(peer_ip).await;
643 assert!(!has_disconnected);
644 assert_eq!(tcp.num_connected(), 0);
645 assert_eq!(tcp.num_connecting(), 0);
646 assert!(!tcp.is_connected(peer_ip));
647 assert!(!tcp.is_connecting(peer_ip));
648 }
649
650 #[tokio::test]
651 async fn test_can_add_connection() {
652 let tcp = Tcp::new(Config { max_connections: 1, ..Default::default() });
653
654 let peer = Tcp::new(Config {
656 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
657 desired_listening_port: Some(0),
658 max_connections: 1,
659 ..Default::default()
660 });
661 let peer_ip = peer.enable_listener().await.unwrap();
662
663 assert!(tcp.can_add_connection());
664
665 let stream = TcpStream::connect(peer_ip).await.unwrap();
667 tcp.connections.add(Connection::new(peer_ip, stream, ConnectionSide::Initiator));
668 assert!(!tcp.can_add_connection());
669
670 let another_ip = SocketAddr::from_str("1.2.3.4:4242").unwrap();
673 let result = tcp.connect(another_ip).await;
674 assert!(matches!(result, Err(ConnectError::MaximumConnectionsReached { .. })));
675
676 tcp.connections.remove(peer_ip);
678 assert!(tcp.can_add_connection());
679
680 tcp.connecting.lock().insert(peer_ip);
682 assert!(!tcp.can_add_connection());
683
684 let another_ip = SocketAddr::from_str("1.2.3.4:4242").unwrap();
686 let result = tcp.connect(another_ip).await;
687 assert!(matches!(result, Err(ConnectError::MaximumConnectionsReached { .. })));
688
689 tcp.connecting.lock().remove(&peer_ip);
691 assert!(tcp.can_add_connection());
692
693 let stream = TcpStream::connect(peer_ip).await.unwrap();
695 tcp.connections.add(Connection::new(peer_ip, stream, ConnectionSide::Responder));
696 tcp.connecting.lock().insert(peer_ip);
697 assert!(!tcp.can_add_connection());
698
699 tcp.connections.remove(peer_ip);
701 tcp.connecting.lock().remove(&peer_ip);
702 assert!(tcp.can_add_connection());
703 }
704
705 #[tokio::test]
706 async fn test_handle_connection() {
707 let tcp = Tcp::new(Config {
708 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
709 max_connections: 1,
710 ..Default::default()
711 });
712
713 let peer1 = Tcp::new(Config {
715 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
716 desired_listening_port: Some(0),
717 max_connections: 1,
718 ..Default::default()
719 });
720 let peer1_ip = peer1.enable_listener().await.unwrap();
721
722 let stream = TcpStream::connect(peer1_ip).await.unwrap();
724 tcp.connections.add(Connection::new(peer1_ip, stream, ConnectionSide::Responder));
725 assert!(!tcp.can_add_connection());
726 assert_eq!(tcp.num_connected(), 1);
727 assert_eq!(tcp.num_connecting(), 0);
728 assert!(tcp.is_connected(peer1_ip));
729 assert!(!tcp.is_connecting(peer1_ip));
730
731 let peer2 = Tcp::new(Config {
733 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
734 desired_listening_port: Some(0),
735 max_connections: 1,
736 ..Default::default()
737 });
738 let peer2_ip = peer2.enable_listener().await.unwrap();
739
740 let stream = TcpStream::connect(peer2_ip).await.unwrap();
742 tcp.handle_connection(stream, peer2_ip);
743 assert!(!tcp.can_add_connection());
744 assert_eq!(tcp.num_connected(), 1);
745 assert_eq!(tcp.num_connecting(), 0);
746 assert!(tcp.is_connected(peer1_ip));
747 assert!(!tcp.is_connected(peer2_ip));
748 assert!(!tcp.is_connecting(peer1_ip));
749 assert!(!tcp.is_connecting(peer2_ip));
750 }
751
752 #[tokio::test]
753 async fn test_adapt_stream() {
754 let tcp = Tcp::new(Config { max_connections: 1, ..Default::default() });
755
756 let peer = Tcp::new(Config {
758 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
759 desired_listening_port: Some(0),
760 max_connections: 1,
761 ..Default::default()
762 });
763 let peer_ip = peer.enable_listener().await.unwrap();
764
765 tcp.connecting.lock().insert(peer_ip);
767 assert_eq!(tcp.num_connected(), 0);
768 assert_eq!(tcp.num_connecting(), 1);
769 assert!(!tcp.is_connected(peer_ip));
770 assert!(tcp.is_connecting(peer_ip));
771
772 let stream = TcpStream::connect(peer_ip).await.unwrap();
774 tcp.adapt_stream(stream, peer_ip, ConnectionSide::Responder).await.unwrap();
775 assert_eq!(tcp.num_connected(), 1);
776 assert_eq!(tcp.num_connecting(), 0);
777 assert!(tcp.is_connected(peer_ip));
778 assert!(!tcp.is_connecting(peer_ip));
779 }
780}