use std::{
fmt::Debug,
io::{self, BufRead},
mem::MaybeUninit,
pin::Pin,
task::{Context, Poll},
};
use crate::{PinBoxFuture, compat::SyncStream};
pub struct AsyncStream<S> {
inner: Pin<Box<SyncStream<S>>>,
read_future: Option<PinBoxFuture<io::Result<usize>>>,
write_future: Option<PinBoxFuture<io::Result<usize>>>,
shutdown_future: Option<PinBoxFuture<io::Result<()>>>,
}
impl<S> AsyncStream<S> {
pub fn new(stream: S) -> Self {
Self::new_impl(SyncStream::new(stream))
}
pub fn with_capacity(cap: usize, stream: S) -> Self {
Self::new_impl(SyncStream::with_capacity(cap, stream))
}
fn new_impl(inner: SyncStream<S>) -> Self {
Self {
inner: Box::pin(inner),
read_future: None,
write_future: None,
shutdown_future: None,
}
}
pub fn get_ref(&self) -> &S {
self.inner.get_ref()
}
}
macro_rules! poll_future {
($f:expr, $cx:expr, $e:expr) => {{
let mut future = match $f.take() {
Some(f) => f,
None => Box::pin($e),
};
let f = future.as_mut();
match f.poll($cx) {
Poll::Pending => {
$f.replace(future);
return Poll::Pending;
}
Poll::Ready(res) => res,
}
}};
}
macro_rules! poll_future_would_block {
($f:expr, $cx:expr, $e:expr, $io:expr) => {{
if let Some(mut f) = $f.take() {
if f.as_mut().poll($cx).is_pending() {
$f.replace(f);
return Poll::Pending;
}
}
match $io {
Ok(len) => Poll::Ready(Ok(len)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
$f.replace(Box::pin($e));
$cx.waker().wake_by_ref();
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}};
}
impl<S: crate::AsyncRead + 'static> futures_util::AsyncRead for AsyncStream<S> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let inner: &'static mut SyncStream<S> =
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
poll_future_would_block!(
self.read_future,
cx,
inner.fill_read_buf(),
io::Read::read(inner, buf)
)
}
}
impl<S: crate::AsyncRead + 'static> AsyncStream<S> {
pub fn poll_read_uninit(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [MaybeUninit<u8>],
) -> Poll<io::Result<usize>> {
let inner: &'static mut SyncStream<S> =
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
poll_future_would_block!(
self.read_future,
cx,
inner.fill_read_buf(),
inner.read_buf_uninit(buf)
)
}
}
impl<S: crate::AsyncRead + 'static> futures_util::AsyncBufRead for AsyncStream<S> {
fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let inner: &'static mut SyncStream<S> =
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
poll_future_would_block!(
self.read_future,
cx,
inner.fill_read_buf(),
io::BufRead::fill_buf(inner).map(|slice| unsafe { &*(slice as *const _) })
)
}
fn consume(mut self: Pin<&mut Self>, amt: usize) {
unsafe { self.inner.as_mut().get_unchecked_mut().consume(amt) }
}
}
impl<S: crate::AsyncWrite + 'static> futures_util::AsyncWrite for AsyncStream<S> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if self.shutdown_future.is_some() {
debug_assert!(self.write_future.is_none());
return Poll::Pending;
}
let inner: &'static mut SyncStream<S> =
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
poll_future_would_block!(
self.write_future,
cx,
inner.flush_write_buf(),
io::Write::write(inner, buf)
)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if self.shutdown_future.is_some() {
debug_assert!(self.write_future.is_none());
return Poll::Pending;
}
let inner: &'static mut SyncStream<S> =
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
let res = poll_future!(self.write_future, cx, inner.flush_write_buf());
Poll::Ready(res.map(|_| ()))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if self.write_future.is_some() || self.inner.has_pending_write() {
debug_assert!(self.shutdown_future.is_none());
self.poll_flush(cx)
} else {
let inner: &'static mut SyncStream<S> =
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
let res = poll_future!(self.shutdown_future, cx, inner.get_mut().shutdown());
Poll::Ready(res)
}
}
}
impl<S: Debug> Debug for AsyncStream<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncStream")
.field("inner", &self.inner)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod test {
use futures_executor::block_on;
use futures_util::AsyncWriteExt;
use super::AsyncStream;
#[test]
fn close() {
block_on(async {
let mut stream = AsyncStream::new(Vec::<u8>::new());
let n = stream.write(b"hello").await.unwrap();
assert_eq!(n, 5);
stream.close().await.unwrap();
assert_eq!(stream.get_ref(), b"hello");
})
}
}