use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream as ClientTlsStream;
use tokio_rustls::server::TlsStream as ServerTlsStream;
pub enum UnifiedStream {
Plain(TcpStream),
ClientTls(ClientTlsStream<TcpStream>),
ServerTls(ServerTlsStream<TcpStream>),
}
impl UnifiedStream {
pub fn into_split(self) -> (UnifiedReadHalf, UnifiedWriteHalf) {
match self {
UnifiedStream::Plain(stream) => {
let (read, write) = stream.into_split();
(UnifiedReadHalf::Plain(read), UnifiedWriteHalf::Plain(write))
}
UnifiedStream::ClientTls(stream) => {
let (read, write) = tokio::io::split(stream);
(
UnifiedReadHalf::ClientTls(read),
UnifiedWriteHalf::ClientTls(write),
)
}
UnifiedStream::ServerTls(stream) => {
let (read, write) = tokio::io::split(stream);
(
UnifiedReadHalf::ServerTls(read),
UnifiedWriteHalf::ServerTls(write),
)
}
}
}
}
impl AsyncRead for UnifiedStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
UnifiedStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
UnifiedStream::ClientTls(stream) => Pin::new(stream).poll_read(cx, buf),
UnifiedStream::ServerTls(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl AsyncWrite for UnifiedStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
UnifiedStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
UnifiedStream::ClientTls(stream) => Pin::new(stream).poll_write(cx, buf),
UnifiedStream::ServerTls(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
UnifiedStream::Plain(stream) => Pin::new(stream).poll_flush(cx),
UnifiedStream::ClientTls(stream) => Pin::new(stream).poll_flush(cx),
UnifiedStream::ServerTls(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
UnifiedStream::Plain(stream) => Pin::new(stream).poll_shutdown(cx),
UnifiedStream::ClientTls(stream) => Pin::new(stream).poll_shutdown(cx),
UnifiedStream::ServerTls(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
pub enum UnifiedReadHalf {
Plain(OwnedReadHalf),
ClientTls(tokio::io::ReadHalf<ClientTlsStream<TcpStream>>),
ServerTls(tokio::io::ReadHalf<ServerTlsStream<TcpStream>>),
}
impl AsyncRead for UnifiedReadHalf {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
UnifiedReadHalf::Plain(read) => Pin::new(read).poll_read(cx, buf),
UnifiedReadHalf::ClientTls(read) => Pin::new(read).poll_read(cx, buf),
UnifiedReadHalf::ServerTls(read) => Pin::new(read).poll_read(cx, buf),
}
}
}
pub enum UnifiedWriteHalf {
Plain(OwnedWriteHalf),
ClientTls(tokio::io::WriteHalf<ClientTlsStream<TcpStream>>),
ServerTls(tokio::io::WriteHalf<ServerTlsStream<TcpStream>>),
}
impl AsyncWrite for UnifiedWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
UnifiedWriteHalf::Plain(write) => Pin::new(write).poll_write(cx, buf),
UnifiedWriteHalf::ClientTls(write) => Pin::new(write).poll_write(cx, buf),
UnifiedWriteHalf::ServerTls(write) => Pin::new(write).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
UnifiedWriteHalf::Plain(write) => Pin::new(write).poll_flush(cx),
UnifiedWriteHalf::ClientTls(write) => Pin::new(write).poll_flush(cx),
UnifiedWriteHalf::ServerTls(write) => Pin::new(write).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
UnifiedWriteHalf::Plain(write) => Pin::new(write).poll_shutdown(cx),
UnifiedWriteHalf::ClientTls(write) => Pin::new(write).poll_shutdown(cx),
UnifiedWriteHalf::ServerTls(write) => Pin::new(write).poll_shutdown(cx),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unified_stream_variants() {
fn _assert_send<T: Send>() {}
fn _assert_sync<T: Sync>() {}
_assert_send::<UnifiedStream>();
_assert_send::<UnifiedReadHalf>();
_assert_send::<UnifiedWriteHalf>();
}
}