#![cfg(feature = "mtls")]
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use arc_swap::ArcSwap;
use futures::Stream;
use rustls::ServerConfig;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::TlsAcceptor;
use tracing::warn;
use crate::error::{NetError, NetResult};
#[derive(Debug, Clone, Copy)]
pub struct TlsCredsRef<'a> {
pub cert_pem: &'a [u8],
pub key_pem: &'a [u8],
}
impl<'a> TlsCredsRef<'a> {
pub fn new(cert_pem: &'a [u8], key_pem: &'a [u8]) -> Self {
Self { cert_pem, key_pem }
}
}
pub fn build_rustls_config(creds: &TlsCredsRef<'_>) -> NetResult<ServerConfig> {
let _ = rustls::crypto::ring::default_provider().install_default();
let mut cert_reader = std::io::Cursor::new(creds.cert_pem);
let cert_chain: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| NetError::TlsError(format!("Failed to parse cert PEM: {e}")))?;
if cert_chain.is_empty() {
return Err(NetError::TlsError(
"Cert PEM contained no certificates".to_string(),
));
}
let key_der = parse_private_key(creds.key_pem)?;
let cfg = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert_chain, key_der)
.map_err(|e| NetError::TlsError(format!("rustls rejected cert/key: {e}")))?;
Ok(cfg)
}
fn parse_private_key(key_pem: &[u8]) -> NetResult<PrivateKeyDer<'static>> {
let mut cursor = std::io::Cursor::new(key_pem);
if let Some(key) = rustls_pemfile::pkcs8_private_keys(&mut cursor)
.next()
.transpose()
.map_err(|e| NetError::TlsError(format!("PKCS#8 parse error: {e}")))?
{
return Ok(PrivateKeyDer::Pkcs8(key));
}
let mut cursor = std::io::Cursor::new(key_pem);
if let Some(key) = rustls_pemfile::rsa_private_keys(&mut cursor)
.next()
.transpose()
.map_err(|e| NetError::TlsError(format!("RSA parse error: {e}")))?
{
return Ok(PrivateKeyDer::Pkcs1(key));
}
let mut cursor = std::io::Cursor::new(key_pem);
if let Some(key) = rustls_pemfile::ec_private_keys(&mut cursor)
.next()
.transpose()
.map_err(|e| NetError::TlsError(format!("EC parse error: {e}")))?
{
return Ok(PrivateKeyDer::Sec1(key));
}
Err(NetError::TlsError(
"No valid private key in PEM (tried PKCS#8, RSA, EC)".to_string(),
))
}
pub struct LiveTlsAcceptor {
listener: TcpListener,
store: Arc<ArcSwap<ServerConfig>>,
}
impl LiveTlsAcceptor {
pub fn new(listener: TcpListener, store: Arc<ArcSwap<ServerConfig>>) -> Self {
Self { listener, store }
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.listener.local_addr()
}
pub fn store(&self) -> Arc<ArcSwap<ServerConfig>> {
Arc::clone(&self.store)
}
pub async fn accept(
&self,
) -> io::Result<(tokio_rustls::server::TlsStream<TcpStream>, SocketAddr)> {
let (tcp, peer) = self.listener.accept().await?;
let cfg = Arc::clone(&self.store.load());
let acceptor = TlsAcceptor::from(cfg);
let tls = acceptor.accept(tcp).await?;
Ok((tls, peer))
}
pub fn into_stream(
self,
) -> impl Stream<Item = io::Result<tokio_rustls::server::TlsStream<TcpStream>>> {
async_stream::stream! {
loop {
match self.accept().await {
Ok((tls, _peer)) => yield Ok(tls),
Err(e) => {
warn!("LiveTlsAcceptor: accept/handshake failed: {e}");
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tls::SelfSignedGenerator;
use rustls::pki_types::ServerName;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_rustls::TlsConnector;
fn pem_pair_with_san(san: &str) -> (Vec<u8>, Vec<u8>) {
let generator = SelfSignedGenerator::new(san)
.with_san(san)
.with_san("127.0.0.1");
let (cert_der, key_der) = generator.generate().expect("generate cert");
let cert_pem = pem_encode("CERTIFICATE", cert_der.as_ref());
let key_pem = match key_der {
PrivateKeyDer::Pkcs8(k) => pem_encode("PRIVATE KEY", k.secret_pkcs8_der()),
PrivateKeyDer::Pkcs1(k) => pem_encode("RSA PRIVATE KEY", k.secret_pkcs1_der()),
PrivateKeyDer::Sec1(k) => pem_encode("EC PRIVATE KEY", k.secret_sec1_der()),
_ => panic!("unexpected key kind"),
};
(cert_pem, key_pem)
}
fn pem_encode(label: &str, der: &[u8]) -> Vec<u8> {
let mut out = format!("-----BEGIN {label}-----\n").into_bytes();
let b64 = base64_encode(der);
for chunk in b64.as_bytes().chunks(64) {
out.extend_from_slice(chunk);
out.push(b'\n');
}
out.extend_from_slice(format!("-----END {label}-----\n").as_bytes());
out
}
fn base64_encode(data: &[u8]) -> String {
const ALPHABET: &[u8; 64] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut out = String::with_capacity(data.len().div_ceil(3) * 4);
let mut i = 0;
while i + 3 <= data.len() {
let n = ((data[i] as u32) << 16) | ((data[i + 1] as u32) << 8) | (data[i + 2] as u32);
out.push(ALPHABET[((n >> 18) & 0x3f) as usize] as char);
out.push(ALPHABET[((n >> 12) & 0x3f) as usize] as char);
out.push(ALPHABET[((n >> 6) & 0x3f) as usize] as char);
out.push(ALPHABET[(n & 0x3f) as usize] as char);
i += 3;
}
let rem = data.len() - i;
if rem == 1 {
let n = (data[i] as u32) << 16;
out.push(ALPHABET[((n >> 18) & 0x3f) as usize] as char);
out.push(ALPHABET[((n >> 12) & 0x3f) as usize] as char);
out.push('=');
out.push('=');
} else if rem == 2 {
let n = ((data[i] as u32) << 16) | ((data[i + 1] as u32) << 8);
out.push(ALPHABET[((n >> 18) & 0x3f) as usize] as char);
out.push(ALPHABET[((n >> 12) & 0x3f) as usize] as char);
out.push(ALPHABET[((n >> 6) & 0x3f) as usize] as char);
out.push('=');
}
out
}
fn client_config_trusting(cert_pem: &[u8]) -> rustls::ClientConfig {
let mut roots = rustls::RootCertStore::empty();
let mut reader = std::io::Cursor::new(cert_pem);
for cert in rustls_pemfile::certs(&mut reader) {
let cert = cert.expect("parse cert");
roots.add(cert).expect("add to root store");
}
rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth()
}
fn cert_has_san(der: &[u8], expected: &str) -> bool {
use x509_parser::prelude::*;
let (_, cert) = match X509Certificate::from_der(der) {
Ok(v) => v,
Err(_) => return false,
};
if let Ok(Some(san_ext)) = cert.subject_alternative_name() {
for name in &san_ext.value.general_names {
if let GeneralName::DNSName(s) = name {
if *s == expected {
return true;
}
}
}
}
false
}
#[test]
fn test_build_rustls_config_from_creds() {
let (cert_pem, key_pem) = pem_pair_with_san("server.test");
let creds = TlsCredsRef::new(&cert_pem, &key_pem);
let cfg = build_rustls_config(&creds).expect("build rustls config");
let _ = cfg;
}
#[test]
fn test_build_rustls_config_invalid_cert_errors() {
let creds = TlsCredsRef::new(
b"-----BEGIN GARBAGE-----\nnope\n-----END GARBAGE-----\n",
b"",
);
let err = build_rustls_config(&creds).expect_err("should fail");
assert!(matches!(err, NetError::TlsError(_)), "got {err:?}");
}
#[test]
fn test_build_rustls_config_empty_cert_errors() {
let creds = TlsCredsRef::new(b"", b"");
let err = build_rustls_config(&creds).expect_err("should fail");
assert!(matches!(err, NetError::TlsError(_)), "got {err:?}");
}
async fn spawn_echo_acceptor(
store: Arc<ArcSwap<ServerConfig>>,
) -> (SocketAddr, tokio::task::JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("local addr");
let acceptor = LiveTlsAcceptor::new(listener, store);
let handle = tokio::spawn(async move {
loop {
match acceptor.accept().await {
Ok((mut tls, _peer)) => {
tokio::spawn(async move {
let mut buf = [0u8; 64];
loop {
match tls.read(&mut buf).await {
Ok(0) | Err(_) => return,
Ok(n) => {
if tls.write_all(&buf[..n]).await.is_err() {
return;
}
if tls.flush().await.is_err() {
return;
}
}
}
}
});
}
Err(_) => return,
}
}
});
(addr, handle)
}
async fn connect_client(
addr: SocketAddr,
cert_pem: &[u8],
sni: &str,
) -> tokio_rustls::client::TlsStream<TcpStream> {
let tcp = TcpStream::connect(addr).await.expect("client connect");
let cfg = Arc::new(client_config_trusting(cert_pem));
let connector = TlsConnector::from(cfg);
let server_name = ServerName::try_from(sni.to_string()).expect("server name");
connector
.connect(server_name, tcp)
.await
.expect("tls handshake")
}
#[tokio::test]
async fn test_live_tls_acceptor_serves_initial_cert() {
let (cert, key) = pem_pair_with_san("v1.test");
let creds = TlsCredsRef::new(&cert, &key);
let cfg = build_rustls_config(&creds).expect("rustls cfg");
let store = Arc::new(ArcSwap::from_pointee(cfg));
let (addr, handle) = spawn_echo_acceptor(Arc::clone(&store)).await;
let mut client = connect_client(addr, &cert, "v1.test").await;
client.write_all(b"ping").await.expect("write");
let mut buf = [0u8; 4];
client.read_exact(&mut buf).await.expect("read echo");
assert_eq!(&buf, b"ping");
let (_io, conn) = client.get_ref();
let peer_certs = conn.peer_certificates().expect("peer certs");
assert_eq!(peer_certs.len(), 1);
assert!(
cert_has_san(peer_certs[0].as_ref(), "v1.test"),
"expected v1.test SAN in initial cert"
);
drop(client);
handle.abort();
}
#[tokio::test]
async fn test_live_tls_acceptor_swap_changes_cert_for_new_connection() {
let (cert_v1, key_v1) = pem_pair_with_san("v1.test");
let (cert_v2, key_v2) = pem_pair_with_san("v2.test");
let cfg_v1 = build_rustls_config(&TlsCredsRef::new(&cert_v1, &key_v1)).expect("v1");
let store = Arc::new(ArcSwap::from_pointee(cfg_v1));
let (addr, handle) = spawn_echo_acceptor(Arc::clone(&store)).await;
let mut a = connect_client(addr, &cert_v1, "v1.test").await;
a.write_all(b"a").await.expect("write");
let mut buf = [0u8; 1];
a.read_exact(&mut buf).await.expect("read");
let (_io, conn_a) = a.get_ref();
let cert_a = conn_a.peer_certificates().expect("certs")[0].clone();
assert!(
cert_has_san(cert_a.as_ref(), "v1.test"),
"expected v1.test SAN before swap"
);
let cfg_v2 = build_rustls_config(&TlsCredsRef::new(&cert_v2, &key_v2)).expect("v2");
store.store(Arc::new(cfg_v2));
tokio::time::sleep(Duration::from_millis(10)).await;
let mut b = connect_client(addr, &cert_v2, "v2.test").await;
b.write_all(b"b").await.expect("write");
b.read_exact(&mut buf).await.expect("read");
let (_io, conn_b) = b.get_ref();
let cert_b = conn_b.peer_certificates().expect("certs")[0].clone();
assert!(
cert_has_san(cert_b.as_ref(), "v2.test"),
"expected v2.test SAN after swap"
);
assert!(
!cert_has_san(cert_b.as_ref(), "v1.test"),
"v2 cert should not advertise v1 SAN"
);
drop(a);
drop(b);
handle.abort();
}
#[tokio::test]
async fn test_live_tls_acceptor_existing_connection_continues_on_old_cert() {
let (cert_v1, key_v1) = pem_pair_with_san("v1.test");
let (cert_v2, key_v2) = pem_pair_with_san("v2.test");
let cfg_v1 = build_rustls_config(&TlsCredsRef::new(&cert_v1, &key_v1)).expect("v1");
let store = Arc::new(ArcSwap::from_pointee(cfg_v1));
let (addr, handle) = spawn_echo_acceptor(Arc::clone(&store)).await;
let mut client_a = connect_client(addr, &cert_v1, "v1.test").await;
client_a.write_all(b"hold").await.expect("write");
let mut buf = [0u8; 4];
client_a.read_exact(&mut buf).await.expect("read echo v1");
assert_eq!(&buf, b"hold");
let (_io, conn_a) = client_a.get_ref();
let cert_a = conn_a.peer_certificates().expect("certs")[0].clone();
assert!(
cert_has_san(cert_a.as_ref(), "v1.test"),
"expected v1.test SAN before swap"
);
let cfg_v2 = build_rustls_config(&TlsCredsRef::new(&cert_v2, &key_v2)).expect("v2");
store.store(Arc::new(cfg_v2));
let mut client_b = connect_client(addr, &cert_v2, "v2.test").await;
client_b.write_all(b"new!").await.expect("write");
client_b.read_exact(&mut buf).await.expect("read echo v2");
assert_eq!(&buf, b"new!");
let (_io, conn_b) = client_b.get_ref();
let cert_b = conn_b.peer_certificates().expect("certs")[0].clone();
assert!(
cert_has_san(cert_b.as_ref(), "v2.test"),
"expected v2.test SAN on new connection"
);
let mut buf2 = [0u8; 5];
client_a
.write_all(b"alive")
.await
.expect("post-swap write through held v1 connection");
client_a
.read_exact(&mut buf2)
.await
.expect("post-swap read through held v1 connection");
assert_eq!(
&buf2, b"alive",
"held v1 connection must still echo through after server-side swap"
);
let (_io, conn_a_after) = client_a.get_ref();
let cert_a_after = conn_a_after.peer_certificates().expect("certs")[0].clone();
assert!(
cert_has_san(cert_a_after.as_ref(), "v1.test"),
"held connection should still report v1.test"
);
assert!(
!cert_has_san(cert_a_after.as_ref(), "v2.test"),
"held connection should not advertise v2 SAN"
);
drop(client_a);
drop(client_b);
handle.abort();
}
}