dynomite/net/
dnode_proxy.rs1use std::io;
20use std::net::SocketAddr;
21
22use tokio::net::TcpListener;
23use tokio::task::JoinHandle;
24use tokio_rustls::TlsAcceptor;
25use tracing::Instrument as _;
26
27use crate::io::reactor::{ConnRole, TcpTransport, Transport};
28use crate::net::client::ClientHandler;
29use crate::net::conn::Conn;
30use crate::net::dnode_client::dnode_client_loop;
31use crate::net::listener::{bind_dual_stack, BindOptions};
32use crate::net::tls::TlsServerTransport;
33use crate::net::NetError;
34
35pub struct DnodeProxy {
37 listener: TcpListener,
38 tls_acceptor: Option<TlsAcceptor>,
39}
40
41impl DnodeProxy {
42 pub fn bind<A: Into<SocketAddr>>(addr: A) -> Result<Self, NetError> {
47 let listener = bind_dual_stack(addr.into(), BindOptions::default())?;
48 Ok(Self {
49 listener,
50 tls_acceptor: None,
51 })
52 }
53
54 #[must_use]
60 pub fn with_tls(mut self, acceptor: TlsAcceptor) -> Self {
61 self.tls_acceptor = Some(acceptor);
62 self
63 }
64
65 #[must_use]
67 pub fn has_tls(&self) -> bool {
68 self.tls_acceptor.is_some()
69 }
70
71 pub fn local_addr(&self) -> io::Result<SocketAddr> {
73 self.listener.local_addr()
74 }
75
76 #[tracing::instrument(
85 name = "dnode_proxy.run",
86 skip_all,
87 fields(
88 local = self.listener.local_addr().map_or_else(|_| String::from("?"), |a| a.to_string()),
89 ),
90 )]
91 pub async fn run<F>(
92 self,
93 cancel: std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>,
94 mut handler_factory: F,
95 ) -> Result<(), NetError>
96 where
97 F: FnMut(
98 tokio::sync::mpsc::Sender<crate::net::dispatcher::OutboundEnvelope>,
99 ) -> ClientHandler
100 + Send,
101 {
102 let mut cancel = cancel;
103 let mut peers: Vec<JoinHandle<Result<(), NetError>>> = Vec::new();
104 let tls_acceptor = self.tls_acceptor.clone();
105 loop {
106 tokio::select! {
107 () = &mut cancel => break,
108 res = self.listener.accept() => {
109 let (sock, peer) = res?;
110 let role = ConnRole::DnodePeerClient;
111 let transport: Box<dyn Transport> = if let Some(acc) = tls_acceptor.as_ref() {
112 match acc.accept(sock).await {
113 Ok(tls) => Box::new(TlsServerTransport::new(tls, role)),
114 Err(e) => {
115 tracing::warn!(?peer, error = %e, "dnode_proxy tls handshake failed; dropping");
116 continue;
117 }
118 }
119 } else {
120 Box::new(TcpTransport::new(sock, role))
121 };
122 let conn = Conn::new(transport, role);
123 let (tx, rx) = tokio::sync::mpsc::channel(64);
124 let handler = handler_factory(tx);
125 tracing::debug!(?peer, "dnode_proxy accepted peer");
126 let accept_span = tracing::info_span!(
127 "dnode_client.accept",
128 peer = %peer,
129 );
130 let h = tokio::spawn(
131 async move {
132 dnode_client_loop(conn, handler, rx).await
133 }
134 .instrument(accept_span),
135 );
136 peers.push(h);
137 }
138 }
139 peers.retain(|h| !h.is_finished());
140 }
141 for h in peers {
142 let _ = h.await;
143 }
144 Ok(())
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[tokio::test]
153 async fn bind_returns_local_addr() {
154 let l = DnodeProxy::bind("127.0.0.1:0".parse::<SocketAddr>().unwrap()).unwrap();
155 assert!(l.local_addr().unwrap().ip().is_loopback());
156 assert!(!l.has_tls());
157 }
158
159 #[tokio::test]
160 async fn with_tls_attaches_acceptor() {
161 let l = DnodeProxy::bind("127.0.0.1:0".parse::<SocketAddr>().unwrap()).unwrap();
162 let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
163 let dir = tempfile::tempdir().unwrap();
164 std::fs::write(dir.path().join("c.pem"), cert.cert.pem()).unwrap();
165 std::fs::write(dir.path().join("k.pem"), cert.signing_key.serialize_pem()).unwrap();
166 let cfg = crate::net::tls::load_server_config(
167 &dir.path().join("c.pem"),
168 &dir.path().join("k.pem"),
169 None,
170 )
171 .unwrap();
172 let l = l.with_tls(crate::net::tls::acceptor_from(cfg));
173 assert!(l.has_tls());
174 }
175}