Skip to main content

pg_wired/
tls.rs

1//! TLS support for PostgreSQL wire connections.
2//!
3//! When the `tls` feature is enabled, connections can negotiate SSL/TLS
4//! with the PostgreSQL server using rustls.
5
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9use tokio::net::TcpStream;
10
11/// A stream that is either plain TCP or TLS-wrapped TCP.
12#[allow(clippy::large_enum_variant)]
13pub(crate) enum MaybeTlsStream {
14    Plain(TcpStream),
15    #[cfg(feature = "tls")]
16    Tls(tokio_rustls::client::TlsStream<TcpStream>),
17}
18
19/// How to handle TLS negotiation with the server.
20#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
21#[non_exhaustive]
22pub enum TlsMode {
23    /// Do not send SSLRequest. Use plain TCP.
24    Disable,
25    /// Send SSLRequest. Upgrade if the server agrees (`S`), fall back to
26    /// plain TCP if the server refuses (`N`). Default.
27    #[default]
28    Prefer,
29    /// Send SSLRequest. Error out if the server refuses.
30    Require,
31}
32
33impl AsyncRead for MaybeTlsStream {
34    fn poll_read(
35        self: Pin<&mut Self>,
36        cx: &mut Context<'_>,
37        buf: &mut ReadBuf<'_>,
38    ) -> Poll<std::io::Result<()>> {
39        match self.get_mut() {
40            MaybeTlsStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
41            #[cfg(feature = "tls")]
42            MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
43        }
44    }
45}
46
47impl AsyncWrite for MaybeTlsStream {
48    fn poll_write(
49        self: Pin<&mut Self>,
50        cx: &mut Context<'_>,
51        buf: &[u8],
52    ) -> Poll<std::io::Result<usize>> {
53        match self.get_mut() {
54            MaybeTlsStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
55            #[cfg(feature = "tls")]
56            MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
57        }
58    }
59
60    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
61        match self.get_mut() {
62            MaybeTlsStream::Plain(s) => Pin::new(s).poll_flush(cx),
63            #[cfg(feature = "tls")]
64            MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
65        }
66    }
67
68    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
69        match self.get_mut() {
70            MaybeTlsStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
71            #[cfg(feature = "tls")]
72            MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
73        }
74    }
75}
76
77impl MaybeTlsStream {
78    /// Get the peer address of the underlying TCP stream.
79    #[allow(dead_code)]
80    pub(crate) fn peer_addr(&self) -> std::io::Result<std::net::SocketAddr> {
81        match self {
82            MaybeTlsStream::Plain(s) => s.peer_addr(),
83            #[cfg(feature = "tls")]
84            MaybeTlsStream::Tls(s) => s.get_ref().0.peer_addr(),
85        }
86    }
87}
88
89/// TLS configuration for PostgreSQL connections.
90///
91/// Construct one of these and pass it to
92/// [`crate::WireConn::connect_with_tls_config`] to override the default trust
93/// store (system root CAs from `webpki-roots`) or to authenticate with a
94/// client certificate.
95///
96/// All certificate and key bytes must be DER-encoded. PEM input is not
97/// accepted; convert it first (for example with `rustls_pemfile::certs`).
98#[cfg(feature = "tls")]
99#[derive(Default, Clone)]
100#[non_exhaustive]
101pub struct TlsConfig {
102    /// Custom root CA certificates (DER). If empty, the system trust store is
103    /// used via `webpki-roots`.
104    pub root_certs: Vec<Vec<u8>>,
105    /// Optional client certificate chain (DER) and private key (DER) for
106    /// mutual TLS.
107    pub client_cert: Option<(Vec<Vec<u8>>, Vec<u8>)>,
108}
109
110/// Negotiate TLS with default configuration (system root CAs, no client cert).
111///
112/// Uses `TlsMode::Prefer`: sends SSLRequest, upgrades on `S`, falls back to
113/// plain TCP on `N`.
114#[cfg(feature = "tls")]
115#[allow(dead_code)]
116pub(crate) async fn negotiate_tls(
117    stream: TcpStream,
118    hostname: &str,
119) -> Result<MaybeTlsStream, crate::error::PgWireError> {
120    negotiate_tls_with_config(stream, hostname, &TlsConfig::default(), TlsMode::Prefer).await
121}
122
123/// Negotiate TLS with custom configuration (custom CAs, client certs).
124///
125/// Behavior depends on `mode`:
126/// - `Disable`: skip SSLRequest, return plain stream.
127/// - `Prefer`: send SSLRequest, upgrade on `S`, fall back on `N`.
128/// - `Require`: send SSLRequest, error if server responds `N`.
129#[cfg(feature = "tls")]
130pub(crate) async fn negotiate_tls_with_config(
131    mut stream: TcpStream,
132    hostname: &str,
133    config: &TlsConfig,
134    mode: TlsMode,
135) -> Result<MaybeTlsStream, crate::error::PgWireError> {
136    use bytes::{BufMut, BytesMut};
137    use tokio::io::{AsyncReadExt, AsyncWriteExt};
138
139    if mode == TlsMode::Disable {
140        return Ok(MaybeTlsStream::Plain(stream));
141    }
142
143    // Send SSLRequest: length=8, code=80877103
144    let mut buf = BytesMut::with_capacity(8);
145    buf.put_i32(8);
146    buf.put_i32(80877103); // SSL request code (1234 << 16 | 5679)
147    stream.write_all(&buf).await?;
148
149    // Read 1-byte response.
150    let mut response = [0u8; 1];
151    stream.read_exact(&mut response).await?;
152
153    match response[0] {
154        b'S' => {
155            // Server supports SSL — upgrade.
156            let mut root_store = rustls::RootCertStore::empty();
157            if config.root_certs.is_empty() {
158                // Use system root CAs.
159                root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
160            } else {
161                // Use custom root CAs.
162                for cert_der in &config.root_certs {
163                    root_store
164                        .add(rustls_pki_types::CertificateDer::from(cert_der.clone()))
165                        .map_err(|e| {
166                            crate::error::PgWireError::Protocol(format!(
167                                "invalid root certificate: {e}"
168                            ))
169                        })?;
170                }
171            }
172
173            // Use the `ring` provider that the `tls` feature enables. We pass
174            // it explicitly via `builder_with_provider` to avoid relying on
175            // the process-level default, which a host application may not
176            // have installed.
177            let provider = std::sync::Arc::new(rustls::crypto::ring::default_provider());
178            let builder = rustls::ClientConfig::builder_with_provider(provider)
179                .with_safe_default_protocol_versions()
180                .map_err(|e| {
181                    crate::error::PgWireError::Protocol(format!(
182                        "TLS protocol version setup failed: {e}"
183                    ))
184                })?
185                .with_root_certificates(root_store);
186
187            let tls_config = if let Some((ref cert_chain, ref key_der)) = config.client_cert {
188                // mTLS: client certificate authentication.
189                let certs: Vec<rustls_pki_types::CertificateDer<'static>> = cert_chain
190                    .iter()
191                    .map(|c| rustls_pki_types::CertificateDer::from(c.clone()))
192                    .collect();
193                let key =
194                    rustls_pki_types::PrivateKeyDer::try_from(key_der.clone()).map_err(|e| {
195                        crate::error::PgWireError::Protocol(format!(
196                            "invalid client private key: {e}"
197                        ))
198                    })?;
199                builder.with_client_auth_cert(certs, key).map_err(|e| {
200                    crate::error::PgWireError::Protocol(format!(
201                        "TLS client auth config error: {e}"
202                    ))
203                })?
204            } else {
205                builder.with_no_client_auth()
206            };
207
208            let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
209            let server_name = rustls_pki_types::ServerName::try_from(hostname.to_string())
210                .map_err(|e| {
211                    crate::error::PgWireError::Protocol(format!("invalid hostname: {e}"))
212                })?;
213
214            let tls_stream = connector.connect(server_name, stream).await?;
215            Ok(MaybeTlsStream::Tls(tls_stream))
216        }
217        b'N' => {
218            if mode == TlsMode::Require {
219                return Err(crate::error::PgWireError::Protocol(
220                    "server does not support TLS but sslmode=require".to_string(),
221                ));
222            }
223            // Prefer mode: server doesn't support SSL, continue with plain TCP.
224            Ok(MaybeTlsStream::Plain(stream))
225        }
226        other => Err(crate::error::PgWireError::Protocol(format!(
227            "unexpected SSL response: {}",
228            other as char
229        ))),
230    }
231}