use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::mpsc;
use tokio::time::Duration;
use super::{DaemonError, TlsConfigError};
const TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
const TLS_MAX_INFLIGHT: usize = 128;
pub(super) enum MaybeTlsStream {
Plain(tokio::net::TcpStream),
Tls(Box<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>),
}
impl AsyncRead for MaybeTlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
Self::Plain(s) => Pin::new(s).poll_read(cx, buf),
Self::Tls(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl AsyncWrite for MaybeTlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match self.get_mut() {
Self::Plain(s) => Pin::new(s).poll_write(cx, buf),
Self::Tls(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
Self::Plain(s) => Pin::new(s).poll_flush(cx),
Self::Tls(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
Self::Plain(s) => Pin::new(s).poll_shutdown(cx),
Self::Tls(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
impl tonic::transport::server::Connected for MaybeTlsStream {
type ConnectInfo = std::net::SocketAddr;
fn connect_info(&self) -> Self::ConnectInfo {
match self {
Self::Plain(s) => s.peer_addr().unwrap_or_else(|_| ([0, 0, 0, 0], 0).into()),
Self::Tls(s) => s
.get_ref()
.0
.peer_addr()
.unwrap_or_else(|_| ([0, 0, 0, 0], 0).into()),
}
}
}
pub(super) fn tls_tcp_incoming(
listener: tokio::net::TcpListener,
tls_acceptor: Option<tokio_rustls::TlsAcceptor>,
) -> tokio_stream::wrappers::ReceiverStream<Result<MaybeTlsStream, std::io::Error>> {
let (tx, rx) = mpsc::channel(128);
let semaphore = Arc::new(tokio::sync::Semaphore::new(TLS_MAX_INFLIGHT));
tokio::spawn(async move {
loop {
let (tcp, addr) = match listener.accept().await {
Ok(conn) => conn,
Err(e) => {
tracing::debug!("TCP accept error: {e}");
continue;
}
};
let Some(acceptor) = tls_acceptor.clone() else {
if tx.send(Ok(MaybeTlsStream::Plain(tcp))).await.is_err() {
break; }
continue;
};
let Ok(permit) = semaphore.clone().acquire_owned().await else {
break; };
let tx = tx.clone();
tokio::spawn(async move {
let stream =
match tokio::time::timeout(TLS_HANDSHAKE_TIMEOUT, acceptor.accept(tcp)).await {
Ok(Ok(tls)) => MaybeTlsStream::Tls(Box::new(tls)),
Ok(Err(e)) => {
tracing::debug!("TLS handshake failed from {addr}: {e}");
drop(permit);
return;
}
Err(_) => {
tracing::debug!("TLS handshake timed out from {addr}");
drop(permit);
return;
}
};
let _ = tx.send(Ok(stream)).await;
drop(permit);
});
}
});
tokio_stream::wrappers::ReceiverStream::new(rx)
}
pub(super) fn load_tls_pem(
cert_path: &str,
key_path: &str,
) -> Result<(Vec<u8>, Vec<u8>), DaemonError> {
let cert = std::fs::read(cert_path).map_err(|source| {
DaemonError::TlsConfig(TlsConfigError::ReadCert {
path: cert_path.to_string(),
source,
})
})?;
let key = std::fs::read(key_path).map_err(|source| {
DaemonError::TlsConfig(TlsConfigError::ReadKey {
path: key_path.to_string(),
source,
})
})?;
Ok((cert, key))
}
pub(super) fn build_tls_acceptor(
cert_pem: &[u8],
key_pem: &[u8],
) -> Result<tokio_rustls::TlsAcceptor, DaemonError> {
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_slice_iter(cert_pem)
.collect::<Result<_, _>>()
.map_err(|e| DaemonError::TlsConfig(TlsConfigError::ParseCerts(e)))?;
let key = PrivateKeyDer::from_pem_slice(key_pem)
.map_err(|e| DaemonError::TlsConfig(TlsConfigError::ParseKey(e)))?;
let config = tokio_rustls::rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| DaemonError::TlsConfig(TlsConfigError::ServerConfig(e)))?;
Ok(tokio_rustls::TlsAcceptor::from(Arc::new(config)))
}
pub(super) async fn serve_https(
listener: tokio::net::TcpListener,
app: axum::Router,
tls_acceptor: tokio_rustls::TlsAcceptor,
) {
use tower::ServiceExt;
let semaphore = Arc::new(tokio::sync::Semaphore::new(TLS_MAX_INFLIGHT));
loop {
let (tcp_stream, remote_addr) = match listener.accept().await {
Ok(conn) => conn,
Err(e) => {
tracing::debug!("TCP accept error: {e}");
continue;
}
};
let Ok(permit) = semaphore.clone().acquire_owned().await else {
break; };
let acceptor = tls_acceptor.clone();
let app = app.clone();
tokio::spawn(async move {
let tls_stream = match tokio::time::timeout(
TLS_HANDSHAKE_TIMEOUT,
acceptor.accept(tcp_stream),
)
.await
{
Ok(Ok(s)) => s,
Ok(Err(e)) => {
tracing::debug!("TLS handshake failed from {remote_addr}: {e}");
drop(permit);
return;
}
Err(_) => {
tracing::debug!("TLS handshake timed out from {remote_addr}");
drop(permit);
return;
}
};
let io = hyper_util::rt::TokioIo::new(tls_stream);
let service =
hyper::service::service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
let app = app.clone();
async move {
let (parts, body) = req.into_parts();
let req = hyper::Request::from_parts(parts, axum::body::Body::new(body));
Ok::<_, std::convert::Infallible>(
app.oneshot(req).await.unwrap_or_else(|err| match err {}),
)
}
});
if let Err(e) =
hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
.serve_connection(io, service)
.await
{
tracing::debug!("HTTPS connection error from {remote_addr}: {e}");
}
drop(permit);
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn load_tls_pem_returns_read_cert_error_for_missing_file() {
let err = load_tls_pem("/nonexistent/cert.pem", "/nonexistent/key.pem").unwrap_err();
match err {
DaemonError::TlsConfig(TlsConfigError::ReadCert { path, .. }) => {
assert_eq!(path, "/nonexistent/cert.pem");
}
other => panic!("expected ReadCert error, got: {other:?}"),
}
}
#[test]
fn load_tls_pem_returns_read_key_error_when_cert_exists_but_key_missing() {
let dir = tempfile::tempdir().unwrap();
let cert_path = dir.path().join("cert.pem");
std::fs::write(&cert_path, b"dummy").unwrap();
let err = load_tls_pem(cert_path.to_str().unwrap(), "/nonexistent/key.pem").unwrap_err();
match err {
DaemonError::TlsConfig(TlsConfigError::ReadKey { path, .. }) => {
assert_eq!(path, "/nonexistent/key.pem");
}
other => panic!("expected ReadKey error, got: {other:?}"),
}
}
#[test]
fn build_tls_acceptor_rejects_invalid_cert_pem() {
let bad_cert = b"not a pem certificate";
let bad_key = b"not a pem key";
match build_tls_acceptor(bad_cert, bad_key) {
Ok(_) => panic!("expected build_tls_acceptor to reject invalid PEM"),
Err(DaemonError::TlsConfig(
TlsConfigError::ParseCerts(_) | TlsConfigError::ParseKey(_),
)) => {}
Err(other) => panic!("expected ParseCerts or ParseKey, got: {other:?}"),
}
}
#[test]
fn tls_config_error_display_contains_source_context() {
let err = DaemonError::TlsConfig(TlsConfigError::ReadCert {
path: "/etc/foo.pem".to_string(),
source: std::io::Error::other("permission denied"),
});
let msg = format!("{err}");
assert!(msg.contains("TLS"));
assert!(msg.contains("/etc/foo.pem"));
}
}