Skip to main content

pea2pea/
node.rs

1use std::{
2    collections::{HashMap, hash_map::Entry},
3    io::{self, ErrorKind},
4    net::SocketAddr,
5    ops::Deref,
6    sync::{
7        Arc,
8        atomic::{AtomicUsize, Ordering::*},
9    },
10    time::Duration,
11};
12
13use parking_lot::Mutex;
14use tokio::{
15    io::split,
16    net::{TcpListener, TcpSocket, TcpStream},
17    sync::{RwLock, oneshot},
18    task::{self, JoinHandle, JoinSet},
19    time::{sleep, timeout},
20};
21use tracing::*;
22
23#[cfg(doc)]
24use crate::protocols::Handshake;
25use crate::{
26    Config, Stats,
27    connections::{
28        Connection, ConnectionGuard, ConnectionInfo, ConnectionSide, Connections,
29        create_connection_span,
30    },
31    protocols::{Protocol, Protocols},
32};
33
34// Starts the selected protocol handler for a new connection
35macro_rules! enable_protocol {
36    ($handler_type: ident, $node:expr, $conn: expr) => {
37        if let Some(handler) = $node.protocols.$handler_type.get() {
38            let (conn_returner, conn_retriever) = oneshot::channel();
39
40            handler.trigger(($conn, conn_returner)).await;
41
42            match conn_retriever.await {
43                Ok(Ok(conn)) => conn,
44                Err(_) => return Err(ErrorKind::BrokenPipe.into()),
45                Ok(e) => return e,
46            }
47        } else {
48            $conn
49        }
50    };
51}
52
53/// A sequential numeric identifier assigned to `Node`s that were not provided with a name.
54static SEQUENTIAL_NODE_ID: AtomicUsize = AtomicUsize::new(0);
55
56/// The types of long-running tasks supported by the Node.
57#[derive(Clone, Copy, PartialEq, Eq, Hash)]
58pub(crate) enum NodeTask {
59    Listener,
60    OnDisconnect,
61    Handshake,
62    OnConnect,
63    Reading,
64    Writing,
65}
66
67/// The central object responsible for handling connections.
68///
69/// note: Due to the architecture of protocol handlers capturing the node, a reference cycle exists
70/// that prevents the Node from being dropped automatically. You must call [`Node::shut_down`] when
71/// you are finished with a node to ensure all background tasks are aborted and sockets are closed.
72#[derive(Clone)]
73pub struct Node(Arc<InnerNode>);
74
75impl Deref for Node {
76    type Target = Arc<InnerNode>;
77
78    fn deref(&self) -> &Self::Target {
79        &self.0
80    }
81}
82
83/// The actual node object that gets wrapped in an Arc in the Node.
84#[doc(hidden)]
85pub struct InnerNode {
86    /// The tracing span.
87    span: Span,
88    /// The node's configuration.
89    config: Config,
90    /// The node's current listening address.
91    listening_addr: RwLock<Option<SocketAddr>>,
92    /// Contains objects used by the protocols implemented by the node.
93    pub(crate) protocols: Protocols,
94    /// Contains objects related to the node's active connections.
95    pub(crate) connections: Connections,
96    /// Collects statistics related to the node itself.
97    stats: Stats,
98    /// The node's tasks.
99    pub(crate) tasks: Mutex<HashMap<NodeTask, JoinHandle<()>>>,
100}
101
102impl Node {
103    /// Creates a new [`Node`] using the given [`Config`].
104    pub fn new(mut config: Config) -> Self {
105        // if there is no pre-configured name, assign a sequential numeric identifier
106        if config.name.is_none() {
107            config.name = Some(SEQUENTIAL_NODE_ID.fetch_add(1, SeqCst).to_string());
108        }
109
110        // create a tracing span containing the node's name
111        let span = create_span(config.name.as_deref().unwrap());
112
113        let node = Node(Arc::new(InnerNode {
114            span,
115            config,
116            listening_addr: Default::default(),
117            protocols: Default::default(),
118            connections: Default::default(),
119            stats: Default::default(),
120            tasks: Default::default(),
121        }));
122
123        debug!(parent: node.span(), "the node is ready");
124
125        node
126    }
127
128    /// Enables or disables listening for inbound connections; returns the actual bound address, which will
129    /// differ from the one in [`Config::listener_addr`] if that one's port was unspecified (i.e. `0`).
130    pub async fn toggle_listener(&self) -> io::Result<Option<SocketAddr>> {
131        // we deliberately maintain the write guard for the entirety of this method
132        let mut listening_addr = self.listening_addr.write().await;
133
134        if let Some(old_listening_addr) = listening_addr.take() {
135            let listener_task = self.tasks.lock().remove(&NodeTask::Listener).unwrap(); // can't fail
136            listener_task.abort();
137            trace!(parent: self.span(), "aborted the listening task");
138            debug!(parent: self.span(), "no longer listening on {old_listening_addr}");
139            *listening_addr = None;
140
141            Ok(None)
142        } else {
143            let listener_addr = self.config().listener_addr.ok_or_else(|| {
144                error!(parent: self.span(), "the listener was toggled on, but Config::listener_addr is not set");
145                ErrorKind::AddrNotAvailable
146            })?;
147            trace!(parent: self.span(), "attempting to listen on {listener_addr}");
148            let listener = TcpListener::bind(listener_addr).await?;
149            let port = listener.local_addr()?.port(); // discover the port if it was unspecified
150            let new_listening_addr = (listener_addr.ip(), port).into();
151
152            // start listening
153            self.start_listening(listener).await;
154            debug!(parent: self.span(), "listening on {new_listening_addr}");
155
156            // update the node's listening address
157            *listening_addr = Some(new_listening_addr);
158
159            Ok(Some(new_listening_addr))
160        }
161    }
162
163    /// Spawn a task responsible for listening for inbound connections.
164    async fn start_listening(&self, listener: TcpListener) {
165        // use a channel to know when the listening task is ready
166        let (tx, rx) = oneshot::channel();
167
168        let node = self.clone();
169        let listening_task = tokio::spawn(async move {
170            trace!(parent: node.span(), "spawned the listening task");
171            if tx.send(()).is_err() {
172                error!(parent: node.span(), "listener setup interrupted; shutting down the listening task");
173                return;
174            }
175
176            loop {
177                match listener.accept().await {
178                    Ok((stream, addr)) => {
179                        // handle connection requests asynchronously
180                        let node = node.clone();
181                        tokio::spawn(async move {
182                            node.handle_connection_request(stream, addr).await.inspect_err(|e|
183                                match e.kind() {
184                                    ErrorKind::QuotaExceeded | ErrorKind::AlreadyExists => {
185                                        debug!(parent: node.span(), "rejecting connection from {addr}: {e}");
186                                    }
187                                    _ => {
188                                        error!(parent: node.span(), "couldn't accept a connection from {addr}: {e}");
189                                    }
190                                }
191                            )
192                        });
193                    }
194                    Err(e) => {
195                        error!(parent: node.span(), "couldn't accept a connection: {e}");
196                        // if we ran out of FDs, sleep to avoid spinning 100% CPU
197                        // while waiting for a slot to free up
198                        sleep(Duration::from_millis(500)).await;
199                    }
200                }
201            }
202        });
203
204        self.tasks.lock().insert(NodeTask::Listener, listening_task);
205        let _ = rx.await;
206    }
207
208    /// Processes a single inbound connection request. Only used in [`Node::start_listening`].
209    async fn handle_connection_request(
210        &self,
211        stream: TcpStream,
212        addr: SocketAddr,
213    ) -> io::Result<()> {
214        // check connection limits and set up a connection guard
215        let guard = self.check_and_reserve(addr)?;
216
217        // finalize the connection
218        self.adapt_stream(stream, addr, ConnectionSide::Responder, guard)
219            .await
220    }
221
222    /// Returns the name assigned to the node.
223    #[inline]
224    pub fn name(&self) -> &str {
225        // safe; can be set as None in Config, but receives a default value on Node creation
226        self.config.name.as_deref().unwrap()
227    }
228
229    /// Returns a reference to the node's config.
230    #[inline]
231    pub fn config(&self) -> &Config {
232        &self.config
233    }
234
235    /// Returns a reference to the node's stats.
236    #[inline]
237    pub fn stats(&self) -> &Stats {
238        &self.stats
239    }
240
241    /// Returns the tracing [`Span`] associated with the node.
242    #[inline]
243    pub fn span(&self) -> &Span {
244        &self.span
245    }
246
247    /// Returns the node's current listening address; returns an error if the node was configured
248    /// to not listen for inbound connections or if the listener is currently disabled.
249    pub async fn listening_addr(&self) -> io::Result<SocketAddr> {
250        self.listening_addr
251            .read()
252            .await
253            .as_ref()
254            .copied()
255            .ok_or_else(|| ErrorKind::AddrNotAvailable.into())
256    }
257
258    /// Enable the applicable protocols for a new connection.
259    async fn enable_protocols(&self, conn: Connection) -> io::Result<Connection> {
260        let mut conn = enable_protocol!(handshake, self, conn);
261
262        // split the stream after the handshake (if not done before)
263        if let Some(stream) = conn.stream.take() {
264            let (reader, writer) = split(stream);
265            conn.reader = Some(Box::new(reader));
266            conn.writer = Some(Box::new(writer));
267        }
268
269        let conn = enable_protocol!(reading, self, conn);
270        let conn = enable_protocol!(writing, self, conn);
271
272        Ok(conn)
273    }
274
275    /// Prepares the freshly acquired connection to handle the protocols the Node implements.
276    async fn adapt_stream(
277        &self,
278        stream: TcpStream,
279        peer_addr: SocketAddr,
280        own_side: ConnectionSide,
281        mut guard: ConnectionGuard<'_>,
282    ) -> io::Result<()> {
283        let conn_span = create_connection_span(peer_addr, self.span());
284        debug!(parent: &conn_span, "establishing connection as the {own_side:?}");
285
286        // register the port seen by the peer
287        if own_side == ConnectionSide::Initiator {
288            if let Ok(addr) = stream.local_addr() {
289                trace!(parent: &conn_span, "the peer is connected on port {}", addr.port());
290            } else {
291                warn!(parent: &conn_span, "couldn't determine the peer-side port");
292            }
293        }
294
295        let connection = Connection::new(peer_addr, stream, !own_side, conn_span.clone());
296
297        // enact the enabled protocols
298        let mut connection = self.enable_protocols(connection).await?;
299
300        // if Reading is enabled, we'll notify the related task when the connection is fully ready
301        let conn_ready_tx = connection.readiness_notifier.take();
302
303        // connecting -> connected
304        self.connections.add(connection);
305        guard.completed = true;
306        drop(guard);
307
308        // send the aforementioned notification so that reading from the socket can commence
309        if let Some(tx) = conn_ready_tx {
310            let _ = tx.send(());
311        }
312
313        debug!(parent: &conn_span, "fully connected");
314
315        // if enabled, enact OnConnect
316        if let Some(handler) = self.protocols.on_connect.get() {
317            trace!(parent: &conn_span, "executing OnConnect logic...");
318            let (sender, receiver) = oneshot::channel();
319            handler.trigger((peer_addr, sender)).await;
320
321            // receive the handle for the running task
322            if let Ok(handle) = receiver.await {
323                if let Some(conn) = self.connections.active.write().get_mut(&peer_addr) {
324                    // add the task to the connection so it gets aborted in case of a disconnect
325                    conn.tasks.push(handle);
326                } else {
327                    // the connection has just been terminated; abort the OnConnect work
328                    handle.abort();
329                }
330            }
331        }
332
333        Ok(())
334    }
335
336    // A helper method to facilitate a common potential disconnect at the callsite.
337    async fn create_stream(
338        &self,
339        addr: SocketAddr,
340        socket: Option<TcpSocket>,
341    ) -> io::Result<TcpStream> {
342        match timeout(
343            Duration::from_millis(self.config().connection_timeout_ms.into()),
344            self.create_stream_inner(addr, socket),
345        )
346        .await
347        {
348            Ok(Ok(stream)) => Ok(stream),
349            Ok(err) => err,
350            Err(err) => Err(io::Error::new(ErrorKind::TimedOut, err)),
351        }
352    }
353
354    /// A wrapper method for greater readability.
355    async fn create_stream_inner(
356        &self,
357        addr: SocketAddr,
358        socket: Option<TcpSocket>,
359    ) -> io::Result<TcpStream> {
360        if let Some(socket) = socket {
361            socket.connect(addr).await
362        } else {
363            TcpStream::connect(addr).await
364        }
365    }
366
367    /// Connects to the provided `SocketAddr`.
368    ///
369    /// note: `pea2pea` identifies connections by their socket address (IP + port). If Node A
370    /// connects to Node B, and Node B simultaneously connects to Node A, the library considers
371    /// these two distinct connections (one outgoing, one incoming). To ensure a single logical
372    /// connection per peer, you must implement a tie-breaking mechanism in your application logic
373    /// in the [`Handshake`] protocol.
374    pub async fn connect(&self, addr: SocketAddr) -> io::Result<()> {
375        self.connect_inner(addr, None)
376            .await
377            .inspect_err(|e| error!(parent: self.span(), "couldn't connect to {addr}: {e}"))
378    }
379
380    /// Connects to a `SocketAddr` using the provided `TcpSocket`.
381    pub async fn connect_using_socket(
382        &self,
383        addr: SocketAddr,
384        socket: TcpSocket,
385    ) -> io::Result<()> {
386        self.connect_inner(addr, Some(socket))
387            .await
388            .inspect_err(|e| error!(parent: self.span(), "couldn't connect to {addr}: {e}"))
389    }
390
391    /// Connects to the provided `SocketAddr` using an optional `TcpSocket`.
392    async fn connect_inner(&self, addr: SocketAddr, socket: Option<TcpSocket>) -> io::Result<()> {
393        // a simple self-connect attempt check
394        if let Ok(listening_addr) = self.listening_addr().await {
395            if addr == listening_addr
396                || addr.ip().is_loopback() && addr.port() == listening_addr.port()
397            {
398                return Err(io::Error::new(
399                    ErrorKind::AddrInUse,
400                    "can't connect to node's own listening address ({addr})",
401                ));
402            }
403        }
404
405        // make sure the address is not already connected to, unless
406        // duplicate connections are permitted in the config
407        if !self.config.allow_duplicate_connections && self.connections.is_connected(addr) {
408            return Err(io::Error::new(
409                ErrorKind::AlreadyExists,
410                "already connected to {addr}",
411            ));
412        }
413
414        // attempt to reserve a connection slot atomically
415        let guard = self.check_and_reserve(addr)?;
416
417        // attempt to physically connect to the specified address
418        let stream = self.create_stream(addr, socket).await?;
419
420        // attempt to finalize the connection
421        self.adapt_stream(stream, addr, ConnectionSide::Initiator, guard)
422            .await
423    }
424
425    /// Disconnects from the provided `SocketAddr`; returns `true` if an actual disconnect took place.
426    pub async fn disconnect(&self, addr: SocketAddr) -> bool {
427        // claim the disconnect to avoid duplicate executions, or return early if already claimed
428        if let Some(conn) = self.connections.active.read().get(&addr) {
429            if conn.disconnecting.swap(true, Relaxed) {
430                // valid connection, but someone else is already disconnecting it
431                return false;
432            }
433        } else {
434            // not connected
435            return false;
436        };
437
438        let conn_span = create_connection_span(addr, self.span());
439        debug!(parent: &conn_span, "disconnecting...");
440
441        // if the OnDisconnect protocol is enabled, trigger it
442        if let Some(handler) = self.protocols.on_disconnect.get() {
443            trace!(parent: &conn_span, "executing OnDisconnect logic...");
444            let (sender, receiver) = oneshot::channel();
445            handler.trigger((addr, sender)).await;
446            if let Ok((handle, waiter)) = receiver.await {
447                // register the associated task with the connection, in case
448                // it gets terminated before its completion
449                if let Some(conn) = self.connections.active.write().get_mut(&addr) {
450                    conn.tasks.push(handle);
451                }
452                // wait for the OnDisconnect protocol to perform its specified actions
453                // time out, or even panic - we're already disconnecting, so ignore the
454                // result
455                let _ = waiter.await;
456            }
457        }
458
459        // ensure that any OnDisconnect-related writes can conclude
460        if let Some(writing) = self.protocols.writing.get() {
461            // remove the connection's message sender so that
462            // the associated loop can exit organically
463            writing.senders.write().remove(&addr);
464
465            // give the Writing task a chance to process it
466            // and flush any final messages to the kernel
467            task::yield_now().await;
468        }
469
470        // the connection can now be "physically" removed
471        let _ = self.connections.remove(addr);
472
473        // decrement the per-IP connection count
474        if let Entry::Occupied(mut e) = self.connections.limits.lock().ip_counts.entry(addr.ip()) {
475            if *e.get() > 1 {
476                *e.get_mut() -= 1;
477            } else {
478                e.remove();
479            }
480        }
481
482        debug!(parent: &conn_span, "fully disconnected");
483
484        true
485    }
486
487    /// Returns a list containing addresses of active connections.
488    pub fn connected_addrs(&self) -> Vec<SocketAddr> {
489        self.connections.addrs()
490    }
491
492    /// Checks whether the provided address is connected.
493    pub fn is_connected(&self, addr: SocketAddr) -> bool {
494        self.connections.is_connected(addr)
495    }
496
497    /// Checks if the node is currently setting up a connection with the provided address.
498    pub fn is_connecting(&self, addr: SocketAddr) -> bool {
499        self.connections.limits.lock().connecting.contains(&addr)
500    }
501
502    /// Returns the number of active connections.
503    pub fn num_connected(&self) -> usize {
504        self.connections.num_connected()
505    }
506
507    /// Returns the number of connections that are currently being set up.
508    pub fn num_connecting(&self) -> usize {
509        self.connections.limits.lock().connecting.len()
510    }
511
512    /// Returns basic information related to a connection.
513    pub fn connection_info(&self, addr: SocketAddr) -> Option<ConnectionInfo> {
514        self.connections.get_info(addr)
515    }
516
517    /// Returns a list of all active connections and their basic information.
518    pub fn connection_infos(&self) -> HashMap<SocketAddr, ConnectionInfo> {
519        self.connections.infos()
520    }
521
522    /// Atomically checks connection limits and reserves a slot if available.
523    fn check_and_reserve(&self, addr: SocketAddr) -> io::Result<ConnectionGuard<'_>> {
524        // this lock is held for the duration of the check to prevent races
525        let mut limits = self.connections.limits.lock();
526
527        // check the per-IP limit first
528        let ip = addr.ip();
529        let num_ip_conns = *limits.ip_counts.get(&ip).unwrap_or(&0);
530        let per_ip_limit = self.config.max_connections_per_ip as usize;
531        if num_ip_conns >= per_ip_limit {
532            return Err(io::Error::new(
533                ErrorKind::QuotaExceeded,
534                "maximum number ({per_ip_limit}) of per-IP connections reached with {ip}",
535            ));
536        }
537
538        // check the global connecting count limit
539        let num_connecting = limits.connecting.len();
540        let connecting_limit = self.config.max_connecting as usize;
541        if num_connecting >= connecting_limit {
542            return Err(io::Error::new(
543                ErrorKind::QuotaExceeded,
544                "maximum number ({connecting_limit}) of pending connections reached",
545            ));
546        }
547
548        // check the global connection count limit
549        let num_connected = self.connections.num_connected();
550        let connection_limit = self.config.max_connections as usize;
551        if num_connected + num_connecting >= connection_limit {
552            return Err(io::Error::new(
553                ErrorKind::QuotaExceeded,
554                "maximum number ({connection_limit}) of connections reached",
555            ));
556        }
557
558        // check if already connecting (duplicate connection attempt from same node)
559        if limits.connecting.contains(&addr) {
560            return Err(io::Error::new(
561                ErrorKind::AlreadyExists,
562                "already connecting to {addr}",
563            ));
564        }
565
566        // reserve a connecting slot
567        *limits.ip_counts.entry(ip).or_insert(0) += 1;
568        limits.connecting.insert(addr);
569
570        Ok(ConnectionGuard {
571            addr,
572            connections: &self.connections,
573            completed: false,
574        })
575    }
576
577    /// Gracefully shuts the node down.
578    pub async fn shut_down(&self) {
579        debug!(parent: self.span(), "shutting down");
580
581        let mut tasks = std::mem::take(&mut *self.tasks.lock());
582
583        // abort the listening task first (if it exists)
584        if let Some(listening_task) = tasks.remove(&NodeTask::Listener) {
585            listening_task.abort();
586        }
587
588        // disconnect from all the peers
589        let mut disconnect_tasks = JoinSet::new();
590        for addr in self.connected_addrs() {
591            let node = self.clone();
592            disconnect_tasks.spawn(async move {
593                node.disconnect(addr).await;
594            });
595        }
596        while disconnect_tasks.join_next().await.is_some() {}
597
598        // abort the remaining tasks, which should now be inert
599        for handle in tasks.into_values() {
600            handle.abort();
601        }
602    }
603}
604
605/// Creates the node's tracing span based on its name.
606fn create_span(node_name: &str) -> Span {
607    macro_rules! try_span {
608        ($lvl:expr) => {
609            let s = span!($lvl, "node", name = node_name);
610            if !s.is_disabled() {
611                return s;
612            }
613        };
614    }
615
616    try_span!(Level::TRACE);
617    try_span!(Level::DEBUG);
618    try_span!(Level::INFO);
619    try_span!(Level::WARN);
620
621    error_span!("node", name = node_name)
622}