use std::collections::BTreeMap;
use std::fs::File;
use std::io::{self, BufReader};
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::{Arc, OnceLock};
use std::task::{Context, Poll};
use parking_lot::RwLock;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use rustls::server::{ClientHello, ResolvesServerCert, WebPkiClientVerifier};
use rustls::sign::CertifiedKey;
use rustls::{ClientConfig, RootCertStore, ServerConfig};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use tokio_rustls::{TlsAcceptor, TlsConnector};
use crate::io::reactor::{ConnRole, Transport};
#[derive(Debug, Error)]
pub enum TlsError {
#[error("tls: io reading {path}: {source}")]
Io {
path: String,
#[source]
source: io::Error,
},
#[error("tls: no usable {kind} found in {path}")]
NoMaterial {
kind: &'static str,
path: String,
},
#[error("tls: rustls rejected configuration: {0}")]
Rustls(String),
}
fn ensure_provider_installed() {
static INSTALL: OnceLock<()> = OnceLock::new();
INSTALL.get_or_init(|| {
let _ = rustls::crypto::ring::default_provider().install_default();
});
}
fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, TlsError> {
let file = File::open(path).map_err(|e| TlsError::Io {
path: path.display().to_string(),
source: e,
})?;
let mut reader = BufReader::new(file);
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
.collect::<io::Result<Vec<_>>>()
.map_err(|e| TlsError::Io {
path: path.display().to_string(),
source: e,
})?;
if certs.is_empty() {
return Err(TlsError::NoMaterial {
kind: "certificate",
path: path.display().to_string(),
});
}
Ok(certs)
}
fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, TlsError> {
let file = File::open(path).map_err(|e| TlsError::Io {
path: path.display().to_string(),
source: e,
})?;
let mut reader = BufReader::new(file);
let key = rustls_pemfile::private_key(&mut reader).map_err(|e| TlsError::Io {
path: path.display().to_string(),
source: e,
})?;
key.ok_or_else(|| TlsError::NoMaterial {
kind: "private key",
path: path.display().to_string(),
})
}
pub fn load_server_config(
cert_path: &Path,
key_path: &Path,
client_ca: Option<&Path>,
) -> Result<Arc<ServerConfig>, TlsError> {
ensure_provider_installed();
let certs = load_certs(cert_path)?;
let key = load_private_key(key_path)?;
let builder = ServerConfig::builder();
let cfg = if let Some(ca_path) = client_ca {
let ca_certs = load_certs(ca_path)?;
let mut roots = RootCertStore::empty();
for c in ca_certs {
roots
.add(c)
.map_err(|e| TlsError::Rustls(format!("ca add: {e}")))?;
}
let verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(roots))
.build()
.map_err(|e| TlsError::Rustls(format!("client verifier: {e}")))?;
builder
.with_client_cert_verifier(verifier)
.with_single_cert(certs, key)
.map_err(|e| TlsError::Rustls(e.to_string()))?
} else {
builder
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| TlsError::Rustls(e.to_string()))?
};
Ok(Arc::new(cfg))
}
pub fn load_client_config(ca_path: Option<&Path>) -> Result<Arc<ClientConfig>, TlsError> {
ensure_provider_installed();
let mut roots = RootCertStore::empty();
if let Some(p) = ca_path {
let ca_certs = load_certs(p)?;
for c in ca_certs {
roots
.add(c)
.map_err(|e| TlsError::Rustls(format!("ca add: {e}")))?;
}
} else {
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
}
let cfg = ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
Ok(Arc::new(cfg))
}
#[must_use]
pub fn acceptor_from(server_config: Arc<ServerConfig>) -> TlsAcceptor {
TlsAcceptor::from(server_config)
}
#[must_use]
pub fn connector_from(client_config: Arc<ClientConfig>) -> TlsConnector {
TlsConnector::from(client_config)
}
pub fn server_name_owned(host: &str) -> Result<ServerName<'static>, TlsError> {
ServerName::try_from(host.to_string())
.map_err(|e| TlsError::Rustls(format!("server name: {e}")))
}
#[must_use]
pub fn dc_sni_hostname(dc: &str) -> String {
format!("dc-{dc}.dynomite.local")
}
fn dc_from_sni_label(name: &str) -> Option<&str> {
name.strip_prefix("dc-")
.and_then(|rest| rest.strip_suffix(".dynomite.local"))
.filter(|dc| !dc.is_empty())
}
#[derive(Debug, Clone)]
pub struct TlsProfileSpec {
pub cert: PathBuf,
pub key: PathBuf,
pub ca: Option<PathBuf>,
}
#[derive(Clone, Default)]
pub struct TlsProfileMap {
per_dc_server: BTreeMap<String, Arc<ServerConfig>>,
per_dc_client: BTreeMap<String, Arc<ClientConfig>>,
per_dc_certified: BTreeMap<String, Arc<CertifiedKey>>,
default_server: Option<Arc<ServerConfig>>,
default_client: Option<Arc<ClientConfig>>,
default_certified: Option<Arc<CertifiedKey>>,
combined_ca_certs: Vec<CertificateDer<'static>>,
has_any_client_ca: bool,
}
impl std::fmt::Debug for TlsProfileMap {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TlsProfileMap")
.field("per_dc", &self.per_dc_server.keys().collect::<Vec<_>>())
.field("has_default", &self.default_server.is_some())
.field("has_any_client_ca", &self.has_any_client_ca)
.finish_non_exhaustive()
}
}
impl TlsProfileMap {
pub fn build(
default: Option<TlsProfileSpec>,
per_dc: BTreeMap<String, TlsProfileSpec>,
) -> Result<Self, TlsError> {
ensure_provider_installed();
let provider = rustls::crypto::CryptoProvider::get_default()
.cloned()
.unwrap_or_else(|| Arc::new(rustls::crypto::ring::default_provider()));
let mut map = Self::default();
if let Some(spec) = default {
let server_cfg = load_server_config(&spec.cert, &spec.key, spec.ca.as_deref())?;
let client_cfg = load_client_config(spec.ca.as_deref())?;
let certified = load_certified_key(&spec.cert, &spec.key, provider.as_ref())?;
if let Some(ca_path) = spec.ca.as_deref() {
map.combined_ca_certs.extend(load_certs(ca_path)?);
map.has_any_client_ca = true;
}
map.default_server = Some(server_cfg);
map.default_client = Some(client_cfg);
map.default_certified = Some(certified);
}
for (dc, spec) in per_dc {
let server_cfg = load_server_config(&spec.cert, &spec.key, spec.ca.as_deref())?;
let client_cfg = load_client_config(spec.ca.as_deref())?;
let certified = load_certified_key(&spec.cert, &spec.key, provider.as_ref())?;
if let Some(ca_path) = spec.ca.as_deref() {
map.combined_ca_certs.extend(load_certs(ca_path)?);
map.has_any_client_ca = true;
}
map.per_dc_server.insert(dc.clone(), server_cfg);
map.per_dc_client.insert(dc.clone(), client_cfg);
map.per_dc_certified.insert(dc, certified);
}
Ok(map)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.default_server.is_none() && self.per_dc_server.is_empty()
}
#[must_use]
pub fn server_config_for_dc(&self, dc: &str) -> Option<Arc<ServerConfig>> {
self.per_dc_server
.get(dc)
.cloned()
.or_else(|| self.default_server.clone())
}
#[must_use]
pub fn client_config_for_dc(&self, dc: &str) -> Option<Arc<ClientConfig>> {
self.per_dc_client
.get(dc)
.cloned()
.or_else(|| self.default_client.clone())
}
#[must_use]
pub fn default_server_config(&self) -> Option<Arc<ServerConfig>> {
self.default_server.clone()
}
#[must_use]
pub fn default_client_config(&self) -> Option<Arc<ClientConfig>> {
self.default_client.clone()
}
#[must_use]
pub fn requires_client_auth(&self) -> bool {
self.has_any_client_ca
}
#[must_use]
pub fn dc_names(&self) -> Vec<String> {
self.per_dc_certified.keys().cloned().collect()
}
pub fn build_sni_acceptor(&self) -> Result<Option<tokio_rustls::TlsAcceptor>, TlsError> {
if self.is_empty() {
return Ok(None);
}
ensure_provider_installed();
let resolver = DcSniResolver {
by_dc: self.per_dc_certified.clone(),
default: self.default_certified.clone(),
};
let builder = ServerConfig::builder();
let cfg = if self.has_any_client_ca {
let mut roots = RootCertStore::empty();
self.populate_combined_ca_roots(&mut roots)?;
let verifier = WebPkiClientVerifier::builder(Arc::new(roots))
.build()
.map_err(|e| TlsError::Rustls(format!("client verifier: {e}")))?;
builder
.with_client_cert_verifier(verifier)
.with_cert_resolver(Arc::new(resolver))
} else {
builder
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver))
};
Ok(Some(tokio_rustls::TlsAcceptor::from(Arc::new(cfg))))
}
fn populate_combined_ca_roots(&self, roots: &mut RootCertStore) -> Result<(), TlsError> {
for cert in &self.combined_ca_certs {
roots
.add(cert.clone())
.map_err(|e| TlsError::Rustls(format!("ca add: {e}")))?;
}
Ok(())
}
}
#[derive(Debug)]
struct DcSniResolver {
by_dc: BTreeMap<String, Arc<CertifiedKey>>,
default: Option<Arc<CertifiedKey>>,
}
impl ResolvesServerCert for DcSniResolver {
fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
if let Some(name) = hello.server_name() {
if let Some(dc) = dc_from_sni_label(name) {
if let Some(ck) = self.by_dc.get(dc) {
return Some(ck.clone());
}
}
}
self.default.clone()
}
}
#[derive(Debug)]
struct ReloadingDcSniResolver {
profiles: Arc<RwLock<TlsProfileMap>>,
}
impl ResolvesServerCert for ReloadingDcSniResolver {
fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
let profiles = self.profiles.read();
if let Some(name) = hello.server_name() {
if let Some(dc) = dc_from_sni_label(name) {
if let Some(ck) = profiles.per_dc_certified.get(dc) {
return Some(ck.clone());
}
}
}
profiles.default_certified.clone()
}
}
#[derive(Clone, Debug, Default)]
pub struct SharedTlsProfiles {
inner: Arc<RwLock<TlsProfileMap>>,
}
impl SharedTlsProfiles {
#[must_use]
pub fn from_map(map: TlsProfileMap) -> Self {
Self {
inner: Arc::new(RwLock::new(map)),
}
}
pub fn replace(&self, map: TlsProfileMap) {
*self.inner.write() = map;
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.read().is_empty()
}
#[must_use]
pub fn client_config_for_dc(&self, dc: &str) -> Option<Arc<ClientConfig>> {
self.inner.read().client_config_for_dc(dc)
}
#[must_use]
pub fn requires_client_auth(&self) -> bool {
self.inner.read().requires_client_auth()
}
#[must_use]
pub fn dc_names(&self) -> Vec<String> {
self.inner.read().dc_names()
}
pub fn build_sni_acceptor(&self) -> Result<Option<TlsAcceptor>, TlsError> {
if self.is_empty() {
return Ok(None);
}
ensure_provider_installed();
let resolver = ReloadingDcSniResolver {
profiles: self.inner.clone(),
};
let has_any_client_ca = self.inner.read().has_any_client_ca;
let builder = ServerConfig::builder();
let cfg = if has_any_client_ca {
let mut roots = RootCertStore::empty();
self.inner.read().populate_combined_ca_roots(&mut roots)?;
let verifier = WebPkiClientVerifier::builder(Arc::new(roots))
.build()
.map_err(|e| TlsError::Rustls(format!("client verifier: {e}")))?;
builder
.with_client_cert_verifier(verifier)
.with_cert_resolver(Arc::new(resolver))
} else {
builder
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver))
};
Ok(Some(TlsAcceptor::from(Arc::new(cfg))))
}
}
fn load_certified_key(
cert_path: &Path,
key_path: &Path,
provider: &rustls::crypto::CryptoProvider,
) -> Result<Arc<CertifiedKey>, TlsError> {
let certs = load_certs(cert_path)?;
let key = load_private_key(key_path)?;
let ck = CertifiedKey::from_der(certs, key, provider)
.map_err(|e| TlsError::Rustls(format!("certified key: {e}")))?;
Ok(Arc::new(ck))
}
#[derive(Debug)]
pub struct TlsServerTransport {
inner: tokio_rustls::server::TlsStream<TcpStream>,
role: ConnRole,
peer_addr: Option<SocketAddr>,
}
impl TlsServerTransport {
#[must_use]
pub fn new(stream: tokio_rustls::server::TlsStream<TcpStream>, role: ConnRole) -> Self {
let peer_addr = stream.get_ref().0.peer_addr().ok();
Self {
inner: stream,
role,
peer_addr,
}
}
}
impl Transport for TlsServerTransport {
fn role(&self) -> ConnRole {
self.role
}
fn peer_addr(&self) -> Option<SocketAddr> {
self.peer_addr
}
}
impl AsyncRead for TlsServerTransport {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl AsyncWrite for TlsServerTransport {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
#[derive(Debug)]
pub struct TlsClientTransport {
inner: tokio_rustls::client::TlsStream<TcpStream>,
role: ConnRole,
peer_addr: Option<SocketAddr>,
}
impl TlsClientTransport {
#[must_use]
pub fn new(stream: tokio_rustls::client::TlsStream<TcpStream>, role: ConnRole) -> Self {
let peer_addr = stream.get_ref().0.peer_addr().ok();
Self {
inner: stream,
role,
peer_addr,
}
}
}
impl Transport for TlsClientTransport {
fn role(&self) -> ConnRole {
self.role
}
fn peer_addr(&self) -> Option<SocketAddr> {
self.peer_addr
}
}
impl AsyncRead for TlsClientTransport {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl AsyncWrite for TlsClientTransport {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::TempDir;
fn write_pem(dir: &TempDir, name: &str, body: &str) -> std::path::PathBuf {
let p = dir.path().join(name);
let mut f = File::create(&p).unwrap();
f.write_all(body.as_bytes()).unwrap();
p
}
fn issue_self_signed() -> (String, String) {
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
(cert.cert.pem(), cert.signing_key.serialize_pem())
}
#[test]
fn load_server_config_round_trip() {
let dir = tempfile::tempdir().unwrap();
let (cert_pem, key_pem) = issue_self_signed();
let cert = write_pem(&dir, "cert.pem", &cert_pem);
let key = write_pem(&dir, "key.pem", &key_pem);
let cfg = load_server_config(&cert, &key, None).unwrap();
assert!(Arc::strong_count(&cfg) >= 1);
}
#[test]
fn load_server_config_rejects_missing_cert() {
let dir = tempfile::tempdir().unwrap();
let bogus = dir.path().join("missing.pem");
let key = write_pem(&dir, "key.pem", "");
let err = load_server_config(&bogus, &key, None).expect_err("missing");
assert!(matches!(err, TlsError::Io { .. }), "got {err:?}");
}
#[test]
fn load_server_config_rejects_empty_cert_file() {
let dir = tempfile::tempdir().unwrap();
let cert = write_pem(&dir, "cert.pem", "");
let key = write_pem(&dir, "key.pem", "");
let err = load_server_config(&cert, &key, None).expect_err("empty");
assert!(matches!(
err,
TlsError::NoMaterial {
kind: "certificate",
..
}
));
}
#[test]
fn load_client_config_with_webpki_default() {
let cfg = load_client_config(None).unwrap();
assert!(Arc::strong_count(&cfg) >= 1);
}
#[test]
fn server_name_owned_accepts_dns_label() {
assert!(server_name_owned("localhost").is_ok());
}
fn write_self_signed(dir: &TempDir, prefix: &str) -> (std::path::PathBuf, std::path::PathBuf) {
let (cert_pem, key_pem) = issue_self_signed();
(
write_pem(dir, &format!("{prefix}-cert.pem"), &cert_pem),
write_pem(dir, &format!("{prefix}-key.pem"), &key_pem),
)
}
#[test]
fn dc_sni_hostname_round_trips() {
assert_eq!(dc_sni_hostname("dc1"), "dc-dc1.dynomite.local");
assert_eq!(dc_from_sni_label("dc-dc1.dynomite.local"), Some("dc1"));
assert_eq!(dc_from_sni_label("localhost"), None);
assert_eq!(dc_from_sni_label("dc-.dynomite.local"), None);
assert_eq!(dc_from_sni_label("dc-dc1.example.com"), None);
}
#[test]
fn tls_profile_map_empty_is_empty() {
let map = TlsProfileMap::build(None, BTreeMap::new()).unwrap();
assert!(map.is_empty());
assert!(map.client_config_for_dc("dc1").is_none());
assert!(map.server_config_for_dc("dc1").is_none());
assert!(!map.requires_client_auth());
assert!(map.build_sni_acceptor().unwrap().is_none());
}
#[test]
fn tls_profile_map_default_only_falls_back() {
let dir = tempfile::tempdir().unwrap();
let (cert, key) = write_self_signed(&dir, "default");
let map = TlsProfileMap::build(
Some(TlsProfileSpec {
cert,
key,
ca: None,
}),
BTreeMap::new(),
)
.unwrap();
assert!(!map.is_empty());
assert!(map.client_config_for_dc("dc1").is_some());
assert!(map.server_config_for_dc("dc-without-profile").is_some());
assert!(map.default_client_config().is_some());
assert!(map.build_sni_acceptor().unwrap().is_some());
}
#[test]
fn tls_profile_map_per_dc_overrides_default() {
let dir = tempfile::tempdir().unwrap();
let (def_cert, def_key) = write_self_signed(&dir, "default");
let (dc1_cert, dc1_key) = write_self_signed(&dir, "dc1");
let mut per_dc = BTreeMap::new();
per_dc.insert(
"dc1".into(),
TlsProfileSpec {
cert: dc1_cert,
key: dc1_key,
ca: None,
},
);
let map = TlsProfileMap::build(
Some(TlsProfileSpec {
cert: def_cert,
key: def_key,
ca: None,
}),
per_dc,
)
.unwrap();
let dc1 = map.client_config_for_dc("dc1").unwrap();
let other = map.client_config_for_dc("other-dc").unwrap();
assert!(
!Arc::ptr_eq(&dc1, &other),
"per-DC entry must differ from the default fallback"
);
assert_eq!(map.dc_names(), vec!["dc1".to_string()]);
}
#[test]
fn tls_profile_map_per_dc_only_no_default() {
let dir = tempfile::tempdir().unwrap();
let (cert, key) = write_self_signed(&dir, "dc2");
let mut per_dc = BTreeMap::new();
per_dc.insert(
"dc2".into(),
TlsProfileSpec {
cert,
key,
ca: None,
},
);
let map = TlsProfileMap::build(None, per_dc).unwrap();
assert!(map.client_config_for_dc("dc2").is_some());
assert!(map.client_config_for_dc("dc3").is_none());
assert!(map.server_config_for_dc("dc3").is_none());
}
#[test]
fn tls_profile_map_propagates_load_error() {
let dir = tempfile::tempdir().unwrap();
let bogus = dir.path().join("missing.pem");
let mut per_dc = BTreeMap::new();
per_dc.insert(
"dc1".into(),
TlsProfileSpec {
cert: bogus.clone(),
key: bogus,
ca: None,
},
);
let err = TlsProfileMap::build(None, per_dc).expect_err("missing");
assert!(matches!(err, TlsError::Io { .. }), "got {err:?}");
}
}