pgwire_replication/tls/
rustls.rs

1//! TLS support using rustls.
2//!
3//! This module provides TLS/SSL connection upgrade for PostgreSQL connections
4//! using the rustls library. It supports:
5//!
6//! - All PostgreSQL SSL modes (disable, prefer, require, verify-ca, verify-full)
7//! - Custom CA certificates
8//! - Client certificate authentication (mTLS)
9//! - SNI hostname override
10//!
11//! # SSL Modes
12//!
13//! | Mode | Chain Verified | Hostname Verified | Falls back to plain |
14//! |------|----------------|-------------------|---------------------|
15//! | `Disable` | - | - | N/A (never uses TLS) |
16//! | `Prefer` | No | No | Yes |
17//! | `Require` | No | No | No |
18//! | `VerifyCa` | Yes | No | No |
19//! | `VerifyFull` | Yes | Yes | No |
20//!
21//! # Security Considerations
22//!
23//! - `Prefer` and `Require` modes are vulnerable to MITM attacks
24//! - `VerifyCa` protects against MITM but allows any hostname
25//! - `VerifyFull` provides full protection (recommended for production)
26//!
27//! # Example
28//!
29//! ```no_run
30//! use pgwire_replication::config::TlsConfig;
31//! use pgwire_replication::tls::rustls::{maybe_upgrade_to_tls, MaybeTlsStream};
32//! use tokio::net::TcpStream;
33//! use std::path::PathBuf;
34//!
35//! #[tokio::main]
36//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
37//!     let tls_config = TlsConfig::verify_full(Some(PathBuf::new()))
38//!         .with_sni_hostname("db.example.com");
39//!
40//!     let tcp_stream = TcpStream::connect(("db.example.com", 5432)).await?;
41//!
42//!     let stream = maybe_upgrade_to_tls(tcp_stream, &tls_config, "db.example.com").await?;
43//!     match stream {
44//!         MaybeTlsStream::Plain(_) => {}
45//!         MaybeTlsStream::Tls(_) => {}
46//!     }
47//!
48//!     Ok(())
49//! }
50//! ```
51
52use std::pin::Pin;
53use std::task::{Context, Poll};
54use std::{fs::File, io::BufReader, sync::Arc};
55
56use rustls::{ClientConfig, RootCertStore};
57use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
58use tokio::net::TcpStream;
59use tokio_rustls::{client::TlsStream, TlsConnector};
60
61use crate::config::{SslMode, TlsConfig};
62use crate::error::{PgWireError, Result};
63use crate::protocol::framing::write_ssl_request;
64
65/// A stream that may or may not be TLS-encrypted.
66///
67/// This enum allows code to work with both plain TCP and TLS connections
68/// through a unified interface via `AsyncRead` and `AsyncWrite` implementations.
69#[derive(Debug)]
70pub enum MaybeTlsStream {
71    /// Unencrypted TCP connection
72    Plain(TcpStream),
73    /// TLS-encrypted connection (boxed to reduce enum size)
74    Tls(Box<TlsStream<TcpStream>>),
75}
76
77impl MaybeTlsStream {
78    /// Returns `true` if this is a TLS-encrypted stream.
79    #[inline]
80    pub fn is_tls(&self) -> bool {
81        matches!(self, MaybeTlsStream::Tls(_))
82    }
83
84    /// Returns `true` if this is a plain (unencrypted) stream.
85    #[inline]
86    pub fn is_plain(&self) -> bool {
87        matches!(self, MaybeTlsStream::Plain(_))
88    }
89
90    /// Returns a reference to the underlying `TcpStream`.
91    ///
92    /// For TLS streams, this returns the inner TCP stream.
93    pub fn get_ref(&self) -> &TcpStream {
94        match self {
95            MaybeTlsStream::Plain(s) => s,
96            MaybeTlsStream::Tls(s) => s.get_ref().0,
97        }
98    }
99}
100
101impl AsyncRead for MaybeTlsStream {
102    fn poll_read(
103        self: Pin<&mut Self>,
104        cx: &mut Context<'_>,
105        buf: &mut ReadBuf<'_>,
106    ) -> Poll<std::io::Result<()>> {
107        match self.get_mut() {
108            MaybeTlsStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
109            MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
110        }
111    }
112}
113
114impl AsyncWrite for MaybeTlsStream {
115    fn poll_write(
116        self: Pin<&mut Self>,
117        cx: &mut Context<'_>,
118        buf: &[u8],
119    ) -> Poll<std::io::Result<usize>> {
120        match self.get_mut() {
121            MaybeTlsStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
122            MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
123        }
124    }
125
126    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
127        match self.get_mut() {
128            MaybeTlsStream::Plain(s) => Pin::new(s).poll_flush(cx),
129            MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
130        }
131    }
132
133    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
134        match self.get_mut() {
135            MaybeTlsStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
136            MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
137        }
138    }
139}
140
141/// Attempt to upgrade a TCP connection to TLS based on configuration.
142///
143/// This function implements PostgreSQL's SSL negotiation protocol:
144/// 1. Send SSLRequest message
145/// 2. Read single-byte response ('S' = proceed, 'N' = rejected)
146/// 3. If proceeding, perform TLS handshake
147///
148/// # Arguments
149/// * `tcp` - The TCP connection to potentially upgrade
150/// * `tls` - TLS configuration specifying mode and certificates
151/// * `host` - Target hostname (used for SNI and verification)
152///
153/// # Returns
154/// A `MaybeTlsStream` that is either the original TCP stream or a TLS-wrapped stream.
155///
156/// # Errors
157/// - Server rejects TLS when mode requires it
158/// - TLS handshake failure
159/// - Certificate verification failure (for VerifyCa/VerifyFull)
160/// - Invalid certificate/key files
161pub async fn maybe_upgrade_to_tls(
162    mut tcp: TcpStream,
163    tls: &TlsConfig,
164    host: &str,
165) -> Result<MaybeTlsStream> {
166    match tls.mode {
167        SslMode::Disable => return Ok(MaybeTlsStream::Plain(tcp)),
168        SslMode::Prefer | SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull => {}
169    }
170
171    // Install crypto provider (required for rustls 0.23+)
172    // This is idempotent - safe to call multiple times
173    let _ = rustls::crypto::ring::default_provider().install_default();
174
175    // PostgreSQL TLS negotiation: send SSLRequest, expect 'S' or 'N'
176    write_ssl_request(&mut tcp).await?;
177
178    let mut resp = [0u8; 1];
179    use tokio::io::AsyncReadExt;
180    tcp.read_exact(&mut resp).await?;
181
182    if resp[0] != b'S' {
183        // Server refused TLS
184        return match tls.mode {
185            SslMode::Prefer => Ok(MaybeTlsStream::Plain(tcp)),
186            _ => Err(PgWireError::Tls(
187                "server does not support TLS (SSLRequest rejected)".into(),
188            )),
189        };
190    }
191
192    // Determine verification requirements based on SSL mode
193    let verify_chain = matches!(tls.mode, SslMode::VerifyCa | SslMode::VerifyFull);
194    let verify_hostname = matches!(tls.mode, SslMode::VerifyFull);
195
196    let cfg = build_rustls_config(tls, verify_chain, verify_hostname, host)?;
197    let connector = TlsConnector::from(Arc::new(cfg));
198
199    // SNI hostname: use override if provided, otherwise use connection host
200    let sni = tls.sni_hostname.as_deref().unwrap_or(host);
201    let server_name = rustls::pki_types::ServerName::try_from(sni.to_string())
202        .map_err(|_| PgWireError::Tls(format!("invalid SNI hostname '{sni}'")))?;
203
204    let tls_stream = connector
205        .connect(server_name, tcp)
206        .await
207        .map_err(|e| PgWireError::Tls(format!("TLS handshake failed: {e}")))?;
208
209    Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
210}
211
212/// Build rustls ClientConfig based on TLS settings.
213fn build_rustls_config(
214    tls: &TlsConfig,
215    verify_chain: bool,
216    verify_hostname: bool,
217    host: &str,
218) -> Result<ClientConfig> {
219    // ---- mTLS config validation ----
220    let has_cert = tls.client_cert_pem_path.is_some();
221    let has_key = tls.client_key_pem_path.is_some();
222    if has_cert ^ has_key {
223        return Err(PgWireError::Tls(format!(
224            "TLS config error: mTLS requires both client_cert_pem_path and client_key_pem_path \
225             (got cert={has_cert} key={has_key})"
226        )));
227    }
228
229    // Operator hint: VerifyFull + IP literal is a common misconfiguration
230    if verify_hostname && host.parse::<std::net::IpAddr>().is_ok() && tls.sni_hostname.is_none() {
231        return Err(PgWireError::Tls(format!(
232            "TLS config error: VerifyFull enabled but host '{host}' is an IP address. \
233             Hint: use a DNS name matching the certificate, or set tls.sni_hostname, \
234             or use VerifyCa mode."
235        )));
236    }
237
238    // ---- Root certificate store ----
239    let roots = build_root_store(tls)?;
240    let roots_arc = Arc::new(roots.clone());
241
242    // ---- Base config builder ----
243    let builder = ClientConfig::builder().with_root_certificates(roots);
244
245    // ---- Client authentication (mTLS) ----
246    let mut cfg: ClientConfig = if has_cert {
247        let cert_path = tls.client_cert_pem_path.as_ref().unwrap();
248        let key_path = tls.client_key_pem_path.as_ref().unwrap();
249
250        let cert_chain = load_cert_chain(cert_path)?;
251        let key = load_private_key(key_path)?;
252
253        builder
254            .with_client_auth_cert(cert_chain, key)
255            .map_err(|e| {
256                PgWireError::Tls(format!("TLS config error: invalid client cert/key: {e}"))
257            })?
258    } else {
259        builder.with_no_client_auth()
260    };
261
262    // ---- Custom verification policy ----
263    if !verify_chain {
264        // Prefer/Require: skip all verification (dangerous but matches PostgreSQL behavior)
265        cfg.dangerous()
266            .set_certificate_verifier(Arc::new(NoVerifier));
267        return Ok(cfg);
268    }
269
270    if verify_chain && !verify_hostname {
271        // VerifyCa: verify certificate chain but ignore hostname mismatch
272        let inner = rustls::client::WebPkiServerVerifier::builder(roots_arc)
273            .build()
274            .map_err(|e| PgWireError::Tls(format!("TLS config error: build verifier: {e}")))?;
275
276        cfg.dangerous()
277            .set_certificate_verifier(Arc::new(VerifyChainOnly { inner }));
278    }
279
280    // VerifyFull: rustls default behavior already verifies chain + hostname
281    Ok(cfg)
282}
283
284/// Build root certificate store from config or system defaults.
285fn build_root_store(tls: &TlsConfig) -> Result<RootCertStore> {
286    use rustls::pki_types::CertificateDer;
287
288    let mut roots = RootCertStore::empty();
289
290    if let Some(path) = &tls.ca_pem_path {
291        // Load custom CA certificates
292        let f = File::open(path).map_err(|e| {
293            PgWireError::Tls(format!(
294                "TLS config error: failed to open CA PEM '{}': {e}",
295                path.display()
296            ))
297        })?;
298        let mut rd = BufReader::new(f);
299
300        let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut rd)
301            .collect::<std::result::Result<Vec<_>, _>>()
302            .map_err(|e| {
303                PgWireError::Tls(format!(
304                    "TLS config error: failed to parse CA PEM '{}': {e}",
305                    path.display()
306                ))
307            })?
308            .into_iter()
309            .map(|c| c.into_owned())
310            .collect();
311
312        let (added, _ignored) = roots.add_parsable_certificates(certs);
313        if added == 0 {
314            return Err(PgWireError::Tls(format!(
315                "TLS config error: no valid CA certificates found in '{}'",
316                path.display()
317            )));
318        }
319    } else {
320        // Use Mozilla's root certificates (webpki-roots)
321        roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
322    }
323
324    Ok(roots)
325}
326
327/// Load certificate chain from PEM file.
328fn load_cert_chain(
329    path: &std::path::Path,
330) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
331    use rustls::pki_types::CertificateDer;
332
333    let f = File::open(path).map_err(|e| {
334        PgWireError::Tls(format!(
335            "TLS config error: failed to open client certificate '{}': {e}",
336            path.display()
337        ))
338    })?;
339    let mut rd = BufReader::new(f);
340
341    let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut rd)
342        .collect::<std::result::Result<Vec<_>, _>>()
343        .map_err(|e| {
344            PgWireError::Tls(format!(
345                "TLS config error: failed to parse client certificate '{}': {e}",
346                path.display()
347            ))
348        })?
349        .into_iter()
350        .map(|c| c.into_owned())
351        .collect();
352
353    if certs.is_empty() {
354        return Err(PgWireError::Tls(format!(
355            "TLS config error: no certificates found in '{}'",
356            path.display()
357        )));
358    }
359
360    Ok(certs)
361}
362
363/// Load private key from PEM file.
364///
365/// Supports PKCS#8, PKCS#1 (RSA), and SEC1 (EC) key formats.
366fn load_private_key(path: &std::path::Path) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
367    // Try PKCS#8 first (most common modern format)
368    if let Some(key) = try_load_pkcs8_key(path)? {
369        return Ok(key);
370    }
371
372    // Try RSA PKCS#1 format
373    if let Some(key) = try_load_rsa_key(path)? {
374        return Ok(key);
375    }
376
377    // Try EC SEC1 format
378    if let Some(key) = try_load_ec_key(path)? {
379        return Ok(key);
380    }
381
382    Err(PgWireError::Tls(format!(
383        "TLS config error: no private key found in '{}'. \
384         Supported formats: PKCS#8, PKCS#1 (RSA), SEC1 (EC)",
385        path.display()
386    )))
387}
388
389fn try_load_pkcs8_key(
390    path: &std::path::Path,
391) -> Result<Option<rustls::pki_types::PrivateKeyDer<'static>>> {
392    use rustls::pki_types::PrivateKeyDer;
393
394    let f = File::open(path).map_err(|e| {
395        PgWireError::Tls(format!(
396            "TLS config error: failed to open private key '{}': {e}",
397            path.display()
398        ))
399    })?;
400    let mut rd = BufReader::new(f);
401
402    let keys: Vec<PrivateKeyDer<'static>> = rustls_pemfile::pkcs8_private_keys(&mut rd)
403        .filter_map(|r| r.ok())
404        .map(PrivateKeyDer::from)
405        .collect();
406
407    match keys.len() {
408        0 => Ok(None),
409        1 => Ok(Some(keys.into_iter().next().unwrap())),
410        n => Err(PgWireError::Tls(format!(
411            "TLS config error: found {n} PKCS#8 keys in '{}', expected 1",
412            path.display()
413        ))),
414    }
415}
416
417fn try_load_rsa_key(
418    path: &std::path::Path,
419) -> Result<Option<rustls::pki_types::PrivateKeyDer<'static>>> {
420    use rustls::pki_types::PrivateKeyDer;
421
422    let f = File::open(path).map_err(|e| {
423        PgWireError::Tls(format!(
424            "TLS config error: failed to open private key '{}': {e}",
425            path.display()
426        ))
427    })?;
428    let mut rd = BufReader::new(f);
429
430    let keys: Vec<PrivateKeyDer<'static>> = rustls_pemfile::rsa_private_keys(&mut rd)
431        .filter_map(|r| r.ok())
432        .map(PrivateKeyDer::from)
433        .collect();
434
435    match keys.len() {
436        0 => Ok(None),
437        1 => Ok(Some(keys.into_iter().next().unwrap())),
438        n => Err(PgWireError::Tls(format!(
439            "TLS config error: found {n} RSA keys in '{}', expected 1",
440            path.display()
441        ))),
442    }
443}
444
445fn try_load_ec_key(
446    path: &std::path::Path,
447) -> Result<Option<rustls::pki_types::PrivateKeyDer<'static>>> {
448    use rustls::pki_types::PrivateKeyDer;
449
450    let f = File::open(path).map_err(|e| {
451        PgWireError::Tls(format!(
452            "TLS config error: failed to open private key '{}': {e}",
453            path.display()
454        ))
455    })?;
456    let mut rd = BufReader::new(f);
457
458    let keys: Vec<PrivateKeyDer<'static>> = rustls_pemfile::ec_private_keys(&mut rd)
459        .filter_map(|r| r.ok())
460        .map(PrivateKeyDer::from)
461        .collect();
462
463    match keys.len() {
464        0 => Ok(None),
465        1 => Ok(Some(keys.into_iter().next().unwrap())),
466        n => Err(PgWireError::Tls(format!(
467            "TLS config error: found {n} EC keys in '{}', expected 1",
468            path.display()
469        ))),
470    }
471}
472
473// ==================== Custom Certificate Verifiers ====================
474
475/// Verifier that accepts any certificate without verification.
476///
477/// Used for `SslMode::Prefer` and `SslMode::Require`.
478///
479/// # Security Warning
480/// This provides NO protection against man-in-the-middle attacks.
481#[derive(Debug)]
482struct NoVerifier;
483
484impl rustls::client::danger::ServerCertVerifier for NoVerifier {
485    fn verify_server_cert(
486        &self,
487        _end_entity: &rustls::pki_types::CertificateDer<'_>,
488        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
489        _server_name: &rustls::pki_types::ServerName<'_>,
490        _ocsp: &[u8],
491        _now: rustls::pki_types::UnixTime,
492    ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
493        Ok(rustls::client::danger::ServerCertVerified::assertion())
494    }
495
496    fn verify_tls12_signature(
497        &self,
498        _message: &[u8],
499        _cert: &rustls::pki_types::CertificateDer<'_>,
500        _dss: &rustls::DigitallySignedStruct,
501    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
502        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
503    }
504
505    fn verify_tls13_signature(
506        &self,
507        _message: &[u8],
508        _cert: &rustls::pki_types::CertificateDer<'_>,
509        _dss: &rustls::DigitallySignedStruct,
510    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
511        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
512    }
513
514    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
515        // Support all common signature schemes
516        vec![
517            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
518            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
519            rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
520            rustls::SignatureScheme::ED25519,
521            rustls::SignatureScheme::RSA_PKCS1_SHA256,
522            rustls::SignatureScheme::RSA_PKCS1_SHA384,
523            rustls::SignatureScheme::RSA_PKCS1_SHA512,
524            rustls::SignatureScheme::RSA_PSS_SHA256,
525            rustls::SignatureScheme::RSA_PSS_SHA384,
526            rustls::SignatureScheme::RSA_PSS_SHA512,
527        ]
528    }
529}
530
531/// Verifier that validates the certificate chain but ignores hostname mismatch.
532///
533/// Used for `SslMode::VerifyCa`.
534#[derive(Debug)]
535struct VerifyChainOnly {
536    inner: Arc<dyn rustls::client::danger::ServerCertVerifier>,
537}
538
539impl rustls::client::danger::ServerCertVerifier for VerifyChainOnly {
540    fn verify_server_cert(
541        &self,
542        end_entity: &rustls::pki_types::CertificateDer<'_>,
543        intermediates: &[rustls::pki_types::CertificateDer<'_>],
544        server_name: &rustls::pki_types::ServerName<'_>,
545        ocsp: &[u8],
546        now: rustls::pki_types::UnixTime,
547    ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
548        match self
549            .inner
550            .verify_server_cert(end_entity, intermediates, server_name, ocsp, now)
551        {
552            Ok(ok) => Ok(ok),
553            // VerifyCa: ignore hostname mismatch but enforce chain validation
554            Err(rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidForName)) => {
555                Ok(rustls::client::danger::ServerCertVerified::assertion())
556            }
557            Err(e) => Err(e),
558        }
559    }
560
561    fn verify_tls12_signature(
562        &self,
563        message: &[u8],
564        cert: &rustls::pki_types::CertificateDer<'_>,
565        dss: &rustls::DigitallySignedStruct,
566    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
567        self.inner.verify_tls12_signature(message, cert, dss)
568    }
569
570    fn verify_tls13_signature(
571        &self,
572        message: &[u8],
573        cert: &rustls::pki_types::CertificateDer<'_>,
574        dss: &rustls::DigitallySignedStruct,
575    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
576        self.inner.verify_tls13_signature(message, cert, dss)
577    }
578
579    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
580        self.inner.supported_verify_schemes()
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    use super::*;
587    use std::io::Write;
588    use tempfile::NamedTempFile;
589
590    // ==================== mTLS misconfiguration detection ====================
591    // These catch a common mistake: providing cert without key or vice versa
592
593    #[test]
594    fn mtls_requires_both_cert_and_key() {
595        // Cert without key
596        let tls = TlsConfig {
597            client_cert_pem_path: Some("/path/to/cert.pem".into()),
598            client_key_pem_path: None,
599            ..Default::default()
600        };
601        let err = build_rustls_config(&tls, false, false, "localhost").unwrap_err();
602        assert!(err.to_string().contains("mTLS requires both"));
603
604        // Key without cert
605        let tls = TlsConfig {
606            client_cert_pem_path: None,
607            client_key_pem_path: Some("/path/to/key.pem".into()),
608            ..Default::default()
609        };
610        let err = build_rustls_config(&tls, false, false, "localhost").unwrap_err();
611        assert!(err.to_string().contains("mTLS requires both"));
612    }
613
614    // ==================== VerifyFull + IP address detection ====================
615    // Catches a common mistake: VerifyFull requires hostname, not IP
616
617    #[test]
618    fn verify_full_rejects_ip_without_sni_override() {
619        let tls = TlsConfig {
620            mode: SslMode::VerifyFull,
621            ..Default::default()
622        };
623
624        // Should fail early: IP address can't match certificate hostname
625        let err = build_rustls_config(&tls, true, true, "192.168.1.1").unwrap_err();
626        assert!(err.to_string().contains("IP address"));
627    }
628
629    // Note: Testing that SNI override allows IP addresses requires a CryptoProvider
630    // which isn't available in unit tests. This is covered by integration tests.
631
632    // ==================== File error handling ====================
633    // Ensures clear error messages for common file issues
634
635    #[test]
636    fn missing_ca_file_gives_clear_error() {
637        let tls = TlsConfig {
638            ca_pem_path: Some("/nonexistent/ca.pem".into()),
639            ..Default::default()
640        };
641
642        let err = build_root_store(&tls).unwrap_err().to_string();
643        assert!(err.contains("failed to open"));
644        assert!(err.contains("ca.pem"));
645    }
646
647    #[test]
648    fn empty_ca_file_gives_clear_error() {
649        let f = NamedTempFile::new().unwrap();
650        let tls = TlsConfig {
651            ca_pem_path: Some(f.path().to_path_buf()),
652            ..Default::default()
653        };
654
655        let err = build_root_store(&tls).unwrap_err().to_string();
656        assert!(err.contains("no valid CA certificates"));
657    }
658
659    #[test]
660    fn empty_key_file_gives_clear_error() {
661        let f = NamedTempFile::new().unwrap();
662
663        let err = load_private_key(f.path()).unwrap_err().to_string();
664        assert!(err.contains("no private key"));
665    }
666
667    #[test]
668    fn invalid_pem_gives_clear_error() {
669        let mut f = NamedTempFile::new().unwrap();
670        f.write_all(b"this is not a valid PEM file").unwrap();
671
672        // Should fail gracefully, not panic
673        assert!(load_private_key(f.path()).is_err());
674        assert!(load_cert_chain(f.path()).is_err());
675    }
676}