pub mod connector;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll}
};
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, Result},
net::TcpStream
};
use futures::future::BoxFuture;
#[cfg(unix)]
use tokio::net::UnixStream;
#[cfg(feature = "tls")]
use tokio_rustls::client::TlsStream;
pub use connector::{Connector, TcpConnInfo};
#[cfg(unix)]
pub use connector::UdsConnInfo;
#[cfg(feature = "tls")]
pub use connector::TlsTcpConnInfo;
#[allow(clippy::large_enum_variant)]
pub enum Stream {
Tcp(TcpStream),
#[cfg(unix)]
Uds(UnixStream),
#[cfg(feature = "tls")]
TlsTcp(TlsStream<TcpStream>)
}
impl Stream {
#[allow(clippy::result_large_err)]
pub fn try_into_tcp(self) -> std::result::Result<TcpStream, Self> {
if let Self::Tcp(strm) = self {
Ok(strm)
} else {
Err(self)
}
}
#[cfg(unix)]
#[allow(clippy::result_large_err)]
pub fn try_into_uds(self) -> std::result::Result<UnixStream, Self> {
if let Self::Uds(strm) = self {
Ok(strm)
} else {
Err(self)
}
}
#[cfg(unix)]
#[allow(clippy::result_large_err)]
pub fn try_into_tlstcp(
self
) -> std::result::Result<TlsStream<TcpStream>, Self> {
if let Self::TlsTcp(strm) = self {
Ok(strm)
} else {
Err(self)
}
}
}
impl Stream {
#[inline]
pub const fn reqflush(&self) -> bool {
match self {
Self::Tcp(_) => false,
#[cfg(unix)]
Self::Uds(_) => false,
#[cfg(feature = "tls")]
Self::TlsTcp(_) => true
}
}
pub fn ciphersuite(&self) -> Option<String> {
match self {
#[cfg(feature = "tls")]
Self::TlsTcp(strm) => {
let (_, conn) = strm.get_ref();
let ciphersuite = conn.negotiated_cipher_suite()?;
Some(format!("{:?}", ciphersuite.suite()))
}
_ => None
}
}
}
macro_rules! delegate_call {
($self:ident.$method:ident($($args:ident),+)) => {
unsafe {
match $self.get_unchecked_mut() {
Self::Tcp(s) => Pin::new_unchecked(s).$method($($args),+),
#[cfg(unix)]
Self::Uds(s) => Pin::new_unchecked(s).$method($($args),+),
#[cfg(feature = "tls")]
Self::TlsTcp(s) => Pin::new_unchecked(s).$method($($args),+),
}
}
}
}
impl AsyncRead for Stream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>
) -> Poll<Result<()>> {
delegate_call!(self.poll_read(cx, buf))
}
}
impl AsyncWrite for Stream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8]
) -> Poll<Result<usize>> {
delegate_call!(self.poll_write(cx, buf))
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<tokio::io::Result<()>> {
delegate_call!(self.poll_flush(cx))
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<tokio::io::Result<()>> {
delegate_call!(self.poll_shutdown(cx))
}
}
pub async fn with_conn<F, T>(mut strm: Stream, f: F) -> Result<T>
where
F: for<'a> FnOnce(&'a mut Stream) -> BoxFuture<'a, Result<T>>,
T: Send
{
let res = f(&mut strm).await;
strm.flush().await?;
strm.shutdown().await?;
res
}
pub async fn with_conn_owned<F, Fut, T>(strm: Stream, f: F) -> Result<T>
where
F: FnOnce(Stream) -> Fut,
Fut: Future<Output = (Stream, Result<T>)>,
T: Send
{
let (mut strm, res) = f(strm).await;
strm.flush().await?;
strm.shutdown().await?;
res
}