1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
use crate::config::SslMode; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::private::ForcePrivateApi; use crate::tls::{ChannelBinding, TlsConnect}; use crate::Error; use bytes::BytesMut; use postgres_protocol::message::frontend; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub async fn connect_tls<S, T>( mut stream: S, mode: SslMode, tls: T, ) -> Result<(MaybeTlsStream<S, T::Stream>, ChannelBinding), Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect<S>, { match mode { SslMode::Disable => return Ok((MaybeTlsStream::Raw(stream), ChannelBinding::none())), SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => { return Ok((MaybeTlsStream::Raw(stream), ChannelBinding::none())) } SslMode::Prefer | SslMode::Require => {} SslMode::__NonExhaustive => unreachable!(), } let mut buf = BytesMut::new(); frontend::ssl_request(&mut buf); stream.write_all(&buf).await.map_err(Error::io)?; let mut buf = [0]; stream.read_exact(&mut buf).await.map_err(Error::io)?; if buf[0] != b'S' { if SslMode::Require == mode { return Err(Error::tls("server does not support TLS".into())); } else { return Ok((MaybeTlsStream::Raw(stream), ChannelBinding::none())); } } let (stream, channel_binding) = tls .connect(stream) .await .map_err(|e| Error::tls(e.into()))?; Ok((MaybeTlsStream::Tls(stream), channel_binding)) }