use byteorder::{ByteOrder, NetworkEndian};
use bytes::BytesMut;
use futures_core::{ready, Stream};
use serde::Deserialize;
use std::io;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_io::AsyncRead;
#[derive(Debug)]
pub struct AsyncBincodeReader<R, T> {
reader: R,
pub(crate) buffer: BytesMut,
into: PhantomData<T>,
}
impl<R, T> Unpin for AsyncBincodeReader<R, T> where R: Unpin {}
impl<R, T> Default for AsyncBincodeReader<R, T>
where
R: Default,
{
fn default() -> Self {
Self::from(R::default())
}
}
impl<R, T> AsyncBincodeReader<R, T> {
pub fn get_ref(&self) -> &R {
&self.reader
}
pub fn get_mut(&mut self) -> &mut R {
&mut self.reader
}
pub fn buffer(&self) -> &[u8] {
&self.buffer[..]
}
pub fn into_inner(self) -> R {
self.reader
}
}
impl<R, T> From<R> for AsyncBincodeReader<R, T> {
fn from(reader: R) -> Self {
AsyncBincodeReader {
buffer: BytesMut::with_capacity(8192),
reader,
into: PhantomData,
}
}
}
impl<R, T> Stream for AsyncBincodeReader<R, T>
where
for<'a> T: Deserialize<'a>,
R: AsyncRead + Unpin,
{
type Item = Result<T, bincode::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
if let FillResult::EOF = ready!(self.as_mut().fill(cx, 5).map_err(bincode::Error::from))? {
return Poll::Ready(None);
}
let message_size: u32 = NetworkEndian::read_u32(&self.buffer[..4]);
let target_buffer_size = message_size as usize;
ready!(self
.as_mut()
.fill(cx, target_buffer_size + 4)
.map_err(bincode::Error::from))?;
self.buffer.advance(4);
let message = bincode::deserialize(&self.buffer[..target_buffer_size])?;
self.buffer.advance(target_buffer_size);
Poll::Ready(Some(Ok(message)))
}
}
enum FillResult {
Filled,
EOF,
}
impl<R, T> AsyncBincodeReader<R, T>
where
for<'a> T: Deserialize<'a>,
R: AsyncRead + Unpin,
{
fn fill(
mut self: Pin<&mut Self>,
cx: &mut Context,
target_size: usize,
) -> Poll<Result<FillResult, io::Error>> {
if self.buffer.len() >= target_size {
return Poll::Ready(Ok(FillResult::Filled));
}
if self.buffer.capacity() < target_size {
let missing = target_size - self.buffer.capacity();
self.buffer.reserve(missing);
}
let had = self.buffer.len();
let mut rest = self.buffer.split_off(had);
let max = rest.capacity();
unsafe { rest.set_len(max) };
while self.buffer.len() < target_size {
let n = ready!(Pin::new(&mut self.reader).poll_read(cx, &mut rest[..]))?;
if n == 0 {
if self.buffer.len() == 0 {
return Poll::Ready(Ok(FillResult::EOF));
} else {
return Poll::Ready(Err(io::Error::from(io::ErrorKind::BrokenPipe)));
}
}
let read = rest.split_to(n);
self.buffer.unsplit(read);
}
Poll::Ready(Ok(FillResult::Filled))
}
}