use std::future::{Future, IntoFuture};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder as AutoBuilder;
use hyper_util::server::graceful::GracefulShutdown;
use tokio::net::TcpListener;
use tower::ServiceExt;
use crate::server::{
DEFAULT_TLS_HANDSHAKE_TIMEOUT, PeerAddr, PeerCerts, is_transient_accept_error,
};
pub fn serve_tls(
listener: TcpListener,
router: axum::Router,
tls_config: Arc<rustls::ServerConfig>,
) -> ServeTls {
ServeTls {
listener,
router,
acceptor: tokio_rustls::TlsAcceptor::from(tls_config),
tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
shutdown: None,
}
}
#[must_use = "ServeTls does nothing unless `.await`ed"]
pub struct ServeTls {
listener: TcpListener,
router: axum::Router,
acceptor: tokio_rustls::TlsAcceptor,
tls_handshake_timeout: Duration,
shutdown: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
}
impl ServeTls {
#[must_use = "ServeTls does nothing unless `.await`ed"]
pub fn tls_handshake_timeout(mut self, timeout: Duration) -> Self {
self.tls_handshake_timeout = timeout;
self
}
#[must_use = "ServeTls does nothing unless `.await`ed"]
pub fn with_graceful_shutdown<F>(mut self, signal: F) -> Self
where
F: Future<Output = ()> + Send + 'static,
{
self.shutdown = Some(Box::pin(signal));
self
}
}
impl std::fmt::Debug for ServeTls {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ServeTls")
.field("listener", &self.listener)
.field("tls_handshake_timeout", &self.tls_handshake_timeout)
.field("shutdown", &self.shutdown.is_some())
.finish_non_exhaustive()
}
}
impl IntoFuture for ServeTls {
type Output = std::io::Result<()>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(self.run())
}
}
impl ServeTls {
async fn run(self) -> std::io::Result<()> {
let ServeTls {
listener,
router,
acceptor,
tls_handshake_timeout,
shutdown,
} = self;
let mut shutdown = shutdown.unwrap_or_else(|| Box::pin(std::future::pending()));
let graceful = GracefulShutdown::new();
loop {
let (stream, remote_addr) = tokio::select! {
biased;
_ = &mut shutdown => {
tracing::info!("Shutdown signal received; draining connections");
break;
}
accepted = listener.accept() => match accepted {
Ok(conn) => conn,
Err(err) if is_transient_accept_error(&err) => {
tracing::warn!("Transient accept error (continuing): {err}");
continue;
}
Err(err) => return Err(err),
},
};
if let Err(e) = stream.set_nodelay(true) {
tracing::warn!("failed to set TCP_NODELAY: {e}");
}
let acceptor = acceptor.clone();
let router = router.clone();
let watcher = graceful.watcher();
tokio::spawn(async move {
let tls_stream = match tokio::time::timeout(
tls_handshake_timeout,
acceptor.accept(stream),
)
.await
{
Ok(Ok(s)) => s,
Ok(Err(err)) => {
tracing::debug!(remote_addr = %remote_addr, error = ?err, "TLS handshake failed");
return;
}
Err(_) => {
tracing::warn!(
remote_addr = %remote_addr,
"TLS handshake timed out after {tls_handshake_timeout:?}",
);
return;
}
};
let (_, conn) = tls_stream.get_ref();
let peer_addr = PeerAddr(remote_addr);
let peer_certs = conn
.peer_certificates()
.map(|chain| PeerCerts(chain.iter().map(|c| c.clone().into_owned()).collect()));
let svc = hyper::service::service_fn(
move |mut req: hyper::Request<hyper::body::Incoming>| {
req.extensions_mut().insert(peer_addr.clone());
if let Some(c) = &peer_certs {
req.extensions_mut().insert(c.clone());
}
router.clone().oneshot(req.map(axum::body::Body::new))
},
);
let conn = AutoBuilder::new(TokioExecutor::new())
.serve_connection_with_upgrades(TokioIo::new(tls_stream), svc)
.into_owned();
if let Err(err) = watcher.watch(conn).await {
tracing::trace!(remote_addr = %remote_addr, error = %err, "Connection ended with error");
}
});
}
drop(listener);
graceful.shutdown().await;
tracing::info!("All connections drained; shutdown complete");
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
use crate::{Response as ConnectResponse, Router as ConnectRouter, handler_fn};
use rcgen::{CertificateParams, CertifiedIssuer, IsCa, KeyPair, SanType};
use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
type Pki = (
Arc<rustls::ServerConfig>,
Arc<rustls::ClientConfig>,
CertificateDer<'static>,
);
fn pki() -> Pki {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let ca_key = KeyPair::generate().unwrap();
let mut ca_params = CertificateParams::default();
ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
let ca = CertifiedIssuer::self_signed(ca_params, ca_key).unwrap();
let issue = |sans: &[SanType]| {
let k = KeyPair::generate().unwrap();
let mut p = CertificateParams::default();
p.subject_alt_names = sans.to_vec();
let c = p.signed_by(&k, &ca).unwrap();
(
CertificateDer::from(c.der().to_vec()),
PrivatePkcs8KeyDer::from(k.serialized_der().to_vec()).into(),
)
};
let (srv_cert, srv_key) = issue(&[SanType::DnsName("localhost".try_into().unwrap())]);
let (cli_cert, cli_key) = issue(&[]);
let mut roots = rustls::RootCertStore::empty();
roots.add(CertificateDer::from(ca.der().to_vec())).unwrap();
let roots = Arc::new(roots);
let cv = rustls::server::WebPkiClientVerifier::builder(Arc::clone(&roots))
.build()
.unwrap();
let server = rustls::ServerConfig::builder()
.with_client_cert_verifier(cv)
.with_single_cert(vec![srv_cert], srv_key)
.unwrap();
let client = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_client_auth_cert(vec![cli_cert.clone()], cli_key)
.unwrap();
(Arc::new(server), Arc::new(client), cli_cert)
}
const ECHO_REQ: &[u8] = b"POST /svc/Echo HTTP/1.1\r\n\
Host: localhost\r\n\
Content-Type: application/proto\r\n\
Content-Length: 0\r\n\
Connection: close\r\n\
\r\n";
#[tokio::test]
async fn serve_tls_injects_peer_identity() {
let (server_cfg, client_cfg, expected_client_der) = pki();
type Captured = Arc<Mutex<Option<(PeerAddr, Option<PeerCerts>)>>>;
let captured: Captured = Arc::new(Mutex::new(None));
let handler_captured = Arc::clone(&captured);
let connect = ConnectRouter::new().route(
"svc",
"Echo",
handler_fn(
move |ctx: crate::RequestContext, _req: buffa_types::Empty| {
let cap = Arc::clone(&handler_captured);
async move {
*cap.lock().unwrap() = Some((
ctx.extensions.get::<PeerAddr>().cloned().unwrap(),
ctx.extensions.get::<PeerCerts>().cloned(),
));
ConnectResponse::ok(buffa_types::Empty::default())
}
},
),
);
let app = axum::Router::new().fallback_service(connect.into_axum_service());
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(
serve_tls(listener, app, server_cfg)
.with_graceful_shutdown(async {
rx.await.ok();
})
.into_future(),
);
let resp = echo_over_tls(addr, client_cfg).await;
assert!(
resp.starts_with(b"HTTP/1.1 2"),
"expected 2xx, got: {}",
String::from_utf8_lossy(&resp[..resp.len().min(120)])
);
tx.send(()).unwrap();
tokio::time::timeout(Duration::from_secs(5), serve)
.await
.expect("serve should shut down within timeout")
.unwrap()
.unwrap();
let (peer_addr, peer_certs) = captured.lock().unwrap().take().expect("handler ran");
assert_eq!(peer_addr.0.ip(), addr.ip());
let certs = peer_certs.expect("mTLS client should present a cert chain");
assert_eq!(certs.0.len(), 1);
assert_eq!(certs.0[0].as_ref(), expected_client_der.as_ref());
}
async fn echo_over_tls(
addr: std::net::SocketAddr,
client_cfg: Arc<rustls::ClientConfig>,
) -> Vec<u8> {
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let connector = tokio_rustls::TlsConnector::from(client_cfg);
let sni = rustls::pki_types::ServerName::try_from("localhost").unwrap();
let mut tls = connector.connect(sni, tcp).await.unwrap();
tls.write_all(ECHO_REQ).await.unwrap();
let mut resp = Vec::new();
tls.read_to_end(&mut resp).await.unwrap();
resp
}
#[tokio::test]
async fn handshake_timeout_drops_stalled_connection() {
let (server_cfg, _, _) = pki();
let app = axum::Router::new();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(
serve_tls(listener, app, server_cfg)
.tls_handshake_timeout(Duration::from_millis(100))
.with_graceful_shutdown(async {
rx.await.ok();
})
.into_future(),
);
let _stalled = tokio::net::TcpStream::connect(addr).await.unwrap();
tokio::time::sleep(Duration::from_millis(250)).await;
tx.send(()).unwrap();
tokio::time::timeout(Duration::from_secs(5), serve)
.await
.expect("handshake timeout must release the watcher so drain completes")
.unwrap()
.unwrap();
}
#[tokio::test]
async fn handshake_error_does_not_kill_accept_loop() {
let (server_cfg, client_cfg, _) = pki();
let calls = Arc::new(Mutex::new(0u32));
let handler_calls = Arc::clone(&calls);
let connect = ConnectRouter::new().route(
"svc",
"Echo",
handler_fn(
move |_ctx: crate::RequestContext, _req: buffa_types::Empty| {
let calls = Arc::clone(&handler_calls);
async move {
*calls.lock().unwrap() += 1;
ConnectResponse::ok(buffa_types::Empty::default())
}
},
),
);
let app = axum::Router::new().fallback_service(connect.into_axum_service());
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(
serve_tls(listener, app, server_cfg)
.with_graceful_shutdown(async {
rx.await.ok();
})
.into_future(),
);
let mut bad = tokio::net::TcpStream::connect(addr).await.unwrap();
bad.write_all(b"GET / HTTP/1.1\r\n\r\n").await.unwrap();
let mut buf = [0u8; 64];
let _ = bad.read(&mut buf).await; drop(bad);
let resp = echo_over_tls(addr, client_cfg).await;
assert!(
resp.starts_with(b"HTTP/1.1 2"),
"valid client must succeed after a handshake error: {}",
String::from_utf8_lossy(&resp[..resp.len().min(120)])
);
tx.send(()).unwrap();
tokio::time::timeout(Duration::from_secs(5), serve)
.await
.unwrap()
.unwrap()
.unwrap();
assert_eq!(
*calls.lock().unwrap(),
1,
"only the valid client reaches the handler"
);
}
#[tokio::test]
async fn graceful_shutdown_drains_in_flight_request() {
let (server_cfg, client_cfg, _) = pki();
let (in_flight_tx, in_flight_rx) = tokio::sync::oneshot::channel::<()>();
let (release_tx, release_rx) = tokio::sync::oneshot::channel::<()>();
let in_flight_tx = Arc::new(Mutex::new(Some(in_flight_tx)));
let release_rx = Arc::new(Mutex::new(Some(release_rx)));
let connect = ConnectRouter::new().route(
"svc",
"Echo",
handler_fn(
move |_ctx: crate::RequestContext, _req: buffa_types::Empty| {
let in_flight = in_flight_tx.lock().unwrap().take();
let release = release_rx.lock().unwrap().take();
async move {
if let Some(tx) = in_flight {
tx.send(()).ok();
}
if let Some(rx) = release {
rx.await.ok();
}
ConnectResponse::ok(buffa_types::Empty::default())
}
},
),
);
let app = axum::Router::new().fallback_service(connect.into_axum_service());
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(
serve_tls(listener, app, server_cfg)
.with_graceful_shutdown(async {
shutdown_rx.await.ok();
})
.into_future(),
);
let client = tokio::spawn(echo_over_tls(addr, client_cfg));
in_flight_rx.await.unwrap();
shutdown_tx.send(()).unwrap();
release_tx.send(()).unwrap();
let resp = tokio::time::timeout(Duration::from_secs(5), client)
.await
.expect("in-flight request should complete during drain")
.unwrap();
assert!(
resp.starts_with(b"HTTP/1.1 2"),
"in-flight request must complete: {}",
String::from_utf8_lossy(&resp[..resp.len().min(120)])
);
tokio::time::timeout(Duration::from_secs(5), serve)
.await
.expect("serve should drain after the in-flight request completes")
.unwrap()
.unwrap();
}
}