use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
#[allow(clippy::large_enum_variant)]
pub(crate) enum MaybeTlsStream {
Plain(TcpStream),
#[cfg(feature = "tls")]
Tls(tokio_rustls::client::TlsStream<TcpStream>),
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[non_exhaustive]
pub enum TlsMode {
Disable,
#[default]
Prefer,
Require,
}
impl AsyncRead for MaybeTlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTlsStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "tls")]
MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl AsyncWrite for MaybeTlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match self.get_mut() {
MaybeTlsStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
#[cfg(feature = "tls")]
MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTlsStream::Plain(s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "tls")]
MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTlsStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
#[cfg(feature = "tls")]
MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
impl MaybeTlsStream {
#[allow(dead_code)]
pub(crate) fn peer_addr(&self) -> std::io::Result<std::net::SocketAddr> {
match self {
MaybeTlsStream::Plain(s) => s.peer_addr(),
#[cfg(feature = "tls")]
MaybeTlsStream::Tls(s) => s.get_ref().0.peer_addr(),
}
}
}
#[cfg(feature = "tls")]
#[derive(Default, Clone)]
#[non_exhaustive]
pub struct TlsConfig {
pub root_certs: Vec<Vec<u8>>,
pub client_cert: Option<(Vec<Vec<u8>>, Vec<u8>)>,
}
#[cfg(feature = "tls")]
#[allow(dead_code)]
pub(crate) async fn negotiate_tls(
stream: TcpStream,
hostname: &str,
) -> Result<MaybeTlsStream, crate::error::PgWireError> {
negotiate_tls_with_config(stream, hostname, &TlsConfig::default(), TlsMode::Prefer).await
}
#[cfg(feature = "tls")]
pub(crate) async fn negotiate_tls_with_config(
mut stream: TcpStream,
hostname: &str,
config: &TlsConfig,
mode: TlsMode,
) -> Result<MaybeTlsStream, crate::error::PgWireError> {
use bytes::{BufMut, BytesMut};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
if mode == TlsMode::Disable {
return Ok(MaybeTlsStream::Plain(stream));
}
let mut buf = BytesMut::with_capacity(8);
buf.put_i32(8);
buf.put_i32(80877103); stream.write_all(&buf).await?;
let mut response = [0u8; 1];
stream.read_exact(&mut response).await?;
match response[0] {
b'S' => {
let mut root_store = rustls::RootCertStore::empty();
if config.root_certs.is_empty() {
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
} else {
for cert_der in &config.root_certs {
root_store
.add(rustls_pki_types::CertificateDer::from(cert_der.clone()))
.map_err(|e| {
crate::error::PgWireError::Protocol(format!(
"invalid root certificate: {e}"
))
})?;
}
}
let provider = std::sync::Arc::new(rustls::crypto::ring::default_provider());
let builder = rustls::ClientConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.map_err(|e| {
crate::error::PgWireError::Protocol(format!(
"TLS protocol version setup failed: {e}"
))
})?
.with_root_certificates(root_store);
let tls_config = if let Some((ref cert_chain, ref key_der)) = config.client_cert {
let certs: Vec<rustls_pki_types::CertificateDer<'static>> = cert_chain
.iter()
.map(|c| rustls_pki_types::CertificateDer::from(c.clone()))
.collect();
let key =
rustls_pki_types::PrivateKeyDer::try_from(key_der.clone()).map_err(|e| {
crate::error::PgWireError::Protocol(format!(
"invalid client private key: {e}"
))
})?;
builder.with_client_auth_cert(certs, key).map_err(|e| {
crate::error::PgWireError::Protocol(format!(
"TLS client auth config error: {e}"
))
})?
} else {
builder.with_no_client_auth()
};
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let server_name = rustls_pki_types::ServerName::try_from(hostname.to_string())
.map_err(|e| {
crate::error::PgWireError::Protocol(format!("invalid hostname: {e}"))
})?;
let tls_stream = connector.connect(server_name, stream).await?;
Ok(MaybeTlsStream::Tls(tls_stream))
}
b'N' => {
if mode == TlsMode::Require {
return Err(crate::error::PgWireError::Protocol(
"server does not support TLS but sslmode=require".to_string(),
));
}
Ok(MaybeTlsStream::Plain(stream))
}
other => Err(crate::error::PgWireError::Protocol(format!(
"unexpected SSL response: {}",
other as char
))),
}
}