#![cfg(any(test, feature = "test-support"))]
use std::io;
use bytes::{Buf, BufMut, BytesMut};
use crate::message_assembler::{
ContinuationFrameHeader,
FirstFrameHeader,
FrameHeader,
FrameSequence,
MessageAssembler,
MessageKey,
ParsedFrameHeader,
};
pub mod frame_codec;
#[cfg(feature = "pool")]
pub mod pool_client;
pub use frame_codec::{TestAdapter, TestCodec, TestFrame};
#[cfg(feature = "pool")]
pub use pool_client::{
ClientHello,
Ping,
Pong,
PoolServerBehavior,
PoolTestServer,
TestClientPool,
acquire_and_record,
build_pooled_client,
build_preamble_pool,
};
#[derive(Clone, Copy, Debug, Default)]
pub struct TestAssembler;
impl MessageAssembler for TestAssembler {
fn parse_frame_header(&self, payload: &[u8]) -> Result<ParsedFrameHeader, io::Error> {
parse_frame_header(payload)
}
}
#[derive(Clone, Copy, Debug)]
struct FrameFlags(u8);
impl FrameFlags {
fn is_last(self) -> bool { self.0 & 0b1 == 0b1 }
fn has_optional_field(self) -> bool { self.0 & 0b10 == 0b10 }
}
pub fn parse_frame_header(payload: &[u8]) -> Result<ParsedFrameHeader, io::Error> {
let mut buf = payload;
let initial = buf.remaining();
let kind = take_u8(&mut buf)?;
let flags = FrameFlags(take_u8(&mut buf)?);
let message_key = MessageKey::from(take_u64(&mut buf)?);
let header = match kind {
0x01 => parse_first_frame_header(&mut buf, flags, message_key)?,
0x02 => parse_continuation_frame_header(&mut buf, flags, message_key)?,
_ => return Err(invalid_data("unknown header kind")),
};
let header_len = initial - buf.remaining();
Ok(ParsedFrameHeader::new(header, header_len))
}
fn parse_first_frame_header(
buf: &mut &[u8],
flags: FrameFlags,
message_key: MessageKey,
) -> Result<FrameHeader, io::Error> {
let metadata_len = usize::from(take_u16(buf)?);
let body_len = take_usize_u32(buf, "body length too large")?;
let total_body_len = take_optional_usize_u32(buf, flags, "total length too large")?;
Ok(FrameHeader::First(FirstFrameHeader {
message_key,
metadata_len,
body_len,
total_body_len,
is_last: flags.is_last(),
}))
}
fn parse_continuation_frame_header(
buf: &mut &[u8],
flags: FrameFlags,
message_key: MessageKey,
) -> Result<FrameHeader, io::Error> {
let body_len = take_usize_u32(buf, "body length too large")?;
let sequence = take_optional_sequence(buf, flags)?;
Ok(FrameHeader::Continuation(ContinuationFrameHeader {
message_key,
sequence,
body_len,
is_last: flags.is_last(),
}))
}
fn take_usize_u32(buf: &mut &[u8], message: &'static str) -> Result<usize, io::Error> {
usize::try_from(take_u32(buf)?).map_err(|_| invalid_data(message))
}
fn take_optional_usize_u32(
buf: &mut &[u8],
flags: FrameFlags,
message: &'static str,
) -> Result<Option<usize>, io::Error> {
if !flags.has_optional_field() {
return Ok(None);
}
take_usize_u32(buf, message).map(Some)
}
fn take_optional_sequence(
buf: &mut &[u8],
flags: FrameFlags,
) -> Result<Option<FrameSequence>, io::Error> {
if !flags.has_optional_field() {
return Ok(None);
}
Ok(Some(FrameSequence::from(take_u32(buf)?)))
}
fn take_u8(buf: &mut &[u8]) -> Result<u8, io::Error> {
ensure_remaining(buf, 1)?;
Ok(buf.get_u8())
}
fn take_u16(buf: &mut &[u8]) -> Result<u16, io::Error> {
ensure_remaining(buf, 2)?;
Ok(buf.get_u16())
}
fn take_u32(buf: &mut &[u8]) -> Result<u32, io::Error> {
ensure_remaining(buf, 4)?;
Ok(buf.get_u32())
}
fn take_u64(buf: &mut &[u8]) -> Result<u64, io::Error> {
ensure_remaining(buf, 8)?;
Ok(buf.get_u64())
}
fn ensure_remaining(buf: &mut &[u8], needed: usize) -> Result<(), io::Error> {
if buf.remaining() < needed {
return Err(invalid_data("header too short"));
}
Ok(())
}
fn invalid_data(message: &'static str) -> io::Error {
io::Error::new(io::ErrorKind::InvalidData, message)
}
pub fn first_frame_payload(
key: MessageKey,
body: &[u8],
is_last: bool,
total: Option<u32>,
) -> Result<Vec<u8>, io::Error> {
let mut payload = BytesMut::new();
payload.put_u8(0x01);
let mut flags = 0u8;
if is_last {
flags |= 0b1;
}
if total.is_some() {
flags |= 0b10;
}
payload.put_u8(flags);
payload.put_u64(u64::from(key));
payload.put_u16(0);
let body_len = u32::try_from(body.len()).map_err(|_| invalid_data("body length too large"))?;
payload.put_u32(body_len);
if let Some(total) = total {
payload.put_u32(total);
}
payload.extend_from_slice(body);
Ok(payload.to_vec())
}
pub fn continuation_frame_payload(
key: MessageKey,
sequence: FrameSequence,
body: &[u8],
is_last: bool,
) -> Result<Vec<u8>, io::Error> {
let mut payload = BytesMut::new();
payload.put_u8(0x02);
let mut flags = 0b10;
if is_last {
flags |= 0b1;
}
payload.put_u8(flags);
payload.put_u64(u64::from(key));
let body_len = u32::try_from(body.len()).map_err(|_| invalid_data("body length too large"))?;
payload.put_u32(body_len);
payload.put_u32(u32::from(sequence));
payload.extend_from_slice(body);
Ok(payload.to_vec())
}