wireframe 0.3.0

Simplify building servers and clients for custom binary protocols.
Documentation
//! Shared helpers for codec-error regression tests.

use bytes::BytesMut;
use tokio_util::codec::Decoder;
use wireframe::{
    byte_order::write_network_u32,
    codec::{
        CodecError,
        CodecErrorContext,
        EofError,
        FrameCodec,
        LENGTH_HEADER_SIZE,
        LengthDelimitedFrameCodec,
        RecoveryPolicy,
        RecoveryPolicyHook,
    },
};
use wireframe_testing::{TestResult, encode_frame, new_test_codec};

pub(crate) struct StrictRecoveryHook;

impl RecoveryPolicyHook for StrictRecoveryHook {
    fn recovery_policy(&self, _error: &CodecError, _ctx: &CodecErrorContext) -> RecoveryPolicy {
        RecoveryPolicy::Disconnect
    }
}

pub(crate) fn extract_eof_error(error: &std::io::Error) -> TestResult<EofError> {
    let mut current = error
        .get_ref()
        .map(|inner| inner as &(dyn std::error::Error + 'static));

    while let Some(err) = current {
        if let Some(eof) = err.downcast_ref::<EofError>() {
            return Ok(*eof);
        }
        current = err.source();
    }

    Err(format!("expected wrapped EofError, got {error}").into())
}

pub(crate) fn classify_eof(bytes: &[u8], consume_frame_before_eof: bool) -> TestResult<EofError> {
    let codec = LengthDelimitedFrameCodec::new(1024);
    let mut decoder = codec.decoder();
    let mut buffer = BytesMut::from(bytes);

    if consume_frame_before_eof {
        let frame = decoder.decode(&mut buffer)?;
        if frame.is_none() {
            return Err("expected a complete frame before EOF".into());
        }
    }

    match decoder.decode_eof(&mut buffer) {
        Ok(None) => Ok(EofError::CleanClose),
        Ok(Some(_)) => Err("unexpected extra frame while classifying EOF".into()),
        Err(error) => extract_eof_error(&error),
    }
}

pub(crate) fn encoded_default_frame(payload: &[u8]) -> TestResult<Vec<u8>> {
    let mut codec = new_test_codec(1024);
    encode_frame(&mut codec, payload.to_vec()).map_err(Into::into)
}

pub(crate) fn truncated_payload_wire(payload: &[u8]) -> TestResult<Vec<u8>> {
    let expected = u32::try_from(payload.len())?;
    let partial_len = payload.len().saturating_sub(1);
    let mut wire = Vec::with_capacity(LENGTH_HEADER_SIZE + partial_len);
    wire.extend_from_slice(&write_network_u32(expected));
    let partial = payload
        .get(..partial_len)
        .ok_or("partial payload slice should be available")?;
    wire.extend_from_slice(partial);
    Ok(wire)
}