Skip to main content

pgwire_replication/tls/
rustls.rs

1//! TLS support using rustls.
2//!
3//! This module provides TLS/SSL connection upgrade for PostgreSQL connections
4//! using the rustls library. It supports:
5//!
6//! - All PostgreSQL SSL modes (disable, prefer, require, verify-ca, verify-full)
7//! - Custom CA certificates
8//! - Client certificate authentication (mTLS)
9//! - SNI hostname override
10//!
11//! # SSL Modes
12//!
13//! | Mode | Chain Verified | Hostname Verified | Falls back to plain |
14//! |------|----------------|-------------------|---------------------|
15//! | `Disable` | - | - | N/A (never uses TLS) |
16//! | `Prefer` | No | No | Yes |
17//! | `Require` | No | No | No |
18//! | `VerifyCa` | Yes | No | No |
19//! | `VerifyFull` | Yes | Yes | No |
20//!
21//! # Security Considerations
22//!
23//! - `Prefer` and `Require` modes are vulnerable to MITM attacks
24//! - `VerifyCa` protects against MITM but allows any hostname
25//! - `VerifyFull` provides full protection (recommended for production)
26//!
27//! # Example
28//!
29//! ```no_run
30//! use pgwire_replication::config::TlsConfig;
31//! use pgwire_replication::tls::rustls::{maybe_upgrade_to_tls, MaybeTlsStream};
32//! use tokio::net::TcpStream;
33//! use std::path::PathBuf;
34//!
35//! #[tokio::main]
36//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
37//!     let tls_config = TlsConfig::verify_full(Some(PathBuf::new()))
38//!         .with_sni_hostname("db.example.com");
39//!
40//!     let tcp_stream = TcpStream::connect(("db.example.com", 5432)).await?;
41//!
42//!     let stream = maybe_upgrade_to_tls(tcp_stream, &tls_config, "db.example.com").await?;
43//!     match stream {
44//!         MaybeTlsStream::Plain(_) => {}
45//!         MaybeTlsStream::Tls(_) => {}
46//!     }
47//!
48//!     Ok(())
49//! }
50//! ```
51
52use std::pin::Pin;
53use std::sync::Arc;
54use std::task::{Context, Poll};
55
56use rustls::{ClientConfig, RootCertStore};
57use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
58use tokio::net::TcpStream;
59use tokio_rustls::{client::TlsStream, TlsConnector};
60
61use crate::config::{SslMode, TlsConfig};
62use crate::error::{PgWireError, Result};
63use crate::protocol::framing::write_ssl_request;
64
65/// A stream that may or may not be TLS-encrypted.
66///
67/// This enum allows code to work with both plain TCP and TLS connections
68/// through a unified interface via `AsyncRead` and `AsyncWrite` implementations.
69#[derive(Debug)]
70pub enum MaybeTlsStream {
71    /// Unencrypted TCP connection
72    Plain(TcpStream),
73    /// TLS-encrypted connection (boxed to reduce enum size)
74    Tls(Box<TlsStream<TcpStream>>),
75}
76
77impl MaybeTlsStream {
78    /// Returns `true` if this is a TLS-encrypted stream.
79    #[inline]
80    pub fn is_tls(&self) -> bool {
81        matches!(self, MaybeTlsStream::Tls(_))
82    }
83
84    /// Returns `true` if this is a plain (unencrypted) stream.
85    #[inline]
86    pub fn is_plain(&self) -> bool {
87        matches!(self, MaybeTlsStream::Plain(_))
88    }
89
90    /// Returns a reference to the underlying `TcpStream`.
91    ///
92    /// For TLS streams, this returns the inner TCP stream.
93    pub fn get_ref(&self) -> &TcpStream {
94        match self {
95            MaybeTlsStream::Plain(s) => s,
96            MaybeTlsStream::Tls(s) => s.get_ref().0,
97        }
98    }
99}
100
101impl AsyncRead for MaybeTlsStream {
102    fn poll_read(
103        self: Pin<&mut Self>,
104        cx: &mut Context<'_>,
105        buf: &mut ReadBuf<'_>,
106    ) -> Poll<std::io::Result<()>> {
107        match self.get_mut() {
108            MaybeTlsStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
109            MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
110        }
111    }
112}
113
114impl AsyncWrite for MaybeTlsStream {
115    fn poll_write(
116        self: Pin<&mut Self>,
117        cx: &mut Context<'_>,
118        buf: &[u8],
119    ) -> Poll<std::io::Result<usize>> {
120        match self.get_mut() {
121            MaybeTlsStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
122            MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
123        }
124    }
125
126    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
127        match self.get_mut() {
128            MaybeTlsStream::Plain(s) => Pin::new(s).poll_flush(cx),
129            MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
130        }
131    }
132
133    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
134        match self.get_mut() {
135            MaybeTlsStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
136            MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
137        }
138    }
139}
140
141/// Attempt to upgrade a TCP connection to TLS based on configuration.
142///
143/// This function implements PostgreSQL's SSL negotiation protocol:
144/// 1. Send SSLRequest message
145/// 2. Read single-byte response ('S' = proceed, 'N' = rejected)
146/// 3. If proceeding, perform TLS handshake
147///
148/// # Arguments
149/// * `tcp` - The TCP connection to potentially upgrade
150/// * `tls` - TLS configuration specifying mode and certificates
151/// * `host` - Target hostname (used for SNI and verification)
152///
153/// # Returns
154/// A `MaybeTlsStream` that is either the original TCP stream or a TLS-wrapped stream.
155///
156/// # Errors
157/// - Server rejects TLS when mode requires it
158/// - TLS handshake failure
159/// - Certificate verification failure (for VerifyCa/VerifyFull)
160/// - Invalid certificate/key files
161pub async fn maybe_upgrade_to_tls(
162    mut tcp: TcpStream,
163    tls: &TlsConfig,
164    host: &str,
165) -> Result<MaybeTlsStream> {
166    match tls.mode {
167        SslMode::Disable => return Ok(MaybeTlsStream::Plain(tcp)),
168        SslMode::Prefer | SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull => {}
169    }
170
171    // Install crypto provider (required for rustls 0.23+)
172    // This is idempotent - safe to call multiple times
173    let _ = rustls::crypto::ring::default_provider().install_default();
174
175    // PostgreSQL TLS negotiation: send SSLRequest, expect 'S' or 'N'
176    write_ssl_request(&mut tcp).await?;
177
178    let mut resp = [0u8; 1];
179    use tokio::io::AsyncReadExt;
180    tcp.read_exact(&mut resp).await?;
181
182    if resp[0] != b'S' {
183        // Server refused TLS
184        return match tls.mode {
185            SslMode::Prefer => Ok(MaybeTlsStream::Plain(tcp)),
186            _ => Err(PgWireError::Tls(
187                "server does not support TLS (SSLRequest rejected)".into(),
188            )),
189        };
190    }
191
192    // Determine verification requirements based on SSL mode
193    let verify_chain = matches!(tls.mode, SslMode::VerifyCa | SslMode::VerifyFull);
194    let verify_hostname = matches!(tls.mode, SslMode::VerifyFull);
195
196    let cfg = build_rustls_config(tls, verify_chain, verify_hostname, host)?;
197    let connector = TlsConnector::from(Arc::new(cfg));
198
199    // SNI hostname: use override if provided, otherwise use connection host
200    let sni = tls.sni_hostname.as_deref().unwrap_or(host);
201    let server_name = rustls::pki_types::ServerName::try_from(sni.to_string())
202        .map_err(|_| PgWireError::Tls(format!("invalid SNI hostname '{sni}'")))?;
203
204    let tls_stream = connector
205        .connect(server_name, tcp)
206        .await
207        .map_err(|e| PgWireError::Tls(format!("TLS handshake failed: {e}")))?;
208
209    Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
210}
211
212/// Build rustls ClientConfig based on TLS settings.
213fn build_rustls_config(
214    tls: &TlsConfig,
215    verify_chain: bool,
216    verify_hostname: bool,
217    host: &str,
218) -> Result<ClientConfig> {
219    // ---- mTLS config validation ----
220    let has_cert = tls.client_cert_pem_path.is_some();
221    let has_key = tls.client_key_pem_path.is_some();
222    if has_cert ^ has_key {
223        return Err(PgWireError::Tls(format!(
224            "TLS config error: mTLS requires both client_cert_pem_path and client_key_pem_path \
225             (got cert={has_cert} key={has_key})"
226        )));
227    }
228
229    // Operator hint: VerifyFull + IP literal is a common misconfiguration
230    if verify_hostname && host.parse::<std::net::IpAddr>().is_ok() && tls.sni_hostname.is_none() {
231        return Err(PgWireError::Tls(format!(
232            "TLS config error: VerifyFull enabled but host '{host}' is an IP address. \
233             Hint: use a DNS name matching the certificate, or set tls.sni_hostname, \
234             or use VerifyCa mode."
235        )));
236    }
237
238    // ---- Root certificate store ----
239    let roots = build_root_store(tls)?;
240    let roots_arc = Arc::new(roots.clone());
241
242    // ---- Base config builder ----
243    let builder = ClientConfig::builder().with_root_certificates(roots);
244
245    // ---- Client authentication (mTLS) ----
246    let mut cfg: ClientConfig = if has_cert {
247        let cert_path = tls.client_cert_pem_path.as_ref().unwrap();
248        let key_path = tls.client_key_pem_path.as_ref().unwrap();
249
250        let cert_chain = load_cert_chain(cert_path)?;
251        let key = load_private_key(key_path)?;
252
253        builder
254            .with_client_auth_cert(cert_chain, key)
255            .map_err(|e| {
256                PgWireError::Tls(format!("TLS config error: invalid client cert/key: {e}"))
257            })?
258    } else {
259        builder.with_no_client_auth()
260    };
261
262    // ---- Custom verification policy ----
263    if !verify_chain {
264        // Prefer/Require: skip all verification (dangerous but matches PostgreSQL behavior)
265        cfg.dangerous()
266            .set_certificate_verifier(Arc::new(NoVerifier));
267        return Ok(cfg);
268    }
269
270    if verify_chain && !verify_hostname {
271        // VerifyCa: verify certificate chain but ignore hostname mismatch
272        let inner = rustls::client::WebPkiServerVerifier::builder(roots_arc)
273            .build()
274            .map_err(|e| PgWireError::Tls(format!("TLS config error: build verifier: {e}")))?;
275
276        cfg.dangerous()
277            .set_certificate_verifier(Arc::new(VerifyChainOnly { inner }));
278    }
279
280    // VerifyFull: rustls default behavior already verifies chain + hostname
281    Ok(cfg)
282}
283
284/// Build root certificate store from config or system defaults.
285fn build_root_store(tls: &TlsConfig) -> Result<RootCertStore> {
286    use rustls::pki_types::CertificateDer;
287
288    let mut roots = RootCertStore::empty();
289
290    if let Some(path) = &tls.ca_pem_path {
291        // Load custom CA certificates
292        use rustls::pki_types::pem::PemObject;
293
294        let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(path)
295            .map_err(|e| {
296                PgWireError::Tls(format!(
297                    "TLS config error: failed to open CA PEM '{}': {e}",
298                    path.display()
299                ))
300            })?
301            .collect::<std::result::Result<Vec<_>, _>>()
302            .map_err(|e| {
303                PgWireError::Tls(format!(
304                    "TLS config error: failed to parse CA PEM '{}': {e}",
305                    path.display()
306                ))
307            })?;
308
309        let (added, _ignored) = roots.add_parsable_certificates(certs);
310        if added == 0 {
311            return Err(PgWireError::Tls(format!(
312                "TLS config error: no valid CA certificates found in '{}'",
313                path.display()
314            )));
315        }
316    } else {
317        // Use Mozilla's root certificates (webpki-roots)
318        roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
319    }
320
321    Ok(roots)
322}
323
324/// Load certificate chain from PEM file.
325fn load_cert_chain(
326    path: &std::path::Path,
327) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
328    use rustls::pki_types::pem::PemObject;
329    use rustls::pki_types::CertificateDer;
330
331    let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(path)
332        .map_err(|e| {
333            PgWireError::Tls(format!(
334                "TLS config error: failed to open client certificate '{}': {e}",
335                path.display()
336            ))
337        })?
338        .collect::<std::result::Result<Vec<_>, _>>()
339        .map_err(|e| {
340            PgWireError::Tls(format!(
341                "TLS config error: failed to parse client certificate '{}': {e}",
342                path.display()
343            ))
344        })?;
345
346    if certs.is_empty() {
347        return Err(PgWireError::Tls(format!(
348            "TLS config error: no certificates found in '{}'",
349            path.display()
350        )));
351    }
352
353    Ok(certs)
354}
355
356/// Load private key from PEM file.
357///
358/// Supports PKCS#8, PKCS#1 (RSA), and SEC1 (EC) key formats.
359fn load_private_key(path: &std::path::Path) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
360    use rustls::pki_types::pem::PemObject;
361    use rustls::pki_types::PrivateKeyDer;
362
363    PrivateKeyDer::from_pem_file(path).map_err(|e| {
364        PgWireError::Tls(format!(
365            "TLS config error: failed to load private key from '{}': {e}. \
366             Supported formats: PKCS#8, PKCS#1 (RSA), SEC1 (EC)",
367            path.display()
368        ))
369    })
370}
371
372// ==================== Custom Certificate Verifiers ====================
373
374/// Verifier that accepts any certificate without verification.
375///
376/// Used for `SslMode::Prefer` and `SslMode::Require`.
377///
378/// # Security Warning
379/// This provides NO protection against man-in-the-middle attacks.
380#[derive(Debug)]
381struct NoVerifier;
382
383impl rustls::client::danger::ServerCertVerifier for NoVerifier {
384    fn verify_server_cert(
385        &self,
386        _end_entity: &rustls::pki_types::CertificateDer<'_>,
387        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
388        _server_name: &rustls::pki_types::ServerName<'_>,
389        _ocsp: &[u8],
390        _now: rustls::pki_types::UnixTime,
391    ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
392        Ok(rustls::client::danger::ServerCertVerified::assertion())
393    }
394
395    fn verify_tls12_signature(
396        &self,
397        _message: &[u8],
398        _cert: &rustls::pki_types::CertificateDer<'_>,
399        _dss: &rustls::DigitallySignedStruct,
400    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
401        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
402    }
403
404    fn verify_tls13_signature(
405        &self,
406        _message: &[u8],
407        _cert: &rustls::pki_types::CertificateDer<'_>,
408        _dss: &rustls::DigitallySignedStruct,
409    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
410        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
411    }
412
413    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
414        // Support all common signature schemes
415        vec![
416            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
417            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
418            rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
419            rustls::SignatureScheme::ED25519,
420            rustls::SignatureScheme::RSA_PKCS1_SHA256,
421            rustls::SignatureScheme::RSA_PKCS1_SHA384,
422            rustls::SignatureScheme::RSA_PKCS1_SHA512,
423            rustls::SignatureScheme::RSA_PSS_SHA256,
424            rustls::SignatureScheme::RSA_PSS_SHA384,
425            rustls::SignatureScheme::RSA_PSS_SHA512,
426        ]
427    }
428}
429
430/// Verifier that validates the certificate chain but ignores hostname mismatch.
431///
432/// Used for `SslMode::VerifyCa`.
433#[derive(Debug)]
434struct VerifyChainOnly {
435    inner: Arc<dyn rustls::client::danger::ServerCertVerifier>,
436}
437
438impl rustls::client::danger::ServerCertVerifier for VerifyChainOnly {
439    fn verify_server_cert(
440        &self,
441        end_entity: &rustls::pki_types::CertificateDer<'_>,
442        intermediates: &[rustls::pki_types::CertificateDer<'_>],
443        server_name: &rustls::pki_types::ServerName<'_>,
444        ocsp: &[u8],
445        now: rustls::pki_types::UnixTime,
446    ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
447        match self
448            .inner
449            .verify_server_cert(end_entity, intermediates, server_name, ocsp, now)
450        {
451            Ok(ok) => Ok(ok),
452            // VerifyCa: ignore hostname mismatch but enforce chain validation
453            Err(rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidForName)) => {
454                Ok(rustls::client::danger::ServerCertVerified::assertion())
455            }
456            Err(e) => Err(e),
457        }
458    }
459
460    fn verify_tls12_signature(
461        &self,
462        message: &[u8],
463        cert: &rustls::pki_types::CertificateDer<'_>,
464        dss: &rustls::DigitallySignedStruct,
465    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
466        self.inner.verify_tls12_signature(message, cert, dss)
467    }
468
469    fn verify_tls13_signature(
470        &self,
471        message: &[u8],
472        cert: &rustls::pki_types::CertificateDer<'_>,
473        dss: &rustls::DigitallySignedStruct,
474    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
475        self.inner.verify_tls13_signature(message, cert, dss)
476    }
477
478    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
479        self.inner.supported_verify_schemes()
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486    use std::io::Write;
487    use tempfile::NamedTempFile;
488
489    // ==================== mTLS misconfiguration detection ====================
490    // These catch a common mistake: providing cert without key or vice versa
491
492    #[test]
493    fn mtls_requires_both_cert_and_key() {
494        // Cert without key
495        let tls = TlsConfig {
496            client_cert_pem_path: Some("/path/to/cert.pem".into()),
497            client_key_pem_path: None,
498            ..Default::default()
499        };
500        let err = build_rustls_config(&tls, false, false, "localhost").unwrap_err();
501        assert!(err.to_string().contains("mTLS requires both"));
502
503        // Key without cert
504        let tls = TlsConfig {
505            client_cert_pem_path: None,
506            client_key_pem_path: Some("/path/to/key.pem".into()),
507            ..Default::default()
508        };
509        let err = build_rustls_config(&tls, false, false, "localhost").unwrap_err();
510        assert!(err.to_string().contains("mTLS requires both"));
511    }
512
513    // ==================== VerifyFull + IP address detection ====================
514    // Catches a common mistake: VerifyFull requires hostname, not IP
515
516    #[test]
517    fn verify_full_rejects_ip_without_sni_override() {
518        let tls = TlsConfig {
519            mode: SslMode::VerifyFull,
520            ..Default::default()
521        };
522
523        // Should fail early: IP address can't match certificate hostname
524        let err = build_rustls_config(&tls, true, true, "192.168.1.1").unwrap_err();
525        assert!(err.to_string().contains("IP address"));
526    }
527
528    // Note: Testing that SNI override allows IP addresses requires a CryptoProvider
529    // which isn't available in unit tests. This is covered by integration tests.
530
531    // ==================== File error handling ====================
532    // Ensures clear error messages for common file issues
533
534    #[test]
535    fn missing_ca_file_gives_clear_error() {
536        let tls = TlsConfig {
537            ca_pem_path: Some("/nonexistent/ca.pem".into()),
538            ..Default::default()
539        };
540
541        let err = build_root_store(&tls).unwrap_err().to_string();
542        assert!(err.contains("failed to open"));
543        assert!(err.contains("ca.pem"));
544    }
545
546    #[test]
547    fn empty_ca_file_gives_clear_error() {
548        let f = NamedTempFile::new().unwrap();
549        let tls = TlsConfig {
550            ca_pem_path: Some(f.path().to_path_buf()),
551            ..Default::default()
552        };
553
554        let err = build_root_store(&tls).unwrap_err().to_string();
555        assert!(err.contains("no valid CA certificates"));
556    }
557
558    #[test]
559    fn empty_key_file_gives_clear_error() {
560        let f = NamedTempFile::new().unwrap();
561
562        let err = load_private_key(f.path()).unwrap_err().to_string();
563        assert!(err.contains("failed to load private key"));
564    }
565
566    #[test]
567    fn invalid_pem_gives_clear_error() {
568        let mut f = NamedTempFile::new().unwrap();
569        f.write_all(b"this is not a valid PEM file").unwrap();
570
571        // Should fail gracefully, not panic
572        assert!(load_private_key(f.path()).is_err());
573        assert!(load_cert_chain(f.path()).is_err());
574    }
575}