use crate::{
codec::{Decoder, Encoder},
error::Error,
};
use bytes::{Buf, BytesMut};
use futures_core::{ready, Stream};
use futures_io::{AsyncRead, AsyncWrite};
use futures_sink::Sink;
use pin_project_lite::pin_project;
use std::{
borrow::{Borrow, BorrowMut},
io,
ops::Deref,
pin::Pin,
task::{Context, Poll},
};
const INITIAL_CAPACITY: usize = 8 * 1024;
const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY;
pin_project! {
#[derive(Debug)]
pub(super) struct FramedInner<T, U, State> {
#[pin]
pub(super) inner: T,
pub(super) state: State,
pub(super) codec: U,
}
}
impl<T, U, State> Deref for FramedInner<T, U, State> {
type Target = T;
fn deref(&self) -> &T {
&self.inner
}
}
impl<T, U, R> Stream for FramedInner<T, U, R>
where
T: AsyncRead,
U: Decoder,
R: BorrowMut<ReadFrame>,
{
type Item = Result<U::Item, Error<U::Error>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut pinned = self.project();
let state: &mut ReadFrame = pinned.state.borrow_mut();
let mut buf = [0u8; INITIAL_CAPACITY];
loop {
if state.has_errored {
log::trace!("Returning None and setting paused");
state.is_readable = false;
state.has_errored = false;
return Poll::Ready(None);
}
if state.is_readable {
if state.eof {
let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| {
log::trace!("Got an error, going to errored state");
state.has_errored = true;
Error::Codec(err)
})?;
if frame.is_none() {
state.is_readable = false;
}
return Poll::Ready(frame.map(Ok));
}
log::trace!("Attempting to decode a frame");
if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|err| {
log::trace!("Got an error, going to errored state");
state.has_errored = true;
Error::Codec(err)
})? {
log::trace!("Frame decoded from buffer");
return Poll::Ready(Some(Ok(frame)));
}
state.is_readable = false;
}
let bytes_read = ready!(pinned.inner.as_mut().poll_read(cx, &mut buf).map_err(|err| {
log::trace!("Got an error, going to errored state");
state.has_errored = true;
err
}))?;
state.buffer.extend_from_slice(&buf[.. bytes_read]);
if bytes_read == 0 {
if state.eof {
return Poll::Ready(None);
}
state.eof = true;
} else {
state.eof = false;
}
state.is_readable = true;
}
}
}
impl<T, U, W> FramedInner<T, U, W>
where
T: AsyncWrite,
U: Encoder,
W: BorrowMut<WriteFrame>,
{
fn poll_flush_until(self: Pin<&mut Self>, cx: &mut Context<'_>, limit: usize) -> Poll<Result<(), io::Error>> {
let mut pinned = self.project();
let state = pinned.state.borrow_mut();
let orig_len = state.buffer.len();
log::trace!("Flushing framed transport");
while state.buffer.len() > limit {
log::trace!("Writing; remaining = {}", state.buffer.len());
let num_write = ready!(pinned.inner.as_mut().poll_write(cx, &state.buffer))?;
if num_write == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"FramedWrite: end of input",
)));
}
state.buffer.advance(num_write);
}
log::trace!("Framed transport flushed");
if orig_len != state.buffer.len() {
pinned.inner.poll_flush(cx)
} else {
Poll::Ready(Ok(()))
}
}
}
impl<T, U, W> Sink<U::Item> for FramedInner<T, U, W>
where
T: AsyncWrite,
U: Encoder,
W: BorrowMut<WriteFrame>,
{
type Error = Error<U::Error>;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.poll_flush_until(cx, BACKPRESSURE_BOUNDARY - 1).map_err(Into::into)
}
fn start_send(self: Pin<&mut Self>, item: U::Item) -> Result<(), Self::Error> {
let pinned = self.project();
pinned
.codec
.encode(item, &mut pinned.state.borrow_mut().buffer)
.map_err(Error::Codec)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.poll_flush_until(cx, 0).map_err(Into::into)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.as_mut().poll_flush(cx))?;
self.project().inner.poll_close(cx).map_err(Into::into)
}
}
pub(super) struct ReadFrame {
pub(super) buffer: BytesMut,
pub(super) eof: bool,
pub(super) has_errored: bool,
pub(super) is_readable: bool,
}
impl Default for ReadFrame {
fn default() -> Self {
Self {
buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
eof: false,
has_errored: false,
is_readable: false,
}
}
}
impl From<BytesMut> for ReadFrame {
fn from(mut buffer: BytesMut) -> Self {
let size = buffer.capacity();
if size < INITIAL_CAPACITY {
buffer.reserve(INITIAL_CAPACITY - size);
}
Self {
buffer,
eof: false,
has_errored: false,
is_readable: size > 0,
}
}
}
pub(super) struct WriteFrame {
pub(super) buffer: BytesMut,
}
impl Default for WriteFrame {
fn default() -> Self {
Self {
buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
}
}
}
impl From<BytesMut> for WriteFrame {
fn from(mut buffer: BytesMut) -> Self {
let size = buffer.capacity();
if size < INITIAL_CAPACITY {
buffer.reserve(INITIAL_CAPACITY - size);
}
Self { buffer }
}
}
#[derive(Default)]
pub(super) struct RWFrames {
pub(super) read: ReadFrame,
pub(super) write: WriteFrame,
}
impl Borrow<ReadFrame> for RWFrames {
fn borrow(&self) -> &ReadFrame {
&self.read
}
}
impl BorrowMut<ReadFrame> for RWFrames {
fn borrow_mut(&mut self) -> &mut ReadFrame {
&mut self.read
}
}
impl Borrow<WriteFrame> for RWFrames {
fn borrow(&self) -> &WriteFrame {
&self.write
}
}
impl BorrowMut<WriteFrame> for RWFrames {
fn borrow_mut(&mut self) -> &mut WriteFrame {
&mut self.write
}
}