mod types;
use std::{fmt, io};
use bytes::{BufMut, BytesMut};
use rstest::fixture;
pub use types::*;
use wireframe::{
message_assembler::{FrameHeader, MessageAssembler, ParsedFrameHeader},
test_helpers::TestAssembler,
};
pub use wireframe_testing::TestResult;
use crate::TestApp;
#[derive(Debug, Clone, Copy)]
pub struct FirstHeaderSpec {
pub key: MessageKey,
pub metadata_len: MetadataLength,
pub body_len: BodyLength,
pub total_len: Option<BodyLength>,
pub is_last: bool,
}
impl FirstHeaderSpec {
pub fn new(key: MessageKey, body_len: BodyLength) -> Self {
Self {
key,
metadata_len: MetadataLength(0),
body_len,
total_len: None,
is_last: false,
}
}
pub fn with_metadata_len(mut self, metadata_len: MetadataLength) -> Self {
self.metadata_len = metadata_len;
self
}
pub fn with_total_len(mut self, total_len: BodyLength) -> Self {
self.total_len = Some(total_len);
self
}
pub fn with_last_flag(mut self, is_last: bool) -> Self {
self.is_last = is_last;
self
}
}
#[derive(Debug, Clone, Copy)]
pub struct ContinuationHeaderSpec {
pub key: MessageKey,
pub body_len: BodyLength,
pub sequence: Option<SequenceNumber>,
pub is_last: bool,
}
impl ContinuationHeaderSpec {
pub fn new(key: MessageKey, body_len: BodyLength) -> Self {
Self {
key,
body_len,
sequence: None,
is_last: false,
}
}
pub fn with_sequence(mut self, sequence: SequenceNumber) -> Self {
self.sequence = Some(sequence);
self
}
pub fn with_last_flag(mut self, is_last: bool) -> Self {
self.is_last = is_last;
self
}
}
#[derive(Debug, Clone, Copy)]
struct HeaderEnvelope {
kind: u8,
flags: u8,
key: u64,
}
#[derive(Default)]
pub struct MessageAssemblerWorld {
payload: Option<Vec<u8>>,
parsed: Option<ParsedFrameHeader>,
error: Option<io::Error>,
app: Option<TestApp>,
}
impl fmt::Debug for MessageAssemblerWorld {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MessageAssemblerWorld")
.field("payload", &self.payload)
.field("parsed", &self.parsed)
.field("error", &self.error)
.field(
"app",
&self.app.as_ref().map(|_| "wireframe::app::WireframeApp"),
)
.finish()
}
}
#[rustfmt::skip]
#[fixture]
pub fn message_assembler_world() -> MessageAssemblerWorld {
MessageAssemblerWorld::default()
}
impl MessageAssemblerWorld {
fn assert_field<T, F>(&self, field_name: &str, expected: &T, extractor: F) -> TestResult
where
T: PartialEq + fmt::Display + Copy,
F: FnOnce(&FrameHeader) -> Result<T, String>,
{
let parsed = self.parsed.as_ref().ok_or("no parsed header")?;
let actual = extractor(parsed.header())?;
if actual != *expected {
return Err(format!("expected {field_name} {expected}, got {actual}").into());
}
Ok(())
}
fn assert_first_field<T, F>(&self, field_name: &str, expected: &T, extractor: F) -> TestResult
where
T: PartialEq + fmt::Display + Copy,
F: FnOnce(&wireframe::message_assembler::FirstFrameHeader) -> T,
{
self.assert_field(field_name, expected, |header| {
if let FrameHeader::First(header) = header {
Ok(extractor(header))
} else {
Err("expected first header".to_string())
}
})
}
fn assert_continuation_field<T, F>(
&self,
field_name: &str,
expected: &T,
extractor: F,
) -> TestResult
where
T: PartialEq + fmt::Display + Copy,
F: FnOnce(&wireframe::message_assembler::ContinuationFrameHeader) -> T,
{
self.assert_field(field_name, expected, |header| {
if let FrameHeader::Continuation(header) = header {
Ok(extractor(header))
} else {
Err("expected continuation header".to_string())
}
})
}
pub fn set_first_header(&mut self, spec: FirstHeaderSpec) -> TestResult {
let mut flags = 0u8;
if spec.is_last {
flags |= 0b1;
}
if spec.total_len.is_some() {
flags |= 0b10;
}
self.set_payload_with_header(
HeaderEnvelope {
kind: 0x01,
flags,
key: spec.key.0,
},
|bytes| {
let metadata_len =
u16::try_from(spec.metadata_len.0).map_err(|_| "metadata length too large")?;
bytes.put_u16(metadata_len);
let body_len =
u32::try_from(spec.body_len.0).map_err(|_| "body length too large")?;
bytes.put_u32(body_len);
if let Some(total) = spec.total_len {
let total = u32::try_from(total.0).map_err(|_| "total length too large")?;
bytes.put_u32(total);
}
Ok(())
},
)
}
pub fn set_continuation_header(&mut self, spec: ContinuationHeaderSpec) -> TestResult {
let mut flags = 0u8;
if spec.is_last {
flags |= 0b1;
}
if spec.sequence.is_some() {
flags |= 0b10;
}
self.set_payload_with_header(
HeaderEnvelope {
kind: 0x02,
flags,
key: spec.key.0,
},
|bytes| {
let body_len =
u32::try_from(spec.body_len.0).map_err(|_| "body length too large")?;
bytes.put_u32(body_len);
if let Some(seq) = spec.sequence {
bytes.put_u32(seq.0);
}
Ok(())
},
)
}
fn set_payload_with_header<F>(&mut self, envelope: HeaderEnvelope, encode: F) -> TestResult
where
F: FnOnce(&mut BytesMut) -> TestResult,
{
let mut bytes = BytesMut::new();
bytes.put_u8(envelope.kind);
bytes.put_u8(envelope.flags);
bytes.put_u64(envelope.key);
encode(&mut bytes)?;
self.payload = Some(bytes.to_vec());
Ok(())
}
pub fn set_invalid_payload(&mut self) { self.payload = Some(vec![0x01]); }
pub fn parse_header(&mut self) -> TestResult {
let payload = self.payload.as_deref().ok_or("payload not set")?;
let fallback = TestAssembler;
let assembler: &dyn MessageAssembler = match self.app.as_ref() {
Some(app) => app
.message_assembler()
.ok_or("message assembler not set")?
.as_ref(),
None => &fallback,
};
match assembler.parse_frame_header(payload) {
Ok(parsed) => {
self.parsed = Some(parsed);
self.error = None;
}
Err(err) => {
self.parsed = None;
self.error = Some(err);
}
}
Ok(())
}
pub fn assert_header_kind(&self, expected: &str) -> TestResult {
let parsed = self.parsed.as_ref().ok_or("no parsed header")?;
let matches_kind = matches!(
(expected, parsed.header()),
("first", FrameHeader::First(_)) | ("continuation", FrameHeader::Continuation(_))
);
if matches_kind {
Ok(())
} else {
Err(format!("expected {expected} header").into())
}
}
}
mod message_assembler_asserts;