use super::{Packet, PacketHeader, PacketBytes, PacketError};
use std::{mem, io};
use bytes::{BytesRead, BytesWrite, BytesSeek};
use tokio::io::AsyncReadExt;
pub struct PacketReceiver<P, B>
where
P: Packet<B>,
B: PacketBytes
{
bytes: B,
header: Option<P::Header>,
read: usize,
body_limit: u32
}
impl<P, B> PacketReceiver<P, B>
where
P: Packet<B>,
B: PacketBytes
{
pub fn new(body_limit: u32) -> Self {
Self {
bytes: B::new(P::Header::LEN as usize),
header: None,
read: 0,
body_limit
}
}
pub fn set_body_limit(&mut self, body_limit: u32) {
self.body_limit = body_limit;
}
pub async fn read_header<R, F>(
&mut self,
reader: &mut R,
mutate: F
) -> Result<(), PacketReceiverError<P::Header>>
where
R: AsyncReadExt + Unpin,
F: FnOnce(&mut B) -> Result<(), PacketError>
{
if self.header.is_some() {
return Ok(())
}
let mut header_bytes = self.bytes.full_header_mut();
header_bytes.seek(self.read);
loop {
let r = reader.read(header_bytes.remaining_mut()).await
.map_err(Error::Io)?;
header_bytes.advance(r);
self.read += r;
if header_bytes.remaining().is_empty() {
break
}
if r == 0 {
return Err(PacketReceiverError::Io(eof()))
}
}
mutate(&mut self.bytes)
.map_err(Error::Hard)?;
let header = P::Header::from_bytes(self.bytes.header())
.map_err(Error::Hard)?;
self.header = Some(header);
self.read = 0;
Ok(())
}
pub async fn read_body<R, F>(
&mut self,
reader: &mut R,
mutate: F
) -> Result<P, PacketReceiverError<P::Header>>
where
R: AsyncReadExt + Unpin,
F: FnOnce(&mut B) -> Result<(), PacketError>
{
let len = self.header.as_ref()
.expect("read the header first")
.body_len();
if len == 0 {
return self.take_message()
}
if self.read == 0 {
if self.body_limit != 0 && len > self.body_limit {
return Err(Error::Hard(PacketError::BodyLimitReached(len)))
}
let mut body_bytes = self.bytes.body_mut();
body_bytes.resize(len as usize);
debug_assert_eq!(body_bytes.as_mut().len(), len as usize);
}
let mut body_bytes = self.bytes.full_body_mut();
body_bytes.seek(self.read);
assert!(body_bytes.len() > 0, "a body should never be empty");
loop {
let r = reader.read(body_bytes.remaining_mut()).await
.map_err(PacketReceiverError::Io)?;
body_bytes.advance(r);
self.read += r;
if body_bytes.remaining().is_empty() {
break
}
if r == 0 {
return Err(Error::Io(eof()))
}
}
mutate(&mut self.bytes)
.map_err(|e| self.soft_error(e))?;
self.take_message()
}
fn take_message(&mut self) -> Result<P, PacketReceiverError<P::Header>> {
let bytes = mem::replace(
&mut self.bytes,
B::new(P::Header::LEN as usize)
);
let header = self.header.take().unwrap();
self.read = 0;
P::from_bytes_and_header(bytes, header.clone())
.map_err(|e| Error::Soft(header, e))
}
fn soft_error(&mut self, e: PacketError) -> PacketReceiverError<P::Header> {
self.bytes = B::new(P::Header::LEN as usize);
let header = self.header.take().unwrap();
self.read = 0;
Error::Soft(header, e)
}
}
pub enum PacketReceiverError<H> {
Io(io::Error),
Hard(PacketError),
Soft(H, PacketError)
}
type Error<H> = PacketReceiverError<H>;
fn eof() -> io::Error {
io::Error::new(io::ErrorKind::UnexpectedEof, "early eof")
}