mod decoder_ops;
use bytes::BytesMut;
use rstest::fixture;
use wireframe::codec::{CodecError, EofError, FramingError, ProtocolError, RecoveryPolicy};
pub use wireframe_testing::TestResult;
#[derive(Clone, Copy, Debug, Default)]
pub enum ErrorType {
#[default]
Framing,
Protocol,
Io,
Eof,
}
#[derive(Clone, Copy, Debug, Default)]
pub enum FramingVariant {
#[default]
Oversized,
InvalidEncoding,
IncompleteHeader,
ChecksumMismatch,
Empty,
}
#[derive(Clone, Copy, Debug, Default)]
pub enum EofVariant {
#[default]
CleanClose,
MidFrame,
MidHeader,
}
#[derive(Debug, Default)]
pub struct CodecErrorWorld {
error_type: ErrorType,
framing_variant: FramingVariant,
eof_variant: EofVariant,
current_error: Option<CodecError>,
pub(crate) detected_eof: Option<EofError>,
pub(crate) max_frame_length: usize,
pub(crate) buffer: BytesMut,
pub(crate) decoder_error: Option<std::io::Error>,
pub(crate) clean_close_detected: bool,
}
#[rustfmt::skip]
#[fixture]
pub fn codec_error_world() -> CodecErrorWorld {
CodecErrorWorld::default()
}
impl CodecErrorWorld {
pub fn set_error_type(&mut self, error_type: &str) -> TestResult {
self.error_type = match error_type {
"framing" => ErrorType::Framing,
"protocol" => ErrorType::Protocol,
"io" => ErrorType::Io,
"eof" => ErrorType::Eof,
_ => return Err(format!("unknown error type: {error_type}").into()),
};
self.build_error();
Ok(())
}
pub fn set_framing_variant(&mut self, variant: &str) -> TestResult {
self.framing_variant = match variant {
"oversized" => FramingVariant::Oversized,
"invalid_encoding" => FramingVariant::InvalidEncoding,
"incomplete_header" => FramingVariant::IncompleteHeader,
"checksum_mismatch" => FramingVariant::ChecksumMismatch,
"empty" => FramingVariant::Empty,
_ => return Err(format!("unknown framing variant: {variant}").into()),
};
self.build_error();
Ok(())
}
pub fn set_eof_variant(&mut self, variant: &str) -> TestResult {
self.eof_variant = match variant {
"clean_close" => EofVariant::CleanClose,
"mid_frame" => EofVariant::MidFrame,
"mid_header" => EofVariant::MidHeader,
_ => return Err(format!("unknown eof variant: {variant}").into()),
};
self.build_error();
Ok(())
}
fn build_error(&mut self) {
self.current_error = Some(match self.error_type {
ErrorType::Framing => CodecError::Framing(self.build_framing_error()),
ErrorType::Protocol => {
CodecError::Protocol(ProtocolError::UnknownMessageType { type_id: 99 })
}
ErrorType::Io => CodecError::Io(std::io::Error::other("test error")),
ErrorType::Eof => CodecError::Eof(self.build_eof_error()),
});
}
fn build_framing_error(&self) -> FramingError {
match self.framing_variant {
FramingVariant::Oversized => FramingError::OversizedFrame {
size: 2000,
max: 1024,
},
FramingVariant::InvalidEncoding => FramingError::InvalidLengthEncoding,
FramingVariant::IncompleteHeader => FramingError::IncompleteHeader { have: 2, need: 4 },
FramingVariant::ChecksumMismatch => FramingError::ChecksumMismatch {
expected: 0xdead,
actual: 0xbeef,
},
FramingVariant::Empty => FramingError::EmptyFrame,
}
}
fn build_eof_error(&self) -> EofError {
match self.eof_variant {
EofVariant::CleanClose => EofError::CleanClose,
EofVariant::MidFrame => EofError::MidFrame {
bytes_received: 100,
expected: 200,
},
EofVariant::MidHeader => EofError::MidHeader {
bytes_received: 2,
header_size: 4,
},
}
}
pub fn verify_recovery_policy(&self, expected: &str) -> TestResult {
let expected_policy = match expected {
"drop" => RecoveryPolicy::Drop,
"quarantine" => RecoveryPolicy::Quarantine,
"disconnect" => RecoveryPolicy::Disconnect,
_ => return Err(format!("unknown recovery policy: {expected}").into()),
};
let error = self.current_error.as_ref().ok_or("no error has been set")?;
let actual_policy = error.default_recovery_policy();
if actual_policy != expected_policy {
return Err(
format!("expected policy {expected_policy:?}, got {actual_policy:?}").into(),
);
}
Ok(())
}
}