Skip to main content

crabka_security/
reload.rs

1//! Hot-reloadable TLS server config.
2//!
3//! Wraps a [`rustls::ServerConfig`] in an [`arc_swap::ArcSwap`] so the
4//! broker can swap cert/key/client-CA without restarting the listener
5//! or breaking in-flight connections. New TLS handshakes pick up the
6//! latest config; already-established TLS sessions continue to use the
7//! `ServerConfig` they negotiated against.
8//!
9//! Cooperates with the existing [`crate::TlsConfig`] — the path-and-
10//! options struct stays as the source of truth, and
11//! [`DynamicServerConfig::reload_from`] re-reads the files into a fresh
12//! `Arc<ServerConfig>` and atomically swaps it in.
13
14use std::sync::Arc;
15
16use arc_swap::ArcSwap;
17
18use crate::tls::{TlsConfig, TlsError};
19
20/// Atomically swappable wrapper around a [`rustls::ServerConfig`].
21/// Cheap to clone (one `Arc` bump); cheap to read (lock-free); the
22/// only expensive operation is [`reload_from`], which re-parses cert
23/// files.
24///
25/// [`reload_from`]: Self::reload_from
26#[derive(Debug)]
27pub struct DynamicServerConfig {
28    inner: ArcSwap<rustls::ServerConfig>,
29}
30
31impl DynamicServerConfig {
32    /// Build a fresh `DynamicServerConfig` from a [`TlsConfig`]. Reads
33    /// cert + key + optional client-CA paths immediately.
34    ///
35    /// # Errors
36    ///
37    /// Propagates the underlying [`TlsError`] from
38    /// [`TlsConfig::build_server_config`].
39    pub fn from_tls_config(cfg: &TlsConfig) -> Result<Arc<Self>, TlsError> {
40        let server_config = cfg.build_server_config()?;
41        Ok(Arc::new(Self {
42            inner: ArcSwap::new(server_config),
43        }))
44    }
45
46    /// Snapshot the current `ServerConfig`. The returned `Arc` is
47    /// independent of subsequent [`reload_from`](Self::reload_from)
48    /// calls — already-running handshakes against this snapshot are
49    /// unaffected by a concurrent reload.
50    #[must_use]
51    pub fn current(&self) -> Arc<rustls::ServerConfig> {
52        // `ArcSwap::load_full` clones the inner Arc (one atomic bump);
53        // we then unwrap the outer `Guard` into a plain `Arc`.
54        self.inner.load_full()
55    }
56
57    /// Re-read cert + key + optional client-CA from disk and swap the
58    /// new `ServerConfig` in atomically. On error the previous config
59    /// is left in place and the error is returned to the caller.
60    ///
61    /// Reload is a no-op semantically when the inputs haven't changed
62    /// — but rustls doesn't expose a content-equality hash on
63    /// `ServerConfig`, so we just rebuild unconditionally. Callers
64    /// that want to skip identical reloads should diff mtimes upstream.
65    ///
66    /// # Errors
67    ///
68    /// Propagates the underlying [`TlsError`] from
69    /// [`TlsConfig::build_server_config`].
70    pub fn reload_from(&self, cfg: &TlsConfig) -> Result<(), TlsError> {
71        let new = cfg.build_server_config()?;
72        self.inner.store(new);
73        Ok(())
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use assert2::assert;
81    use std::fs::File;
82    use std::io::Write;
83    use std::path::PathBuf;
84
85    use crate::tls::ClientAuthMode;
86
87    fn install_provider() {
88        let _ = rustls::crypto::ring::default_provider().install_default();
89    }
90
91    fn write_pair(dir: &std::path::Path, cert_pem: &str, key_pem: &str) -> (PathBuf, PathBuf) {
92        let cp = dir.join("cert.pem");
93        let kp = dir.join("key.pem");
94        File::create(&cp)
95            .unwrap()
96            .write_all(cert_pem.as_bytes())
97            .unwrap();
98        File::create(&kp)
99            .unwrap()
100            .write_all(key_pem.as_bytes())
101            .unwrap();
102        (cp, kp)
103    }
104
105    /// The `current()` snapshot is stable across a subsequent
106    /// `reload_from` — an existing handshake that captured the old
107    /// `Arc<ServerConfig>` continues to see the old certs.
108    #[test]
109    fn snapshot_is_stable_across_reload() {
110        install_provider();
111        let dir = tempfile::tempdir().unwrap();
112        let (cp, kp) = write_pair(
113            dir.path(),
114            include_str!("../tests/fixtures/dev_cert.pem"),
115            include_str!("../tests/fixtures/dev_key.pem"),
116        );
117        let cfg = TlsConfig {
118            cert_chain_path: cp.clone(),
119            private_key_path: kp.clone(),
120            trust_roots_path: None,
121            client_ca_path: None,
122            client_auth: ClientAuthMode::Disabled,
123        };
124        let dynamic = DynamicServerConfig::from_tls_config(&cfg).unwrap();
125        let snap_before = dynamic.current();
126
127        // Overwrite the cert files with the alt fixture and reload.
128        std::fs::write(&cp, include_str!("../tests/fixtures/dev_cert_alt.pem")).unwrap();
129        std::fs::write(&kp, include_str!("../tests/fixtures/dev_key_alt.pem")).unwrap();
130        dynamic.reload_from(&cfg).expect("reload must succeed");
131        let snap_after = dynamic.current();
132
133        // The two snapshots must be distinct `Arc`s — `current()`
134        // returns the latest, so `after != before` is the post-swap
135        // invariant. Use raw pointer comparison: `ServerConfig` is not
136        // `PartialEq`, but `Arc::ptr_eq` is exactly what we want.
137        assert!(
138            !Arc::ptr_eq(&snap_before, &snap_after),
139            "after-reload snapshot must point at a fresh ServerConfig"
140        );
141    }
142
143    /// Reload error leaves the prior config in place. We can't easily
144    /// produce a "rebuild fails" path with valid filesystem inputs, so
145    /// simulate by pointing the second reload at a nonexistent cert
146    /// path. The current-snapshot must remain the originally-loaded
147    /// config.
148    #[test]
149    fn reload_error_does_not_swap() {
150        install_provider();
151        let dir = tempfile::tempdir().unwrap();
152        let (cp, kp) = write_pair(
153            dir.path(),
154            include_str!("../tests/fixtures/dev_cert.pem"),
155            include_str!("../tests/fixtures/dev_key.pem"),
156        );
157        let cfg = TlsConfig {
158            cert_chain_path: cp,
159            private_key_path: kp,
160            trust_roots_path: None,
161            client_ca_path: None,
162            client_auth: ClientAuthMode::Disabled,
163        };
164        let dynamic = DynamicServerConfig::from_tls_config(&cfg).unwrap();
165        let snap_before = dynamic.current();
166
167        let bogus = TlsConfig {
168            cert_chain_path: dir.path().join("missing.pem"),
169            private_key_path: dir.path().join("missing.key"),
170            trust_roots_path: None,
171            client_ca_path: None,
172            client_auth: ClientAuthMode::Disabled,
173        };
174        let err = dynamic.reload_from(&bogus).unwrap_err();
175        // Don't pattern-match on the variant — `NoCerts` vs
176        // `NoPrivateKey` depends on read order. The point is that
177        // `reload_from` errored.
178        let _ = err;
179
180        let snap_after = dynamic.current();
181        assert!(
182            Arc::ptr_eq(&snap_before, &snap_after),
183            "failed reload must leave previous ServerConfig in place"
184        );
185    }
186}