use super::conninfo::{ConnInfo, SslMode, SslNegotiation};
use super::wire;
use crate::error::ReplicationError;
use bytes::BytesMut;
use rustls::pki_types::pem::PemObject;
use std::sync::{Arc, LazyLock};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, BufReader};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
static CRYPTO_PROVIDER: LazyLock<Arc<rustls::crypto::CryptoProvider>> =
LazyLock::new(|| Arc::new(rustls::crypto::aws_lc_rs::default_provider()));
fn crypto_provider() -> Arc<rustls::crypto::CryptoProvider> {
CRYPTO_PROVIDER.clone()
}
const TLS_BUF_SIZE: usize = 16_384;
pub enum Transport {
Plain(TcpStream),
Tls(BufReader<tokio_rustls::client::TlsStream<TcpStream>>),
}
impl Transport {
pub fn tls_server_end_point(&self) -> Option<Vec<u8>> {
match self {
Transport::Tls(buf_reader) => {
let tls_stream = buf_reader.get_ref();
let (_, conn) = tls_stream.get_ref();
let certs = conn.peer_certificates()?;
let ee_cert_der = certs.first()?;
let hash = compute_tls_server_end_point_hash(ee_cert_der.as_ref())?;
Some(hash)
}
Transport::Plain(_) => None,
}
}
}
fn compute_tls_server_end_point_hash(cert_der: &[u8]) -> Option<Vec<u8>> {
use aws_lc_rs::digest;
let rsa_prefix = &[0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01];
let ecdsa_prefix = &[0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03];
let algorithm = if let Some(pos) = find_subsequence(cert_der, rsa_prefix) {
let suffix_pos = pos + rsa_prefix.len();
if suffix_pos < cert_der.len() {
match cert_der[suffix_pos] {
0x04 | 0x05 => &digest::SHA256, 0x0b => &digest::SHA256, 0x0c => &digest::SHA384, 0x0d => &digest::SHA512, _ => &digest::SHA256, }
} else {
&digest::SHA256
}
} else if let Some(pos) = find_subsequence(cert_der, ecdsa_prefix) {
let suffix_pos = pos + ecdsa_prefix.len();
if suffix_pos < cert_der.len() {
match cert_der[suffix_pos] {
0x02 => &digest::SHA256, 0x03 => &digest::SHA384, 0x04 => &digest::SHA512, _ => &digest::SHA256, }
} else {
&digest::SHA256
}
} else {
&digest::SHA256
};
let hash = digest::digest(algorithm, cert_der);
Some(hash.as_ref().to_vec())
}
fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
haystack
.windows(needle.len())
.position(|window| window == needle)
}
impl AsyncRead for Transport {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
Transport::Plain(ref mut s) => std::pin::Pin::new(s).poll_read(cx, buf),
Transport::Tls(ref mut s) => std::pin::Pin::new(s).poll_read(cx, buf),
}
}
}
impl AsyncWrite for Transport {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
match self.get_mut() {
Transport::Plain(ref mut s) => std::pin::Pin::new(s).poll_write(cx, buf),
Transport::Tls(ref mut s) => std::pin::Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
Transport::Plain(ref mut s) => std::pin::Pin::new(s).poll_flush(cx),
Transport::Tls(ref mut s) => std::pin::Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
Transport::Plain(ref mut s) => std::pin::Pin::new(s).poll_shutdown(cx),
Transport::Tls(ref mut s) => std::pin::Pin::new(s).poll_shutdown(cx),
}
}
}
pub async fn connect(info: &ConnInfo) -> Result<(Transport, i32, BytesMut), ReplicationError> {
let addr = format!("{}:{}", info.host, info.port);
let tcp = tcp_connect(&addr, info).await?;
let mut transport = match info.sslmode {
SslMode::Disable => Transport::Plain(tcp),
SslMode::Prefer | SslMode::Allow => {
match negotiate_tls(tcp, info).await {
Ok(t) => t,
Err(e) => {
tracing::debug!("TLS negotiation failed, falling back to plain: {e}");
let tcp = tcp_connect(&addr, info).await?;
Transport::Plain(tcp)
}
}
}
_ => negotiate_tls(tcp, info).await?,
};
let mut buf = BytesMut::with_capacity(8192);
let server_version = startup_and_auth(&mut transport, &mut buf, info).await?;
Ok((transport, server_version, buf))
}
async fn tcp_connect(addr: &str, info: &ConnInfo) -> Result<TcpStream, ReplicationError> {
let tcp = if info.connect_timeout > 0 {
let timeout = Duration::from_secs(info.connect_timeout);
tokio::time::timeout(timeout, TcpStream::connect(addr))
.await
.map_err(|_| {
ReplicationError::transient_connection(format!(
"Connection to {addr} timed out after {}s",
info.connect_timeout
))
})?
.map_err(|e| {
ReplicationError::transient_connection(format!("Failed to connect to {addr}: {e}"))
})?
} else {
TcpStream::connect(addr).await.map_err(|e| {
ReplicationError::transient_connection(format!("Failed to connect to {addr}: {e}"))
})?
};
tcp.set_nodelay(true).ok();
if info.keepalives {
configure_tcp_keepalive(&tcp, info);
}
Ok(tcp)
}
fn configure_tcp_keepalive(tcp: &TcpStream, info: &ConnInfo) {
use socket2::SockRef;
let sock = SockRef::from(tcp);
let keepalive = socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(info.keepalives_idle))
.with_interval(Duration::from_secs(info.keepalives_interval));
#[cfg(any(
target_os = "linux",
target_os = "macos",
target_os = "ios",
target_os = "freebsd",
target_os = "netbsd",
))]
let keepalive = keepalive.with_retries(info.keepalives_count);
if let Err(e) = sock.set_tcp_keepalive(&keepalive) {
tracing::warn!("Failed to set TCP keepalive: {e}");
} else {
tracing::debug!(
"TCP keepalive configured: idle={}s, interval={}s",
info.keepalives_idle,
info.keepalives_interval
);
}
}
async fn negotiate_tls(tcp: TcpStream, info: &ConnInfo) -> Result<Transport, ReplicationError> {
if info.sslnegotiation == SslNegotiation::Direct {
match negotiate_tls_direct(tcp, info).await {
Ok(transport) => return Ok(transport),
Err(e) => {
tracing::debug!(
"Direct SSL negotiation failed, falling back to standard SSLRequest: {e}"
);
let addr = format!("{}:{}", info.host, info.port);
let tcp = tcp_connect(&addr, info).await?;
return negotiate_tls_standard(tcp, info).await;
}
}
}
negotiate_tls_standard(tcp, info).await
}
async fn negotiate_tls_standard(
mut tcp: TcpStream,
info: &ConnInfo,
) -> Result<Transport, ReplicationError> {
let ssl_req = wire::build_ssl_request();
wire::write_all(&mut tcp, &ssl_req).await?;
wire::flush(&mut tcp).await?;
let response = wire::read_byte(&mut tcp).await?;
match response {
b'S' => {
let tls_config = build_tls_config(info)?;
let connector = TlsConnector::from(Arc::new(tls_config));
let server_name = rustls::pki_types::ServerName::try_from(info.host.as_str())
.map_err(|e| {
ReplicationError::permanent_connection(format!(
"Invalid server name for TLS: {e}"
))
})?
.to_owned();
let tls_stream = connector.connect(server_name, tcp).await.map_err(|e| {
ReplicationError::transient_connection(format!("TLS handshake failed: {e}"))
})?;
tracing::debug!("TLS connection established");
Ok(Transport::Tls(BufReader::with_capacity(
TLS_BUF_SIZE,
tls_stream,
)))
}
b'N' => {
match info.sslmode {
SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull => {
Err(ReplicationError::permanent_connection(
"Server does not support SSL but sslmode=require".to_string(),
))
}
_ => {
tracing::debug!("Server doesn't support SSL, falling back to plain");
Ok(Transport::Plain(tcp))
}
}
}
other => Err(ReplicationError::protocol(format!(
"Unexpected SSLRequest response: 0x{other:02x}"
))),
}
}
async fn negotiate_tls_direct(
tcp: TcpStream,
info: &ConnInfo,
) -> Result<Transport, ReplicationError> {
let mut tls_config = build_tls_config(info)?;
tls_config.alpn_protocols = vec![b"postgresql".to_vec()];
let connector = TlsConnector::from(Arc::new(tls_config));
let server_name = rustls::pki_types::ServerName::try_from(info.host.as_str())
.map_err(|e| {
ReplicationError::permanent_connection(format!("Invalid server name for TLS: {e}"))
})?
.to_owned();
let tls_stream = connector.connect(server_name, tcp).await.map_err(|e| {
ReplicationError::transient_connection(format!("Direct TLS handshake failed: {e}"))
})?;
tracing::debug!("Direct TLS connection established (ALPN postgresql)");
Ok(Transport::Tls(BufReader::with_capacity(
TLS_BUF_SIZE,
tls_stream,
)))
}
fn build_tls_config(info: &ConnInfo) -> Result<rustls::ClientConfig, ReplicationError> {
let provider = crypto_provider();
match info.sslmode {
SslMode::VerifyFull => {
let root_store = build_root_store(info.sslrootcert.as_deref())?;
let config = rustls::ClientConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.map_err(|e| {
ReplicationError::permanent_connection(format!(
"Failed to configure TLS protocol versions: {e}"
))
})?
.with_root_certificates(root_store)
.with_no_client_auth();
Ok(config)
}
SslMode::VerifyCa => {
let root_store = build_root_store(info.sslrootcert.as_deref())?;
let verifier = NoHostnameVerifier::new(root_store, provider.clone());
let config = rustls::ClientConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.map_err(|e| {
ReplicationError::permanent_connection(format!(
"Failed to configure TLS protocol versions: {e}"
))
})?
.dangerous()
.with_custom_certificate_verifier(Arc::new(verifier))
.with_no_client_auth();
Ok(config)
}
SslMode::Require | SslMode::Prefer | SslMode::Allow => {
let config = rustls::ClientConfig::builder_with_provider(provider.clone())
.with_safe_default_protocol_versions()
.map_err(|e| {
ReplicationError::permanent_connection(format!(
"Failed to configure TLS protocol versions: {e}"
))
})?
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerification(provider)))
.with_no_client_auth();
Ok(config)
}
SslMode::Disable => {
unreachable!("TLS config not needed when sslmode=disable")
}
}
}
fn build_root_store(sslrootcert: Option<&str>) -> Result<rustls::RootCertStore, ReplicationError> {
let mut store = rustls::RootCertStore::empty();
if let Some(path) = sslrootcert {
let file = std::fs::File::open(path).map_err(|e| {
ReplicationError::permanent_connection(format!(
"Failed to open sslrootcert file '{path}': {e}"
))
})?;
let mut reader = std::io::BufReader::new(file);
let certs: Vec<rustls::pki_types::CertificateDer<'static>> =
rustls::pki_types::CertificateDer::pem_reader_iter(&mut reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| {
ReplicationError::permanent_connection(format!(
"Failed to parse PEM certificates from '{path}': {e}"
))
})?;
if certs.is_empty() {
return Err(ReplicationError::permanent_connection(format!(
"No certificates found in sslrootcert file '{path}'"
)));
}
for cert in certs {
store.add(cert).map_err(|e| {
ReplicationError::permanent_connection(format!(
"Failed to add certificate from '{path}': {e}"
))
})?;
}
return Ok(store);
}
store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
Ok(store)
}
#[derive(Debug)]
struct NoVerification(Arc<rustls::crypto::CryptoProvider>);
impl rustls::client::danger::ServerCertVerifier for NoVerification {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.0.signature_verification_algorithms.supported_schemes()
}
}
#[derive(Debug)]
struct NoHostnameVerifier {
inner: Arc<rustls::client::WebPkiServerVerifier>,
}
impl NoHostnameVerifier {
fn new(roots: rustls::RootCertStore, provider: Arc<rustls::crypto::CryptoProvider>) -> Self {
let inner =
rustls::client::WebPkiServerVerifier::builder_with_provider(Arc::new(roots), provider)
.build()
.expect("failed to build WebPkiServerVerifier");
Self { inner }
}
}
impl rustls::client::danger::ServerCertVerifier for NoHostnameVerifier {
fn verify_server_cert(
&self,
end_entity: &rustls::pki_types::CertificateDer<'_>,
intermediates: &[rustls::pki_types::CertificateDer<'_>],
server_name: &rustls::pki_types::ServerName<'_>,
ocsp_response: &[u8],
now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
match self.inner.verify_server_cert(
end_entity,
intermediates,
server_name,
ocsp_response,
now,
) {
Ok(v) => Ok(v),
Err(rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidForName)) => {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
Err(rustls::Error::InvalidCertificate(
rustls::CertificateError::NotValidForNameContext { .. },
)) => {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
Err(other) => Err(other), }
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
async fn startup_and_auth(
transport: &mut Transport,
buf: &mut BytesMut,
info: &ConnInfo,
) -> Result<i32, ReplicationError> {
let replication_str;
let mut params: Vec<(&str, &str)> = vec![
("user", &info.user),
("database", &info.dbname),
("client_encoding", "UTF8"),
];
match info.replication {
super::conninfo::ReplicationMode::Database => {
replication_str = "database".to_string();
params.push(("replication", &replication_str));
}
super::conninfo::ReplicationMode::Physical => {
replication_str = "true".to_string();
params.push(("replication", &replication_str));
}
super::conninfo::ReplicationMode::None => {}
}
let startup_msg = wire::build_startup_message(¶ms);
wire::write_all(transport, &startup_msg).await?;
wire::flush(transport).await?;
let tls_server_end_point = transport.tls_server_end_point();
super::auth::authenticate(
transport,
buf,
&info.user,
info.password.as_deref(),
tls_server_end_point,
)
.await?;
let mut server_version = 0i32;
loop {
let msg = wire::read_message(transport, buf).await?;
if msg.is_empty() {
continue;
}
match msg[0] {
b'S' => {
let payload = &msg[5..];
let (key, consumed) = wire::read_cstring(payload);
let (value, _) = wire::read_cstring(&payload[consumed..]);
tracing::debug!("ParameterStatus: {key}={value}");
if key == "server_version" {
server_version = parse_server_version(value);
}
}
b'K' => {
tracing::debug!("Received BackendKeyData");
}
b'Z' => {
tracing::debug!("Startup complete, server ready");
break;
}
b'E' => {
let fields = super::error::parse_error_fields(&msg[5..]);
return Err(ReplicationError::permanent_connection(format!(
"Server error during startup: {}",
fields
)));
}
b'N' => {
let fields = super::error::parse_error_fields(&msg[5..]);
tracing::info!("Server notice: {}", fields);
}
_ => {
tracing::debug!("Skipping message type '{}' during startup", msg[0] as char);
}
}
}
if server_version < 140000 {
return Err(ReplicationError::permanent_connection(format!(
"PostgreSQL version {} is not supported. Requires 14+",
server_version
)));
}
Ok(server_version)
}
fn parse_server_version(version_str: &str) -> i32 {
let version = version_str.split_whitespace().next().unwrap_or("");
let parts: Vec<&str> = version.split('.').collect();
match parts.len() {
1 => {
parts[0].parse::<i32>().unwrap_or(0) * 10000
}
2 => {
let major = parts[0].parse::<i32>().unwrap_or(0);
let minor = parts[1].parse::<i32>().unwrap_or(0);
major * 10000 + minor
}
_ => {
let major = parts[0].parse::<i32>().unwrap_or(0);
let minor = parts[1].parse::<i32>().unwrap_or(0);
let patch = parts[2].parse::<i32>().unwrap_or(0);
major * 10000 + minor * 100 + patch
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_conninfo(sslmode: SslMode, sslrootcert: Option<String>) -> ConnInfo {
ConnInfo {
host: "localhost".to_string(),
port: 5432,
user: "test".to_string(),
password: None,
dbname: "test".to_string(),
sslmode,
sslrootcert,
sslnegotiation: super::super::conninfo::SslNegotiation::Postgres,
replication: super::super::conninfo::ReplicationMode::None,
connect_timeout: 0,
keepalives: true,
keepalives_idle: 120,
keepalives_interval: 10,
keepalives_count: 3,
}
}
#[test]
fn test_parse_server_version() {
assert_eq!(parse_server_version("16.1"), 160001);
assert_eq!(parse_server_version("14.2"), 140002);
assert_eq!(parse_server_version("16.1 (Debian 16.1-1)"), 160001);
assert_eq!(parse_server_version("15"), 150000);
}
#[test]
fn test_parse_server_version_three_part() {
assert_eq!(parse_server_version("14.2.1"), 140201);
assert_eq!(parse_server_version("9.6.24"), 90624);
}
#[test]
fn test_parse_server_version_empty() {
assert_eq!(parse_server_version(""), 0);
}
#[test]
fn test_parse_server_version_garbage() {
assert_eq!(parse_server_version("abc"), 0);
assert_eq!(parse_server_version("not.a.version"), 0);
}
#[test]
fn test_parse_server_version_with_extra_text() {
assert_eq!(
parse_server_version("16.4 (Ubuntu 16.4-1.pgdg22.04+1)"),
160004
);
assert_eq!(parse_server_version("16.12 - Azure"), 160012);
}
#[test]
fn test_build_tls_config_require() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let info = test_conninfo(SslMode::Require, None);
let config = build_tls_config(&info);
assert!(config.is_ok());
}
#[test]
fn test_build_tls_config_verify_full() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let info = test_conninfo(SslMode::VerifyFull, None);
let config = build_tls_config(&info);
assert!(config.is_ok());
}
#[test]
fn test_build_tls_config_verify_ca() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let info = test_conninfo(SslMode::VerifyCa, None);
let config = build_tls_config(&info);
assert!(config.is_ok());
}
#[test]
fn test_build_root_store_default() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let store = build_root_store(None).unwrap();
assert!(store.len() > 50, "Expected many CAs, got {}", store.len());
}
#[test]
fn test_build_root_store_custom_file() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let pem_content = include_str!("../../../load-tests/fixtures/test_ca.pem");
let dir = std::env::temp_dir().join("pg_walstream_test");
std::fs::create_dir_all(&dir).unwrap();
let pem_path = dir.join("test_ca.pem");
std::fs::write(&pem_path, pem_content).unwrap();
let store = build_root_store(Some(pem_path.to_str().unwrap())).unwrap();
assert!(
store.len() >= 1,
"Expected at least 1 CA from custom file, got {}",
store.len()
);
let _ = std::fs::remove_file(&pem_path);
let _ = std::fs::remove_dir(&dir);
}
#[test]
fn test_build_root_store_missing_file() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let result = build_root_store(Some("/nonexistent/path/ca.pem"));
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("sslrootcert"),
"Error should mention sslrootcert: {err}"
);
}
#[test]
fn test_build_root_store_empty_pem_file() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let dir = std::env::temp_dir().join("pg_walstream_test_empty");
std::fs::create_dir_all(&dir).unwrap();
let pem_path = dir.join("empty.pem");
std::fs::write(&pem_path, b"").unwrap();
let result = build_root_store(Some(pem_path.to_str().unwrap()));
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("No certificates found"),
"Error should mention no certs: {err}"
);
let _ = std::fs::remove_file(&pem_path);
let _ = std::fs::remove_dir(&dir);
}
#[test]
fn test_build_root_store_malformed_pem_file() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let dir = std::env::temp_dir().join("pg_walstream_test_bad");
std::fs::create_dir_all(&dir).unwrap();
let pem_path = dir.join("bad.pem");
std::fs::write(
&pem_path,
b"-----BEGIN CERTIFICATE-----\nnot-valid-base64-or-der!!!\n-----END CERTIFICATE-----\n",
)
.unwrap();
let result = build_root_store(Some(pem_path.to_str().unwrap()));
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.to_lowercase().contains("pem") || err.contains("sslrootcert"),
"Error should mention PEM/sslrootcert: {err}"
);
let _ = std::fs::remove_file(&pem_path);
let _ = std::fs::remove_dir(&dir);
}
#[test]
fn test_build_root_store_default_uses_webpki_roots() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let store = build_root_store(None).unwrap();
assert_eq!(
store.len(),
webpki_roots::TLS_SERVER_ROOTS.len(),
"Default root store should contain exactly the Mozilla bundle"
);
}
#[test]
fn test_build_tls_config_verify_full_with_custom_ca() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let info = test_conninfo(SslMode::VerifyFull, Some("/nonexistent/ca.pem".to_string()));
let config = build_tls_config(&info);
assert!(config.is_err());
}
#[test]
fn test_build_tls_config_require_ignores_sslrootcert() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let info = test_conninfo(SslMode::Require, Some("/nonexistent/ca.pem".to_string()));
let config = build_tls_config(&info);
assert!(config.is_ok());
}
#[test]
fn test_build_tls_config_has_no_alpn_by_default() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let info = test_conninfo(SslMode::Require, None);
let config = build_tls_config(&info).unwrap();
assert!(
config.alpn_protocols.is_empty(),
"Standard TLS config should have no ALPN protocols set"
);
}
#[test]
fn test_direct_ssl_sets_alpn_postgresql() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let info = test_conninfo(SslMode::Require, None);
let mut config = build_tls_config(&info).unwrap();
config.alpn_protocols = vec![b"postgresql".to_vec()];
assert_eq!(config.alpn_protocols, vec![b"postgresql".to_vec()]);
}
#[test]
fn test_sslnegotiation_default_is_postgres() {
let info = test_conninfo(SslMode::Require, None);
assert_eq!(
info.sslnegotiation,
super::super::conninfo::SslNegotiation::Postgres
);
}
#[test]
fn test_find_subsequence_found() {
let haystack = &[0x01, 0x02, 0x03, 0x04, 0x05];
assert_eq!(find_subsequence(haystack, &[0x02, 0x03]), Some(1));
}
#[test]
fn test_find_subsequence_at_start() {
let haystack = &[0x01, 0x02, 0x03];
assert_eq!(find_subsequence(haystack, &[0x01, 0x02]), Some(0));
}
#[test]
fn test_find_subsequence_at_end() {
let haystack = &[0x01, 0x02, 0x03];
assert_eq!(find_subsequence(haystack, &[0x02, 0x03]), Some(1));
}
#[test]
fn test_find_subsequence_not_found() {
let haystack = &[0x01, 0x02, 0x03];
assert_eq!(find_subsequence(haystack, &[0x04, 0x05]), None);
}
#[test]
fn test_find_subsequence_single_byte_needle() {
let haystack = &[0x01, 0x02, 0x03];
assert_eq!(find_subsequence(haystack, &[0x02]), Some(1));
}
#[test]
fn test_find_subsequence_needle_longer_than_haystack() {
let haystack = &[0x01];
assert_eq!(find_subsequence(haystack, &[0x01, 0x02]), None);
}
fn cert_with_rsa_oid(suffix: u8) -> Vec<u8> {
let mut data = vec![0x30, 0x82, 0x00, 0x10]; data.extend_from_slice(&[0x06, 0x09]); data.extend_from_slice(&[0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, suffix]);
data.extend_from_slice(&[0x00; 20]); data
}
fn cert_with_ecdsa_oid(suffix: u8) -> Vec<u8> {
let mut data = vec![0x30, 0x82, 0x00, 0x10]; data.extend_from_slice(&[0x06, 0x08]); data.extend_from_slice(&[0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, suffix]);
data.extend_from_slice(&[0x00; 20]); data
}
#[test]
fn test_channel_binding_hash_sha256_with_rsa() {
use aws_lc_rs::digest;
let cert = cert_with_rsa_oid(0x0b); let hash = compute_tls_server_end_point_hash(&cert).unwrap();
let expected = digest::digest(&digest::SHA256, &cert);
assert_eq!(hash, expected.as_ref());
}
#[test]
fn test_channel_binding_hash_sha384_with_rsa() {
use aws_lc_rs::digest;
let cert = cert_with_rsa_oid(0x0c); let hash = compute_tls_server_end_point_hash(&cert).unwrap();
let expected = digest::digest(&digest::SHA384, &cert);
assert_eq!(hash, expected.as_ref());
}
#[test]
fn test_channel_binding_hash_sha512_with_rsa() {
use aws_lc_rs::digest;
let cert = cert_with_rsa_oid(0x0d); let hash = compute_tls_server_end_point_hash(&cert).unwrap();
let expected = digest::digest(&digest::SHA512, &cert);
assert_eq!(hash, expected.as_ref());
}
#[test]
fn test_channel_binding_hash_md5_falls_back_to_sha256() {
use aws_lc_rs::digest;
let cert = cert_with_rsa_oid(0x04); let hash = compute_tls_server_end_point_hash(&cert).unwrap();
let expected = digest::digest(&digest::SHA256, &cert);
assert_eq!(hash, expected.as_ref());
}
#[test]
fn test_channel_binding_hash_sha1_falls_back_to_sha256() {
use aws_lc_rs::digest;
let cert = cert_with_rsa_oid(0x05); let hash = compute_tls_server_end_point_hash(&cert).unwrap();
let expected = digest::digest(&digest::SHA256, &cert);
assert_eq!(hash, expected.as_ref());
}
#[test]
fn test_channel_binding_hash_ecdsa_sha256() {
use aws_lc_rs::digest;
let cert = cert_with_ecdsa_oid(0x02); let hash = compute_tls_server_end_point_hash(&cert).unwrap();
let expected = digest::digest(&digest::SHA256, &cert);
assert_eq!(hash, expected.as_ref());
}
#[test]
fn test_channel_binding_hash_ecdsa_sha384() {
use aws_lc_rs::digest;
let cert = cert_with_ecdsa_oid(0x03); let hash = compute_tls_server_end_point_hash(&cert).unwrap();
let expected = digest::digest(&digest::SHA384, &cert);
assert_eq!(hash, expected.as_ref());
}
#[test]
fn test_channel_binding_hash_ecdsa_sha512() {
use aws_lc_rs::digest;
let cert = cert_with_ecdsa_oid(0x04); let hash = compute_tls_server_end_point_hash(&cert).unwrap();
let expected = digest::digest(&digest::SHA512, &cert);
assert_eq!(hash, expected.as_ref());
}
#[test]
fn test_channel_binding_hash_unknown_oid_defaults_sha256() {
use aws_lc_rs::digest;
let cert = vec![0x30, 0x82, 0x00, 0x10, 0xAA, 0xBB, 0xCC, 0xDD];
let hash = compute_tls_server_end_point_hash(&cert).unwrap();
let expected = digest::digest(&digest::SHA256, &cert);
assert_eq!(hash, expected.as_ref());
}
#[test]
fn test_channel_binding_hash_unknown_rsa_suffix_defaults_sha256() {
use aws_lc_rs::digest;
let cert = cert_with_rsa_oid(0xFF); let hash = compute_tls_server_end_point_hash(&cert).unwrap();
let expected = digest::digest(&digest::SHA256, &cert);
assert_eq!(hash, expected.as_ref());
}
#[test]
fn test_transport_plain_returns_none_for_tls_endpoint() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let std_tcp = std::net::TcpStream::connect(addr).unwrap();
std_tcp.set_nonblocking(true).unwrap();
let _peer = listener.accept().unwrap();
let tcp = match tokio::runtime::Handle::try_current() {
Ok(_) => tokio::net::TcpStream::from_std(std_tcp).unwrap(),
Err(_) => {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_io()
.build()
.unwrap();
let _guard = rt.enter();
tokio::net::TcpStream::from_std(std_tcp).unwrap()
}
};
let transport = Transport::Plain(tcp);
assert!(transport.tls_server_end_point().is_none());
}
}