1use std::{
2 collections::{HashMap, hash_map::Entry},
3 io::{self, ErrorKind},
4 net::SocketAddr,
5 ops::Deref,
6 sync::{
7 Arc,
8 atomic::{AtomicUsize, Ordering::*},
9 },
10 time::Duration,
11};
12
13use parking_lot::Mutex;
14use tokio::{
15 io::split,
16 net::{TcpListener, TcpSocket, TcpStream},
17 sync::{RwLock, oneshot},
18 task::{self, JoinHandle, JoinSet},
19 time::{sleep, timeout},
20};
21use tracing::*;
22
23#[cfg(doc)]
24use crate::protocols::Handshake;
25use crate::{
26 Config, Stats,
27 connections::{
28 Connection, ConnectionGuard, ConnectionInfo, ConnectionSide, Connections,
29 create_connection_span,
30 },
31 protocols::{Protocol, Protocols},
32};
33
34macro_rules! enable_protocol {
36 ($handler_type: ident, $node:expr, $conn: expr) => {
37 if let Some(handler) = $node.protocols.$handler_type.get() {
38 let (conn_returner, conn_retriever) = oneshot::channel();
39
40 handler.trigger(($conn, conn_returner)).await;
41
42 match conn_retriever.await {
43 Ok(Ok(conn)) => conn,
44 Err(_) => return Err(ErrorKind::BrokenPipe.into()),
45 Ok(e) => return e,
46 }
47 } else {
48 $conn
49 }
50 };
51}
52
53static SEQUENTIAL_NODE_ID: AtomicUsize = AtomicUsize::new(0);
55
56#[derive(Clone, Copy, PartialEq, Eq, Hash)]
58pub(crate) enum NodeTask {
59 Listener,
60 OnDisconnect,
61 Handshake,
62 OnConnect,
63 Reading,
64 Writing,
65}
66
67#[derive(Clone)]
73pub struct Node(Arc<InnerNode>);
74
75impl Deref for Node {
76 type Target = Arc<InnerNode>;
77
78 fn deref(&self) -> &Self::Target {
79 &self.0
80 }
81}
82
83#[doc(hidden)]
85pub struct InnerNode {
86 span: Span,
88 config: Config,
90 listening_addr: RwLock<Option<SocketAddr>>,
92 pub(crate) protocols: Protocols,
94 pub(crate) connections: Connections,
96 stats: Stats,
98 pub(crate) tasks: Mutex<HashMap<NodeTask, JoinHandle<()>>>,
100}
101
102impl Node {
103 pub fn new(mut config: Config) -> Self {
105 if config.name.is_none() {
107 config.name = Some(SEQUENTIAL_NODE_ID.fetch_add(1, SeqCst).to_string());
108 }
109
110 let span = create_span(config.name.as_deref().unwrap());
112
113 let node = Node(Arc::new(InnerNode {
114 span,
115 config,
116 listening_addr: Default::default(),
117 protocols: Default::default(),
118 connections: Default::default(),
119 stats: Default::default(),
120 tasks: Default::default(),
121 }));
122
123 debug!(parent: node.span(), "the node is ready");
124
125 node
126 }
127
128 pub async fn toggle_listener(&self) -> io::Result<Option<SocketAddr>> {
131 let mut listening_addr = self.listening_addr.write().await;
133
134 if let Some(old_listening_addr) = listening_addr.take() {
135 let listener_task = self.tasks.lock().remove(&NodeTask::Listener).unwrap(); listener_task.abort();
137 trace!(parent: self.span(), "aborted the listening task");
138 debug!(parent: self.span(), "no longer listening on {old_listening_addr}");
139 *listening_addr = None;
140
141 Ok(None)
142 } else {
143 let listener_addr = self.config().listener_addr.ok_or_else(|| {
144 error!(parent: self.span(), "the listener was toggled on, but Config::listener_addr is not set");
145 ErrorKind::AddrNotAvailable
146 })?;
147 trace!(parent: self.span(), "attempting to listen on {listener_addr}");
148 let listener = TcpListener::bind(listener_addr).await?;
149 let port = listener.local_addr()?.port(); let new_listening_addr = (listener_addr.ip(), port).into();
151
152 self.start_listening(listener).await;
154 debug!(parent: self.span(), "listening on {new_listening_addr}");
155
156 *listening_addr = Some(new_listening_addr);
158
159 Ok(Some(new_listening_addr))
160 }
161 }
162
163 async fn start_listening(&self, listener: TcpListener) {
165 let (tx, rx) = oneshot::channel();
167
168 let node = self.clone();
169 let listening_task = tokio::spawn(async move {
170 trace!(parent: node.span(), "spawned the listening task");
171 if tx.send(()).is_err() {
172 error!(parent: node.span(), "listener setup interrupted; shutting down the listening task");
173 return;
174 }
175
176 loop {
177 match listener.accept().await {
178 Ok((stream, addr)) => {
179 let node = node.clone();
181 tokio::spawn(async move {
182 node.handle_connection_request(stream, addr).await.inspect_err(|e|
183 match e.kind() {
184 ErrorKind::QuotaExceeded | ErrorKind::AlreadyExists => {
185 debug!(parent: node.span(), "rejecting connection from {addr}: {e}");
186 }
187 _ => {
188 error!(parent: node.span(), "couldn't accept a connection from {addr}: {e}");
189 }
190 }
191 )
192 });
193 }
194 Err(e) => {
195 error!(parent: node.span(), "couldn't accept a connection: {e}");
196 sleep(Duration::from_millis(500)).await;
199 }
200 }
201 }
202 });
203
204 self.tasks.lock().insert(NodeTask::Listener, listening_task);
205 let _ = rx.await;
206 }
207
208 async fn handle_connection_request(
210 &self,
211 stream: TcpStream,
212 addr: SocketAddr,
213 ) -> io::Result<()> {
214 let guard = self.check_and_reserve(addr)?;
216
217 self.adapt_stream(stream, addr, ConnectionSide::Responder, guard)
219 .await
220 }
221
222 #[inline]
224 pub fn name(&self) -> &str {
225 self.config.name.as_deref().unwrap()
227 }
228
229 #[inline]
231 pub fn config(&self) -> &Config {
232 &self.config
233 }
234
235 #[inline]
237 pub fn stats(&self) -> &Stats {
238 &self.stats
239 }
240
241 #[inline]
243 pub fn span(&self) -> &Span {
244 &self.span
245 }
246
247 pub async fn listening_addr(&self) -> io::Result<SocketAddr> {
250 self.listening_addr
251 .read()
252 .await
253 .as_ref()
254 .copied()
255 .ok_or_else(|| ErrorKind::AddrNotAvailable.into())
256 }
257
258 async fn enable_protocols(&self, conn: Connection) -> io::Result<Connection> {
260 let mut conn = enable_protocol!(handshake, self, conn);
261
262 if let Some(stream) = conn.stream.take() {
264 let (reader, writer) = split(stream);
265 conn.reader = Some(Box::new(reader));
266 conn.writer = Some(Box::new(writer));
267 }
268
269 let conn = enable_protocol!(reading, self, conn);
270 let conn = enable_protocol!(writing, self, conn);
271
272 Ok(conn)
273 }
274
275 async fn adapt_stream(
277 &self,
278 stream: TcpStream,
279 peer_addr: SocketAddr,
280 own_side: ConnectionSide,
281 mut guard: ConnectionGuard<'_>,
282 ) -> io::Result<()> {
283 let conn_span = create_connection_span(peer_addr, self.span());
284 debug!(parent: &conn_span, "establishing connection as the {own_side:?}");
285
286 if own_side == ConnectionSide::Initiator {
288 if let Ok(addr) = stream.local_addr() {
289 trace!(parent: &conn_span, "the peer is connected on port {}", addr.port());
290 } else {
291 warn!(parent: &conn_span, "couldn't determine the peer-side port");
292 }
293 }
294
295 let connection = Connection::new(peer_addr, stream, !own_side, conn_span.clone());
296
297 let mut connection = self.enable_protocols(connection).await?;
299
300 let conn_ready_tx = connection.readiness_notifier.take();
302
303 self.connections.add(connection);
305 guard.completed = true;
306 drop(guard);
307
308 if let Some(tx) = conn_ready_tx {
310 let _ = tx.send(());
311 }
312
313 debug!(parent: &conn_span, "fully connected");
314
315 if let Some(handler) = self.protocols.on_connect.get() {
317 trace!(parent: &conn_span, "executing OnConnect logic...");
318 let (sender, receiver) = oneshot::channel();
319 handler.trigger((peer_addr, sender)).await;
320
321 if let Ok(handle) = receiver.await {
323 if let Some(conn) = self.connections.active.write().get_mut(&peer_addr) {
324 conn.tasks.push(handle);
326 } else {
327 handle.abort();
329 }
330 }
331 }
332
333 Ok(())
334 }
335
336 async fn create_stream(
338 &self,
339 addr: SocketAddr,
340 socket: Option<TcpSocket>,
341 ) -> io::Result<TcpStream> {
342 match timeout(
343 Duration::from_millis(self.config().connection_timeout_ms.into()),
344 self.create_stream_inner(addr, socket),
345 )
346 .await
347 {
348 Ok(Ok(stream)) => Ok(stream),
349 Ok(err) => err,
350 Err(err) => Err(io::Error::new(ErrorKind::TimedOut, err)),
351 }
352 }
353
354 async fn create_stream_inner(
356 &self,
357 addr: SocketAddr,
358 socket: Option<TcpSocket>,
359 ) -> io::Result<TcpStream> {
360 if let Some(socket) = socket {
361 socket.connect(addr).await
362 } else {
363 TcpStream::connect(addr).await
364 }
365 }
366
367 pub async fn connect(&self, addr: SocketAddr) -> io::Result<()> {
375 self.connect_inner(addr, None)
376 .await
377 .inspect_err(|e| error!(parent: self.span(), "couldn't connect to {addr}: {e}"))
378 }
379
380 pub async fn connect_using_socket(
382 &self,
383 addr: SocketAddr,
384 socket: TcpSocket,
385 ) -> io::Result<()> {
386 self.connect_inner(addr, Some(socket))
387 .await
388 .inspect_err(|e| error!(parent: self.span(), "couldn't connect to {addr}: {e}"))
389 }
390
391 async fn connect_inner(&self, addr: SocketAddr, socket: Option<TcpSocket>) -> io::Result<()> {
393 if let Ok(listening_addr) = self.listening_addr().await {
395 if addr == listening_addr
396 || addr.ip().is_loopback() && addr.port() == listening_addr.port()
397 {
398 return Err(io::Error::new(
399 ErrorKind::AddrInUse,
400 "can't connect to node's own listening address ({addr})",
401 ));
402 }
403 }
404
405 if !self.config.allow_duplicate_connections && self.connections.is_connected(addr) {
408 return Err(io::Error::new(
409 ErrorKind::AlreadyExists,
410 "already connected to {addr}",
411 ));
412 }
413
414 let guard = self.check_and_reserve(addr)?;
416
417 let stream = self.create_stream(addr, socket).await?;
419
420 self.adapt_stream(stream, addr, ConnectionSide::Initiator, guard)
422 .await
423 }
424
425 pub async fn disconnect(&self, addr: SocketAddr) -> bool {
427 if let Some(conn) = self.connections.active.read().get(&addr) {
429 if conn.disconnecting.swap(true, Relaxed) {
430 return false;
432 }
433 } else {
434 return false;
436 };
437
438 let conn_span = create_connection_span(addr, self.span());
439 debug!(parent: &conn_span, "disconnecting...");
440
441 if let Some(handler) = self.protocols.on_disconnect.get() {
443 trace!(parent: &conn_span, "executing OnDisconnect logic...");
444 let (sender, receiver) = oneshot::channel();
445 handler.trigger((addr, sender)).await;
446 if let Ok((handle, waiter)) = receiver.await {
447 if let Some(conn) = self.connections.active.write().get_mut(&addr) {
450 conn.tasks.push(handle);
451 }
452 let _ = waiter.await;
456 }
457 }
458
459 if let Some(writing) = self.protocols.writing.get() {
461 writing.senders.write().remove(&addr);
464
465 task::yield_now().await;
468 }
469
470 let _ = self.connections.remove(addr);
472
473 if let Entry::Occupied(mut e) = self.connections.limits.lock().ip_counts.entry(addr.ip()) {
475 if *e.get() > 1 {
476 *e.get_mut() -= 1;
477 } else {
478 e.remove();
479 }
480 }
481
482 debug!(parent: &conn_span, "fully disconnected");
483
484 true
485 }
486
487 pub fn connected_addrs(&self) -> Vec<SocketAddr> {
489 self.connections.addrs()
490 }
491
492 pub fn is_connected(&self, addr: SocketAddr) -> bool {
494 self.connections.is_connected(addr)
495 }
496
497 pub fn is_connecting(&self, addr: SocketAddr) -> bool {
499 self.connections.limits.lock().connecting.contains(&addr)
500 }
501
502 pub fn num_connected(&self) -> usize {
504 self.connections.num_connected()
505 }
506
507 pub fn num_connecting(&self) -> usize {
509 self.connections.limits.lock().connecting.len()
510 }
511
512 pub fn connection_info(&self, addr: SocketAddr) -> Option<ConnectionInfo> {
514 self.connections.get_info(addr)
515 }
516
517 pub fn connection_infos(&self) -> HashMap<SocketAddr, ConnectionInfo> {
519 self.connections.infos()
520 }
521
522 fn check_and_reserve(&self, addr: SocketAddr) -> io::Result<ConnectionGuard<'_>> {
524 let mut limits = self.connections.limits.lock();
526
527 let ip = addr.ip();
529 let num_ip_conns = *limits.ip_counts.get(&ip).unwrap_or(&0);
530 let per_ip_limit = self.config.max_connections_per_ip as usize;
531 if num_ip_conns >= per_ip_limit {
532 return Err(io::Error::new(
533 ErrorKind::QuotaExceeded,
534 "maximum number ({per_ip_limit}) of per-IP connections reached with {ip}",
535 ));
536 }
537
538 let num_connecting = limits.connecting.len();
540 let connecting_limit = self.config.max_connecting as usize;
541 if num_connecting >= connecting_limit {
542 return Err(io::Error::new(
543 ErrorKind::QuotaExceeded,
544 "maximum number ({connecting_limit}) of pending connections reached",
545 ));
546 }
547
548 let num_connected = self.connections.num_connected();
550 let connection_limit = self.config.max_connections as usize;
551 if num_connected + num_connecting >= connection_limit {
552 return Err(io::Error::new(
553 ErrorKind::QuotaExceeded,
554 "maximum number ({connection_limit}) of connections reached",
555 ));
556 }
557
558 if limits.connecting.contains(&addr) {
560 return Err(io::Error::new(
561 ErrorKind::AlreadyExists,
562 "already connecting to {addr}",
563 ));
564 }
565
566 *limits.ip_counts.entry(ip).or_insert(0) += 1;
568 limits.connecting.insert(addr);
569
570 Ok(ConnectionGuard {
571 addr,
572 connections: &self.connections,
573 completed: false,
574 })
575 }
576
577 pub async fn shut_down(&self) {
579 debug!(parent: self.span(), "shutting down");
580
581 let mut tasks = std::mem::take(&mut *self.tasks.lock());
582
583 if let Some(listening_task) = tasks.remove(&NodeTask::Listener) {
585 listening_task.abort();
586 }
587
588 let mut disconnect_tasks = JoinSet::new();
590 for addr in self.connected_addrs() {
591 let node = self.clone();
592 disconnect_tasks.spawn(async move {
593 node.disconnect(addr).await;
594 });
595 }
596 while disconnect_tasks.join_next().await.is_some() {}
597
598 for handle in tasks.into_values() {
600 handle.abort();
601 }
602 }
603}
604
605fn create_span(node_name: &str) -> Span {
607 macro_rules! try_span {
608 ($lvl:expr) => {
609 let s = span!($lvl, "node", name = node_name);
610 if !s.is_disabled() {
611 return s;
612 }
613 };
614 }
615
616 try_span!(Level::TRACE);
617 try_span!(Level::DEBUG);
618 try_span!(Level::INFO);
619 try_span!(Level::WARN);
620
621 error_span!("node", name = node_name)
622}