use crate::{
FrameDecodeError,
connection::ConnectionId,
frame::{Frame, header},
};
use bytes::{Bytes, BytesMut};
use futures::{AsyncRead, AsyncWrite, Sink, Stream, ready};
use std::{
fmt, io,
pin::Pin,
task::{Context, Poll},
};
const MAX_FRAME_BODY_LEN: usize = 1024 * 1024;
enum ReadState {
Init,
Header {
offset: usize,
buffer: [u8; header::HEADER_SIZE],
},
Body {
header: header::Header,
offset: usize,
buffer: BytesMut,
},
}
impl fmt::Debug for ReadState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ReadState::Init => write!(f, "ReadState::Init"),
ReadState::Header { offset, .. } => {
write!(f, "ReadState::Header ( offset: {} )", offset)
}
ReadState::Body { header, offset, .. } => {
write!(f, "ReadState::Body(header: {}, offset: {})", header, offset)
}
}
}
}
pub(crate) struct Framed<T> {
id: ConnectionId,
inner: T,
read_state: ReadState,
write_state: WriteState,
}
impl<T: AsyncRead + AsyncWrite + Unpin> Framed<T> {
pub fn new(id: ConnectionId, inner: T) -> Self {
Framed {
id,
inner,
read_state: ReadState::Init,
write_state: WriteState::Init,
}
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> Stream for Framed<T> {
type Item = Result<Frame, FrameDecodeError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = &mut *self;
loop {
tracing::trace!("Connection({}): read {:?}", this.id, this.read_state);
match &mut this.read_state {
ReadState::Init => {
this.read_state = ReadState::Header {
offset: 0,
buffer: [0; header::HEADER_SIZE],
};
}
ReadState::Header { offset, buffer } => {
if *offset == header::HEADER_SIZE {
let header = header::decode(buffer)?;
tracing::trace!("Connection({}): decoded header: {}", this.id, header);
let body_len = header.length() as usize;
if body_len > MAX_FRAME_BODY_LEN {
return Poll::Ready(Some(Err(FrameDecodeError::TooLarge(
body_len,
MAX_FRAME_BODY_LEN,
))));
}
this.read_state = ReadState::Body {
header,
offset: 0,
buffer: BytesMut::zeroed(body_len),
};
continue;
}
let buf = &mut buffer[*offset..header::HEADER_SIZE];
match ready!(Pin::new(&mut this.inner).poll_read(cx, buf))? {
0 => {
if *offset == 0 {
return Poll::Ready(None);
} else {
return Poll::Ready(Some(Err(io::Error::from(
io::ErrorKind::UnexpectedEof,
)
.into())));
}
}
n => *offset += n,
}
}
ReadState::Body {
header,
offset,
buffer,
} => {
let body_len = header.length() as usize;
if *offset == body_len {
let header = header.clone();
let buffer = buffer.split().freeze();
this.read_state = ReadState::Init; return Poll::Ready(Some(Ok(Frame {
header,
body: buffer,
})));
}
let buf = &mut buffer[*offset..body_len];
match ready!(Pin::new(&mut this.inner).poll_read(cx, buf))? {
0 => {
return Poll::Ready(Some(Err(io::Error::from(
io::ErrorKind::UnexpectedEof,
)
.into())));
}
n => *offset += n,
}
}
}
}
}
}
enum WriteState {
Init,
Header {
header: [u8; header::HEADER_SIZE],
buffer: Bytes,
offset: usize,
},
Body {
offset: usize,
buffer: Bytes,
},
}
impl fmt::Debug for WriteState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WriteState::Init => write!(f, "WriteState::Init"),
WriteState::Header { offset, .. } => {
write!(f, "WriteState::Header ( offset: {} )", offset)
}
WriteState::Body { offset, buffer } => {
write!(
f,
"WriteState::Body(offset: {}, length: {})",
offset,
buffer.len()
)
}
}
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> Sink<Frame> for Framed<T> {
type Error = io::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let this = &mut *self;
loop {
match &mut this.write_state {
WriteState::Init => {
return Poll::Ready(Ok(()));
}
WriteState::Header {
header,
buffer,
offset,
} => match Pin::new(&mut this.inner).poll_write(cx, &header[*offset..]) {
Poll::Ready(Ok(n)) => {
*offset += n;
if *offset == header.len() {
if buffer.is_empty() {
this.write_state = WriteState::Init;
} else {
this.write_state = WriteState::Body {
offset: 0,
buffer: buffer.clone(),
};
}
}
return Poll::Ready(Ok(()));
}
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
},
WriteState::Body { offset, buffer } => {
match Pin::new(&mut this.inner).poll_write(cx, &buffer[*offset..]) {
Poll::Ready(Ok(n)) => {
*offset += n;
if *offset == buffer.len() {
this.write_state = WriteState::Init;
}
return Poll::Ready(Ok(()));
}
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
}
}
}
}
}
fn start_send(self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> {
tracing::trace!("Connection({}): start_send {:?}", self.id, item);
let Frame { header, body } = item;
self.get_mut().write_state = WriteState::Header {
header: header::encode(&header),
buffer: body,
offset: 0,
};
Ok(())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.as_mut().poll_ready(cx))?;
Pin::new(&mut self.as_mut().inner).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
tracing::trace!("Connection({}): poll_close", self.id);
ready!(self.as_mut().poll_ready(cx))?;
Pin::new(&mut self.as_mut().inner).poll_close(cx)
}
}