use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, io};
use bytes::{Buf, BytesMut};
use either::Either;
use futures_core::{ready, Stream};
use futures_sink::Sink;
use crate::{AsyncRead, AsyncWrite, Decoder, Encoder};
const LW: usize = 1024;
const HW: usize = 8 * 1024;
bitflags::bitflags! {
struct Flags: u8 {
const EOF = 0b0001;
const READABLE = 0b0010;
const DISCONNECTED = 0b0100;
const SHUTDOWN = 0b1000;
}
}
pub struct Framed<T, U> {
io: T,
codec: U,
flags: Flags,
read_buf: BytesMut,
write_buf: BytesMut,
err: Option<io::Error>,
}
impl<T, U> Framed<T, U>
where
T: AsyncRead + AsyncWrite,
U: Decoder + Encoder,
{
#[inline]
pub fn new(io: T, codec: U) -> Framed<T, U> {
Framed {
io,
codec,
err: None,
flags: Flags::empty(),
read_buf: BytesMut::with_capacity(HW),
write_buf: BytesMut::with_capacity(HW),
}
}
}
impl<T, U> Framed<T, U> {
#[inline]
pub fn from_parts(parts: FramedParts<T, U>) -> Framed<T, U> {
Framed {
io: parts.io,
codec: parts.codec,
flags: parts.flags,
write_buf: parts.write_buf,
read_buf: parts.read_buf,
err: parts.err,
}
}
#[inline]
pub fn get_codec(&self) -> &U {
&self.codec
}
#[inline]
pub fn get_codec_mut(&mut self) -> &mut U {
&mut self.codec
}
#[inline]
pub fn get_ref(&self) -> &T {
&self.io
}
#[inline]
pub fn get_mut(&mut self) -> &mut T {
&mut self.io
}
#[inline]
pub fn read_buf(&mut self) -> &mut BytesMut {
&mut self.read_buf
}
#[inline]
pub fn write_buf(&mut self) -> &mut BytesMut {
&mut self.write_buf
}
#[inline]
pub fn is_write_buf_empty(&self) -> bool {
self.write_buf.is_empty()
}
#[inline]
pub fn is_write_buf_full(&self) -> bool {
self.write_buf.len() >= HW
}
#[inline]
pub fn is_closed(&self) -> bool {
self.flags.contains(Flags::DISCONNECTED)
}
#[inline]
pub fn into_framed<U2>(self, codec: U2) -> Framed<T, U2> {
Framed {
codec,
io: self.io,
flags: self.flags,
read_buf: self.read_buf,
write_buf: self.write_buf,
err: self.err,
}
}
#[inline]
pub fn map_io<F, T2>(self, f: F) -> Framed<T2, U>
where
F: Fn(T) -> T2,
{
Framed {
io: f(self.io),
codec: self.codec,
flags: self.flags,
read_buf: self.read_buf,
write_buf: self.write_buf,
err: self.err,
}
}
#[inline]
pub fn map_codec<F, U2>(self, f: F) -> Framed<T, U2>
where
F: Fn(U) -> U2,
{
Framed {
io: self.io,
codec: f(self.codec),
flags: self.flags,
read_buf: self.read_buf,
write_buf: self.write_buf,
err: self.err,
}
}
#[inline]
pub fn into_parts(self) -> FramedParts<T, U> {
FramedParts {
io: self.io,
codec: self.codec,
flags: self.flags,
read_buf: self.read_buf,
write_buf: self.write_buf,
err: self.err,
}
}
}
impl<T, U> Framed<T, U>
where
T: AsyncWrite + Unpin,
U: Encoder,
{
#[inline]
pub fn write(
&mut self,
item: <U as Encoder>::Item,
) -> Result<(), <U as Encoder>::Error> {
let remaining = self.write_buf.capacity() - self.write_buf.len();
if remaining < LW {
self.write_buf.reserve(HW - remaining);
}
self.codec.encode(item, &mut self.write_buf)?;
Ok(())
}
#[inline]
pub fn is_write_ready(&self) -> bool {
self.write_buf.len() < HW
}
pub fn flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
log::trace!("flushing framed transport");
let len = self.write_buf.len();
if len != 0 {
let mut written = 0;
while written < len {
match Pin::new(&mut self.io).poll_write(cx, &self.write_buf[written..]) {
Poll::Pending => break,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!(
"Disconnected during flush, written {}",
written
);
self.flags.insert(Flags::DISCONNECTED);
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame to transport",
)));
} else {
written += n
}
}
Poll::Ready(Err(e)) => {
log::trace!("Error during flush: {}", e);
self.flags.insert(Flags::DISCONNECTED);
return Poll::Ready(Err(e));
}
}
}
if written == len {
self.write_buf.clear()
} else {
self.write_buf.advance(written);
}
}
ready!(Pin::new(&mut self.io).poll_flush(cx))?;
if self.write_buf.is_empty() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
impl<T, U> Framed<T, U>
where
T: AsyncRead + AsyncWrite + Unpin,
{
#[inline]
pub fn close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
if !self.flags.contains(Flags::DISCONNECTED) {
ready!(Pin::new(&mut self.io).poll_flush(cx))?;
if !self.flags.contains(Flags::SHUTDOWN) {
ready!(Pin::new(&mut self.io).poll_shutdown(cx)).map_err(|e| {
self.flags.insert(Flags::DISCONNECTED);
e
})?;
self.flags.insert(Flags::SHUTDOWN);
}
let mut buf = [0u8; 512];
loop {
match ready!(Pin::new(&mut self.io).poll_read(cx, &mut buf)) {
Err(_) | Ok(0) => {
break;
}
_ => (),
}
}
self.flags.insert(Flags::DISCONNECTED);
}
log::trace!("framed transport flushed and closed");
Poll::Ready(Ok(()))
}
}
pub type ItemType<U> =
Result<<U as Decoder>::Item, Either<<U as Decoder>::Error, io::Error>>;
impl<T, U> Framed<T, U>
where
T: AsyncRead + Unpin,
U: Decoder,
{
pub fn next_item(&mut self, cx: &mut Context<'_>) -> Poll<Option<ItemType<U>>> {
let mut done_read = false;
loop {
if self.flags.contains(Flags::READABLE) {
if self.flags.contains(Flags::EOF) {
return match self.codec.decode_eof(&mut self.read_buf) {
Ok(Some(frame)) => Poll::Ready(Some(Ok(frame))),
Ok(None) => {
if let Some(err) = self.err.take() {
Poll::Ready(Some(Err(Either::Right(err))))
} else if !self.read_buf.is_empty() {
Poll::Ready(Some(Err(Either::Right(io::Error::new(
io::ErrorKind::Other,
"bytes remaining on stream",
)))))
} else {
Poll::Ready(None)
}
}
Err(e) => return Poll::Ready(Some(Err(Either::Left(e)))),
};
}
log::trace!("attempting to decode a frame");
match self.codec.decode(&mut self.read_buf) {
Ok(Some(frame)) => {
log::trace!("frame decoded from buffer");
return Poll::Ready(Some(Ok(frame)));
}
Err(e) => return Poll::Ready(Some(Err(Either::Left(e)))),
_ => (),
}
self.flags.remove(Flags::READABLE);
if done_read {
return Poll::Pending;
}
}
debug_assert!(!self.flags.contains(Flags::EOF));
let mut updated = false;
loop {
let remaining = self.read_buf.capacity() - self.read_buf.len();
if remaining < LW {
self.read_buf.reserve(HW - remaining)
}
match Pin::new(&mut self.io).poll_read_buf(cx, &mut self.read_buf) {
Poll::Pending => {
if updated {
done_read = true;
self.flags.insert(Flags::READABLE);
break;
} else {
return Poll::Pending;
}
}
Poll::Ready(Ok(n)) => {
if n == 0 {
self.flags.insert(Flags::EOF | Flags::READABLE);
if updated {
done_read = true;
}
break;
} else {
updated = true;
}
}
Poll::Ready(Err(e)) => {
if updated {
done_read = true;
self.err = Some(e);
self.flags.insert(Flags::EOF | Flags::READABLE);
break;
} else {
return Poll::Ready(Some(Err(Either::Right(e))));
}
}
}
}
}
}
}
impl<T, U> Stream for Framed<T, U>
where
T: AsyncRead + Unpin,
U: Decoder + Unpin,
{
type Item = Result<U::Item, Either<U::Error, io::Error>>;
#[inline]
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
self.next_item(cx)
}
}
impl<T, U> Sink<U::Item> for Framed<T, U>
where
T: AsyncRead + AsyncWrite + Unpin,
U: Encoder + Unpin,
{
type Error = Either<U::Error, io::Error>;
#[inline]
fn poll_ready(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
if self.is_write_ready() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
#[inline]
fn start_send(
mut self: Pin<&mut Self>,
item: <U as Encoder>::Item,
) -> Result<(), Self::Error> {
self.write(item).map_err(Either::Left)
}
#[inline]
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.flush(cx).map_err(Either::Right)
}
#[inline]
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.close(cx).map_err(Either::Right)
}
}
impl<T, U> fmt::Debug for Framed<T, U>
where
T: fmt::Debug,
U: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Framed")
.field("io", &self.io)
.field("codec", &self.codec)
.finish()
}
}
#[derive(Debug)]
pub struct FramedParts<T, U> {
pub io: T,
pub codec: U,
pub read_buf: BytesMut,
pub write_buf: BytesMut,
flags: Flags,
err: Option<io::Error>,
}
impl<T, U> FramedParts<T, U> {
pub fn new(io: T, codec: U) -> FramedParts<T, U> {
FramedParts {
io,
codec,
err: None,
flags: Flags::empty(),
read_buf: BytesMut::new(),
write_buf: BytesMut::new(),
}
}
pub fn with_read_buf(io: T, codec: U, read_buf: BytesMut) -> FramedParts<T, U> {
FramedParts {
io,
codec,
read_buf,
err: None,
flags: Flags::empty(),
write_buf: BytesMut::new(),
}
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use futures::future::lazy;
use futures::Sink;
use ntex::testing::Io;
use super::*;
use crate::BytesCodec;
#[ntex::test]
async fn test_basics() {
let (_, server) = Io::create();
let mut server = Framed::new(server, BytesCodec);
server.get_codec_mut();
server.get_ref();
server.get_mut();
let parts = server.into_parts();
let server = Framed::from_parts(FramedParts::new(parts.io, parts.codec));
assert!(format!("{:?}", server).contains("Framed"));
}
#[ntex::test]
async fn test_sink() {
let (client, server) = Io::create();
client.remote_buffer_cap(1024);
let mut server = Framed::new(server, BytesCodec);
assert!(lazy(|cx| Pin::new(&mut server).poll_ready(cx))
.await
.is_ready());
let data = Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n");
Pin::new(&mut server).start_send(data).unwrap();
assert_eq!(client.read_any(), b"".as_ref());
assert_eq!(server.read_buf(), b"".as_ref());
assert_eq!(server.write_buf(), b"GET /test HTTP/1.1\r\n\r\n".as_ref());
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
.await
.is_ready());
assert_eq!(client.read_any(), b"GET /test HTTP/1.1\r\n\r\n".as_ref());
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
.await
.is_pending());
client.close().await;
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
.await
.is_ready());
assert!(client.is_closed());
}
#[ntex::test]
async fn test_write_pending() {
let (client, server) = Io::create();
let mut server = Framed::new(server, BytesCodec);
assert!(lazy(|cx| Pin::new(&mut server).poll_ready(cx))
.await
.is_ready());
let data = Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n");
Pin::new(&mut server).start_send(data).unwrap();
client.remote_buffer_cap(3);
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
.await
.is_pending());
assert_eq!(client.read_any(), b"GET".as_ref());
client.remote_buffer_cap(1024);
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
.await
.is_ready());
assert_eq!(client.read_any(), b" /test HTTP/1.1\r\n\r\n".as_ref());
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
.await
.is_pending());
client.close().await;
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
.await
.is_ready());
assert!(client.is_closed());
assert!(server.is_closed());
}
#[ntex::test]
async fn test_read_pending() {
let (client, server) = Io::create();
let mut server = Framed::new(server, BytesCodec);
client.read_pending();
assert!(lazy(|cx| Pin::new(&mut server).next_item(cx))
.await
.is_pending());
client.write(b"GET /test HTTP/1.1\r\n\r\n");
client.close().await;
let item = lazy(|cx| Pin::new(&mut server).next_item(cx))
.await
.map(|i| i.unwrap().unwrap().freeze());
assert_eq!(
item,
Poll::Ready(Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n"))
);
let item = lazy(|cx| Pin::new(&mut server).next_item(cx))
.await
.map(|i| i.is_none());
assert_eq!(item, Poll::Ready(true));
}
#[ntex::test]
async fn test_read_error() {
let (client, server) = Io::create();
let mut server = Framed::new(server, BytesCodec);
client.read_pending();
assert!(lazy(|cx| Pin::new(&mut server).next_item(cx))
.await
.is_pending());
client.write(b"GET /test HTTP/1.1\r\n\r\n");
client.read_error(io::Error::new(io::ErrorKind::Other, "error"));
let item = lazy(|cx| Pin::new(&mut server).next_item(cx))
.await
.map(|i| i.unwrap().unwrap().freeze());
assert_eq!(
item,
Poll::Ready(Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n"))
);
assert_eq!(
lazy(|cx| Pin::new(&mut server).next_item(cx))
.await
.map(|i| i.unwrap().is_err()),
Poll::Ready(true)
);
}
}