#[cfg(doc)]
use crate::{Client, WorkerBuilder};
use crate::proto::{self, utils};
use crate::{Error, Reconnect};
use std::io;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, BufStream};
use tokio::net::TcpStream as TokioTcpStream;
use tokio_rustls::client::TlsStream as RustlsStream;
use tokio_rustls::rustls::{ClientConfig, RootCertStore};
use tokio_rustls::TlsConnector;
#[pin_project::pin_project]
pub struct TlsStream<S> {
connector: TlsConnector,
hostname: String,
#[pin]
stream: RustlsStream<S>,
}
impl TlsStream<TokioTcpStream> {
pub async fn connect() -> Result<Self, Error> {
let config = ClientConfig::builder()
.with_root_certificates(RootCertStore::empty())
.with_no_client_auth();
let connnector = TlsConnector::from(Arc::new(config));
TlsStream::with_connector(connnector, None).await
}
pub async fn connect_to(addr: &str) -> Result<Self, Error> {
let config = ClientConfig::builder()
.with_root_certificates(RootCertStore::empty())
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(config));
TlsStream::with_connector(connector, Some(addr)).await
}
pub async fn connect_with_native_certs_to(addr: &str) -> Result<Self, Error> {
let mut store = RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs()? {
store.add(cert).map_err(io::Error::other)?;
}
let config = ClientConfig::builder()
.with_root_certificates(store)
.with_no_client_auth();
TlsStream::with_connector(TlsConnector::from(Arc::new(config)), Some(addr)).await
}
pub async fn with_client_config(conf: ClientConfig, url: Option<&str>) -> Result<Self, Error> {
let con = TlsConnector::from(Arc::new(conf));
TlsStream::with_connector(con, url).await
}
pub async fn with_connector(connector: TlsConnector, url: Option<&str>) -> Result<Self, Error> {
let url = match url {
Some(url) => utils::url_parse(url),
None => utils::url_parse(&utils::get_env_url()),
}?;
let host_and_port = utils::host_from_url(&url);
let tcp_stream = TokioTcpStream::connect(&host_and_port).await?;
let host = url.host_str().unwrap().to_string();
Ok(TlsStream::new(tcp_stream, connector, host).await?)
}
}
impl<S> TlsStream<S>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
pub async fn default(stream: S, hostname: String) -> io::Result<Self> {
let conf = ClientConfig::builder()
.with_root_certificates(RootCertStore::empty())
.with_no_client_auth();
Self::new(stream, TlsConnector::from(Arc::new(conf)), hostname).await
}
pub async fn new(stream: S, connector: TlsConnector, hostname: String) -> io::Result<Self> {
let server_name = hostname
.clone()
.try_into()
.expect("a valid DNS name or IP address");
let tls_stream = connector
.connect(server_name, stream)
.await
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionAborted, e))?;
Ok(TlsStream {
connector,
hostname,
stream: tls_stream,
})
}
}
#[async_trait::async_trait]
impl Reconnect for BufStream<TlsStream<tokio::net::TcpStream>> {
async fn reconnect(&mut self) -> io::Result<proto::BoxedConnection> {
let stream = self.get_mut().stream.get_mut().0.reconnect().await?;
let tls_stream = TlsStream::new(
stream,
self.get_ref().connector.clone(),
self.get_ref().hostname.clone(),
)
.await?;
let buffered = BufStream::new(tls_stream);
Ok(Box::new(buffered))
}
}
#[async_trait::async_trait]
impl Reconnect for BufStream<TlsStream<proto::BoxedConnection>> {
async fn reconnect(&mut self) -> io::Result<proto::BoxedConnection> {
let stream = self.get_mut().stream.get_mut().0.reconnect().await?;
let tls_stream = TlsStream::new(
stream,
self.get_ref().connector.clone(),
self.get_ref().hostname.clone(),
)
.await?;
let buffered = BufStream::new(tls_stream);
Ok(Box::new(buffered))
}
}
impl<S> Deref for TlsStream<S> {
type Target = RustlsStream<S>;
fn deref(&self) -> &Self::Target {
&self.stream
}
}
impl<S> DerefMut for TlsStream<S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.stream
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for TlsStream<S> {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<io::Result<()>> {
self.project().stream.poll_read(cx, buf)
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for TlsStream<S> {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, io::Error>> {
self.project().stream.poll_write(cx, buf)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
self.project().stream.poll_flush(cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
self.project().stream.poll_shutdown(cx)
}
}