#![allow(clippy::missing_inline_in_public_items)]
use std::{
io,
task::{Poll, ready},
};
pub trait AsyncIo {
fn poll_read_ready(&self, cx: &mut std::task::Context) -> Poll<io::Result<()>>;
fn poll_write_ready(&self, cx: &mut std::task::Context) -> Poll<io::Result<()>>;
fn try_read(&self, buf: &mut [u8]) -> io::Result<usize>;
fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize>;
fn try_write(&self, buf: &[u8]) -> io::Result<usize>;
fn try_write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize>;
fn is_write_vectored(&self) -> bool;
fn poll_read(
&self,
buf: &mut [u8],
cx: &mut std::task::Context,
) -> Poll<io::Result<usize>> {
match self.try_read(buf) {
Ok(read) => Poll::Ready(Ok(read)),
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
tri!(ready!(self.poll_read_ready(cx)));
self.poll_read(buf, cx)
}
Err(err) => Poll::Ready(Err(err)),
}
}
fn poll_read_buf<B>(
&self,
buf: &mut B,
cx: &mut std::task::Context,
) -> Poll<io::Result<usize>>
where
B: bytes::BufMut + ?Sized,
{
if !buf.has_remaining_mut() {
return Poll::Ready(Ok(0));
}
let read = {
let dst = unsafe {
&mut *(buf.chunk_mut().as_uninit_slice_mut() as *mut [std::mem::MaybeUninit<u8>]
as *mut [u8])
};
tri!(ready!(self.poll_read(dst, cx)))
};
unsafe {
buf.advance_mut(read);
}
Poll::Ready(Ok(read))
}
fn poll_write(&self, buf: &[u8], cx: &mut std::task::Context) -> Poll<io::Result<usize>> {
match self.try_write(buf) {
Ok(read) => Poll::Ready(Ok(read)),
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
tri!(ready!(self.poll_write_ready(cx)));
self.poll_write(buf, cx)
}
Err(err) => Poll::Ready(Err(err)),
}
}
fn poll_write_vectored(
&self,
bufs: &[io::IoSlice<'_>],
cx: &mut std::task::Context,
) -> Poll<io::Result<usize>> {
match self.try_write_vectored(bufs) {
Ok(read) => Poll::Ready(Ok(read)),
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
tri!(ready!(self.poll_write_ready(cx)));
self.poll_write_vectored(bufs, cx)
}
Err(err) => Poll::Ready(Err(err)),
}
}
fn poll_write_buf<B>(
&self,
buf: &mut B,
cx: &mut std::task::Context,
) -> Poll<io::Result<usize>>
where
B: bytes::Buf + ?Sized,
{
self.poll_write(buf.chunk(), cx)
.map(|e| e.inspect(|&read| buf.advance(read)))
}
fn poll_write_all_buf<B>(&self, buf: &mut B) -> Poll<io::Result<()>>
where
B: bytes::Buf + ?Sized,
{
const MAX_VECTOR_ELEMENTS: usize = 64;
while buf.has_remaining() {
let read = if self.is_write_vectored() {
let mut slices = [io::IoSlice::new(&[]); MAX_VECTOR_ELEMENTS];
let cnt = buf.chunks_vectored(&mut slices);
tri!(self.try_write_vectored(&slices[..cnt]))
} else {
tri!(self.try_write(buf.chunk()))
};
buf.advance(read);
if read == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
}
}
Poll::Ready(Ok(()))
}
}
macro_rules! tri {
($e:expr) => {
match $e {
Ok(ok) => ok,
Err(err) => return Poll::Ready(Err(err)),
}
};
}
use tri;
#[cfg(feature = "tokio")]
mod tokio_io {
use super::*;
use tokio::{
io::AsyncWrite,
net::{TcpStream, UnixStream},
};
impl AsyncIo for TcpStream {
#[inline]
fn poll_read_ready(&self, cx: &mut std::task::Context) -> Poll<io::Result<()>> {
self.poll_read_ready(cx)
}
#[inline]
fn poll_write_ready(&self, cx: &mut std::task::Context) -> Poll<io::Result<()>> {
self.poll_write_ready(cx)
}
#[inline]
fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
self.try_read(buf)
}
#[inline]
fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
self.try_read_vectored(bufs)
}
#[inline]
fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
self.try_write(buf)
}
#[inline]
fn try_write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
self.try_write_vectored(bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
AsyncWrite::is_write_vectored(self)
}
}
impl AsyncIo for UnixStream {
#[inline]
fn poll_read_ready(&self, cx: &mut std::task::Context) -> Poll<io::Result<()>> {
self.poll_read_ready(cx)
}
#[inline]
fn poll_write_ready(&self, cx: &mut std::task::Context) -> Poll<io::Result<()>> {
self.poll_write_ready(cx)
}
#[inline]
fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
self.try_read(buf)
}
#[inline]
fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
self.try_read_vectored(bufs)
}
#[inline]
fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
self.try_write(buf)
}
#[inline]
fn try_write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
self.try_write_vectored(bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
AsyncWrite::is_write_vectored(self)
}
}
}