use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures::Stream;
use tokio::io::AsyncWrite;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TrySendError;
use tokio_util::io::StreamReader;
pub fn pipe(buffer_size: usize) -> (ChannelWriter, ChannelStream) {
let (sender, receiver) = mpsc::channel(buffer_size);
(ChannelWriter::new(sender), ChannelStream::new(receiver))
}
pub struct ChannelStream(mpsc::Receiver<io::Result<Bytes>>);
impl ChannelStream {
fn new(rx: mpsc::Receiver<io::Result<Bytes>>) -> Self {
Self(rx)
}
pub fn reader(self) -> ChannelReader {
ChannelReader::new(self)
}
}
impl Stream for ChannelStream {
type Item = io::Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.0.poll_recv(cx)
}
}
type ChannelReader = StreamReader<ChannelStream, Bytes>;
pub struct ChannelWriter(mpsc::Sender<io::Result<Bytes>>);
impl ChannelWriter {
fn new(tx: mpsc::Sender<io::Result<Bytes>>) -> Self {
Self(tx)
}
}
impl AsyncWrite for ChannelWriter {
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
let perm = match self.0.try_reserve() {
Ok(p) => p,
Err(TrySendError::Closed(_)) => {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::BrokenPipe, "receiver closed")));
},
Err(TrySendError::Full(_)) => return Poll::Pending,
};
let data = Bytes::copy_from_slice(buf);
let len = data.len();
perm.send(Ok(data));
Poll::Ready(Ok(len))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[cfg(test)]
mod tests {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::*;
#[tokio::test]
async fn test_channel_read_write() {
let (mut writer, stream) = pipe(10);
let mut reader = stream.reader();
writer.write_all(b"Hello, ").await.unwrap();
writer.write_all(b"World!").await.unwrap();
drop(writer);
let mut buf = Vec::new();
reader.read_to_end(&mut buf).await.unwrap();
assert_eq!(buf, b"Hello, World!");
}
}