use futures_channel::mpsc;
use futures_core::Stream;
use futures_io::{AsyncRead, AsyncWrite};
use std::io::{self, Cursor};
use std::pin::Pin;
use std::task::{Context, Poll};
pub(crate) fn new(count: usize) -> (Reader, Writer) {
let (mut buf_pool_tx, buf_pool_rx) = mpsc::channel(count);
let (buf_stream_tx, buf_stream_rx) = mpsc::channel(count);
for _ in 0..count {
buf_pool_tx.try_send(Cursor::new(Vec::new())).expect("buffer pool overflow");
}
let reader = Reader {
buf_pool_tx,
buf_stream_rx,
chunk: None,
};
let writer = Writer {
buf_pool_rx,
buf_stream_tx,
};
(reader, writer)
}
pub(crate) struct Reader {
buf_pool_tx: mpsc::Sender<Cursor<Vec<u8>>>,
buf_stream_rx: mpsc::Receiver<Cursor<Vec<u8>>>,
chunk: Option<Cursor<Vec<u8>>>,
}
impl AsyncRead for Reader {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
let mut chunk = match self.chunk.take() {
Some(chunk) => chunk,
None => match Pin::new(&mut self.buf_stream_rx).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(Ok(0)),
Poll::Ready(Some(buf)) => buf,
}
};
let len = match Pin::new(&mut chunk).poll_read(cx, buf) {
Poll::Pending => unreachable!(),
Poll::Ready(Ok(len)) => len,
Poll::Ready(Err(e)) => panic!("cursor returned an error: {}", e),
};
if chunk.position() < chunk.get_ref().len() as u64 {
self.chunk = Some(chunk);
}
else {
chunk.set_position(0);
chunk.get_mut().clear();
match self.buf_pool_tx.try_send(chunk) {
Ok(()) => {}
Err(e) => {
if e.is_full() {
panic!("buffer pool overflow")
}
else if e.is_disconnected() {
}
else {
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
}
}
}
}
Poll::Ready(Ok(len))
}
}
pub(crate) struct Writer {
buf_pool_rx: mpsc::Receiver<Cursor<Vec<u8>>>,
buf_stream_tx: mpsc::Sender<Cursor<Vec<u8>>>,
}
impl AsyncWrite for Writer {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
if self.buf_stream_tx.is_closed() {
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
}
match Pin::new(&mut self.buf_pool_rx).poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())),
Poll::Ready(Some(mut chunk)) => {
chunk.get_mut().extend_from_slice(buf);
match self.buf_stream_tx.try_send(chunk) {
Ok(()) => Poll::Ready(Ok(buf.len())),
Err(e) => {
if e.is_full() {
panic!("buffer pool overflow")
} else {
Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
}
}
}
}
}
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
self.buf_stream_tx.close_channel();
Poll::Ready(Ok(()))
}
}
#[cfg(all(test, feature = "nightly"))]
mod tests {
use futures::executor::block_on;
use futures::prelude::*;
use futures::task::noop_waker;
use super::*;
#[test]
fn read_then_write() {
block_on(async {
let (mut reader, mut writer) = new(1);
writer.write_all(b"hello").await.unwrap();
let mut dest = [0; 5];
assert_eq!(reader.read(&mut dest).await.unwrap(), 5);
assert_eq!(&dest, b"hello");
})
}
#[test]
fn reader_still_drainable_after_writer_disconnects() {
block_on(async {
let (mut reader, mut writer) = new(1);
writer.write_all(b"hello").await.unwrap();
drop(writer);
let mut dest = [0; 5];
assert_eq!(reader.read(&mut dest).await.unwrap(), 5);
assert_eq!(&dest, b"hello");
assert_eq!(reader.read(&mut dest).await.unwrap(), 0);
})
}
#[test]
fn writer_errors_if_reader_is_dropped() {
let waker = noop_waker();
let mut context = Context::from_waker(&waker);
let (reader, mut writer) = new(2);
drop(reader);
match writer.write(b"hello").poll_unpin(&mut context) {
Poll::Ready(Err(e)) => assert_eq!(e.kind(), io::ErrorKind::BrokenPipe),
_ => panic!("expected poll to be ready"),
}
}
}