use bytes::BytesMut;
use tokio_util::codec::Decoder;
use wireframe::{
byte_order::{read_network_u32, write_network_u32},
codec::{EofError, FrameCodec, LENGTH_HEADER_SIZE, LengthDelimitedFrameCodec},
};
use super::{CodecErrorWorld, TestResult};
impl CodecErrorWorld {
fn reset_codec_state(&mut self) {
self.buffer = BytesMut::new();
self.decoder_error = None;
self.clean_close_detected = false;
}
pub fn setup_default_codec(&mut self) {
self.max_frame_length = 1024;
self.reset_codec_state();
}
pub fn setup_codec_with_max_length(&mut self, max_len: usize) {
self.max_frame_length = max_len;
self.reset_codec_state();
}
pub fn send_complete_frame(&mut self, payload: &[u8]) -> TestResult {
use tokio_util::codec::Encoder;
let codec = LengthDelimitedFrameCodec::new(self.max_frame_length);
let mut encoder = codec.encoder();
encoder.encode(bytes::Bytes::copy_from_slice(payload), &mut self.buffer)?;
Ok(())
}
pub fn send_partial_frame_header_only(&mut self) {
self.buffer.extend_from_slice(&write_network_u32(100)); }
pub fn decode_eof_clean_close(&mut self) -> TestResult {
let codec = LengthDelimitedFrameCodec::new(self.max_frame_length);
let mut decoder = codec.decoder();
while let Some(_frame) = decoder.decode(&mut self.buffer)? {
}
match decoder.decode_eof(&mut self.buffer) {
Ok(None) => {
self.clean_close_detected = true;
self.detected_eof = Some(EofError::CleanClose);
Ok(())
}
Ok(Some(_)) => Err("unexpected frame after EOF".into()),
Err(e) => {
self.decoder_error = Some(e);
Err("expected clean close, got error".into())
}
}
}
fn extract_expected_length(&self) -> usize {
self.buffer
.get(..LENGTH_HEADER_SIZE)
.and_then(|slice| <[u8; LENGTH_HEADER_SIZE]>::try_from(slice).ok())
.map_or(0, |bytes| read_network_u32(bytes) as usize)
}
fn classify_eof_error(&mut self, e: &std::io::Error) {
if e.kind() != std::io::ErrorKind::UnexpectedEof {
return;
}
let detected = e
.get_ref()
.and_then(Self::find_eof_error)
.unwrap_or_else(|| self.infer_eof_from_buffer());
self.detected_eof = Some(detected);
}
fn find_eof_error(error: &(dyn std::error::Error + Send + Sync + 'static)) -> Option<EofError> {
let mut current: Option<&(dyn std::error::Error + 'static)> = Some(error);
while let Some(err) = current {
if let Some(eof) = err.downcast_ref::<EofError>() {
return Some(*eof);
}
current = err.source();
}
None
}
fn infer_eof_from_buffer(&self) -> EofError {
if self.buffer.len() < LENGTH_HEADER_SIZE {
EofError::MidHeader {
bytes_received: self.buffer.len(),
header_size: LENGTH_HEADER_SIZE,
}
} else {
EofError::MidFrame {
bytes_received: self.buffer.len().saturating_sub(LENGTH_HEADER_SIZE),
expected: self.extract_expected_length(),
}
}
}
pub fn decode_eof_with_partial_data(&mut self) -> TestResult {
let codec = LengthDelimitedFrameCodec::new(self.max_frame_length);
let mut decoder = codec.decoder();
match decoder.decode_eof(&mut self.buffer) {
Ok(None) => Err("expected EOF error, got Ok(None)".into()),
Ok(Some(_)) => Err("expected EOF error, got frame".into()),
Err(e) => {
self.classify_eof_error(&e);
self.decoder_error = Some(e);
Ok(())
}
}
}
pub fn encode_oversized_frame(&mut self, size: usize) -> TestResult {
use tokio_util::codec::Encoder;
let codec = LengthDelimitedFrameCodec::new(self.max_frame_length);
let mut encoder = codec.encoder();
let payload = bytes::Bytes::from(vec![0_u8; size]);
match encoder.encode(payload, &mut self.buffer) {
Ok(()) => Err("expected oversized error, got Ok".into()),
Err(e) => {
self.decoder_error = Some(e);
Ok(())
}
}
}
pub fn verify_clean_eof(&self) -> TestResult {
if self.clean_close_detected {
return Ok(());
}
match &self.detected_eof {
Some(EofError::CleanClose) => Ok(()),
Some(other) => Err(format!("expected clean close, got {other:?}").into()),
None => Err("no EOF was detected".into()),
}
}
pub fn verify_incomplete_eof(&self) -> TestResult {
match &self.detected_eof {
Some(EofError::MidFrame { .. } | EofError::MidHeader { .. }) => Ok(()),
Some(other) => Err(format!("expected incomplete EOF, got {other:?}").into()),
None => Err("no EOF was detected".into()),
}
}
pub fn verify_oversized_error(&self) -> TestResult {
let err = self
.decoder_error
.as_ref()
.ok_or("no decoder error captured")?;
if err.kind() == std::io::ErrorKind::InvalidData {
Ok(())
} else {
Err(format!("expected InvalidData error, got {:?}", err.kind()).into())
}
}
}
#[cfg(test)]
mod tests {
use bytes::BufMut;
use rstest::{fixture, rstest};
use super::*;
#[fixture]
fn codec_error_world() -> CodecErrorWorld {
let mut world = CodecErrorWorld::default();
world.reset_codec_state();
world
}
#[rstest]
#[case::clean_close(EofError::CleanClose)]
#[case::mid_header(EofError::MidHeader {
bytes_received: 1,
header_size: LENGTH_HEADER_SIZE,
})]
#[case::mid_frame(EofError::MidFrame {
bytes_received: 2,
expected: 3,
})]
fn classify_eof_error_uses_inner_eof_error_variant(
#[case] variant: EofError,
mut codec_error_world: CodecErrorWorld,
) {
let io_err = std::io::Error::new(std::io::ErrorKind::UnexpectedEof, variant);
codec_error_world.classify_eof_error(&io_err);
assert_eq!(codec_error_world.detected_eof, Some(variant));
}
#[rstest]
fn classify_eof_error_falls_back_to_buffer_classification(
mut codec_error_world: CodecErrorWorld,
) {
let io_err = std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "not an eof error");
codec_error_world.buffer.extend_from_slice(&[0x01, 0x02]);
let expected = codec_error_world.infer_eof_from_buffer();
codec_error_world.classify_eof_error(&io_err);
assert_eq!(codec_error_world.detected_eof, Some(expected));
}
#[rstest]
fn classify_eof_error_ignores_non_unexpected_eof(mut codec_error_world: CodecErrorWorld) {
let io_err = std::io::Error::other("other error");
codec_error_world.detected_eof = Some(EofError::CleanClose);
codec_error_world.classify_eof_error(&io_err);
assert_eq!(codec_error_world.detected_eof, Some(EofError::CleanClose));
}
#[rstest]
fn infer_eof_from_buffer_reports_mid_header(mut codec_error_world: CodecErrorWorld) {
codec_error_world.buffer.extend_from_slice(&[0x01, 0x02]);
match codec_error_world.infer_eof_from_buffer() {
EofError::MidHeader {
bytes_received,
header_size,
} => {
assert_eq!(bytes_received, 2);
assert_eq!(header_size, LENGTH_HEADER_SIZE);
}
other => panic!("expected MidHeader, got {other:?}"),
}
}
#[rstest]
fn infer_eof_from_buffer_reports_mid_frame(mut codec_error_world: CodecErrorWorld) {
let expected_len: u32 = 42;
let expected_usize = usize::try_from(expected_len).expect("expected length fits in usize");
codec_error_world.buffer.put_u32(expected_len);
codec_error_world.buffer.extend_from_slice(&[0x11, 0x22]);
match codec_error_world.infer_eof_from_buffer() {
EofError::MidFrame {
bytes_received,
expected,
} => {
assert_eq!(
bytes_received,
codec_error_world
.buffer
.len()
.saturating_sub(LENGTH_HEADER_SIZE)
);
assert_eq!(expected, expected_usize);
}
other => panic!("expected MidFrame, got {other:?}"),
}
}
}