madsim_tokio_postgres/
connect_tls.rs1use crate::config::SslMode;
2use crate::maybe_tls_stream::MaybeTlsStream;
3use crate::tls::private::ForcePrivateApi;
4use crate::tls::TlsConnect;
5use crate::Error;
6use bytes::BytesMut;
7use postgres_protocol::message::frontend;
8use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
9
10pub async fn connect_tls<S, T>(
11 mut stream: S,
12 mode: SslMode,
13 tls: T,
14) -> Result<MaybeTlsStream<S, T::Stream>, Error>
15where
16 S: AsyncRead + AsyncWrite + Unpin,
17 T: TlsConnect<S>,
18{
19 match mode {
20 SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
21 SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => {
22 return Ok(MaybeTlsStream::Raw(stream))
23 }
24 SslMode::Prefer | SslMode::Require => {}
25 }
26
27 let mut buf = BytesMut::new();
28 frontend::ssl_request(&mut buf);
29 stream.write_all(&buf).await.map_err(Error::io)?;
30
31 let mut buf = [0];
32 stream.read_exact(&mut buf).await.map_err(Error::io)?;
33
34 if buf[0] != b'S' {
35 if SslMode::Require == mode {
36 return Err(Error::tls("server does not support TLS".into()));
37 } else {
38 return Ok(MaybeTlsStream::Raw(stream));
39 }
40 }
41
42 let stream = tls
43 .connect(stream)
44 .await
45 .map_err(|e| Error::tls(e.into()))?;
46
47 Ok(MaybeTlsStream::Tls(stream))
48}