use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use rustls::{ClientConfig, RootCertStore};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use tokio_rustls::{client::TlsStream, TlsConnector};
use crate::config::{SslMode, TlsConfig};
use crate::error::{PgWireError, Result};
use crate::protocol::framing::write_ssl_request;
#[derive(Debug)]
pub enum MaybeTlsStream {
Plain(TcpStream),
Tls(Box<TlsStream<TcpStream>>),
}
impl MaybeTlsStream {
#[inline]
pub fn is_tls(&self) -> bool {
matches!(self, MaybeTlsStream::Tls(_))
}
#[inline]
pub fn is_plain(&self) -> bool {
matches!(self, MaybeTlsStream::Plain(_))
}
pub fn get_ref(&self) -> &TcpStream {
match self {
MaybeTlsStream::Plain(s) => s,
MaybeTlsStream::Tls(s) => s.get_ref().0,
}
}
}
impl AsyncRead for MaybeTlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTlsStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl AsyncWrite for MaybeTlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match self.get_mut() {
MaybeTlsStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTlsStream::Plain(s) => Pin::new(s).poll_flush(cx),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTlsStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
pub async fn maybe_upgrade_to_tls(
mut tcp: TcpStream,
tls: &TlsConfig,
host: &str,
) -> Result<MaybeTlsStream> {
match tls.mode {
SslMode::Disable => return Ok(MaybeTlsStream::Plain(tcp)),
SslMode::Prefer | SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull => {}
}
let _ = rustls::crypto::ring::default_provider().install_default();
write_ssl_request(&mut tcp).await?;
let mut resp = [0u8; 1];
use tokio::io::AsyncReadExt;
tcp.read_exact(&mut resp).await?;
if resp[0] != b'S' {
return match tls.mode {
SslMode::Prefer => Ok(MaybeTlsStream::Plain(tcp)),
_ => Err(PgWireError::Tls(
"server does not support TLS (SSLRequest rejected)".into(),
)),
};
}
let verify_chain = matches!(tls.mode, SslMode::VerifyCa | SslMode::VerifyFull);
let verify_hostname = matches!(tls.mode, SslMode::VerifyFull);
let cfg = build_rustls_config(tls, verify_chain, verify_hostname, host)?;
let connector = TlsConnector::from(Arc::new(cfg));
let sni = tls.sni_hostname.as_deref().unwrap_or(host);
let server_name = rustls::pki_types::ServerName::try_from(sni.to_string())
.map_err(|_| PgWireError::Tls(format!("invalid SNI hostname '{sni}'")))?;
let tls_stream = connector
.connect(server_name, tcp)
.await
.map_err(|e| PgWireError::Tls(format!("TLS handshake failed: {e}")))?;
Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
}
fn build_rustls_config(
tls: &TlsConfig,
verify_chain: bool,
verify_hostname: bool,
host: &str,
) -> Result<ClientConfig> {
let has_cert = tls.client_cert_pem_path.is_some();
let has_key = tls.client_key_pem_path.is_some();
if has_cert ^ has_key {
return Err(PgWireError::Tls(format!(
"TLS config error: mTLS requires both client_cert_pem_path and client_key_pem_path \
(got cert={has_cert} key={has_key})"
)));
}
if verify_hostname && host.parse::<std::net::IpAddr>().is_ok() && tls.sni_hostname.is_none() {
return Err(PgWireError::Tls(format!(
"TLS config error: VerifyFull enabled but host '{host}' is an IP address. \
Hint: use a DNS name matching the certificate, or set tls.sni_hostname, \
or use VerifyCa mode."
)));
}
let roots = build_root_store(tls)?;
let roots_arc = Arc::new(roots.clone());
let builder = ClientConfig::builder().with_root_certificates(roots);
let mut cfg: ClientConfig = if has_cert {
let cert_path = tls.client_cert_pem_path.as_ref().unwrap();
let key_path = tls.client_key_pem_path.as_ref().unwrap();
let cert_chain = load_cert_chain(cert_path)?;
let key = load_private_key(key_path)?;
builder
.with_client_auth_cert(cert_chain, key)
.map_err(|e| {
PgWireError::Tls(format!("TLS config error: invalid client cert/key: {e}"))
})?
} else {
builder.with_no_client_auth()
};
if !verify_chain {
cfg.dangerous()
.set_certificate_verifier(Arc::new(NoVerifier));
return Ok(cfg);
}
if verify_chain && !verify_hostname {
let inner = rustls::client::WebPkiServerVerifier::builder(roots_arc)
.build()
.map_err(|e| PgWireError::Tls(format!("TLS config error: build verifier: {e}")))?;
cfg.dangerous()
.set_certificate_verifier(Arc::new(VerifyChainOnly { inner }));
}
Ok(cfg)
}
fn build_root_store(tls: &TlsConfig) -> Result<RootCertStore> {
use rustls::pki_types::CertificateDer;
let mut roots = RootCertStore::empty();
if let Some(path) = &tls.ca_pem_path {
use rustls::pki_types::pem::PemObject;
let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(path)
.map_err(|e| {
PgWireError::Tls(format!(
"TLS config error: failed to open CA PEM '{}': {e}",
path.display()
))
})?
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| {
PgWireError::Tls(format!(
"TLS config error: failed to parse CA PEM '{}': {e}",
path.display()
))
})?;
let (added, _ignored) = roots.add_parsable_certificates(certs);
if added == 0 {
return Err(PgWireError::Tls(format!(
"TLS config error: no valid CA certificates found in '{}'",
path.display()
)));
}
} else {
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
}
Ok(roots)
}
fn load_cert_chain(
path: &std::path::Path,
) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
use rustls::pki_types::pem::PemObject;
use rustls::pki_types::CertificateDer;
let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(path)
.map_err(|e| {
PgWireError::Tls(format!(
"TLS config error: failed to open client certificate '{}': {e}",
path.display()
))
})?
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| {
PgWireError::Tls(format!(
"TLS config error: failed to parse client certificate '{}': {e}",
path.display()
))
})?;
if certs.is_empty() {
return Err(PgWireError::Tls(format!(
"TLS config error: no certificates found in '{}'",
path.display()
)));
}
Ok(certs)
}
fn load_private_key(path: &std::path::Path) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
use rustls::pki_types::pem::PemObject;
use rustls::pki_types::PrivateKeyDer;
PrivateKeyDer::from_pem_file(path).map_err(|e| {
PgWireError::Tls(format!(
"TLS config error: failed to load private key from '{}': {e}. \
Supported formats: PKCS#8, PKCS#1 (RSA), SEC1 (EC)",
path.display()
))
})
}
#[derive(Debug)]
struct NoVerifier;
impl rustls::client::danger::ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp: &[u8],
_now: rustls::pki_types::UnixTime,
) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
]
}
}
#[derive(Debug)]
struct VerifyChainOnly {
inner: Arc<dyn rustls::client::danger::ServerCertVerifier>,
}
impl rustls::client::danger::ServerCertVerifier for VerifyChainOnly {
fn verify_server_cert(
&self,
end_entity: &rustls::pki_types::CertificateDer<'_>,
intermediates: &[rustls::pki_types::CertificateDer<'_>],
server_name: &rustls::pki_types::ServerName<'_>,
ocsp: &[u8],
now: rustls::pki_types::UnixTime,
) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
match self
.inner
.verify_server_cert(end_entity, intermediates, server_name, ocsp, now)
{
Ok(ok) => Ok(ok),
Err(rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidForName)) => {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
Err(e) => Err(e),
}
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn mtls_requires_both_cert_and_key() {
let tls = TlsConfig {
client_cert_pem_path: Some("/path/to/cert.pem".into()),
client_key_pem_path: None,
..Default::default()
};
let err = build_rustls_config(&tls, false, false, "localhost").unwrap_err();
assert!(err.to_string().contains("mTLS requires both"));
let tls = TlsConfig {
client_cert_pem_path: None,
client_key_pem_path: Some("/path/to/key.pem".into()),
..Default::default()
};
let err = build_rustls_config(&tls, false, false, "localhost").unwrap_err();
assert!(err.to_string().contains("mTLS requires both"));
}
#[test]
fn verify_full_rejects_ip_without_sni_override() {
let tls = TlsConfig {
mode: SslMode::VerifyFull,
..Default::default()
};
let err = build_rustls_config(&tls, true, true, "192.168.1.1").unwrap_err();
assert!(err.to_string().contains("IP address"));
}
#[test]
fn missing_ca_file_gives_clear_error() {
let tls = TlsConfig {
ca_pem_path: Some("/nonexistent/ca.pem".into()),
..Default::default()
};
let err = build_root_store(&tls).unwrap_err().to_string();
assert!(err.contains("failed to open"));
assert!(err.contains("ca.pem"));
}
#[test]
fn empty_ca_file_gives_clear_error() {
let f = NamedTempFile::new().unwrap();
let tls = TlsConfig {
ca_pem_path: Some(f.path().to_path_buf()),
..Default::default()
};
let err = build_root_store(&tls).unwrap_err().to_string();
assert!(err.contains("no valid CA certificates"));
}
#[test]
fn empty_key_file_gives_clear_error() {
let f = NamedTempFile::new().unwrap();
let err = load_private_key(f.path()).unwrap_err().to_string();
assert!(err.contains("failed to load private key"));
}
#[test]
fn invalid_pem_gives_clear_error() {
let mut f = NamedTempFile::new().unwrap();
f.write_all(b"this is not a valid PEM file").unwrap();
assert!(load_private_key(f.path()).is_err());
assert!(load_cert_chain(f.path()).is_err());
}
}