use bytes::Buf;
use futures_core::stream::Stream;
use futures_sink::Sink;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
#[derive(Debug)]
pub struct StreamReader<S, B> {
inner: S,
chunk: Option<B>,
}
impl<S, B, E> StreamReader<S, B>
where
S: Stream<Item = Result<B, E>>,
B: Buf,
E: Into<std::io::Error>,
{
pub fn new(stream: S) -> Self {
Self {
inner: stream,
chunk: None,
}
}
fn has_chunk(&self) -> bool {
if let Some(ref chunk) = self.chunk {
chunk.remaining() > 0
} else {
false
}
}
pub fn into_inner_with_chunk(self) -> (S, Option<B>) {
if self.has_chunk() {
(self.inner, self.chunk)
} else {
(self.inner, None)
}
}
}
impl<S, B> StreamReader<S, B> {
pub fn get_ref(&self) -> &S {
&self.inner
}
pub fn get_mut(&mut self) -> &mut S {
&mut self.inner
}
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
self.project().inner
}
pub fn into_inner(self) -> S {
self.inner
}
}
impl<S, B, E> AsyncRead for StreamReader<S, B>
where
S: Stream<Item = Result<B, E>>,
B: Buf,
E: Into<std::io::Error>,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if buf.remaining() == 0 {
return Poll::Ready(Ok(()));
}
let inner_buf = match self.as_mut().poll_fill_buf(cx) {
Poll::Ready(Ok(buf)) => buf,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
};
let len = std::cmp::min(inner_buf.len(), buf.remaining());
buf.put_slice(&inner_buf[..len]);
self.consume(len);
Poll::Ready(Ok(()))
}
}
impl<S, B, E> AsyncBufRead for StreamReader<S, B>
where
S: Stream<Item = Result<B, E>>,
B: Buf,
E: Into<std::io::Error>,
{
fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
loop {
if self.as_mut().has_chunk() {
let buf = self.project().chunk.as_ref().unwrap().chunk();
return Poll::Ready(Ok(buf));
} else {
match self.as_mut().project().inner.poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
*self.as_mut().project().chunk = Some(chunk);
}
Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())),
Poll::Ready(None) => return Poll::Ready(Ok(&[])),
Poll::Pending => return Poll::Pending,
}
}
}
}
fn consume(self: Pin<&mut Self>, amt: usize) {
if amt > 0 {
self.project()
.chunk
.as_mut()
.expect("No chunk present")
.advance(amt);
}
}
}
impl<S: Unpin, B> Unpin for StreamReader<S, B> {}
struct StreamReaderProject<'a, S, B> {
inner: Pin<&'a mut S>,
chunk: &'a mut Option<B>,
}
impl<S, B> StreamReader<S, B> {
#[inline]
fn project(self: Pin<&mut Self>) -> StreamReaderProject<'_, S, B> {
let me = unsafe { Pin::into_inner_unchecked(self) };
StreamReaderProject {
inner: unsafe { Pin::new_unchecked(&mut me.inner) },
chunk: &mut me.chunk,
}
}
}
impl<S: Sink<T, Error = E>, B, E, T> Sink<T> for StreamReader<S, B> {
type Error = E;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}