Skip to main content

heliosdb_proxy/backend/
tls.rs

1//! TLS handshake for backend PostgreSQL connections.
2//!
3//! Flow:
4//! 1. Send `SSLRequest` (8 bytes: length=8, code=80877103) on plain TCP.
5//! 2. Read one byte: `S` = server accepts TLS, `N` = server refuses.
6//! 3. On `S`, run a rustls client handshake on top of the same TCP
7//!    socket and continue with the normal PG startup message over the
8//!    TLS stream.
9//! 4. On `N`, fail (if TLS was required) or fall back to plain.
10
11use super::error::{BackendError, BackendResult};
12use super::stream::Stream;
13use rustls::{ClientConfig, RootCertStore};
14use rustls_pki_types::ServerName;
15use std::sync::Arc;
16use tokio::io::{AsyncReadExt, AsyncWriteExt};
17use tokio::net::TcpStream;
18use tokio_rustls::TlsConnector;
19
20/// TLS connection policy.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum TlsMode {
23    /// Never attempt TLS — plain TCP only.
24    Disable,
25    /// Try TLS first; if the server refuses, fall back to plain. Matches
26    /// `libpq sslmode=prefer`.
27    Prefer,
28    /// Require TLS. Error out if the server refuses.
29    Require,
30}
31
32/// Build a rustls `ClientConfig` that verifies peer certs against the
33/// Mozilla root set shipped in `webpki-roots`. Keeping the builder here
34/// lets callers reuse it without reconstructing on every connect.
35pub fn default_client_config() -> Arc<ClientConfig> {
36    let mut roots = RootCertStore::empty();
37    roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
38
39    let config = ClientConfig::builder()
40        .with_root_certificates(roots)
41        .with_no_client_auth();
42
43    Arc::new(config)
44}
45
46/// Perform the PG SSLRequest dance and (if accepted) upgrade the TCP
47/// stream to TLS.
48///
49/// `sni` is the server name used for certificate verification; it must
50/// match the server certificate's CN/SAN. Typically the hostname from
51/// the cluster config.
52pub async fn negotiate(
53    mut tcp: TcpStream,
54    mode: TlsMode,
55    config: Arc<ClientConfig>,
56    sni: &str,
57) -> BackendResult<Stream> {
58    if mode == TlsMode::Disable {
59        return Ok(Stream::Plain(tcp));
60    }
61
62    // SSLRequest frame: [length=8][code=80877103]
63    let ssl_request: [u8; 8] = [
64        0x00, 0x00, 0x00, 0x08, // length = 8
65        0x04, 0xd2, 0x16, 0x2f, // 80877103
66    ];
67    tcp.write_all(&ssl_request).await?;
68
69    let mut reply = [0u8; 1];
70    tcp.read_exact(&mut reply).await?;
71
72    match reply[0] {
73        b'S' => {
74            let dns = ServerName::try_from(sni.to_string())
75                .map_err(|_| BackendError::Tls(format!("invalid SNI hostname: {:?}", sni)))?;
76            let connector = TlsConnector::from(config);
77            let tls = connector
78                .connect(dns, tcp)
79                .await
80                .map_err(|e| BackendError::Tls(e.to_string()))?;
81            Ok(Stream::Tls(tls))
82        }
83        b'N' => {
84            if mode == TlsMode::Require {
85                Err(BackendError::Tls(
86                    "server refused TLS and tls_mode=require".to_string(),
87                ))
88            } else {
89                Ok(Stream::Plain(tcp))
90            }
91        }
92        other => Err(BackendError::Tls(format!(
93            "unexpected reply to SSLRequest: 0x{:02x}",
94            other
95        ))),
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102
103    #[test]
104    fn test_default_client_config_builds() {
105        let _ = default_client_config();
106    }
107
108    #[test]
109    fn test_tls_mode_variants() {
110        assert_ne!(TlsMode::Disable, TlsMode::Prefer);
111        assert_ne!(TlsMode::Prefer, TlsMode::Require);
112    }
113}