use std::io;
use std::io::Write;
use circular::Buffer;
use nom_derive::Parse;
use crate::packet::Packet;
use crate::packet::PacketHeader;
use crate::Error;
#[derive(Clone, Debug)]
pub struct Parser {
buffer: Buffer,
#[allow(dead_code)]
current_header: Option<PacketHeader>
}
impl Default for Parser {
fn default() -> Self {
#[allow(clippy::identity_op)] Self::with_capacity(1 * 1024 * 1024)
}
}
impl Parser {
pub fn with_capacity(capacity: usize) -> Self {
Self{
buffer: Buffer::with_capacity(capacity),
current_header: None
}
}
#[inline]
pub fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.buffer.write(buf)
}
fn check_packet(&mut self) -> bool {
match &self.current_header {
None => {
if(self.buffer.available_data() >= 5) {
let header_bytes = &self.buffer.data()[..5];
self.current_header = Some(PacketHeader::parse(header_bytes).unwrap().1);
self.check_packet()
} else {
false
}
},
Some(header) => {
self.buffer.available_data() >= header.length as usize + 4
}
}
}
pub fn get_packet(&mut self) -> Result<Option<Packet>, Error> {
match (self.check_packet(), &self.current_header) {
(false, _) => Ok(None),
(true, None) => Err(Error::NoHeader),
(true, Some(header)) => {
let len = header.length as usize + 4; let data = &self.buffer.data()[..len];
let packet = Packet::parse(data)?.1;
self.buffer.consume(len);
self.current_header = None;
Ok(Some(packet))
}
}
}
}
#[cfg(test)]
pub(crate) fn encode(packet: &Packet) -> Vec<u8> {
use bincode::Options;
let en = bincode::DefaultOptions::new().with_big_endian().with_fixint_encoding();
en.serialize(packet).unwrap()
}
#[cfg(test)]
mod test {
use rand::Rng;
use test_strategy::proptest;
use crate::packet::*;
use super::*;
fn random_slices<T>(slice: &[T]) -> Vec<&[T]> {
let mut rng = rand::thread_rng();
let magic = rng.gen_range(2..std::cmp::min(16, slice.len() / 2));
let count = rng.gen_range(1..(slice.len() / magic));
let approx_size = slice.len() / count;
let mut lengths = (0..count).map(|_| rng.gen_range((approx_size / 2)..((approx_size * 3) / 2))).collect::<Vec<_>>();
while(lengths.iter().sum::<usize>() != slice.len()) {
if(lengths.iter().sum::<usize>() < slice.len()) {
for len in lengths.iter_mut() {
if(rng.gen::<f32>() < 0.01) {
(*len) += 1;
}
}
} else {
for len in lengths.iter_mut() {
if(*len > 1 && rng.gen::<f32>() < 0.01) {
(*len) -= 1;
}
}
}
}
let mut subslices = Vec::with_capacity(lengths.len());
let mut current_pos = 0;
for len in lengths {
subslices.push(&slice[current_pos..(len + current_pos)]);
current_pos += len;
}
subslices
}
#[ignore] #[test]
fn validate_packet_result_size() {
assert_eq!(std::mem::size_of::<Result<Packet, Error>>(), std::mem::size_of::<Packet>() + std::mem::size_of::<*const ()>());
assert_eq!(std::mem::size_of::<Result<Option<Packet>, Error>>(), std::mem::size_of::<Result<Packet, Error>>());
}
#[test]
fn single_init() {
let mut stream = Parser::default();
assert_eq!(stream.get_packet(), Ok(None));
let packet = Payload::init(1, vec![]).into_packet();
stream.write(&encode(&packet)).unwrap();
assert_eq!(stream.get_packet(), Ok(Some(packet)));
assert_eq!(stream.get_packet(), Ok(None));
}
#[test]
fn multipart_init() {
let mut stream = Parser::default();
let packet = Payload::init(2, vec![]).into_packet();
let bytes = encode(&packet);
stream.write(&bytes[0..3]).unwrap();
assert_eq!(stream.get_packet(), Ok(None));
stream.write(&bytes[3..bytes.len()]).unwrap();
assert_eq!(stream.get_packet(), Ok(Some(packet)));
}
#[test]
fn handshake() {
let mut stream = Parser::default();
let init = Payload::init(32768, (0..100).collect()).into_packet();
stream.write(&encode(&init)).unwrap();
assert_eq!(stream.get_packet(), Ok(Some(init)));
let version = Payload::version(3, (100..150).collect()).into_packet();
stream.write(&encode(&version)).unwrap();
assert_eq!(stream.get_packet(), Ok(Some(version)));
assert_eq!(stream.get_packet(), Ok(None));
}
#[test]
fn handshake_queued() {
let mut stream = Parser::default();
let init = Payload::init(32768, (0..100).collect()).into_packet();
stream.write(&encode(&init)).unwrap();
let version = Payload::version(3, (100..150).collect()).into_packet();
stream.write(&encode(&version)).unwrap();
assert_eq!(stream.get_packet(), Ok(Some(init)));
assert_eq!(stream.get_packet(), Ok(Some(version)));
assert_eq!(stream.get_packet(), Ok(None));
}
#[proptest]
fn arbitrary_sequence(input: Vec<Packet>) {
let mut stream = Parser::default();
assert_eq!(stream.get_packet(), Ok(None));
for packet in input {
stream.write(&encode(&packet)).unwrap();
assert_eq!(stream.get_packet(), Ok(Some(packet)));
assert_eq!(stream.get_packet(), Ok(None));
}
assert_eq!(stream.get_packet(), Ok(None));
}
#[proptest]
fn arbitrary_sequence_queued(input: Vec<Packet>) {
let mut stream = Parser::default();
assert_eq!(stream.get_packet(), Ok(None));
for packet in &input {
stream.write(&encode(packet)).unwrap();
}
for packet in input {
assert_eq!(stream.get_packet(), Ok(Some(packet)));
}
assert_eq!(stream.get_packet(), Ok(None));
}
#[proptest]
fn arbitrary_sequence_multipart(input: Vec<Packet>) {
let mut stream = Parser::default();
assert_eq!(stream.get_packet(), Ok(None));
for packet in input {
let bytes = encode(&packet);
let slices = random_slices(&bytes);
for slice in slices {
stream.write(slice).unwrap();
}
assert_eq!(stream.get_packet(), Ok(Some(packet)));
assert_eq!(stream.get_packet(), Ok(None));
}
assert_eq!(stream.get_packet(), Ok(None));
}
#[proptest]
fn arbitrary_sequence_multipart_queued(input: Vec<Packet>) {
let mut stream = Parser::default();
assert_eq!(stream.get_packet(), Ok(None));
for packet in &input {
let bytes = encode(packet);
let slices = random_slices(&bytes);
for slice in slices {
stream.write(slice).unwrap();
}
}
for packet in input {
assert_eq!(stream.get_packet(), Ok(Some(packet)));
}
assert_eq!(stream.get_packet(), Ok(None));
}
}