Skip to main content

heliosdb_proxy/
client_tls.rs

1//! Client-facing TLS termination.
2//!
3//! The proxy can terminate TLS from PostgreSQL clients: it answers the
4//! `SSLRequest` with `S`, runs a rustls **server** handshake over the TCP
5//! socket, and then speaks the wire protocol over the encrypted stream.
6//! Optionally it requires and verifies a client certificate (mTLS).
7//!
8//! Backend connections stay plain `TcpStream` (or use the separate backend
9//! TLS in `backend::tls`); this module is only about the client side.
10
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll};
14
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16use tokio::net::TcpStream;
17use tokio_rustls::server::TlsStream;
18use tokio_rustls::TlsAcceptor;
19
20use crate::config::TlsConfig;
21
22/// A client connection that may or may not be TLS-wrapped. Implements
23/// `AsyncRead`/`AsyncWrite` by delegating to the active variant, so the
24/// whole session loop can be written against one stream type regardless of
25/// whether the client negotiated TLS.
26pub enum ClientStream {
27    Plain(TcpStream),
28    Tls(Box<TlsStream<TcpStream>>),
29}
30
31impl ClientStream {
32    /// The peer certificate subject (DER-encoded leaf), if the client
33    /// presented one during an mTLS handshake. Used for identity mapping.
34    pub fn peer_cert_present(&self) -> bool {
35        match self {
36            ClientStream::Plain(_) => false,
37            ClientStream::Tls(s) => s
38                .get_ref()
39                .1
40                .peer_certificates()
41                .map(|c| !c.is_empty())
42                .unwrap_or(false),
43        }
44    }
45}
46
47impl AsyncRead for ClientStream {
48    fn poll_read(
49        self: Pin<&mut Self>,
50        cx: &mut Context<'_>,
51        buf: &mut ReadBuf<'_>,
52    ) -> Poll<std::io::Result<()>> {
53        match self.get_mut() {
54            ClientStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
55            ClientStream::Tls(s) => Pin::new(s.as_mut()).poll_read(cx, buf),
56        }
57    }
58}
59
60impl AsyncWrite for ClientStream {
61    fn poll_write(
62        self: Pin<&mut Self>,
63        cx: &mut Context<'_>,
64        buf: &[u8],
65    ) -> Poll<std::io::Result<usize>> {
66        match self.get_mut() {
67            ClientStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
68            ClientStream::Tls(s) => Pin::new(s.as_mut()).poll_write(cx, buf),
69        }
70    }
71
72    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
73        match self.get_mut() {
74            ClientStream::Plain(s) => Pin::new(s).poll_flush(cx),
75            ClientStream::Tls(s) => Pin::new(s.as_mut()).poll_flush(cx),
76        }
77    }
78
79    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
80        match self.get_mut() {
81            ClientStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
82            ClientStream::Tls(s) => Pin::new(s.as_mut()).poll_shutdown(cx),
83        }
84    }
85}
86
87/// Build a `TlsAcceptor` from the proxy's `[tls]` config: load the server
88/// certificate chain + private key (PEM), and — when `require_client_cert`
89/// is set — a client-certificate verifier rooted at `ca_path` (mTLS).
90pub fn build_tls_acceptor(tls: &TlsConfig) -> Result<TlsAcceptor, String> {
91    use rustls::pki_types::{CertificateDer, PrivateKeyDer};
92
93    let cert_chain: Vec<CertificateDer<'static>> = {
94        let data = std::fs::read(&tls.cert_path)
95            .map_err(|e| format!("reading cert {}: {}", tls.cert_path, e))?;
96        rustls_pemfile::certs(&mut &data[..])
97            .collect::<Result<Vec<_>, _>>()
98            .map_err(|e| format!("parsing cert {}: {}", tls.cert_path, e))?
99    };
100    if cert_chain.is_empty() {
101        return Err(format!("no certificates found in {}", tls.cert_path));
102    }
103
104    let key: PrivateKeyDer<'static> = {
105        let data = std::fs::read(&tls.key_path)
106            .map_err(|e| format!("reading key {}: {}", tls.key_path, e))?;
107        rustls_pemfile::private_key(&mut &data[..])
108            .map_err(|e| format!("parsing key {}: {}", tls.key_path, e))?
109            .ok_or_else(|| format!("no private key found in {}", tls.key_path))?
110    };
111
112    let builder = rustls::ServerConfig::builder();
113
114    let config = if tls.require_client_cert {
115        let ca_path = tls
116            .ca_path
117            .as_ref()
118            .ok_or_else(|| "require_client_cert is set but ca_path is missing".to_string())?;
119        let ca_data =
120            std::fs::read(ca_path).map_err(|e| format!("reading ca {}: {}", ca_path, e))?;
121        let mut roots = rustls::RootCertStore::empty();
122        for ca in rustls_pemfile::certs(&mut &ca_data[..]) {
123            let ca = ca.map_err(|e| format!("parsing ca {}: {}", ca_path, e))?;
124            roots
125                .add(ca)
126                .map_err(|e| format!("adding ca cert: {}", e))?;
127        }
128        let verifier =
129            rustls::server::WebPkiClientVerifier::builder(Arc::new(roots))
130                .build()
131                .map_err(|e| format!("building client verifier: {}", e))?;
132        builder
133            .with_client_cert_verifier(verifier)
134            .with_single_cert(cert_chain, key)
135            .map_err(|e| format!("server config (mTLS): {}", e))?
136    } else {
137        builder
138            .with_no_client_auth()
139            .with_single_cert(cert_chain, key)
140            .map_err(|e| format!("server config: {}", e))?
141    };
142
143    Ok(TlsAcceptor::from(Arc::new(config)))
144}