use std::{io, pin::Pin, task::Poll};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpStream, ToSocketAddrs, UnixStream},
};
use crate::{futures::map, io::AsyncIo};
#[derive(Debug)]
pub struct IoStream {
repr: Repr,
}
impl IoStream {
#[inline]
pub fn connect<A>(addr: A) -> impl Future<Output = io::Result<Self>>
where
A: ToSocketAddrs,
{
map(TcpStream::connect(addr), |e| match e {
Ok(ok) => Ok(Self {
repr: Repr::Tcp(ok),
}),
Err(err) => Err(err),
})
}
#[inline]
pub fn connect_unix<P>(path: P) -> impl Future<Output = io::Result<Self>>
where
P: AsRef<std::path::Path>,
{
map(UnixStream::connect(path), |e| match e {
Ok(ok) => Ok(Self {
repr: Repr::Unix(ok),
}),
Err(err) => Err(err),
})
}
}
impl From<TcpStream> for IoStream {
#[inline]
fn from(value: TcpStream) -> Self {
Self { repr: Repr::Tcp(value) }
}
}
impl From<UnixStream> for IoStream {
#[inline]
fn from(value: UnixStream) -> Self {
Self { repr: Repr::Unix(value) }
}
}
#[derive(Debug)]
enum Repr {
Tcp(TcpStream),
Unix(UnixStream),
}
impl AsyncIo for IoStream {
#[inline]
fn poll_read_ready(&self, cx: &mut std::task::Context) -> Poll<io::Result<()>> {
match &self.repr {
Repr::Tcp(t) => t.poll_read_ready(cx),
Repr::Unix(u) => u.poll_read_ready(cx),
}
}
#[inline]
fn poll_write_ready(&self, cx: &mut std::task::Context) -> Poll<io::Result<()>> {
match &self.repr {
Repr::Tcp(t) => t.poll_write_ready(cx),
Repr::Unix(u) => u.poll_write_ready(cx),
}
}
#[inline]
fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
match &self.repr {
Repr::Tcp(t) => t.try_read(buf),
Repr::Unix(u) => u.try_read(buf),
}
}
#[inline]
fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
match &self.repr {
Repr::Tcp(t) => t.try_read_vectored(bufs),
Repr::Unix(u) => u.try_read_vectored(bufs),
}
}
#[inline]
fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
match &self.repr {
Repr::Tcp(t) => t.try_write(buf),
Repr::Unix(u) => u.try_write(buf),
}
}
#[inline]
fn try_write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
match &self.repr {
Repr::Tcp(t) => t.try_write_vectored(bufs),
Repr::Unix(u) => u.try_write_vectored(bufs),
}
}
#[inline]
fn is_write_vectored(&self) -> bool {
match &self.repr {
Repr::Tcp(t) => AsyncWrite::is_write_vectored(t),
Repr::Unix(u) => AsyncWrite::is_write_vectored(u),
}
}
}
impl AsyncRead for IoStream {
#[inline]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match &mut self.repr {
Repr::Tcp(t) => Pin::new(t).poll_read(cx, buf),
Repr::Unix(u) => Pin::new(u).poll_read(cx, buf),
}
}
}
impl AsyncWrite for IoStream {
#[inline]
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match &mut self.repr {
Repr::Tcp(t) => Pin::new(t).poll_write(cx, buf),
Repr::Unix(u) => Pin::new(u).poll_write(cx, buf),
}
}
#[inline]
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
match &mut self.repr {
Repr::Tcp(t) => Pin::new(t).poll_flush(cx),
Repr::Unix(u) => Pin::new(u).poll_flush(cx),
}
}
#[inline]
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
match &mut self.repr {
Repr::Tcp(t) => Pin::new(t).poll_shutdown(cx),
Repr::Unix(u) => Pin::new(u).poll_shutdown(cx),
}
}
#[inline]
fn is_write_vectored(&self) -> bool {
match &self.repr {
Repr::Tcp(t) => AsyncWrite::is_write_vectored(t),
Repr::Unix(u) => AsyncWrite::is_write_vectored(u),
}
}
#[inline]
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
match &mut self.repr {
Repr::Tcp(t) => Pin::new(t).poll_write_vectored(cx, bufs),
Repr::Unix(u) => Pin::new(u).poll_write_vectored(cx, bufs),
}
}
}