use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, io};
use bytes::{Buf, BytesMut};
use futures_core::{ready, Stream};
use futures_sink::Sink;
use crate::{AsyncRead, AsyncWrite, Decoder, Encoder};
const LW: usize = 1024;
const HW: usize = 8 * 1024;
bitflags::bitflags! {
struct Flags: u8 {
const EOF = 0b0001;
const READABLE = 0b0010;
}
}
pub struct Framed<T, U> {
io: T,
codec: U,
flags: Flags,
read_buf: BytesMut,
write_buf: BytesMut,
}
impl<T, U> Unpin for Framed<T, U> {}
impl<T, U> Framed<T, U>
where
T: AsyncRead + AsyncWrite,
U: Decoder + Encoder,
{
#[inline]
pub fn new(io: T, codec: U) -> Framed<T, U> {
Framed {
io,
codec,
flags: Flags::empty(),
read_buf: BytesMut::with_capacity(HW),
write_buf: BytesMut::with_capacity(HW),
}
}
}
impl<T, U> Framed<T, U> {
#[inline]
pub fn from_parts(parts: FramedParts<T, U>) -> Framed<T, U> {
Framed {
io: parts.io,
codec: parts.codec,
flags: parts.flags,
write_buf: parts.write_buf,
read_buf: parts.read_buf,
}
}
#[inline]
pub fn get_codec(&self) -> &U {
&self.codec
}
#[inline]
pub fn get_codec_mut(&mut self) -> &mut U {
&mut self.codec
}
#[inline]
pub fn get_ref(&self) -> &T {
&self.io
}
#[inline]
pub fn get_mut(&mut self) -> &mut T {
&mut self.io
}
#[inline]
pub fn is_write_buf_empty(&self) -> bool {
self.write_buf.is_empty()
}
#[inline]
pub fn is_write_buf_full(&self) -> bool {
self.write_buf.len() >= HW
}
#[inline]
pub fn into_framed<U2>(self, codec: U2) -> Framed<T, U2> {
Framed {
codec,
io: self.io,
flags: self.flags,
read_buf: self.read_buf,
write_buf: self.write_buf,
}
}
#[inline]
pub fn map_io<F, T2>(self, f: F) -> Framed<T2, U>
where
F: Fn(T) -> T2,
{
Framed {
io: f(self.io),
codec: self.codec,
flags: self.flags,
read_buf: self.read_buf,
write_buf: self.write_buf,
}
}
#[inline]
pub fn map_codec<F, U2>(self, f: F) -> Framed<T, U2>
where
F: Fn(U) -> U2,
{
Framed {
io: self.io,
codec: f(self.codec),
flags: self.flags,
read_buf: self.read_buf,
write_buf: self.write_buf,
}
}
#[inline]
pub fn into_parts(self) -> FramedParts<T, U> {
FramedParts {
io: self.io,
codec: self.codec,
flags: self.flags,
read_buf: self.read_buf,
write_buf: self.write_buf,
}
}
}
impl<T, U> Framed<T, U>
where
T: AsyncWrite + Unpin,
U: Encoder,
{
#[inline]
pub fn write(
&mut self,
item: <U as Encoder>::Item,
) -> Result<(), <U as Encoder>::Error> {
let remaining = self.write_buf.capacity() - self.write_buf.len();
if remaining < LW {
self.write_buf.reserve(HW - remaining);
}
self.codec.encode(item, &mut self.write_buf)?;
Ok(())
}
#[inline]
pub fn is_write_ready(&self) -> bool {
self.write_buf.len() < HW
}
pub fn flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), U::Error>> {
log::trace!("flushing framed transport");
while !self.write_buf.is_empty() {
log::trace!("writing; remaining={}", self.write_buf.len());
let n = ready!(Pin::new(&mut self.io).poll_write(cx, &self.write_buf))?;
if n == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame to transport",
)
.into()));
}
self.write_buf.advance(n);
}
ready!(Pin::new(&mut self.io).poll_flush(cx))?;
log::trace!("framed transport flushed");
Poll::Ready(Ok(()))
}
#[inline]
pub fn close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), U::Error>> {
ready!(Pin::new(&mut self.io).poll_flush(cx))?;
ready!(Pin::new(&mut self.io).poll_shutdown(cx))?;
log::trace!("framed transport flushed and closed");
Poll::Ready(Ok(()))
}
}
impl<T, U> Framed<T, U>
where
T: AsyncRead + Unpin,
U: Decoder,
{
pub fn next_item(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<U::Item, U::Error>>>
where
T: AsyncRead,
U: Decoder,
{
loop {
if self.flags.contains(Flags::READABLE) {
if self.flags.contains(Flags::EOF) {
match self.codec.decode_eof(&mut self.read_buf) {
Ok(Some(frame)) => return Poll::Ready(Some(Ok(frame))),
Ok(None) => return Poll::Ready(None),
Err(e) => return Poll::Ready(Some(Err(e))),
}
}
log::trace!("attempting to decode a frame");
match self.codec.decode(&mut self.read_buf) {
Ok(Some(frame)) => {
log::trace!("frame decoded from buffer");
return Poll::Ready(Some(Ok(frame)));
}
Err(e) => return Poll::Ready(Some(Err(e))),
_ => (),
}
self.flags.remove(Flags::READABLE);
}
debug_assert!(!self.flags.contains(Flags::EOF));
let remaining = self.read_buf.capacity() - self.read_buf.len();
if remaining < LW {
self.read_buf.reserve(HW - remaining)
}
let cnt = match Pin::new(&mut self.io).poll_read_buf(cx, &mut self.read_buf)
{
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
Poll::Ready(Ok(cnt)) => cnt,
};
if cnt == 0 {
self.flags.insert(Flags::EOF);
}
self.flags.insert(Flags::READABLE);
}
}
}
impl<T, U> Stream for Framed<T, U>
where
T: AsyncRead + Unpin,
U: Decoder,
{
type Item = Result<U::Item, U::Error>;
#[inline]
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
self.next_item(cx)
}
}
impl<T, U> Sink<U::Item> for Framed<T, U>
where
T: AsyncWrite + Unpin,
U: Encoder,
U::Error: From<io::Error>,
{
type Error = U::Error;
#[inline]
fn poll_ready(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
if self.is_write_ready() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
#[inline]
fn start_send(
mut self: Pin<&mut Self>,
item: <U as Encoder>::Item,
) -> Result<(), Self::Error> {
self.write(item)
}
#[inline]
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.flush(cx)
}
#[inline]
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.close(cx)
}
}
impl<T, U> fmt::Debug for Framed<T, U>
where
T: fmt::Debug,
U: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Framed")
.field("io", &self.io)
.field("codec", &self.codec)
.finish()
}
}
#[derive(Debug)]
pub struct FramedParts<T, U> {
pub io: T,
pub codec: U,
pub read_buf: BytesMut,
pub write_buf: BytesMut,
flags: Flags,
}
impl<T, U> FramedParts<T, U> {
pub fn new(io: T, codec: U) -> FramedParts<T, U> {
FramedParts {
io,
codec,
flags: Flags::empty(),
read_buf: BytesMut::new(),
write_buf: BytesMut::new(),
}
}
pub fn with_read_buf(io: T, codec: U, read_buf: BytesMut) -> FramedParts<T, U> {
FramedParts {
io,
codec,
read_buf,
flags: Flags::empty(),
write_buf: BytesMut::new(),
}
}
}