Skip to main content

dynomite/net/
dnode_proxy.rs

1//! DNODE_PEER_PROXY listener.
2//!
3//! Listens for inbound peer connections from other Dynomite nodes
4//! and spawns a [`crate::net::dnode_client::dnode_client_loop`] task
5//! per accepted socket. When configured with a
6//! [`tokio_rustls::TlsAcceptor`] (via [`DnodeProxy::with_tls`])
7//! every accepted socket is upgraded to TLS before handoff.
8//!
9//! # Examples
10//!
11//! ```no_run
12//! use dynomite::net::DnodeProxy;
13//!
14//! let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
15//! let listener = DnodeProxy::bind(addr).unwrap();
16//! let _ = listener.local_addr();
17//! ```
18
19use std::io;
20use std::net::SocketAddr;
21
22use tokio::net::TcpListener;
23use tokio::task::JoinHandle;
24use tokio_rustls::TlsAcceptor;
25use tracing::Instrument as _;
26
27use crate::io::reactor::{ConnRole, TcpTransport, Transport};
28use crate::net::client::ClientHandler;
29use crate::net::conn::Conn;
30use crate::net::dnode_client::dnode_client_loop;
31use crate::net::listener::{bind_dual_stack, BindOptions};
32use crate::net::tls::TlsServerTransport;
33use crate::net::NetError;
34
35/// DNODE_PEER_PROXY listener.
36pub struct DnodeProxy {
37    listener: TcpListener,
38    tls_acceptor: Option<TlsAcceptor>,
39}
40
41impl DnodeProxy {
42    /// Bind a peer-listener to the given address.
43    ///
44    /// # Errors
45    /// Forwarded from the underlying socket calls.
46    pub fn bind<A: Into<SocketAddr>>(addr: A) -> Result<Self, NetError> {
47        let listener = bind_dual_stack(addr.into(), BindOptions::default())?;
48        Ok(Self {
49            listener,
50            tls_acceptor: None,
51        })
52    }
53
54    /// Attach a TLS acceptor; every accepted peer is wrapped via
55    /// [`TlsAcceptor::accept`] before being handed off to the
56    /// per-peer driver. When the acceptor is unset (the default)
57    /// the listener serves plaintext TCP, matching the historical
58    /// behaviour.
59    #[must_use]
60    pub fn with_tls(mut self, acceptor: TlsAcceptor) -> Self {
61        self.tls_acceptor = Some(acceptor);
62        self
63    }
64
65    /// True when the listener is configured with a TLS acceptor.
66    #[must_use]
67    pub fn has_tls(&self) -> bool {
68        self.tls_acceptor.is_some()
69    }
70
71    /// Local address of the listener.
72    pub fn local_addr(&self) -> io::Result<SocketAddr> {
73        self.listener.local_addr()
74    }
75
76    /// Drive the accept loop. The supplied `handler_factory` is
77    /// called once per accepted peer; it receives the
78    /// per-connection responder sender (the matching half of the
79    /// channel the inbound driver reads from) and returns the
80    /// [`ClientHandler`] the per-peer loop should use.
81    ///
82    /// # Errors
83    /// Forwarded from the listener accept call.
84    #[tracing::instrument(
85        name = "dnode_proxy.run",
86        skip_all,
87        fields(
88            local = self.listener.local_addr().map_or_else(|_| String::from("?"), |a| a.to_string()),
89        ),
90    )]
91    pub async fn run<F>(
92        self,
93        cancel: std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>,
94        mut handler_factory: F,
95    ) -> Result<(), NetError>
96    where
97        F: FnMut(
98                tokio::sync::mpsc::Sender<crate::net::dispatcher::OutboundEnvelope>,
99            ) -> ClientHandler
100            + Send,
101    {
102        let mut cancel = cancel;
103        let mut peers: Vec<JoinHandle<Result<(), NetError>>> = Vec::new();
104        let tls_acceptor = self.tls_acceptor.clone();
105        loop {
106            tokio::select! {
107                () = &mut cancel => break,
108                res = self.listener.accept() => {
109                    let (sock, peer) = res?;
110                    let role = ConnRole::DnodePeerClient;
111                    let transport: Box<dyn Transport> = if let Some(acc) = tls_acceptor.as_ref() {
112                        match acc.accept(sock).await {
113                            Ok(tls) => Box::new(TlsServerTransport::new(tls, role)),
114                            Err(e) => {
115                                tracing::warn!(?peer, error = %e, "dnode_proxy tls handshake failed; dropping");
116                                continue;
117                            }
118                        }
119                    } else {
120                        Box::new(TcpTransport::new(sock, role))
121                    };
122                    let conn = Conn::new(transport, role);
123                    let (tx, rx) = tokio::sync::mpsc::channel(64);
124                    let handler = handler_factory(tx);
125                    tracing::debug!(?peer, "dnode_proxy accepted peer");
126                    let accept_span = tracing::info_span!(
127                        "dnode_client.accept",
128                        peer = %peer,
129                    );
130                    let h = tokio::spawn(
131                        async move {
132                            dnode_client_loop(conn, handler, rx).await
133                        }
134                        .instrument(accept_span),
135                    );
136                    peers.push(h);
137                }
138            }
139            peers.retain(|h| !h.is_finished());
140        }
141        for h in peers {
142            let _ = h.await;
143        }
144        Ok(())
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[tokio::test]
153    async fn bind_returns_local_addr() {
154        let l = DnodeProxy::bind("127.0.0.1:0".parse::<SocketAddr>().unwrap()).unwrap();
155        assert!(l.local_addr().unwrap().ip().is_loopback());
156        assert!(!l.has_tls());
157    }
158
159    #[tokio::test]
160    async fn with_tls_attaches_acceptor() {
161        let l = DnodeProxy::bind("127.0.0.1:0".parse::<SocketAddr>().unwrap()).unwrap();
162        let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
163        let dir = tempfile::tempdir().unwrap();
164        std::fs::write(dir.path().join("c.pem"), cert.cert.pem()).unwrap();
165        std::fs::write(dir.path().join("k.pem"), cert.signing_key.serialize_pem()).unwrap();
166        let cfg = crate::net::tls::load_server_config(
167            &dir.path().join("c.pem"),
168            &dir.path().join("k.pem"),
169            None,
170        )
171        .unwrap();
172        let l = l.with_tls(crate::net::tls::acceptor_from(cfg));
173        assert!(l.has_tls());
174    }
175}