Skip to main content

dynomite/net/
tls.rs

1//! TLS helpers for the peer plane and the Riak gateways.
2//!
3//! Two small surfaces:
4//!
5//! * [`load_server_config`] reads PEM cert + key from disk and
6//!   returns an [`Arc<rustls::ServerConfig>`] suitable for
7//!   wrapping [`tokio_rustls::TlsAcceptor`]. When a CA path is
8//!   given, client certificates are verified against that CA
9//!   (mTLS); when it is `None`, client cert verification is
10//!   disabled.
11//! * [`load_client_config`] builds an [`Arc<rustls::ClientConfig>`]
12//!   suitable for [`tokio_rustls::TlsConnector`]. When a CA path
13//!   is given the root store is loaded from that file; otherwise
14//!   the bundled `webpki_roots` Mozilla bundle is used.
15//!
16//! Two thin newtypes wrap an established TLS stream and expose
17//! the [`crate::io::reactor::Transport`] interface so the rest of
18//! the network stack stays unchanged:
19//!
20//! * [`TlsServerTransport`] wraps the inbound side of a TLS
21//!   connection ([`tokio_rustls::server::TlsStream<TcpStream>`]).
22//! * [`TlsClientTransport`] wraps the outbound side
23//!   ([`tokio_rustls::client::TlsStream<TcpStream>`]).
24//!
25//! Mismatched config (cert without key or key without cert) is
26//! caught by the conf validator (see
27//! [`crate::conf::ConfPool::validate`]); this module assumes the
28//! caller has already cross-checked.
29//!
30//! # Examples
31//!
32//! ```no_run
33//! use std::path::PathBuf;
34//! use dynomite::net::tls::{load_client_config, load_server_config};
35//!
36//! let cert = PathBuf::from("/etc/dynomite/peer.crt");
37//! let key = PathBuf::from("/etc/dynomite/peer.key");
38//! let _server = load_server_config(&cert, &key, None).unwrap();
39//! let _client = load_client_config(None).unwrap();
40//! ```
41//!
42//! # Provider selection
43//!
44//! `rustls` 0.23 requires an installed [crypto provider]. This
45//! module installs the `ring` provider as the process default the
46//! first time it is called, via a `OnceLock`. The install is
47//! idempotent and the lock is local to this module so callers do
48//! not have to think about it.
49//!
50//! [crypto provider]: rustls::crypto::CryptoProvider
51
52use std::collections::BTreeMap;
53use std::fs::File;
54use std::io::{self, BufReader};
55use std::net::SocketAddr;
56use std::path::{Path, PathBuf};
57use std::pin::Pin;
58use std::sync::{Arc, OnceLock};
59use std::task::{Context, Poll};
60
61use parking_lot::RwLock;
62use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
63use rustls::server::{ClientHello, ResolvesServerCert, WebPkiClientVerifier};
64use rustls::sign::CertifiedKey;
65use rustls::{ClientConfig, RootCertStore, ServerConfig};
66use thiserror::Error;
67use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
68use tokio::net::TcpStream;
69use tokio_rustls::{TlsAcceptor, TlsConnector};
70
71use crate::io::reactor::{ConnRole, Transport};
72
73/// Errors raised by the TLS loaders.
74#[derive(Debug, Error)]
75pub enum TlsError {
76    /// Failed to open or read a PEM file.
77    #[error("tls: io reading {path}: {source}")]
78    Io {
79        /// Path that failed.
80        path: String,
81        /// Underlying I/O error.
82        #[source]
83        source: io::Error,
84    },
85    /// PEM file did not contain a usable certificate or key.
86    #[error("tls: no usable {kind} found in {path}")]
87    NoMaterial {
88        /// Either `"certificate"` or `"private key"`.
89        kind: &'static str,
90        /// Path that came up empty.
91        path: String,
92    },
93    /// `rustls` rejected the supplied material.
94    #[error("tls: rustls rejected configuration: {0}")]
95    Rustls(String),
96}
97
98/// One-time installer for the rustls process-default crypto
99/// provider. Selecting `ring` keeps us off `aws-lc-rs` (whose
100/// build-time C dependency fails on the project's nix shell) and
101/// matches the `quiche` transport's bundled provider.
102fn ensure_provider_installed() {
103    static INSTALL: OnceLock<()> = OnceLock::new();
104    INSTALL.get_or_init(|| {
105        // Ignore the result: another thread or another caller in
106        // the same process may have installed a provider already
107        // (the test harness, for example, links every binary into
108        // one process). The provider we install is the same
109        // either way.
110        let _ = rustls::crypto::ring::default_provider().install_default();
111    });
112}
113
114fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, TlsError> {
115    let file = File::open(path).map_err(|e| TlsError::Io {
116        path: path.display().to_string(),
117        source: e,
118    })?;
119    let mut reader = BufReader::new(file);
120    let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
121        .collect::<io::Result<Vec<_>>>()
122        .map_err(|e| TlsError::Io {
123            path: path.display().to_string(),
124            source: e,
125        })?;
126    if certs.is_empty() {
127        return Err(TlsError::NoMaterial {
128            kind: "certificate",
129            path: path.display().to_string(),
130        });
131    }
132    Ok(certs)
133}
134
135fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, TlsError> {
136    let file = File::open(path).map_err(|e| TlsError::Io {
137        path: path.display().to_string(),
138        source: e,
139    })?;
140    let mut reader = BufReader::new(file);
141    let key = rustls_pemfile::private_key(&mut reader).map_err(|e| TlsError::Io {
142        path: path.display().to_string(),
143        source: e,
144    })?;
145    key.ok_or_else(|| TlsError::NoMaterial {
146        kind: "private key",
147        path: path.display().to_string(),
148    })
149}
150
151/// Build a [`ServerConfig`] from PEM cert + key.
152///
153/// When `client_ca` is `Some(p)`, every accepted connection must
154/// present a certificate signed by a CA from that PEM bundle
155/// (mutual TLS). When `None`, client certificates are not
156/// requested and the server accepts plaintext authentication.
157///
158/// # Errors
159/// Returns [`TlsError`] if any file is missing, malformed, or
160/// rejected by rustls.
161pub fn load_server_config(
162    cert_path: &Path,
163    key_path: &Path,
164    client_ca: Option<&Path>,
165) -> Result<Arc<ServerConfig>, TlsError> {
166    ensure_provider_installed();
167    let certs = load_certs(cert_path)?;
168    let key = load_private_key(key_path)?;
169
170    let builder = ServerConfig::builder();
171    let cfg = if let Some(ca_path) = client_ca {
172        let ca_certs = load_certs(ca_path)?;
173        let mut roots = RootCertStore::empty();
174        for c in ca_certs {
175            roots
176                .add(c)
177                .map_err(|e| TlsError::Rustls(format!("ca add: {e}")))?;
178        }
179        let verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(roots))
180            .build()
181            .map_err(|e| TlsError::Rustls(format!("client verifier: {e}")))?;
182        builder
183            .with_client_cert_verifier(verifier)
184            .with_single_cert(certs, key)
185            .map_err(|e| TlsError::Rustls(e.to_string()))?
186    } else {
187        builder
188            .with_no_client_auth()
189            .with_single_cert(certs, key)
190            .map_err(|e| TlsError::Rustls(e.to_string()))?
191    };
192    Ok(Arc::new(cfg))
193}
194
195/// Build a [`ClientConfig`] for outbound TLS.
196///
197/// When `ca_path` is `Some(p)`, the supplied PEM bundle is the
198/// only trust anchor. When `None`, the Mozilla bundle from
199/// [`webpki_roots`] is loaded; this is appropriate for clusters
200/// whose peer certs chain to a public CA, and is the conservative
201/// default for outbound calls in tests where the operator has
202/// not pinned a CA.
203///
204/// # Errors
205/// Returns [`TlsError`] if a CA file is missing or malformed.
206pub fn load_client_config(ca_path: Option<&Path>) -> Result<Arc<ClientConfig>, TlsError> {
207    ensure_provider_installed();
208    let mut roots = RootCertStore::empty();
209    if let Some(p) = ca_path {
210        let ca_certs = load_certs(p)?;
211        for c in ca_certs {
212            roots
213                .add(c)
214                .map_err(|e| TlsError::Rustls(format!("ca add: {e}")))?;
215        }
216    } else {
217        roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
218    }
219    let cfg = ClientConfig::builder()
220        .with_root_certificates(roots)
221        .with_no_client_auth();
222    Ok(Arc::new(cfg))
223}
224
225/// Convenience wrapper that adapts a [`ServerConfig`] into a
226/// [`TlsAcceptor`].
227#[must_use]
228pub fn acceptor_from(server_config: Arc<ServerConfig>) -> TlsAcceptor {
229    TlsAcceptor::from(server_config)
230}
231
232/// Convenience wrapper that adapts a [`ClientConfig`] into a
233/// [`TlsConnector`].
234#[must_use]
235pub fn connector_from(client_config: Arc<ClientConfig>) -> TlsConnector {
236    TlsConnector::from(client_config)
237}
238
239/// Parse a [`ServerName`] from a host string.
240///
241/// # Errors
242/// Returns [`TlsError::Rustls`] when the input is not a valid
243/// DNS name or IP literal.
244pub fn server_name_owned(host: &str) -> Result<ServerName<'static>, TlsError> {
245    ServerName::try_from(host.to_string())
246        .map_err(|e| TlsError::Rustls(format!("server name: {e}")))
247}
248
249/// SNI label the peer plane uses to route handshakes to the
250/// matching per-DC profile.
251///
252/// Both ends of a peer-plane TLS handshake set the SNI to
253/// `dc-<peer-dc>.dynomite.local`; the listener's SNI resolver
254/// (see [`TlsProfileMap::build_sni_acceptor`]) parses this label
255/// to pick the certificate.
256///
257/// # Examples
258///
259/// ```
260/// use dynomite::net::tls::dc_sni_hostname;
261/// assert_eq!(dc_sni_hostname("dc1"), "dc-dc1.dynomite.local");
262/// ```
263#[must_use]
264pub fn dc_sni_hostname(dc: &str) -> String {
265    format!("dc-{dc}.dynomite.local")
266}
267
268/// Inverse of [`dc_sni_hostname`]: extract the DC name from an
269/// SNI label that follows the `dc-<dc>.dynomite.local` shape, or
270/// return `None` if the label does not match.
271fn dc_from_sni_label(name: &str) -> Option<&str> {
272    name.strip_prefix("dc-")
273        .and_then(|rest| rest.strip_suffix(".dynomite.local"))
274        .filter(|dc| !dc.is_empty())
275}
276
277/// PEM material for one TLS profile.
278///
279/// Used by [`TlsProfileMap::build`] to assemble per-DC server
280/// and client configs from on-disk paths. `cert` and `key` are
281/// required; `ca` is optional and, when present, pins the
282/// trust anchor for both directions and turns the listener
283/// into a mutual-TLS deployment.
284#[derive(Debug, Clone)]
285pub struct TlsProfileSpec {
286    /// PEM certificate path.
287    pub cert: PathBuf,
288    /// PEM private-key path matching [`Self::cert`].
289    pub key: PathBuf,
290    /// Optional PEM CA bundle.
291    pub ca: Option<PathBuf>,
292}
293
294/// Bundle of precompiled rustls configs for the peer plane,
295/// keyed by datacenter name plus an optional default profile
296/// used as a fallback for any DC without an explicit entry.
297///
298/// The map is built once at startup by
299/// [`TlsProfileMap::build`] and shared (cheaply, every member
300/// is an `Arc` under the hood) across the dnode listener and
301/// every per-peer outbound supervisor. Lookups are O(log n) in
302/// the number of DCs.
303///
304/// # Examples
305///
306/// ```no_run
307/// use std::collections::BTreeMap;
308/// use std::path::PathBuf;
309/// use dynomite::net::tls::{TlsProfileMap, TlsProfileSpec};
310///
311/// let mut per_dc = BTreeMap::new();
312/// per_dc.insert(
313///     "dc1".to_string(),
314///     TlsProfileSpec {
315///         cert: PathBuf::from("/etc/dynomite/dc1.pem"),
316///         key: PathBuf::from("/etc/dynomite/dc1.key"),
317///         ca: None,
318///     },
319/// );
320/// let map = TlsProfileMap::build(None, per_dc).unwrap();
321/// assert!(map.client_config_for_dc("dc1").is_some());
322/// assert!(map.client_config_for_dc("dc-without-profile").is_none());
323/// ```
324#[derive(Clone, Default)]
325pub struct TlsProfileMap {
326    per_dc_server: BTreeMap<String, Arc<ServerConfig>>,
327    per_dc_client: BTreeMap<String, Arc<ClientConfig>>,
328    per_dc_certified: BTreeMap<String, Arc<CertifiedKey>>,
329    default_server: Option<Arc<ServerConfig>>,
330    default_client: Option<Arc<ClientConfig>>,
331    default_certified: Option<Arc<CertifiedKey>>,
332    /// Combined CA cert chain (DER) across every profile that
333    /// supplied a CA bundle. Used by
334    /// [`Self::build_sni_acceptor`] to assemble a single client
335    /// verifier that trusts any configured CA.
336    combined_ca_certs: Vec<CertificateDer<'static>>,
337    has_any_client_ca: bool,
338}
339
340impl std::fmt::Debug for TlsProfileMap {
341    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342        f.debug_struct("TlsProfileMap")
343            .field("per_dc", &self.per_dc_server.keys().collect::<Vec<_>>())
344            .field("has_default", &self.default_server.is_some())
345            .field("has_any_client_ca", &self.has_any_client_ca)
346            .finish_non_exhaustive()
347    }
348}
349
350impl TlsProfileMap {
351    /// Build a map from a default profile (the legacy
352    /// `peer_tls_*` triple) plus a `dc -> TlsProfileSpec` map.
353    ///
354    /// Either argument may be empty: a `None` `default` plus
355    /// an empty `per_dc` produces an empty map (peer plane runs
356    /// plaintext).
357    ///
358    /// # Errors
359    /// Returns the first [`TlsError`] from a failing PEM load.
360    pub fn build(
361        default: Option<TlsProfileSpec>,
362        per_dc: BTreeMap<String, TlsProfileSpec>,
363    ) -> Result<Self, TlsError> {
364        ensure_provider_installed();
365        let provider = rustls::crypto::CryptoProvider::get_default()
366            .cloned()
367            .unwrap_or_else(|| Arc::new(rustls::crypto::ring::default_provider()));
368
369        let mut map = Self::default();
370
371        if let Some(spec) = default {
372            let server_cfg = load_server_config(&spec.cert, &spec.key, spec.ca.as_deref())?;
373            let client_cfg = load_client_config(spec.ca.as_deref())?;
374            let certified = load_certified_key(&spec.cert, &spec.key, provider.as_ref())?;
375            if let Some(ca_path) = spec.ca.as_deref() {
376                map.combined_ca_certs.extend(load_certs(ca_path)?);
377                map.has_any_client_ca = true;
378            }
379            map.default_server = Some(server_cfg);
380            map.default_client = Some(client_cfg);
381            map.default_certified = Some(certified);
382        }
383
384        for (dc, spec) in per_dc {
385            let server_cfg = load_server_config(&spec.cert, &spec.key, spec.ca.as_deref())?;
386            let client_cfg = load_client_config(spec.ca.as_deref())?;
387            let certified = load_certified_key(&spec.cert, &spec.key, provider.as_ref())?;
388            if let Some(ca_path) = spec.ca.as_deref() {
389                map.combined_ca_certs.extend(load_certs(ca_path)?);
390                map.has_any_client_ca = true;
391            }
392            map.per_dc_server.insert(dc.clone(), server_cfg);
393            map.per_dc_client.insert(dc.clone(), client_cfg);
394            map.per_dc_certified.insert(dc, certified);
395        }
396
397        Ok(map)
398    }
399
400    /// True when no profile (default or per-DC) is configured.
401    /// In this state the peer plane runs plaintext.
402    #[must_use]
403    pub fn is_empty(&self) -> bool {
404        self.default_server.is_none() && self.per_dc_server.is_empty()
405    }
406
407    /// Server config to use for a connection negotiated with a
408    /// peer in `dc`. Returns the per-DC entry if present,
409    /// otherwise the default profile, otherwise `None`.
410    #[must_use]
411    pub fn server_config_for_dc(&self, dc: &str) -> Option<Arc<ServerConfig>> {
412        self.per_dc_server
413            .get(dc)
414            .cloned()
415            .or_else(|| self.default_server.clone())
416    }
417
418    /// Client config to use when dialing a peer in `dc`.
419    /// Returns the per-DC entry if present, otherwise the
420    /// default profile, otherwise `None`.
421    #[must_use]
422    pub fn client_config_for_dc(&self, dc: &str) -> Option<Arc<ClientConfig>> {
423        self.per_dc_client
424            .get(dc)
425            .cloned()
426            .or_else(|| self.default_client.clone())
427    }
428
429    /// Default server config (the legacy / fallback profile).
430    #[must_use]
431    pub fn default_server_config(&self) -> Option<Arc<ServerConfig>> {
432        self.default_server.clone()
433    }
434
435    /// Default client config (the legacy / fallback profile).
436    #[must_use]
437    pub fn default_client_config(&self) -> Option<Arc<ClientConfig>> {
438        self.default_client.clone()
439    }
440
441    /// True when at least one configured profile carries a CA
442    /// bundle. When set, the SNI listener requires every
443    /// inbound peer to present a certificate signed by one of
444    /// the configured CAs (mTLS).
445    #[must_use]
446    pub fn requires_client_auth(&self) -> bool {
447        self.has_any_client_ca
448    }
449
450    /// Names of the DCs with explicit per-DC entries (sorted).
451    #[must_use]
452    pub fn dc_names(&self) -> Vec<String> {
453        self.per_dc_certified.keys().cloned().collect()
454    }
455
456    /// Build a single [`tokio_rustls::TlsAcceptor`] whose
457    /// `ServerConfig` picks the certificate by SNI hostname
458    /// (`dc-<dc-name>.dynomite.local`) and falls back to the
459    /// default profile when SNI is missing or does not match.
460    ///
461    /// Returns `None` when [`Self::is_empty`] is true.
462    ///
463    /// # Errors
464    /// Returns [`TlsError::Rustls`] when rustls rejects the
465    /// assembled root store / verifier (e.g. a malformed CA
466    /// certificate that slipped through the loader).
467    pub fn build_sni_acceptor(&self) -> Result<Option<tokio_rustls::TlsAcceptor>, TlsError> {
468        if self.is_empty() {
469            return Ok(None);
470        }
471        ensure_provider_installed();
472        let resolver = DcSniResolver {
473            by_dc: self.per_dc_certified.clone(),
474            default: self.default_certified.clone(),
475        };
476        let builder = ServerConfig::builder();
477        let cfg = if self.has_any_client_ca {
478            // Combine every CA bundle (default + per-DC) into a
479            // single root store for client verification. This
480            // keeps the listener's mTLS check uniform across
481            // SNI-routed certs; an inbound peer chains to any
482            // of the configured CAs.
483            let mut roots = RootCertStore::empty();
484            self.populate_combined_ca_roots(&mut roots)?;
485            let verifier = WebPkiClientVerifier::builder(Arc::new(roots))
486                .build()
487                .map_err(|e| TlsError::Rustls(format!("client verifier: {e}")))?;
488            builder
489                .with_client_cert_verifier(verifier)
490                .with_cert_resolver(Arc::new(resolver))
491        } else {
492            builder
493                .with_no_client_auth()
494                .with_cert_resolver(Arc::new(resolver))
495        };
496        Ok(Some(tokio_rustls::TlsAcceptor::from(Arc::new(cfg))))
497    }
498
499    /// Populate a [`RootCertStore`] with the CAs from every
500    /// per-DC entry plus the default profile. Used by
501    /// [`Self::build_sni_acceptor`] when at least one profile
502    /// carries a CA.
503    fn populate_combined_ca_roots(&self, roots: &mut RootCertStore) -> Result<(), TlsError> {
504        for cert in &self.combined_ca_certs {
505            roots
506                .add(cert.clone())
507                .map_err(|e| TlsError::Rustls(format!("ca add: {e}")))?;
508        }
509        Ok(())
510    }
511}
512
513/// Custom rustls cert resolver: maps SNI of the shape
514/// `dc-<dc-name>.dynomite.local` to a per-DC `CertifiedKey`,
515/// falling back to the default profile when the SNI is missing
516/// or does not match.
517#[derive(Debug)]
518struct DcSniResolver {
519    by_dc: BTreeMap<String, Arc<CertifiedKey>>,
520    default: Option<Arc<CertifiedKey>>,
521}
522
523impl ResolvesServerCert for DcSniResolver {
524    fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
525        if let Some(name) = hello.server_name() {
526            if let Some(dc) = dc_from_sni_label(name) {
527                if let Some(ck) = self.by_dc.get(dc) {
528                    return Some(ck.clone());
529                }
530            }
531        }
532        self.default.clone()
533    }
534}
535
536/// SNI resolver that reads its certified-key map from a shared
537/// [`SharedTlsProfiles`]. Every handshake re-borrows the inner
538/// [`TlsProfileMap`] via the read lock, so a SIGHUP-driven
539/// [`SharedTlsProfiles::replace`] takes effect on the next
540/// inbound connection without rebuilding the [`TlsAcceptor`].
541#[derive(Debug)]
542struct ReloadingDcSniResolver {
543    profiles: Arc<RwLock<TlsProfileMap>>,
544}
545
546impl ResolvesServerCert for ReloadingDcSniResolver {
547    fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
548        let profiles = self.profiles.read();
549        if let Some(name) = hello.server_name() {
550            if let Some(dc) = dc_from_sni_label(name) {
551                if let Some(ck) = profiles.per_dc_certified.get(dc) {
552                    return Some(ck.clone());
553                }
554            }
555        }
556        profiles.default_certified.clone()
557    }
558}
559
560/// Reloadable wrapper around [`TlsProfileMap`].
561///
562/// Holds an [`Arc<parking_lot::RwLock<TlsProfileMap>>`] so the
563/// inbound listener (via [`Self::build_sni_acceptor`]) and every
564/// outbound peer supervisor can pick up cert / key / CA changes
565/// without rebinding sockets or rebuilding their
566/// [`tokio_rustls::TlsAcceptor`]. The resolver returned by
567/// [`Self::build_sni_acceptor`] reads the inner map on every
568/// handshake.
569///
570/// `Clone` is `Arc`-cheap.
571///
572/// # Examples
573///
574/// ```
575/// use std::collections::BTreeMap;
576/// use dynomite::net::tls::{SharedTlsProfiles, TlsProfileMap};
577/// let map = TlsProfileMap::build(None, BTreeMap::new()).unwrap();
578/// let shared = SharedTlsProfiles::from_map(map);
579/// assert!(shared.is_empty());
580/// ```
581#[derive(Clone, Debug, Default)]
582pub struct SharedTlsProfiles {
583    inner: Arc<RwLock<TlsProfileMap>>,
584}
585
586impl SharedTlsProfiles {
587    /// Wrap an existing [`TlsProfileMap`] in a shared cell.
588    #[must_use]
589    pub fn from_map(map: TlsProfileMap) -> Self {
590        Self {
591            inner: Arc::new(RwLock::new(map)),
592        }
593    }
594
595    /// Atomically replace the inner profile map.
596    ///
597    /// Subsequent handshakes (and outbound dials that consult
598    /// [`Self::client_config_for_dc`]) observe the new material;
599    /// already-negotiated TLS sessions are not affected.
600    pub fn replace(&self, map: TlsProfileMap) {
601        *self.inner.write() = map;
602    }
603
604    /// True when the wrapped map is empty (peer plane plaintext).
605    #[must_use]
606    pub fn is_empty(&self) -> bool {
607        self.inner.read().is_empty()
608    }
609
610    /// Per-DC client config, with the legacy default as fallback.
611    /// Reads the inner map at call time.
612    #[must_use]
613    pub fn client_config_for_dc(&self, dc: &str) -> Option<Arc<ClientConfig>> {
614        self.inner.read().client_config_for_dc(dc)
615    }
616
617    /// True when at least one wrapped profile pins a CA bundle.
618    #[must_use]
619    pub fn requires_client_auth(&self) -> bool {
620        self.inner.read().requires_client_auth()
621    }
622
623    /// Names of the DCs with explicit per-DC entries (sorted).
624    #[must_use]
625    pub fn dc_names(&self) -> Vec<String> {
626        self.inner.read().dc_names()
627    }
628
629    /// Build a SIGHUP-aware [`tokio_rustls::TlsAcceptor`].
630    ///
631    /// The acceptor's underlying [`ServerConfig`] holds a
632    /// resolver that re-reads the wrapped
633    /// [`Arc<parking_lot::RwLock<TlsProfileMap>>`] on every
634    /// handshake, so [`Self::replace`] takes effect on the next
635    /// inbound connection without rebinding the listener.
636    ///
637    /// Returns `None` when the inner map is empty (caller stays
638    /// plaintext).
639    ///
640    /// # Errors
641    /// Returns [`TlsError::Rustls`] when rustls rejects the
642    /// assembled root store or the verifier (e.g. a CA cert
643    /// the loader missed).
644    pub fn build_sni_acceptor(&self) -> Result<Option<TlsAcceptor>, TlsError> {
645        if self.is_empty() {
646            return Ok(None);
647        }
648        ensure_provider_installed();
649        let resolver = ReloadingDcSniResolver {
650            profiles: self.inner.clone(),
651        };
652        let has_any_client_ca = self.inner.read().has_any_client_ca;
653        let builder = ServerConfig::builder();
654        let cfg = if has_any_client_ca {
655            let mut roots = RootCertStore::empty();
656            self.inner.read().populate_combined_ca_roots(&mut roots)?;
657            let verifier = WebPkiClientVerifier::builder(Arc::new(roots))
658                .build()
659                .map_err(|e| TlsError::Rustls(format!("client verifier: {e}")))?;
660            builder
661                .with_client_cert_verifier(verifier)
662                .with_cert_resolver(Arc::new(resolver))
663        } else {
664            builder
665                .with_no_client_auth()
666                .with_cert_resolver(Arc::new(resolver))
667        };
668        Ok(Some(TlsAcceptor::from(Arc::new(cfg))))
669    }
670}
671
672fn load_certified_key(
673    cert_path: &Path,
674    key_path: &Path,
675    provider: &rustls::crypto::CryptoProvider,
676) -> Result<Arc<CertifiedKey>, TlsError> {
677    let certs = load_certs(cert_path)?;
678    let key = load_private_key(key_path)?;
679    let ck = CertifiedKey::from_der(certs, key, provider)
680        .map_err(|e| TlsError::Rustls(format!("certified key: {e}")))?;
681    Ok(Arc::new(ck))
682}
683
684/// [`Transport`] wrapping a server-side TLS stream over a TCP
685/// connection.
686#[derive(Debug)]
687pub struct TlsServerTransport {
688    inner: tokio_rustls::server::TlsStream<TcpStream>,
689    role: ConnRole,
690    peer_addr: Option<SocketAddr>,
691}
692
693impl TlsServerTransport {
694    /// Wrap an established server-side TLS stream.
695    #[must_use]
696    pub fn new(stream: tokio_rustls::server::TlsStream<TcpStream>, role: ConnRole) -> Self {
697        let peer_addr = stream.get_ref().0.peer_addr().ok();
698        Self {
699            inner: stream,
700            role,
701            peer_addr,
702        }
703    }
704}
705
706impl Transport for TlsServerTransport {
707    fn role(&self) -> ConnRole {
708        self.role
709    }
710    fn peer_addr(&self) -> Option<SocketAddr> {
711        self.peer_addr
712    }
713}
714
715impl AsyncRead for TlsServerTransport {
716    fn poll_read(
717        mut self: Pin<&mut Self>,
718        cx: &mut Context<'_>,
719        buf: &mut ReadBuf<'_>,
720    ) -> Poll<io::Result<()>> {
721        Pin::new(&mut self.inner).poll_read(cx, buf)
722    }
723}
724
725impl AsyncWrite for TlsServerTransport {
726    fn poll_write(
727        mut self: Pin<&mut Self>,
728        cx: &mut Context<'_>,
729        buf: &[u8],
730    ) -> Poll<io::Result<usize>> {
731        Pin::new(&mut self.inner).poll_write(cx, buf)
732    }
733    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
734        Pin::new(&mut self.inner).poll_flush(cx)
735    }
736    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
737        Pin::new(&mut self.inner).poll_shutdown(cx)
738    }
739}
740
741/// [`Transport`] wrapping a client-side TLS stream over a TCP
742/// connection.
743#[derive(Debug)]
744pub struct TlsClientTransport {
745    inner: tokio_rustls::client::TlsStream<TcpStream>,
746    role: ConnRole,
747    peer_addr: Option<SocketAddr>,
748}
749
750impl TlsClientTransport {
751    /// Wrap an established client-side TLS stream.
752    #[must_use]
753    pub fn new(stream: tokio_rustls::client::TlsStream<TcpStream>, role: ConnRole) -> Self {
754        let peer_addr = stream.get_ref().0.peer_addr().ok();
755        Self {
756            inner: stream,
757            role,
758            peer_addr,
759        }
760    }
761}
762
763impl Transport for TlsClientTransport {
764    fn role(&self) -> ConnRole {
765        self.role
766    }
767    fn peer_addr(&self) -> Option<SocketAddr> {
768        self.peer_addr
769    }
770}
771
772impl AsyncRead for TlsClientTransport {
773    fn poll_read(
774        mut self: Pin<&mut Self>,
775        cx: &mut Context<'_>,
776        buf: &mut ReadBuf<'_>,
777    ) -> Poll<io::Result<()>> {
778        Pin::new(&mut self.inner).poll_read(cx, buf)
779    }
780}
781
782impl AsyncWrite for TlsClientTransport {
783    fn poll_write(
784        mut self: Pin<&mut Self>,
785        cx: &mut Context<'_>,
786        buf: &[u8],
787    ) -> Poll<io::Result<usize>> {
788        Pin::new(&mut self.inner).poll_write(cx, buf)
789    }
790    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
791        Pin::new(&mut self.inner).poll_flush(cx)
792    }
793    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
794        Pin::new(&mut self.inner).poll_shutdown(cx)
795    }
796}
797
798#[cfg(test)]
799mod tests {
800    use super::*;
801    use std::io::Write;
802    use tempfile::TempDir;
803
804    fn write_pem(dir: &TempDir, name: &str, body: &str) -> std::path::PathBuf {
805        let p = dir.path().join(name);
806        let mut f = File::create(&p).unwrap();
807        f.write_all(body.as_bytes()).unwrap();
808        p
809    }
810
811    fn issue_self_signed() -> (String, String) {
812        let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
813        (cert.cert.pem(), cert.signing_key.serialize_pem())
814    }
815
816    #[test]
817    fn load_server_config_round_trip() {
818        let dir = tempfile::tempdir().unwrap();
819        let (cert_pem, key_pem) = issue_self_signed();
820        let cert = write_pem(&dir, "cert.pem", &cert_pem);
821        let key = write_pem(&dir, "key.pem", &key_pem);
822        let cfg = load_server_config(&cert, &key, None).unwrap();
823        assert!(Arc::strong_count(&cfg) >= 1);
824    }
825
826    #[test]
827    fn load_server_config_rejects_missing_cert() {
828        let dir = tempfile::tempdir().unwrap();
829        let bogus = dir.path().join("missing.pem");
830        let key = write_pem(&dir, "key.pem", "");
831        let err = load_server_config(&bogus, &key, None).expect_err("missing");
832        assert!(matches!(err, TlsError::Io { .. }), "got {err:?}");
833    }
834
835    #[test]
836    fn load_server_config_rejects_empty_cert_file() {
837        let dir = tempfile::tempdir().unwrap();
838        let cert = write_pem(&dir, "cert.pem", "");
839        let key = write_pem(&dir, "key.pem", "");
840        let err = load_server_config(&cert, &key, None).expect_err("empty");
841        assert!(matches!(
842            err,
843            TlsError::NoMaterial {
844                kind: "certificate",
845                ..
846            }
847        ));
848    }
849
850    #[test]
851    fn load_client_config_with_webpki_default() {
852        let cfg = load_client_config(None).unwrap();
853        assert!(Arc::strong_count(&cfg) >= 1);
854    }
855
856    #[test]
857    fn server_name_owned_accepts_dns_label() {
858        assert!(server_name_owned("localhost").is_ok());
859    }
860
861    fn write_self_signed(dir: &TempDir, prefix: &str) -> (std::path::PathBuf, std::path::PathBuf) {
862        let (cert_pem, key_pem) = issue_self_signed();
863        (
864            write_pem(dir, &format!("{prefix}-cert.pem"), &cert_pem),
865            write_pem(dir, &format!("{prefix}-key.pem"), &key_pem),
866        )
867    }
868
869    #[test]
870    fn dc_sni_hostname_round_trips() {
871        assert_eq!(dc_sni_hostname("dc1"), "dc-dc1.dynomite.local");
872        assert_eq!(dc_from_sni_label("dc-dc1.dynomite.local"), Some("dc1"));
873        assert_eq!(dc_from_sni_label("localhost"), None);
874        assert_eq!(dc_from_sni_label("dc-.dynomite.local"), None);
875        assert_eq!(dc_from_sni_label("dc-dc1.example.com"), None);
876    }
877
878    #[test]
879    fn tls_profile_map_empty_is_empty() {
880        let map = TlsProfileMap::build(None, BTreeMap::new()).unwrap();
881        assert!(map.is_empty());
882        assert!(map.client_config_for_dc("dc1").is_none());
883        assert!(map.server_config_for_dc("dc1").is_none());
884        assert!(!map.requires_client_auth());
885        assert!(map.build_sni_acceptor().unwrap().is_none());
886    }
887
888    #[test]
889    fn tls_profile_map_default_only_falls_back() {
890        let dir = tempfile::tempdir().unwrap();
891        let (cert, key) = write_self_signed(&dir, "default");
892        let map = TlsProfileMap::build(
893            Some(TlsProfileSpec {
894                cert,
895                key,
896                ca: None,
897            }),
898            BTreeMap::new(),
899        )
900        .unwrap();
901        assert!(!map.is_empty());
902        // Any DC name resolves to the default.
903        assert!(map.client_config_for_dc("dc1").is_some());
904        assert!(map.server_config_for_dc("dc-without-profile").is_some());
905        assert!(map.default_client_config().is_some());
906        assert!(map.build_sni_acceptor().unwrap().is_some());
907    }
908
909    #[test]
910    fn tls_profile_map_per_dc_overrides_default() {
911        let dir = tempfile::tempdir().unwrap();
912        let (def_cert, def_key) = write_self_signed(&dir, "default");
913        let (dc1_cert, dc1_key) = write_self_signed(&dir, "dc1");
914        let mut per_dc = BTreeMap::new();
915        per_dc.insert(
916            "dc1".into(),
917            TlsProfileSpec {
918                cert: dc1_cert,
919                key: dc1_key,
920                ca: None,
921            },
922        );
923        let map = TlsProfileMap::build(
924            Some(TlsProfileSpec {
925                cert: def_cert,
926                key: def_key,
927                ca: None,
928            }),
929            per_dc,
930        )
931        .unwrap();
932        // dc1 must hit its own entry.
933        let dc1 = map.client_config_for_dc("dc1").unwrap();
934        // Distinct DC must fall back to the default.
935        let other = map.client_config_for_dc("other-dc").unwrap();
936        assert!(
937            !Arc::ptr_eq(&dc1, &other),
938            "per-DC entry must differ from the default fallback"
939        );
940        assert_eq!(map.dc_names(), vec!["dc1".to_string()]);
941    }
942
943    #[test]
944    fn tls_profile_map_per_dc_only_no_default() {
945        let dir = tempfile::tempdir().unwrap();
946        let (cert, key) = write_self_signed(&dir, "dc2");
947        let mut per_dc = BTreeMap::new();
948        per_dc.insert(
949            "dc2".into(),
950            TlsProfileSpec {
951                cert,
952                key,
953                ca: None,
954            },
955        );
956        let map = TlsProfileMap::build(None, per_dc).unwrap();
957        assert!(map.client_config_for_dc("dc2").is_some());
958        // No default: an unknown DC returns None and the
959        // caller falls back to plaintext.
960        assert!(map.client_config_for_dc("dc3").is_none());
961        assert!(map.server_config_for_dc("dc3").is_none());
962    }
963
964    #[test]
965    fn tls_profile_map_propagates_load_error() {
966        let dir = tempfile::tempdir().unwrap();
967        // Cert path that does not exist.
968        let bogus = dir.path().join("missing.pem");
969        let mut per_dc = BTreeMap::new();
970        per_dc.insert(
971            "dc1".into(),
972            TlsProfileSpec {
973                cert: bogus.clone(),
974                key: bogus,
975                ca: None,
976            },
977        );
978        let err = TlsProfileMap::build(None, per_dc).expect_err("missing");
979        assert!(matches!(err, TlsError::Io { .. }), "got {err:?}");
980    }
981}