Skip to main content

s4_server/
tls.rs

1//! TLS termination helpers.
2//!
3//! Used by the binary's listener wiring. Kept as a separate library module so
4//! parsing logic (`load_tls_config`) is unit-testable and the `tokio-rustls`
5//! dependency is centralised here.
6
7use std::error::Error;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use tokio_rustls::TlsAcceptor;
13use tokio_rustls::rustls::ServerConfig;
14use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
15
16/// Loads PEM cert + key files into a rustls `ServerConfig` ready for
17/// `TlsAcceptor::from`. Supports PKCS#8 and RSA private keys via
18/// `rustls_pemfile::private_key`.
19///
20/// ALPN protocols default to `h2` then `http/1.1` — matching the
21/// `hyper_util::server::conn::auto::Builder` upstream so HTTP/2 is negotiated
22/// when the client offers it.
23pub fn load_tls_config(
24    cert_path: &Path,
25    key_path: &Path,
26) -> Result<Arc<ServerConfig>, Box<dyn Error + Send + Sync + 'static>> {
27    use std::fs::File;
28    use std::io::BufReader;
29
30    let mut cert_reader = BufReader::new(File::open(cert_path)?);
31    let certs: Vec<CertificateDer<'static>> =
32        rustls_pemfile::certs(&mut cert_reader).collect::<Result<Vec<_>, _>>()?;
33    if certs.is_empty() {
34        return Err(format!("no certificates found in {}", cert_path.display()).into());
35    }
36
37    let mut key_reader = BufReader::new(File::open(key_path)?);
38    let key: PrivateKeyDer<'static> = rustls_pemfile::private_key(&mut key_reader)?
39        .ok_or_else(|| format!("no private key found in {}", key_path.display()))?;
40
41    let mut config = ServerConfig::builder()
42        .with_no_client_auth()
43        .with_single_cert(certs, key)?;
44    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
45    Ok(Arc::new(config))
46}
47
48/// Installs the `ring` crypto provider as the process-wide default. rustls
49/// 0.23+ requires this before any `ServerConfig::builder()` call. Idempotent.
50pub fn install_default_crypto_provider() {
51    let _ = tokio_rustls::rustls::crypto::ring::default_provider().install_default();
52}
53
54/// Reloadable TLS state (v0.3 #10). Wraps an `ArcSwap<ServerConfig>` so the
55/// listener can swap the cert/key pair atomically on SIGHUP without dropping
56/// any in-flight connections. Construct via [`TlsState::load`] and pass
57/// `Arc<TlsState>` to both the accept loop and the SIGHUP handler.
58pub struct TlsState {
59    cfg: ArcSwap<ServerConfig>,
60    cert_path: PathBuf,
61    key_path: PathBuf,
62}
63
64impl TlsState {
65    /// Initial load — fails on parse error.
66    pub fn load(
67        cert_path: impl Into<PathBuf>,
68        key_path: impl Into<PathBuf>,
69    ) -> Result<Self, Box<dyn Error + Send + Sync + 'static>> {
70        let cert_path = cert_path.into();
71        let key_path = key_path.into();
72        let cfg = load_tls_config(&cert_path, &key_path)?;
73        Ok(Self {
74            cfg: ArcSwap::from(cfg),
75            cert_path,
76            key_path,
77        })
78    }
79
80    /// Build a fresh `TlsAcceptor` from the current config. Cheap (one
81    /// atomic load + Arc clone). Call this once per accepted connection.
82    pub fn acceptor(&self) -> TlsAcceptor {
83        TlsAcceptor::from(self.cfg.load_full())
84    }
85
86    /// Re-read the cert + key from disk and atomically swap the active
87    /// config. Returns `Ok(())` on success and `Err(...)` if the new pair
88    /// failed to parse — the previous config remains in effect either way,
89    /// so a bad reload never causes a listener outage.
90    pub fn reload(&self) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
91        let new_cfg = load_tls_config(&self.cert_path, &self.key_path)?;
92        self.cfg.store(new_cfg);
93        Ok(())
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use std::io::Write;
101    use tempfile::NamedTempFile;
102
103    /// Helper: write a self-signed cert+key pair to two NamedTempFiles using
104    /// rcgen and return them so the test can pass paths to load_tls_config.
105    fn write_self_signed_pair() -> (NamedTempFile, NamedTempFile) {
106        let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
107        let mut cert_file = NamedTempFile::new().unwrap();
108        cert_file.write_all(cert.cert.pem().as_bytes()).unwrap();
109        cert_file.flush().unwrap();
110        let mut key_file = NamedTempFile::new().unwrap();
111        key_file
112            .write_all(cert.key_pair.serialize_pem().as_bytes())
113            .unwrap();
114        key_file.flush().unwrap();
115        (cert_file, key_file)
116    }
117
118    #[test]
119    fn loads_pkcs8_cert_and_key() {
120        install_default_crypto_provider();
121        let (cert, key) = write_self_signed_pair();
122        let cfg = load_tls_config(cert.path(), key.path()).expect("config should load");
123        assert_eq!(
124            cfg.alpn_protocols,
125            vec![b"h2".to_vec(), b"http/1.1".to_vec()]
126        );
127    }
128
129    #[test]
130    fn rejects_missing_cert_file() {
131        let (_cert, key) = write_self_signed_pair();
132        let err = load_tls_config(std::path::Path::new("/nonexistent/cert.pem"), key.path())
133            .expect_err("should fail on missing cert");
134        assert!(
135            err.to_string().contains("No such file") || err.to_string().contains("cannot find")
136        );
137    }
138
139    #[test]
140    fn rejects_empty_cert_file() {
141        let cert = NamedTempFile::new().unwrap();
142        let (_, key) = write_self_signed_pair();
143        let err =
144            load_tls_config(cert.path(), key.path()).expect_err("should fail on empty cert PEM");
145        assert!(err.to_string().contains("no certificates found"));
146    }
147
148    /// v0.3 #10: reload swaps the config atomically. Verify by capturing
149    /// the cert serials before + after and confirming they differ.
150    #[test]
151    fn reload_swaps_active_config() {
152        install_default_crypto_provider();
153        let (cert_a, key_a) = write_self_signed_pair();
154
155        // Start with cert A, copy to a stable path the TlsState owns.
156        let dir = tempfile::tempdir().unwrap();
157        let cert_path = dir.path().join("tls.crt");
158        let key_path = dir.path().join("tls.key");
159        std::fs::copy(cert_a.path(), &cert_path).unwrap();
160        std::fs::copy(key_a.path(), &key_path).unwrap();
161
162        let state = TlsState::load(&cert_path, &key_path).expect("initial load");
163        let cfg_v1: Arc<ServerConfig> = state.cfg.load_full();
164
165        // Swap the on-disk files to a freshly-generated cert B.
166        let (cert_b, key_b) = write_self_signed_pair();
167        std::fs::copy(cert_b.path(), &cert_path).unwrap();
168        std::fs::copy(key_b.path(), &key_path).unwrap();
169
170        state.reload().expect("reload should succeed");
171        let cfg_v2: Arc<ServerConfig> = state.cfg.load_full();
172
173        // Pointer identity is the cleanest check: ArcSwap::store replaces
174        // the inner Arc, so cfg_v1 and cfg_v2 must NOT be the same Arc.
175        assert!(!Arc::ptr_eq(&cfg_v1, &cfg_v2));
176    }
177
178    /// Reload failure (bad PEM) must not break the active config.
179    #[test]
180    fn reload_failure_keeps_previous_config() {
181        install_default_crypto_provider();
182        let (cert_a, key_a) = write_self_signed_pair();
183        let dir = tempfile::tempdir().unwrap();
184        let cert_path = dir.path().join("tls.crt");
185        let key_path = dir.path().join("tls.key");
186        std::fs::copy(cert_a.path(), &cert_path).unwrap();
187        std::fs::copy(key_a.path(), &key_path).unwrap();
188
189        let state = TlsState::load(&cert_path, &key_path).expect("initial load");
190        let cfg_before: Arc<ServerConfig> = state.cfg.load_full();
191
192        // Corrupt the cert file in-place — reload should fail.
193        std::fs::write(&cert_path, b"not a pem certificate").unwrap();
194        let err = state.reload().expect_err("reload should fail");
195        assert!(err.to_string().contains("no certificates found"));
196
197        let cfg_after: Arc<ServerConfig> = state.cfg.load_full();
198        // Same Arc → previous config preserved.
199        assert!(Arc::ptr_eq(&cfg_before, &cfg_after));
200    }
201}