use std::io;
use std::net::SocketAddr;
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
use tokio_rustls::TlsAcceptor;
use tracing::Instrument as _;
use crate::io::reactor::{ConnRole, TcpTransport, Transport};
use crate::net::client::ClientHandler;
use crate::net::conn::Conn;
use crate::net::dnode_client::dnode_client_loop;
use crate::net::listener::{bind_dual_stack, BindOptions};
use crate::net::tls::TlsServerTransport;
use crate::net::NetError;
pub struct DnodeProxy {
listener: TcpListener,
tls_acceptor: Option<TlsAcceptor>,
}
impl DnodeProxy {
pub fn bind<A: Into<SocketAddr>>(addr: A) -> Result<Self, NetError> {
let listener = bind_dual_stack(addr.into(), BindOptions::default())?;
Ok(Self {
listener,
tls_acceptor: None,
})
}
#[must_use]
pub fn with_tls(mut self, acceptor: TlsAcceptor) -> Self {
self.tls_acceptor = Some(acceptor);
self
}
#[must_use]
pub fn has_tls(&self) -> bool {
self.tls_acceptor.is_some()
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.listener.local_addr()
}
#[tracing::instrument(
name = "dnode_proxy.run",
skip_all,
fields(
local = self.listener.local_addr().map_or_else(|_| String::from("?"), |a| a.to_string()),
),
)]
pub async fn run<F>(
self,
cancel: std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>,
mut handler_factory: F,
) -> Result<(), NetError>
where
F: FnMut(
tokio::sync::mpsc::Sender<crate::net::dispatcher::OutboundEnvelope>,
) -> ClientHandler
+ Send,
{
let mut cancel = cancel;
let mut peers: Vec<JoinHandle<Result<(), NetError>>> = Vec::new();
let tls_acceptor = self.tls_acceptor.clone();
loop {
tokio::select! {
() = &mut cancel => break,
res = self.listener.accept() => {
let (sock, peer) = res?;
let role = ConnRole::DnodePeerClient;
let transport: Box<dyn Transport> = if let Some(acc) = tls_acceptor.as_ref() {
match acc.accept(sock).await {
Ok(tls) => Box::new(TlsServerTransport::new(tls, role)),
Err(e) => {
tracing::warn!(?peer, error = %e, "dnode_proxy tls handshake failed; dropping");
continue;
}
}
} else {
Box::new(TcpTransport::new(sock, role))
};
let conn = Conn::new(transport, role);
let (tx, rx) = tokio::sync::mpsc::channel(64);
let handler = handler_factory(tx);
tracing::debug!(?peer, "dnode_proxy accepted peer");
let accept_span = tracing::info_span!(
"dnode_client.accept",
peer = %peer,
);
let h = tokio::spawn(
async move {
dnode_client_loop(conn, handler, rx).await
}
.instrument(accept_span),
);
peers.push(h);
}
}
peers.retain(|h| !h.is_finished());
}
for h in peers {
let _ = h.await;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn bind_returns_local_addr() {
let l = DnodeProxy::bind("127.0.0.1:0".parse::<SocketAddr>().unwrap()).unwrap();
assert!(l.local_addr().unwrap().ip().is_loopback());
assert!(!l.has_tls());
}
#[tokio::test]
async fn with_tls_attaches_acceptor() {
let l = DnodeProxy::bind("127.0.0.1:0".parse::<SocketAddr>().unwrap()).unwrap();
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let dir = tempfile::tempdir().unwrap();
std::fs::write(dir.path().join("c.pem"), cert.cert.pem()).unwrap();
std::fs::write(dir.path().join("k.pem"), cert.signing_key.serialize_pem()).unwrap();
let cfg = crate::net::tls::load_server_config(
&dir.path().join("c.pem"),
&dir.path().join("k.pem"),
None,
)
.unwrap();
let l = l.with_tls(crate::net::tls::acceptor_from(cfg));
assert!(l.has_tls());
}
}