use super::fuse::Fuse;
use super::Decoder;
use bytes::BytesMut;
use futures_sink::Sink;
use futures_util::io::AsyncRead;
use futures_util::ready;
use futures_util::stream::{Stream, TryStreamExt};
use pin_project_lite::pin_project;
use std::io;
use std::marker::Unpin;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task::{Context, Poll};
#[derive(Debug)]
pub struct FramedRead<T, D> {
inner: FramedRead2<Fuse<T, D>>,
}
impl<T, D> Deref for FramedRead<T, D> {
type Target = T;
fn deref(&self) -> &T {
&self.inner
}
}
impl<T, D> DerefMut for FramedRead<T, D> {
fn deref_mut(&mut self) -> &mut T {
&mut self.inner
}
}
impl<T, D> FramedRead<T, D>
where
T: AsyncRead,
D: Decoder,
{
pub fn new(inner: T, decoder: D) -> Self {
Self {
inner: framed_read_2(Fuse::new(inner, decoder), None),
}
}
pub fn from_parts(
FramedReadParts {
io,
decoder,
buffer,
..
}: FramedReadParts<T, D>,
) -> Self {
Self {
inner: framed_read_2(Fuse::new(io, decoder), Some(buffer)),
}
}
pub fn into_parts(self) -> FramedReadParts<T, D> {
let (fuse, buffer) = self.inner.into_parts();
FramedReadParts {
io: fuse.t,
decoder: fuse.u,
buffer,
_priv: (),
}
}
pub fn into_inner(self) -> T {
self.into_parts().io
}
pub fn decoder(&self) -> &D {
&self.inner.u
}
pub fn decoder_mut(&mut self) -> &mut D {
&mut self.inner.u
}
pub fn read_buffer(&self) -> &BytesMut {
&self.inner.buffer
}
}
impl<T, D> Stream for FramedRead<T, D>
where
T: AsyncRead + Unpin,
D: Decoder,
{
type Item = Result<D::Item, D::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.try_poll_next_unpin(cx)
}
}
pin_project! {
#[derive(Debug)]
pub struct FramedRead2<T> {
#[pin]
inner: T,
buffer: BytesMut,
}
}
impl<T> Deref for FramedRead2<T> {
type Target = T;
fn deref(&self) -> &T {
&self.inner
}
}
impl<T> DerefMut for FramedRead2<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.inner
}
}
const INITIAL_CAPACITY: usize = 8 * 1024;
pub fn framed_read_2<T>(inner: T, buffer: Option<BytesMut>) -> FramedRead2<T> {
FramedRead2 {
inner,
buffer: buffer.unwrap_or_else(|| BytesMut::with_capacity(INITIAL_CAPACITY)),
}
}
impl<T> Stream for FramedRead2<T>
where
T: AsyncRead + Decoder + Unpin,
{
type Item = Result<T::Item, T::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = &mut *self;
if let Some(item) = this.inner.decode(&mut this.buffer)? {
return Poll::Ready(Some(Ok(item)));
}
let mut buf = [0u8; INITIAL_CAPACITY];
loop {
let n = ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?;
this.buffer.extend_from_slice(&buf[..n]);
let ended = n == 0;
match this.inner.decode(&mut this.buffer)? {
Some(item) => return Poll::Ready(Some(Ok(item))),
None if ended => {
if this.buffer.is_empty() {
return Poll::Ready(None);
} else {
match this.inner.decode_eof(&mut this.buffer)? {
Some(item) => return Poll::Ready(Some(Ok(item))),
None if this.buffer.is_empty() => return Poll::Ready(None),
None => {
return Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"bytes remaining in stream",
)
.into())));
}
}
}
}
_ => continue,
}
}
}
}
impl<T, I> Sink<I> for FramedRead2<T>
where
T: Sink<I> + Unpin,
{
type Error = T::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}
impl<T> FramedRead2<T> {
pub fn into_parts(self) -> (T, BytesMut) {
(self.inner, self.buffer)
}
pub fn buffer(&self) -> &BytesMut {
&self.buffer
}
}
pub struct FramedReadParts<T, D> {
pub io: T,
pub decoder: D,
pub buffer: BytesMut,
_priv: (),
}
impl<T, D> FramedReadParts<T, D> {
pub fn map_decoder<E, F>(self, f: F) -> FramedReadParts<T, E>
where
E: Decoder,
F: FnOnce(D) -> E,
{
FramedReadParts {
io: self.io,
decoder: f(self.decoder),
buffer: self.buffer,
_priv: (),
}
}
}