wireframe 0.3.0

Simplify building servers and clients for custom binary protocols.
Documentation
//! Test-only helpers for shared test utilities.
#![cfg(any(test, feature = "test-support"))]

use std::io;

use bytes::{Buf, BufMut, BytesMut};

use crate::message_assembler::{
    ContinuationFrameHeader,
    FirstFrameHeader,
    FrameHeader,
    FrameSequence,
    MessageAssembler,
    MessageKey,
    ParsedFrameHeader,
};

pub mod frame_codec;
#[cfg(feature = "pool")]
pub mod pool_client;

pub use frame_codec::{TestAdapter, TestCodec, TestFrame};
#[cfg(feature = "pool")]
pub use pool_client::{
    ClientHello,
    Ping,
    Pong,
    PoolServerBehavior,
    PoolTestServer,
    TestClientPool,
    acquire_and_record,
    build_pooled_client,
    build_preamble_pool,
};

/// Test-friendly message assembler implementation that shares parsing logic.
#[derive(Clone, Copy, Debug, Default)]
pub struct TestAssembler;

impl MessageAssembler for TestAssembler {
    fn parse_frame_header(&self, payload: &[u8]) -> Result<ParsedFrameHeader, io::Error> {
        parse_frame_header(payload)
    }
}

#[derive(Clone, Copy, Debug)]
struct FrameFlags(u8);

impl FrameFlags {
    fn is_last(self) -> bool { self.0 & 0b1 == 0b1 }

    fn has_optional_field(self) -> bool { self.0 & 0b10 == 0b10 }
}

/// Parse a protocol-specific frame header for tests.
///
/// # Errors
///
/// Returns an error if the payload is too short or contains an invalid header.
pub fn parse_frame_header(payload: &[u8]) -> Result<ParsedFrameHeader, io::Error> {
    let mut buf = payload;
    let initial = buf.remaining();

    let kind = take_u8(&mut buf)?;
    let flags = FrameFlags(take_u8(&mut buf)?);
    let message_key = MessageKey::from(take_u64(&mut buf)?);

    let header = match kind {
        0x01 => parse_first_frame_header(&mut buf, flags, message_key)?,
        0x02 => parse_continuation_frame_header(&mut buf, flags, message_key)?,
        _ => return Err(invalid_data("unknown header kind")),
    };

    let header_len = initial - buf.remaining();
    Ok(ParsedFrameHeader::new(header, header_len))
}

fn parse_first_frame_header(
    buf: &mut &[u8],
    flags: FrameFlags,
    message_key: MessageKey,
) -> Result<FrameHeader, io::Error> {
    let metadata_len = usize::from(take_u16(buf)?);
    let body_len = take_usize_u32(buf, "body length too large")?;
    let total_body_len = take_optional_usize_u32(buf, flags, "total length too large")?;

    Ok(FrameHeader::First(FirstFrameHeader {
        message_key,
        metadata_len,
        body_len,
        total_body_len,
        is_last: flags.is_last(),
    }))
}

fn parse_continuation_frame_header(
    buf: &mut &[u8],
    flags: FrameFlags,
    message_key: MessageKey,
) -> Result<FrameHeader, io::Error> {
    let body_len = take_usize_u32(buf, "body length too large")?;
    let sequence = take_optional_sequence(buf, flags)?;

    Ok(FrameHeader::Continuation(ContinuationFrameHeader {
        message_key,
        sequence,
        body_len,
        is_last: flags.is_last(),
    }))
}

fn take_usize_u32(buf: &mut &[u8], message: &'static str) -> Result<usize, io::Error> {
    usize::try_from(take_u32(buf)?).map_err(|_| invalid_data(message))
}

fn take_optional_usize_u32(
    buf: &mut &[u8],
    flags: FrameFlags,
    message: &'static str,
) -> Result<Option<usize>, io::Error> {
    if !flags.has_optional_field() {
        return Ok(None);
    }

    take_usize_u32(buf, message).map(Some)
}

fn take_optional_sequence(
    buf: &mut &[u8],
    flags: FrameFlags,
) -> Result<Option<FrameSequence>, io::Error> {
    if !flags.has_optional_field() {
        return Ok(None);
    }

    Ok(Some(FrameSequence::from(take_u32(buf)?)))
}

fn take_u8(buf: &mut &[u8]) -> Result<u8, io::Error> {
    ensure_remaining(buf, 1)?;
    Ok(buf.get_u8())
}

fn take_u16(buf: &mut &[u8]) -> Result<u16, io::Error> {
    ensure_remaining(buf, 2)?;
    Ok(buf.get_u16())
}

fn take_u32(buf: &mut &[u8]) -> Result<u32, io::Error> {
    ensure_remaining(buf, 4)?;
    Ok(buf.get_u32())
}

fn take_u64(buf: &mut &[u8]) -> Result<u64, io::Error> {
    ensure_remaining(buf, 8)?;
    Ok(buf.get_u64())
}

fn ensure_remaining(buf: &mut &[u8], needed: usize) -> Result<(), io::Error> {
    if buf.remaining() < needed {
        return Err(invalid_data("header too short"));
    }
    Ok(())
}

fn invalid_data(message: &'static str) -> io::Error {
    io::Error::new(io::ErrorKind::InvalidData, message)
}

/// Build a first-frame payload for the test protocol.
///
/// # Errors
///
/// Returns an error if the body length exceeds `u32::MAX`.
pub fn first_frame_payload(
    key: MessageKey,
    body: &[u8],
    is_last: bool,
    total: Option<u32>,
) -> Result<Vec<u8>, io::Error> {
    let mut payload = BytesMut::new();
    payload.put_u8(0x01);
    let mut flags = 0u8;
    if is_last {
        flags |= 0b1;
    }
    if total.is_some() {
        flags |= 0b10;
    }
    payload.put_u8(flags);
    payload.put_u64(u64::from(key));
    payload.put_u16(0);
    let body_len = u32::try_from(body.len()).map_err(|_| invalid_data("body length too large"))?;
    payload.put_u32(body_len);
    if let Some(total) = total {
        payload.put_u32(total);
    }
    payload.extend_from_slice(body);
    Ok(payload.to_vec())
}

/// Build a continuation-frame payload for the test protocol.
///
/// # Errors
///
/// Returns an error if the body length exceeds `u32::MAX`.
pub fn continuation_frame_payload(
    key: MessageKey,
    sequence: FrameSequence,
    body: &[u8],
    is_last: bool,
) -> Result<Vec<u8>, io::Error> {
    let mut payload = BytesMut::new();
    payload.put_u8(0x02);
    let mut flags = 0b10;
    if is_last {
        flags |= 0b1;
    }
    payload.put_u8(flags);
    payload.put_u64(u64::from(key));
    let body_len = u32::try_from(body.len()).map_err(|_| invalid_data("body length too large"))?;
    payload.put_u32(body_len);
    payload.put_u32(u32::from(sequence));
    payload.extend_from_slice(body);
    Ok(payload.to_vec())
}