Skip to main content

mailrs_tls_reload/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(missing_docs)]
3#![deny(rustdoc::broken_intra_doc_links)]
4
5use std::io;
6use std::path::Path;
7use std::sync::Arc;
8
9use arc_swap::ArcSwap;
10use rustls::ServerConfig;
11use tokio_rustls::TlsAcceptor;
12
13/// Wrapper around `Arc<ArcSwap<ServerConfig>>` that lets you swap the
14/// active rustls server config atomically. In-flight TLS handshakes
15/// keep the old config (each `acceptor()` call snapshots the current
16/// pointer); new handshakes use the new config immediately after a
17/// [`swap`](Self::swap).
18///
19/// Typical use: hold a `TlsState` in your server, derive a fresh
20/// [`TlsAcceptor`] for each incoming connection via
21/// [`TlsState::acceptor`], and call [`TlsState::swap`] from your
22/// renewal hook (ACME, certbot reload signal, etc.) when new
23/// certificates land on disk.
24///
25/// ```rust,no_run
26/// use mailrs_tls_reload::{TlsState, load_tls_config};
27/// use std::path::Path;
28///
29/// # async fn run() -> std::io::Result<()> {
30/// let cfg = load_tls_config(Path::new("cert.pem"), Path::new("key.pem"))?;
31/// let state = TlsState::new((*cfg).clone());
32///
33/// // ... in your accept loop:
34/// let acceptor = state.acceptor();
35/// // tokio::spawn(handle(acceptor, socket));
36///
37/// // ... later, certs got renewed:
38/// let new_cfg = load_tls_config(Path::new("cert.pem"), Path::new("key.pem"))?;
39/// state.swap((*new_cfg).clone());
40/// // Subsequent acceptor() calls return the new config.
41/// # Ok(())
42/// # }
43/// ```
44#[derive(Clone)]
45pub struct TlsState {
46    inner: Arc<ArcSwap<ServerConfig>>,
47}
48
49impl TlsState {
50    /// Construct from an initial server config.
51    pub fn new(config: ServerConfig) -> Self {
52        Self {
53            inner: Arc::new(ArcSwap::from_pointee(config)),
54        }
55    }
56
57    /// Snapshot the current config into a fresh `TlsAcceptor`. Each
58    /// snapshot is independent — in-flight handshakes that took an
59    /// acceptor before [`swap`](Self::swap) keep using the old config.
60    pub fn acceptor(&self) -> TlsAcceptor {
61        TlsAcceptor::from(self.inner.load_full())
62    }
63
64    /// Replace the active config atomically. New `acceptor()` calls
65    /// after this point return the new config. In-flight handshakes
66    /// are unaffected.
67    pub fn swap(&self, new: ServerConfig) {
68        self.inner.store(Arc::new(new));
69    }
70
71    /// Read the current config (a fresh `Arc`). Cheap; one atomic load.
72    /// Useful for inspecting the active certs in admin endpoints.
73    pub fn current(&self) -> Arc<ServerConfig> {
74        self.inner.load_full()
75    }
76}
77
78/// Load a rustls `ServerConfig` from PEM-encoded cert + key files on
79/// disk. Returns `Arc<ServerConfig>` ready to hand to [`TlsState::new`]
80/// or [`TlsState::swap`] (after a clone).
81///
82/// The cert file may contain a chain (multiple PEM blocks); the key
83/// file must contain exactly one PEM-encoded private key
84/// (PKCS#1, PKCS#8, or SEC1). No client auth is configured.
85///
86/// Errors:
87/// - `io::Error` (NotFound, PermissionDenied) if either file is missing
88///   or unreadable
89/// - `io::Error(InvalidData)` if either file is not valid PEM, the key
90///   is unparseable, or rustls rejects the (certs, key) pair
91pub fn load_tls_config(cert_path: &Path, key_path: &Path) -> io::Result<Arc<ServerConfig>> {
92    let cert_data = std::fs::read(cert_path)?;
93    let key_data = std::fs::read(key_path)?;
94
95    use rustls_pki_types::pem::PemObject;
96    use rustls_pki_types::{CertificateDer, PrivateKeyDer};
97
98    let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_slice_iter(&cert_data)
99        .collect::<Result<Vec<_>, _>>()
100        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("{e:?}")))?;
101
102    let key = PrivateKeyDer::from_pem_slice(&key_data)
103        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("{e:?}")))?;
104
105    let config = ServerConfig::builder()
106        .with_no_client_auth()
107        .with_single_cert(certs, key)
108        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
109
110    Ok(Arc::new(config))
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use std::io::Write;
117    use std::sync::Once;
118
119    /// Make sure rustls's default crypto provider is installed exactly
120    /// once for the test process. Required by rustls 0.23 before any
121    /// ServerConfig::builder() call.
122    fn install_crypto_provider() {
123        static INIT: Once = Once::new();
124        INIT.call_once(|| {
125            let _ = rustls::crypto::ring::default_provider().install_default();
126        });
127    }
128
129    /// Generate a fresh self-signed cert + key, written as PEM to two
130    /// temp files. Returns (cert_path, key_path).
131    ///
132    /// Uses an `AtomicU64` counter for the temp filename so concurrent
133    /// tests don't race on the same path.
134    fn make_self_signed_pem_files() -> (std::path::PathBuf, std::path::PathBuf) {
135        use std::sync::atomic::{AtomicU64, Ordering};
136        static COUNTER: AtomicU64 = AtomicU64::new(0);
137        let nonce = COUNTER.fetch_add(1, Ordering::Relaxed);
138
139        let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])
140            .expect("rcgen self-signed");
141        let cert_pem = cert.cert.pem();
142        let key_pem = cert.key_pair.serialize_pem();
143        let pid = std::process::id();
144        let cert_path = std::env::temp_dir().join(format!("mailrs-tls-reload-{pid}-{nonce}-cert.pem"));
145        let key_path = std::env::temp_dir().join(format!("mailrs-tls-reload-{pid}-{nonce}-key.pem"));
146        std::fs::write(&cert_path, cert_pem).unwrap();
147        std::fs::write(&key_path, key_pem).unwrap();
148        (cert_path, key_path)
149    }
150
151    #[test]
152    fn load_tls_config_succeeds_with_valid_self_signed() {
153        install_crypto_provider();
154        let (cert_path, key_path) = make_self_signed_pem_files();
155        let _cfg = load_tls_config(&cert_path, &key_path).expect("valid PEMs should load");
156        // Successful load is the assertion. ServerConfig doesn't expose
157        // its inner cert chain via a public API we can introspect here;
158        // the fact that load returned Ok means rustls accepted the
159        // (chain, key) pair.
160    }
161
162    #[test]
163    fn tls_state_new_returns_acceptor() {
164        install_crypto_provider();
165        let (cert_path, key_path) = make_self_signed_pem_files();
166        let cfg = load_tls_config(&cert_path, &key_path).unwrap();
167        let state = TlsState::new((*cfg).clone());
168        let _acceptor = state.acceptor();
169        // Just verifying it constructs and yields an acceptor without
170        // panicking. Actual TLS handshake covered by mailrs-server bin
171        // tests.
172    }
173
174    #[test]
175    fn tls_state_swap_changes_current_pointer() {
176        install_crypto_provider();
177        let (cert_path, key_path) = make_self_signed_pem_files();
178        let cfg_a = load_tls_config(&cert_path, &key_path).unwrap();
179        let state = TlsState::new((*cfg_a).clone());
180        let before = state.current();
181        // Fresh cert → different Arc identity
182        let (cert_path_b, key_path_b) = make_self_signed_pem_files();
183        let cfg_b = load_tls_config(&cert_path_b, &key_path_b).unwrap();
184        state.swap((*cfg_b).clone());
185        let after = state.current();
186        assert!(
187            !Arc::ptr_eq(&before, &after),
188            "swap should produce a fresh Arc"
189        );
190    }
191
192    #[test]
193    fn tls_state_acceptor_snapshots_at_call_time() {
194        install_crypto_provider();
195        let (cert_path, key_path) = make_self_signed_pem_files();
196        let cfg = load_tls_config(&cert_path, &key_path).unwrap();
197        let state = TlsState::new((*cfg).clone());
198        let _acc1 = state.acceptor();
199        // Swap mid-flight:
200        let (cert_path_b, key_path_b) = make_self_signed_pem_files();
201        let cfg_b = load_tls_config(&cert_path_b, &key_path_b).unwrap();
202        state.swap((*cfg_b).clone());
203        // _acc1 was snapshotted before the swap, but TlsAcceptor doesn't
204        // expose the inner config for us to assert ptr equality here.
205        // The test is documentation-by-construction: it compiles + runs,
206        // demonstrating the snapshot path doesn't break under concurrent
207        // swap.
208    }
209
210    #[test]
211    fn tls_state_current_returns_arc() {
212        install_crypto_provider();
213        let (cert_path, key_path) = make_self_signed_pem_files();
214        let cfg = load_tls_config(&cert_path, &key_path).unwrap();
215        let state = TlsState::new((*cfg).clone());
216        let a = state.current();
217        let b = state.current();
218        // Without a swap in between, current() returns Arcs to the
219        // same underlying config.
220        assert!(Arc::ptr_eq(&a, &b));
221    }
222
223    #[test]
224    fn tls_state_clone_shares_inner() {
225        install_crypto_provider();
226        let (cert_path, key_path) = make_self_signed_pem_files();
227        let cfg = load_tls_config(&cert_path, &key_path).unwrap();
228        let state1 = TlsState::new((*cfg).clone());
229        let state2 = state1.clone();
230        // After clone, both share the same inner ArcSwap, so a swap
231        // through one is visible via the other.
232        let (cert_path_b, key_path_b) = make_self_signed_pem_files();
233        let cfg_b = load_tls_config(&cert_path_b, &key_path_b).unwrap();
234        let before = state2.current();
235        state1.swap((*cfg_b).clone());
236        let after_via_state2 = state2.current();
237        assert!(
238            !Arc::ptr_eq(&before, &after_via_state2),
239            "clone should observe swap"
240        );
241    }
242
243    #[test]
244    fn load_tls_config_rejects_missing_files() {
245        let r = load_tls_config(
246            Path::new("/nonexistent/path/cert.pem"),
247            Path::new("/nonexistent/path/key.pem"),
248        );
249        assert!(r.is_err());
250        let e = r.unwrap_err();
251        assert_eq!(e.kind(), io::ErrorKind::NotFound);
252    }
253
254    #[test]
255    fn load_tls_config_rejects_invalid_pem() {
256        // Write invalid PEM data to a temp file
257        let cert_temp = tempfile_for("not_a_cert.pem", b"definitely not a PEM file");
258        let key_temp = tempfile_for("not_a_key.pem", b"also not a PEM file");
259        let r = load_tls_config(&cert_temp, &key_temp);
260        assert!(r.is_err());
261        let e = r.unwrap_err();
262        assert_eq!(e.kind(), io::ErrorKind::InvalidData);
263    }
264
265    #[test]
266    fn load_tls_config_rejects_empty_cert_file() {
267        let cert_temp = tempfile_for("empty_cert.pem", b"");
268        let key_temp = tempfile_for("empty_key.pem", b"");
269        let r = load_tls_config(&cert_temp, &key_temp);
270        assert!(r.is_err());
271    }
272
273    /// Helper: write `data` to a uniquely-named temp file, return the path.
274    /// The file lives until process exit (test cleanup is OS's problem;
275    /// `/tmp` gets reaped).
276    fn tempfile_for(name: &str, data: &[u8]) -> std::path::PathBuf {
277        let mut path = std::env::temp_dir();
278        path.push(format!(
279            "mailrs-tls-reload-{}-{name}",
280            std::process::id()
281        ));
282        let mut f = std::fs::File::create(&path).expect("create temp");
283        f.write_all(data).expect("write temp");
284        path
285    }
286
287    #[test]
288    fn tempfile_helper_works() {
289        let p = tempfile_for("smoke.txt", b"hi");
290        assert_eq!(std::fs::read(&p).unwrap(), b"hi");
291    }
292}