#[cfg(doc)]
use crate::{Client, WorkerBuilder};
use crate::error::{self, Error};
use crate::proto::{self, utils};
use crate::Reconnect;
use std::io;
use std::ops::{Deref, DerefMut};
use tokio::io::{AsyncRead, AsyncWrite, BufStream};
use tokio::net::TcpStream as TokioTcpStream;
use tokio_native_tls::TlsStream as NativeTlsStream;
use tokio_native_tls::{native_tls::TlsConnector, TlsConnector as AsyncTlsConnector};
#[pin_project::pin_project]
pub struct TlsStream<S> {
connector: AsyncTlsConnector,
hostname: String,
#[pin]
stream: NativeTlsStream<S>,
}
impl TlsStream<TokioTcpStream> {
pub async fn connect() -> Result<Self, Error> {
TlsStream::with_connector(
TlsConnector::builder()
.build()
.map_err(error::Stream::NativeTls)?,
None,
)
.await
}
pub async fn connect_to(addr: &str) -> Result<Self, Error> {
TlsStream::with_connector(
TlsConnector::builder()
.build()
.map_err(error::Stream::NativeTls)?,
Some(addr),
)
.await
}
pub async fn with_connector(connector: TlsConnector, url: Option<&str>) -> Result<Self, Error> {
let url = match url {
Some(url) => utils::url_parse(url),
None => utils::url_parse(&utils::get_env_url()),
}?;
let hostname = utils::host_from_url(&url);
let tcp_stream = TokioTcpStream::connect(&hostname).await?;
Ok(TlsStream::new(tcp_stream, connector, hostname).await?)
}
}
impl<S> TlsStream<S>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
pub async fn default(stream: S, hostname: String) -> io::Result<Self> {
let connector = TlsConnector::builder()
.build()
.map_err(error::Stream::NativeTls)
.unwrap();
Self::new(stream, connector, hostname).await
}
pub async fn new(
stream: S,
connector: impl Into<AsyncTlsConnector>,
hostname: String,
) -> io::Result<Self> {
let connector: AsyncTlsConnector = connector.into();
let tls_stream = connector
.connect(&hostname, stream)
.await
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionAborted, e))?;
Ok(TlsStream {
connector,
hostname,
stream: tls_stream,
})
}
}
#[async_trait::async_trait]
impl Reconnect for BufStream<TlsStream<tokio::net::TcpStream>> {
async fn reconnect(&mut self) -> io::Result<proto::BoxedConnection> {
let stream = self
.get_mut()
.stream
.get_mut()
.get_mut()
.get_mut()
.reconnect()
.await?;
let res = TlsStream::new(
stream,
self.get_ref().connector.clone(),
self.get_ref().hostname.clone(),
)
.await?;
let buffered = BufStream::new(res);
Ok(Box::new(buffered))
}
}
#[async_trait::async_trait]
impl Reconnect for BufStream<TlsStream<proto::BoxedConnection>> {
async fn reconnect(&mut self) -> io::Result<proto::BoxedConnection> {
let stream = self
.get_mut()
.stream
.get_mut()
.get_mut()
.get_mut()
.reconnect()
.await?;
let res = TlsStream::new(
stream,
self.get_ref().connector.clone(),
self.get_ref().hostname.clone(),
)
.await?;
let buffered = BufStream::new(res);
Ok(Box::new(buffered))
}
}
impl<S> Deref for TlsStream<S> {
type Target = NativeTlsStream<S>;
fn deref(&self) -> &Self::Target {
&self.stream
}
}
impl<S> DerefMut for TlsStream<S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.stream
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for TlsStream<S> {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<io::Result<()>> {
self.project().stream.poll_read(cx, buf)
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for TlsStream<S> {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, io::Error>> {
self.project().stream.poll_write(cx, buf)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
self.project().stream.poll_flush(cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
self.project().stream.poll_shutdown(cx)
}
}