use std::io;
use super::packets::{self, IncomingPacket, ControlField};
use super::byte_io::ByteReader;
use super::util::panic_in_test;
use crate::transport::Transport;
use crate::errors::PacketDecodeError;
#[derive(Clone, Copy)]
struct PacketSizeInfo
{
size: usize,
start_offset: usize
}
#[derive(Default)]
pub struct PacketReader
{
bytes: Vec<u8>,
size: Option<PacketSizeInfo>
}
pub enum PacketReadState
{
Incoming(IncomingPacket),
ConnectionClosed,
NeedMoreData
}
impl PacketReader
{
const MAX_PACKET_SIZE: usize = 512_000_000;
fn recv_n(&mut self, transport: &mut dyn Transport, n: usize) -> io::Result<bool>
{
if n <= 0 {
return Ok(false);
}
self.bytes.reserve(n);
unsafe {
let pos = self.bytes.len();
let bytes = std::slice::from_raw_parts_mut(self.bytes.as_mut_ptr().add(pos), n);
let read = transport.read(bytes)?;
if read <= 0 {
Ok(true)
} else {
self.bytes.set_len(pos + read);
Ok(false)
}
}
}
fn parse_packet_size(&self) -> io::Result<Option<PacketSizeInfo>>
{
let mut pos = 1;
let mut ret = 0;
let mut shift = 0;
loop {
if pos >= self.bytes.len() {
return Ok(None); }
let b = self.bytes[pos];
ret |= ((b & 0x7f) as usize) << shift;
pos += 1;
if (b & 0x80) == 0 {
break;
}
shift += 7;
}
ret += pos;
if ret > Self::MAX_PACKET_SIZE {
return Err(PacketDecodeError::ReachedMaxSize.into());
}
Ok(Some(PacketSizeInfo { size: ret, start_offset: pos }))
}
pub fn recv(&mut self, transport: &mut dyn Transport) -> io::Result<PacketReadState>
{
if let Some(size_info) = self.size {
if self.bytes.len() < size_info.size {
if self.recv_n(transport, size_info.size - self.bytes.len())? {
return Ok(PacketReadState::ConnectionClosed);
}
}
if self.bytes.len() >= size_info.size {
let ctrl_field = ControlField(self.bytes[0]);
let mut rd = ByteReader::new(&self.bytes[size_info.start_offset..size_info.size]);
let ret = IncomingPacket::from_bytes(&mut rd, ctrl_field);
if let Ok(pkt) = &ret {
if rd.remaining() > 0 {
panic_in_test!("Did not read {:?} packet in its entirety", pkt.packet_type());
}
}
if self.bytes.len() > size_info.size {
let remaining = self.bytes.len() - size_info.size;
self.bytes.copy_within(size_info.size.., 0);
self.bytes.truncate(remaining);
self.size = match self.parse_packet_size() {
Err(err) => {
self.bytes.clear();
return Err(err);
},
Ok(x) => x
};
} else {
self.bytes.clear();
self.size = None;
}
return match ret {
Ok(x) => Ok(PacketReadState::Incoming(x)),
Err(x) => Err(x.into())
};
}
} else {
if self.recv_n(transport, packets::MAX_HEADER_SIZE - self.bytes.len())? {
return Ok(PacketReadState::ConnectionClosed);
}
match self.parse_packet_size() {
Err(err) => {
self.bytes.clear();
return Err(err);
},
Ok(x) => self.size = x
}
}
Ok(PacketReadState::NeedMoreData)
}
}