Skip to main content

heliosdb_proxy/backend/
tls.rs

1//! TLS handshake for backend PostgreSQL connections.
2//!
3//! Flow:
4//! 1. Send `SSLRequest` (8 bytes: length=8, code=80877103) on plain TCP.
5//! 2. Read one byte: `S` = server accepts TLS, `N` = server refuses.
6//! 3. On `S`, run a rustls client handshake on top of the same TCP
7//!    socket and continue with the normal PG startup message over the
8//!    TLS stream.
9//! 4. On `N`, fail (if TLS was required) or fall back to plain.
10
11use super::error::{BackendError, BackendResult};
12use super::stream::Stream;
13use rustls::{ClientConfig, RootCertStore};
14use rustls_pki_types::ServerName;
15use std::sync::Arc;
16use tokio::io::{AsyncReadExt, AsyncWriteExt};
17use tokio::net::TcpStream;
18use tokio_rustls::TlsConnector;
19
20/// TLS connection policy.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum TlsMode {
23    /// Never attempt TLS — plain TCP only.
24    Disable,
25    /// Try TLS first; if the server refuses, fall back to plain. Matches
26    /// `libpq sslmode=prefer`.
27    Prefer,
28    /// Require TLS. Error out if the server refuses.
29    Require,
30}
31
32/// Build a rustls `ClientConfig` that verifies peer certs against the
33/// Mozilla root set shipped in `webpki-roots`. Keeping the builder here
34/// lets callers reuse it without reconstructing on every connect.
35pub fn default_client_config() -> Arc<ClientConfig> {
36    let mut roots = RootCertStore::empty();
37    roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
38
39    let config = ClientConfig::builder()
40        .with_root_certificates(roots)
41        .with_no_client_auth();
42
43    Arc::new(config)
44}
45
46/// Perform the PG SSLRequest dance and (if accepted) upgrade the TCP
47/// stream to TLS.
48///
49/// `sni` is the server name used for certificate verification; it must
50/// match the server certificate's CN/SAN. Typically the hostname from
51/// the cluster config.
52pub async fn negotiate(
53    mut tcp: TcpStream,
54    mode: TlsMode,
55    config: Arc<ClientConfig>,
56    sni: &str,
57) -> BackendResult<Stream> {
58    if mode == TlsMode::Disable {
59        return Ok(Stream::Plain(tcp));
60    }
61
62    // SSLRequest frame: [length=8][code=80877103]
63    let ssl_request: [u8; 8] = [
64        0x00, 0x00, 0x00, 0x08, // length = 8
65        0x04, 0xd2, 0x16, 0x2f, // 80877103
66    ];
67    tcp.write_all(&ssl_request).await?;
68
69    let mut reply = [0u8; 1];
70    tcp.read_exact(&mut reply).await?;
71
72    match reply[0] {
73        b'S' => {
74            let dns = ServerName::try_from(sni.to_string()).map_err(|_| {
75                BackendError::Tls(format!("invalid SNI hostname: {:?}", sni))
76            })?;
77            let connector = TlsConnector::from(config);
78            let tls = connector
79                .connect(dns, tcp)
80                .await
81                .map_err(|e| BackendError::Tls(e.to_string()))?;
82            Ok(Stream::Tls(tls))
83        }
84        b'N' => {
85            if mode == TlsMode::Require {
86                Err(BackendError::Tls(
87                    "server refused TLS and tls_mode=require".to_string(),
88                ))
89            } else {
90                Ok(Stream::Plain(tcp))
91            }
92        }
93        other => Err(BackendError::Tls(format!(
94            "unexpected reply to SSLRequest: 0x{:02x}",
95            other
96        ))),
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn test_default_client_config_builds() {
106        let _ = default_client_config();
107    }
108
109    #[test]
110    fn test_tls_mode_variants() {
111        assert_ne!(TlsMode::Disable, TlsMode::Prefer);
112        assert_ne!(TlsMode::Prefer, TlsMode::Require);
113    }
114}