use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures::{SinkExt, StreamExt};
#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
use rustls_pki_types::CertificateDer;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::time::{Duration, sleep};
#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
use tokio_rustls::server::TlsStream;
use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts};
use crate::api::auth::StartupHandler;
use crate::api::cancel::CancelHandler;
use crate::api::copy::CopyHandler;
use crate::api::query::{ExtendedQueryHandler, SimpleQueryHandler, send_ready_for_query};
use crate::api::{
ClientInfo, ClientPortalStore, DefaultClient, ErrorHandler, PgWireConnectionState,
PgWireServerHandlers,
};
use crate::error::{ErrorInfo, PgWireError, PgWireResult};
use crate::messages::response::{GssEncResponse, ReadyForQuery, SslResponse, TransactionStatus};
use crate::messages::startup::SecretKey;
use crate::messages::{
DecodeContext, PgWireBackendMessage, PgWireFrontendMessage, ProtocolVersion,
SslNegotiationMetaMessage,
};
const STARTUP_TIMEOUT_MILLIS: u64 = 60_000;
#[non_exhaustive]
#[derive(Debug, new)]
pub struct PgWireMessageServerCodec<S> {
pub client_info: DefaultClient<S>,
#[new(default)]
decode_context: DecodeContext,
}
impl<S> Decoder for PgWireMessageServerCodec<S> {
type Item = PgWireFrontendMessage;
type Error = PgWireError;
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
self.decode_context.protocol_version = self.client_info.protocol_version;
match self.client_info.state() {
PgWireConnectionState::AwaitingSslRequest => {
self.decode_context.awaiting_frontend_ssl = true;
self.decode_context.awaiting_frontend_startup = true;
}
PgWireConnectionState::AwaitingStartup => {
self.decode_context.awaiting_frontend_ssl = false;
self.decode_context.awaiting_frontend_startup = true;
}
_ => {
self.decode_context.awaiting_frontend_startup = false;
self.decode_context.awaiting_frontend_ssl = false;
}
}
PgWireFrontendMessage::decode(src, &self.decode_context)
}
}
impl<S> Encoder<PgWireBackendMessage> for PgWireMessageServerCodec<S> {
type Error = io::Error;
fn encode(
&mut self,
item: PgWireBackendMessage,
dst: &mut bytes::BytesMut,
) -> Result<(), Self::Error> {
item.encode(dst).map_err(Into::into)
}
}
impl<T: 'static, S> ClientInfo for Framed<T, PgWireMessageServerCodec<S>> {
fn socket_addr(&self) -> std::net::SocketAddr {
self.codec().client_info.socket_addr
}
fn is_secure(&self) -> bool {
self.codec().client_info.is_secure
}
fn pid_and_secret_key(&self) -> (i32, SecretKey) {
self.codec().client_info.pid_and_secret_key()
}
fn set_pid_and_secret_key(&mut self, pid: i32, secret_key: SecretKey) {
self.codec_mut()
.client_info
.set_pid_and_secret_key(pid, secret_key);
}
fn protocol_version(&self) -> ProtocolVersion {
self.codec().client_info.protocol_version()
}
fn set_protocol_version(&mut self, version: ProtocolVersion) {
self.codec_mut().client_info.set_protocol_version(version);
}
fn state(&self) -> PgWireConnectionState {
self.codec().client_info.state
}
fn set_state(&mut self, new_state: PgWireConnectionState) {
self.codec_mut().client_info.set_state(new_state);
}
fn metadata(&self) -> &std::collections::HashMap<String, String> {
self.codec().client_info.metadata()
}
fn metadata_mut(&mut self) -> &mut std::collections::HashMap<String, String> {
self.codec_mut().client_info.metadata_mut()
}
fn transaction_status(&self) -> TransactionStatus {
self.codec().client_info.transaction_status()
}
fn set_transaction_status(&mut self, new_status: TransactionStatus) {
self.codec_mut()
.client_info
.set_transaction_status(new_status);
}
#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
fn sni_server_name(&self) -> Option<&str> {
self.codec().client_info.sni_server_name()
}
#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
fn client_certificates<'a>(&self) -> Option<&[CertificateDer<'a>]> {
if !self.is_secure() {
None
} else {
let socket =
<dyn std::any::Any>::downcast_ref::<TlsStream<TcpStream>>(self.get_ref()).unwrap();
let (_, tls_session) = socket.get_ref();
tls_session.peer_certificates()
}
}
}
impl<T, S> ClientPortalStore for Framed<T, PgWireMessageServerCodec<S>> {
type PortalStore = <DefaultClient<S> as ClientPortalStore>::PortalStore;
fn portal_store(&self) -> &Self::PortalStore {
self.codec().client_info.portal_store()
}
}
pub async fn process_message<S, A, Q, EQ, C, CR>(
message: PgWireFrontendMessage,
socket: &mut Framed<S, PgWireMessageServerCodec<EQ::Statement>>,
authenticator: Arc<A>,
query_handler: Arc<Q>,
extended_query_handler: Arc<EQ>,
copy_handler: Arc<C>,
cancel_handler: Arc<CR>,
) -> PgWireResult<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
A: StartupHandler,
Q: SimpleQueryHandler,
EQ: ExtendedQueryHandler,
C: CopyHandler,
CR: CancelHandler,
{
if let PgWireFrontendMessage::CancelRequest(cancel) = message {
cancel_handler.on_cancel_request(cancel).await;
socket.close().await?;
return Ok(());
}
match socket.state() {
PgWireConnectionState::AwaitingStartup
| PgWireConnectionState::AuthenticationInProgress => {
authenticator.on_startup(socket, message).await?;
}
PgWireConnectionState::AwaitingSync => {
if let PgWireFrontendMessage::Sync(sync) = message {
extended_query_handler.on_sync(socket, sync).await?;
socket.set_state(PgWireConnectionState::ReadyForQuery);
}
}
PgWireConnectionState::CopyInProgress(is_extended_query) => {
match message {
PgWireFrontendMessage::CopyData(copy_data) => {
copy_handler.on_copy_data(socket, copy_data).await?;
}
PgWireFrontendMessage::CopyDone(copy_done) => {
let result = copy_handler.on_copy_done(socket, copy_done).await;
if !is_extended_query {
socket.set_state(PgWireConnectionState::ReadyForQuery);
}
match result {
Ok(_) => {
if !is_extended_query {
send_ready_for_query(socket, TransactionStatus::Idle).await?
} else {
}
}
err => return err,
}
}
PgWireFrontendMessage::CopyFail(copy_fail) => {
let error = copy_handler.on_copy_fail(socket, copy_fail).await;
if !is_extended_query {
socket.set_state(PgWireConnectionState::ReadyForQuery);
}
return Err(error);
}
_ => {}
}
}
_ => {
match message {
PgWireFrontendMessage::Query(query) => {
query_handler.on_query(socket, query).await?;
}
PgWireFrontendMessage::Parse(parse) => {
extended_query_handler.on_parse(socket, parse).await?;
}
PgWireFrontendMessage::Bind(bind) => {
extended_query_handler.on_bind(socket, bind).await?;
}
PgWireFrontendMessage::Execute(execute) => {
extended_query_handler.on_execute(socket, execute).await?;
}
PgWireFrontendMessage::Describe(describe) => {
extended_query_handler.on_describe(socket, describe).await?;
}
PgWireFrontendMessage::Flush(flush) => {
extended_query_handler.on_flush(socket, flush).await?;
}
PgWireFrontendMessage::Sync(sync) => {
extended_query_handler.on_sync(socket, sync).await?;
}
PgWireFrontendMessage::Close(close) => {
extended_query_handler.on_close(socket, close).await?;
}
_ => {}
}
}
}
Ok(())
}
pub async fn process_error<S, ST>(
socket: &mut Framed<S, PgWireMessageServerCodec<ST>>,
error: PgWireError,
wait_for_sync: bool,
) -> Result<(), io::Error>
where
S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
{
let error_info: ErrorInfo = error.into();
let is_fatal = error_info.is_fatal();
socket
.send(PgWireBackendMessage::ErrorResponse(error_info.into()))
.await?;
let transaction_status = socket.transaction_status().to_error_state();
socket.set_transaction_status(transaction_status);
if wait_for_sync {
socket.set_state(PgWireConnectionState::AwaitingSync);
} else {
socket.set_state(PgWireConnectionState::ReadyForQuery);
socket
.feed(PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new(
transaction_status,
)))
.await?;
}
socket.flush().await?;
if is_fatal {
return socket.close().await;
}
Ok(())
}
#[derive(Debug, PartialEq, Eq)]
enum SslNegotiationType {
Postgres,
Direct,
None,
}
async fn check_ssl_direct_negotiation(tcp_socket: &TcpStream) -> Result<bool, io::Error> {
let mut buf = [0u8; 1];
let n = tcp_socket.peek(&mut buf).await?;
Ok(n > 0 && buf[0] == 0x16)
}
async fn peek_for_sslrequest<ST>(
socket: &mut Framed<TcpStream, PgWireMessageServerCodec<ST>>,
ssl_supported: bool,
) -> Result<SslNegotiationType, io::Error> {
if check_ssl_direct_negotiation(socket.get_ref()).await? {
Ok(SslNegotiationType::Direct)
} else {
let mut ssl_done = false;
let mut gss_done = false;
loop {
match socket.next().await {
Some(Ok(PgWireFrontendMessage::SslNegotiation(
SslNegotiationMetaMessage::PostgresSsl(_),
))) => {
if ssl_supported {
socket
.send(PgWireBackendMessage::SslResponse(SslResponse::Accept))
.await?;
return Ok(SslNegotiationType::Postgres);
} else {
socket
.send(PgWireBackendMessage::SslResponse(SslResponse::Refuse))
.await?;
ssl_done = true;
if gss_done {
return Ok(SslNegotiationType::None);
} else {
continue;
}
}
}
Some(Ok(PgWireFrontendMessage::SslNegotiation(
SslNegotiationMetaMessage::PostgresGss(_),
))) => {
socket
.send(PgWireBackendMessage::GssEncResponse(GssEncResponse::Refuse))
.await?;
gss_done = true;
if ssl_done {
return Ok(SslNegotiationType::None);
} else {
continue;
}
}
_ => {
return Ok(SslNegotiationType::None);
}
}
}
}
}
#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
fn check_alpn_for_direct_ssl<IO>(tls_socket: &TlsStream<IO>) -> Result<(), io::Error> {
let (_, the_conn) = tls_socket.get_ref();
let mut accept = false;
if let Some(alpn) = the_conn.alpn_protocol()
&& alpn == super::POSTGRESQL_ALPN_NAME
{
accept = true;
}
if !accept {
Err(io::Error::new(
io::ErrorKind::InvalidData,
"received direct SSL connection request without ALPN protocol negotiation extension",
))
} else {
Ok(())
}
}
#[non_exhaustive]
pub enum MaybeTls {
Plain(TcpStream),
#[cfg(unix)]
Unix(UnixStream),
#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
Tls(Box<TlsStream<TcpStream>>),
}
macro_rules! maybe_tls {
($self:ident, $poll_x:ident($($args:expr),*)) => {
match $self.get_mut() {
MaybeTls::Plain(io) => Pin::new(io).$poll_x($($args),*),
#[cfg(unix)]
MaybeTls::Unix(io) => Pin::new(io).$poll_x($($args),*),
#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
MaybeTls::Tls(io) => Pin::new(io).$poll_x($($args),*),
}
};
}
impl AsyncRead for MaybeTls {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
maybe_tls!(self, poll_read(cx, buf))
}
}
impl AsyncWrite for MaybeTls {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
maybe_tls!(self, poll_write(cx, buf))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
maybe_tls!(self, poll_flush(cx))
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
maybe_tls!(self, poll_shutdown(cx))
}
}
pub async fn negotiate_tls<S>(
tcp_socket: TcpStream,
tls_acceptor: Option<crate::tokio::TlsAcceptor>,
) -> io::Result<Option<Framed<MaybeTls, PgWireMessageServerCodec<S>>>> {
let addr = tcp_socket.peer_addr()?;
tcp_socket.set_nodelay(true)?;
let client_info = DefaultClient::new(addr, false);
let mut tcp_socket = Framed::new(tcp_socket, PgWireMessageServerCodec::new(client_info));
let ssl = peek_for_sslrequest(&mut tcp_socket, tls_acceptor.is_some()).await?;
let old_parts = tcp_socket.into_parts();
if ssl == SslNegotiationType::None {
let mut parts = FramedParts::new(MaybeTls::Plain(old_parts.io), old_parts.codec);
parts.read_buf = old_parts.read_buf;
parts.write_buf = old_parts.write_buf;
let mut socket = Framed::from_parts(parts);
socket.set_state(PgWireConnectionState::AwaitingStartup);
return Ok(Some(socket));
}
#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
if let Some(tls_acceptor) = tls_acceptor {
let mut client_info = DefaultClient::new(addr, true);
let ssl_socket = Box::new(tls_acceptor.accept(old_parts.io).await?);
if ssl == SslNegotiationType::Direct {
check_alpn_for_direct_ssl(&ssl_socket)?;
}
let sni = {
let (_, conn) = ssl_socket.get_ref();
conn.server_name().map(|s| s.to_string())
};
if let Some(s) = sni {
client_info.sni_server_name = Some(s);
}
let mut parts = FramedParts::new(
MaybeTls::Tls(ssl_socket),
PgWireMessageServerCodec::new(client_info),
);
parts.read_buf = old_parts.read_buf;
parts.write_buf = old_parts.write_buf;
let mut socket = Framed::from_parts(parts);
socket.set_state(PgWireConnectionState::AwaitingStartup);
return Ok(Some(socket));
}
Ok(None)
}
macro_rules! process_socket_messages {
($socket:expr, $startup_timeout:expr, $handlers:expr) => {{
let startup_handler = $handlers.startup_handler();
let simple_query_handler = $handlers.simple_query_handler();
let extended_query_handler = $handlers.extended_query_handler();
let copy_handler = $handlers.copy_handler();
let cancel_handler = $handlers.cancel_handler();
let error_handler = $handlers.error_handler();
let socket = &mut $socket;
loop {
let msg = if matches!(
socket.state(),
PgWireConnectionState::AwaitingStartup
| PgWireConnectionState::AuthenticationInProgress
) {
tokio::select! {
_ = &mut $startup_timeout => None,
msg = socket.next() => msg,
}
} else {
socket.next().await
};
if let Some(Ok(msg)) = msg {
let is_extended_query = match socket.state() {
PgWireConnectionState::CopyInProgress(is_extended_query) => is_extended_query,
_ => msg.is_extended_query(),
};
if let Err(mut e) = process_message(
msg,
socket,
startup_handler.clone(),
simple_query_handler.clone(),
extended_query_handler.clone(),
copy_handler.clone(),
cancel_handler.clone(),
)
.await
{
error_handler.on_error(socket, &mut e);
process_error(socket, e, is_extended_query).await?;
}
} else {
break;
}
}
}};
}
#[cfg(unix)]
pub async fn process_socket_unix<H>(unix_socket: UnixStream, handlers: H) -> Result<(), io::Error>
where
H: PgWireServerHandlers,
{
let startup_timeout = sleep(Duration::from_millis(STARTUP_TIMEOUT_MILLIS));
tokio::pin!(startup_timeout);
let addr = "127.0.0.1:0".parse().unwrap();
let client_info = DefaultClient::new(addr, false);
let mut socket = Framed::new(
MaybeTls::Unix(unix_socket),
PgWireMessageServerCodec::new(client_info),
);
socket.set_state(PgWireConnectionState::AwaitingStartup);
process_socket_messages!(socket, startup_timeout, handlers);
Ok(())
}
pub async fn process_socket<H>(
tcp_socket: TcpStream,
tls_acceptor: Option<crate::tokio::TlsAcceptor>,
handlers: H,
) -> Result<(), io::Error>
where
H: PgWireServerHandlers,
{
let startup_timeout = sleep(Duration::from_millis(STARTUP_TIMEOUT_MILLIS));
tokio::pin!(startup_timeout);
let socket = tokio::select! {
_ = &mut startup_timeout => {
return Ok(())
},
socket = negotiate_tls(tcp_socket, tls_acceptor) => {
socket?
}
};
let Some(mut socket) = socket else {
return Ok(());
};
process_socket_messages!(socket, startup_timeout, handlers);
Ok(())
}
#[cfg(all(test, any(feature = "_ring", feature = "_aws-lc-rs")))]
mod tests {
use super::*;
use std::fs::File;
use std::io::{BufReader, Error as IOError};
use std::sync::Arc;
use tokio::sync::oneshot;
use tokio_rustls::TlsAcceptor;
use tokio_rustls::TlsConnector;
use tokio_rustls::rustls;
use tokio_rustls::rustls::crypto::CryptoProvider;
fn load_test_server_config() -> Result<rustls::ServerConfig, IOError> {
use rustls_pemfile::{certs, pkcs8_private_keys};
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
let certs = certs(&mut BufReader::new(File::open("examples/ssl/server.crt")?))
.collect::<Result<Vec<CertificateDer>, _>>()?;
let key = pkcs8_private_keys(&mut BufReader::new(File::open("examples/ssl/server.key")?))
.map(|key| key.map(PrivateKeyDer::from))
.collect::<Result<Vec<PrivateKeyDer>, _>>()?
.remove(0);
let mut cfg = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
cfg.alpn_protocols = vec![crate::tokio::POSTGRESQL_ALPN_NAME.to_vec()];
Ok(cfg)
}
fn make_test_client_connector() -> Result<TlsConnector, IOError> {
#[derive(Debug)]
struct NoCertVerifier;
impl rustls::client::danger::ServerCertVerifier for NoCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
{
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
{
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
]
}
}
let mut cfg = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoCertVerifier))
.with_no_client_auth();
cfg.alpn_protocols = vec![crate::tokio::POSTGRESQL_ALPN_NAME.to_vec()];
Ok(TlsConnector::from(Arc::new(cfg)))
}
#[tokio::test]
#[ignore]
async fn server_name_metadata_is_set_from_tls_sni() {
use std::net::SocketAddr;
use tokio::io::duplex;
let server_cfg = load_test_server_config().expect("server config");
let acceptor = TlsAcceptor::from(Arc::new(server_cfg));
let connector = make_test_client_connector().expect("client connector");
let (server_io, client_io) = duplex(64 * 1024);
let (tx, rx) = oneshot::channel::<Option<String>>();
tokio::spawn(async move {
let tls = acceptor.accept(server_io).await.unwrap();
let sni = {
let (_, conn) = tls.get_ref();
conn.server_name().map(|s| s.to_string())
};
let peer: SocketAddr = "127.0.0.1:0".parse().unwrap();
let mut ci: DefaultClient<()> = DefaultClient::new(peer, true);
if let Some(s) = sni {
ci.sni_server_name = Some(s);
}
let framed = Framed::new(tls, PgWireMessageServerCodec::new(ci));
let server_name = framed.sni_server_name().map(str::to_string);
let _ = tx.send(server_name);
});
let server_name = rustls_pki_types::ServerName::try_from("localhost").unwrap();
let _ = connector.connect(server_name, client_io).await.unwrap();
let observed = rx.await.expect("server_name from server");
assert_eq!(observed.as_deref(), Some("localhost"));
}
#[tokio::test]
async fn server_name_metadata_is_set_from_tls_sni_in_memory() {
use std::net::SocketAddr;
#[cfg(feature = "_aws-lc-rs")]
CryptoProvider::install_default(tokio_rustls::rustls::crypto::aws_lc_rs::default_provider()).unwrap();
#[cfg(feature = "_ring")]
CryptoProvider::install_default(tokio_rustls::rustls::crypto::ring::default_provider())
.unwrap();
let server_cfg = Arc::new(load_test_server_config().expect("server config"));
#[derive(Debug)]
struct NoCertVerifier;
impl rustls::client::danger::ServerCertVerifier for NoCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
{
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
{
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
]
}
}
let mut client_cfg = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoCertVerifier))
.with_no_client_auth();
client_cfg.alpn_protocols = vec![crate::tokio::POSTGRESQL_ALPN_NAME.to_vec()];
let client_cfg = Arc::new(client_cfg);
let mut server_conn = rustls::ServerConnection::new(server_cfg).unwrap();
let mut client_conn = rustls::ClientConnection::new(
client_cfg,
rustls_pki_types::ServerName::try_from("localhost").unwrap(),
)
.unwrap();
let mut c2s = Vec::new();
let mut s2c = Vec::new();
for _ in 0..1000 {
let _ = client_conn.write_tls(&mut c2s);
if !c2s.is_empty() {
let mut cur = std::io::Cursor::new(&c2s);
let _ = server_conn.read_tls(&mut cur);
c2s.clear();
server_conn.process_new_packets().unwrap();
}
let _ = server_conn.write_tls(&mut s2c);
if !s2c.is_empty() {
let mut cur = std::io::Cursor::new(&s2c);
let _ = client_conn.read_tls(&mut cur);
s2c.clear();
client_conn.process_new_packets().unwrap();
}
if !client_conn.is_handshaking() && !server_conn.is_handshaking() {
break;
}
}
let sni = server_conn.server_name().map(|s| s.to_string());
let peer: SocketAddr = "127.0.0.1:0".parse().unwrap();
let mut ci: DefaultClient<()> = DefaultClient::new(peer, true);
if let Some(s) = sni {
ci.sni_server_name = Some(s);
}
assert_eq!(ci.sni_server_name(), Some("localhost"));
}
}