1use std::pin::Pin;
7use std::task::{Context, Poll};
8use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9use tokio::net::TcpStream;
10
11#[allow(clippy::large_enum_variant)]
13pub(crate) enum MaybeTlsStream {
14 Plain(TcpStream),
15 #[cfg(feature = "tls")]
16 Tls(tokio_rustls::client::TlsStream<TcpStream>),
17}
18
19#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
21#[non_exhaustive]
22pub enum TlsMode {
23 Disable,
25 #[default]
28 Prefer,
29 Require,
31}
32
33impl AsyncRead for MaybeTlsStream {
34 fn poll_read(
35 self: Pin<&mut Self>,
36 cx: &mut Context<'_>,
37 buf: &mut ReadBuf<'_>,
38 ) -> Poll<std::io::Result<()>> {
39 match self.get_mut() {
40 MaybeTlsStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
41 #[cfg(feature = "tls")]
42 MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
43 }
44 }
45}
46
47impl AsyncWrite for MaybeTlsStream {
48 fn poll_write(
49 self: Pin<&mut Self>,
50 cx: &mut Context<'_>,
51 buf: &[u8],
52 ) -> Poll<std::io::Result<usize>> {
53 match self.get_mut() {
54 MaybeTlsStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
55 #[cfg(feature = "tls")]
56 MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
57 }
58 }
59
60 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
61 match self.get_mut() {
62 MaybeTlsStream::Plain(s) => Pin::new(s).poll_flush(cx),
63 #[cfg(feature = "tls")]
64 MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
65 }
66 }
67
68 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
69 match self.get_mut() {
70 MaybeTlsStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
71 #[cfg(feature = "tls")]
72 MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
73 }
74 }
75}
76
77impl MaybeTlsStream {
78 #[allow(dead_code)]
80 pub(crate) fn peer_addr(&self) -> std::io::Result<std::net::SocketAddr> {
81 match self {
82 MaybeTlsStream::Plain(s) => s.peer_addr(),
83 #[cfg(feature = "tls")]
84 MaybeTlsStream::Tls(s) => s.get_ref().0.peer_addr(),
85 }
86 }
87}
88
89#[cfg(feature = "tls")]
99#[derive(Default, Clone)]
100#[non_exhaustive]
101pub struct TlsConfig {
102 pub root_certs: Vec<Vec<u8>>,
105 pub client_cert: Option<(Vec<Vec<u8>>, Vec<u8>)>,
108}
109
110#[cfg(feature = "tls")]
115#[allow(dead_code)]
116pub(crate) async fn negotiate_tls(
117 stream: TcpStream,
118 hostname: &str,
119) -> Result<MaybeTlsStream, crate::error::PgWireError> {
120 negotiate_tls_with_config(stream, hostname, &TlsConfig::default(), TlsMode::Prefer).await
121}
122
123#[cfg(feature = "tls")]
130pub(crate) async fn negotiate_tls_with_config(
131 mut stream: TcpStream,
132 hostname: &str,
133 config: &TlsConfig,
134 mode: TlsMode,
135) -> Result<MaybeTlsStream, crate::error::PgWireError> {
136 use bytes::{BufMut, BytesMut};
137 use tokio::io::{AsyncReadExt, AsyncWriteExt};
138
139 if mode == TlsMode::Disable {
140 return Ok(MaybeTlsStream::Plain(stream));
141 }
142
143 let mut buf = BytesMut::with_capacity(8);
145 buf.put_i32(8);
146 buf.put_i32(80877103); stream.write_all(&buf).await?;
148
149 let mut response = [0u8; 1];
151 stream.read_exact(&mut response).await?;
152
153 match response[0] {
154 b'S' => {
155 let mut root_store = rustls::RootCertStore::empty();
157 if config.root_certs.is_empty() {
158 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
160 } else {
161 for cert_der in &config.root_certs {
163 root_store
164 .add(rustls_pki_types::CertificateDer::from(cert_der.clone()))
165 .map_err(|e| {
166 crate::error::PgWireError::Protocol(format!(
167 "invalid root certificate: {e}"
168 ))
169 })?;
170 }
171 }
172
173 let provider = std::sync::Arc::new(rustls::crypto::ring::default_provider());
178 let builder = rustls::ClientConfig::builder_with_provider(provider)
179 .with_safe_default_protocol_versions()
180 .map_err(|e| {
181 crate::error::PgWireError::Protocol(format!(
182 "TLS protocol version setup failed: {e}"
183 ))
184 })?
185 .with_root_certificates(root_store);
186
187 let tls_config = if let Some((ref cert_chain, ref key_der)) = config.client_cert {
188 let certs: Vec<rustls_pki_types::CertificateDer<'static>> = cert_chain
190 .iter()
191 .map(|c| rustls_pki_types::CertificateDer::from(c.clone()))
192 .collect();
193 let key =
194 rustls_pki_types::PrivateKeyDer::try_from(key_der.clone()).map_err(|e| {
195 crate::error::PgWireError::Protocol(format!(
196 "invalid client private key: {e}"
197 ))
198 })?;
199 builder.with_client_auth_cert(certs, key).map_err(|e| {
200 crate::error::PgWireError::Protocol(format!(
201 "TLS client auth config error: {e}"
202 ))
203 })?
204 } else {
205 builder.with_no_client_auth()
206 };
207
208 let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
209 let server_name = rustls_pki_types::ServerName::try_from(hostname.to_string())
210 .map_err(|e| {
211 crate::error::PgWireError::Protocol(format!("invalid hostname: {e}"))
212 })?;
213
214 let tls_stream = connector.connect(server_name, stream).await?;
215 Ok(MaybeTlsStream::Tls(tls_stream))
216 }
217 b'N' => {
218 if mode == TlsMode::Require {
219 return Err(crate::error::PgWireError::Protocol(
220 "server does not support TLS but sslmode=require".to_string(),
221 ));
222 }
223 Ok(MaybeTlsStream::Plain(stream))
225 }
226 other => Err(crate::error::PgWireError::Protocol(format!(
227 "unexpected SSL response: {}",
228 other as char
229 ))),
230 }
231}