use byteorder::{ByteOrder, NetworkEndian};
use std::{
io,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use bytes::{Buf, BytesMut};
use futures_core::{ready, Stream};
use prost::Message;
use tokio::io::{AsyncRead, ReadBuf};
use crate::{AsyncDestination, AsyncFrameDestination, Framed};
const BUFFER_SIZE: usize = 8192;
const LEN_SIZE: usize = 4;
enum FillResult {
Filled,
Eof,
}
#[derive(Debug)]
pub struct AsyncProstReader<R, T, D> {
reader: R,
pub(crate) buffer: BytesMut,
into: PhantomData<T>,
dest: PhantomData<D>,
}
impl<R, T, D> Unpin for AsyncProstReader<R, T, D> where R: Unpin {}
impl<R, T, D> AsyncProstReader<R, T, D> {
pub fn new(reader: R) -> Self {
Self {
reader,
buffer: BytesMut::with_capacity(BUFFER_SIZE),
into: PhantomData,
dest: PhantomData,
}
}
pub fn get_ref(&self) -> &R {
&self.reader
}
pub fn get_mut(&mut self) -> &mut R {
&mut self.reader
}
pub fn buffer(&self) -> &[u8] {
&self.buffer[..]
}
pub fn into_inner(self) -> R {
self.reader
}
}
impl<R, T, D> Default for AsyncProstReader<R, T, D>
where
R: Default,
{
fn default() -> Self {
Self::from(R::default())
}
}
impl<R, T, D> From<R> for AsyncProstReader<R, T, D> {
fn from(reader: R) -> Self {
Self::new(reader)
}
}
impl<R, T> Stream for AsyncProstReader<R, T, AsyncDestination>
where
T: Message + Default,
R: AsyncRead + Unpin,
{
type Item = Result<T, io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if let FillResult::Eof = ready!(self.as_mut().fill(cx, 5))? {
return Poll::Ready(None);
}
let message_size = NetworkEndian::read_u32(&self.buffer[..LEN_SIZE]) as usize;
ready!(self.as_mut().fill(cx, message_size + LEN_SIZE))?;
self.buffer.advance(LEN_SIZE);
let message =
Message::decode(&self.buffer[..message_size]).map_err(prost::DecodeError::from)?;
self.buffer.advance(message_size);
Poll::Ready(Some(Ok(message)))
}
}
impl<R, T> Stream for AsyncProstReader<R, T, AsyncFrameDestination>
where
R: AsyncRead + Unpin,
T: Framed + Default,
{
type Item = Result<T, io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if let FillResult::Eof = ready!(self.as_mut().fill(cx, LEN_SIZE + 1))? {
return Poll::Ready(None);
}
let size = NetworkEndian::read_u32(&self.buffer[..LEN_SIZE]) as usize;
let header_size = size >> 24;
let body_size = 0x00ffffff & size;
let message_size = header_size + body_size;
ready!(self.as_mut().fill(cx, message_size + LEN_SIZE))?;
self.buffer.advance(LEN_SIZE);
let message = T::decode(&self.buffer[..message_size], header_size)?;
self.buffer.advance(message_size);
Poll::Ready(Some(Ok(message)))
}
}
impl<R, T, D> AsyncProstReader<R, T, D>
where
R: AsyncRead + Unpin,
{
fn fill(
mut self: Pin<&mut Self>,
cx: &mut Context,
target_buffer_size: usize,
) -> Poll<Result<FillResult, io::Error>> {
if self.buffer.len() >= target_buffer_size {
return Poll::Ready(Ok(FillResult::Filled));
}
if self.buffer.capacity() < target_buffer_size {
let missing = target_buffer_size - self.buffer.capacity();
self.buffer.reserve(missing);
}
let had = self.buffer.len();
let mut rest = self.buffer.split_off(had);
let max = rest.capacity();
unsafe { rest.set_len(max) };
while self.buffer.len() < target_buffer_size {
let mut buf = ReadBuf::new(&mut rest[..]);
ready!(Pin::new(&mut self.reader).poll_read(cx, &mut buf))?;
let n = buf.filled().len();
if n == 0 {
if self.buffer.is_empty() {
return Poll::Ready(Ok(FillResult::Eof));
} else {
return Poll::Ready(Err(io::Error::from(io::ErrorKind::BrokenPipe)));
}
}
let read = rest.split_to(n);
self.buffer.unsplit(read);
}
Poll::Ready(Ok(FillResult::Filled))
}
}