use std::future::Future;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::*;
use crate::constants::MAX_FRAME_BODY_LEN;
use crate::errors::ConnError;
use crate::frame::{FrameBody, FrameType};
use crate::traits::{AsyncRecvFrame, AsyncSendFrame};
pub struct StreamFramer<T: Sync + Send + Unpin> {
inner: T,
read_buf: Vec<u8>,
write_buf: Vec<u8>,
}
impl<T: Sync + Send + Unpin> StreamFramer<T> {
pub fn new(inner: T) -> Self {
Self {
inner,
read_buf: Vec::new(),
write_buf: Vec::new(),
}
}
pub fn inner(&self) -> &T {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut T {
&mut self.inner
}
pub fn into_inner(self) -> T {
self.inner
}
}
impl<T: AsyncRead + Sync + Send + Unpin> AsyncRecvFrame for StreamFramer<T> {
fn recv_frame_async(&mut self) -> impl Future<Output = Result<FrameBody, ConnError>> + Send {
async {
let flags = self.inner.read_u8().await?;
if flags != 0 {
return Err(ConnError::UnknownFlags(flags));
}
let ty_tag = self.inner.read_u8().await?;
if FrameType::try_from(ty_tag).is_err() {
return Err(ConnError::UnknownFrameType(ty_tag));
}
let len = self.inner.read_u32().await? as usize;
if len > MAX_FRAME_BODY_LEN {
return Err(ConnError::FrameTooLarge(len));
}
self.read_buf.resize(len, 0);
self.inner.read_exact(&mut self.read_buf).await?;
let frame = FrameBody::from_buf(&self.read_buf)?;
let ty = frame.ty();
trace!(?ty, %len, "recvd frame");
self.read_buf.clear();
if u8::from(frame.ty()) != ty_tag {
return Err(ConnError::MalformedFrame);
}
Ok(frame)
}
}
}
impl<T: AsyncWrite + Sync + Send + Unpin> AsyncSendFrame for StreamFramer<T> {
fn send_frame_async(
&mut self,
body: &FrameBody,
) -> impl Future<Output = Result<(), ConnError>> + Send {
async {
let ty = body.ty();
let flags = 0u8; let ty_tag: u8 = ty.into();
body.into_vec(&mut self.write_buf)?; let len = self.write_buf.len();
if len > MAX_FRAME_BODY_LEN {
return Err(ConnError::FrameTooLarge(len));
}
trace!(?ty, %len, "sending frame");
self.inner.write_u8(flags).await?;
self.inner.write_u8(ty_tag).await?;
self.inner.write_u32(len as u32).await?;
self.inner.write_all(&self.write_buf).await?;
self.write_buf.clear();
Ok(())
}
}
}