use std::io;
use bytes::{Bytes, BytesMut};
use rstest::rstest;
use super::{
examples::{HotlineFrameCodec, MysqlFrameCodec},
*,
};
#[test]
fn length_delimited_codec_clamps_max_frame_length() {
let codec = LengthDelimitedFrameCodec::new(MAX_FRAME_LENGTH.saturating_add(1));
assert_eq!(codec.max_frame_length(), MAX_FRAME_LENGTH);
}
#[test]
fn length_delimited_codec_round_trips_payload() {
let codec = LengthDelimitedFrameCodec::new(128);
let mut encoder = codec.encoder();
let mut decoder = codec.decoder();
let payload = Bytes::from(vec![1_u8, 2, 3, 4]);
let frame = codec.wrap_payload(payload.clone());
let mut buf = BytesMut::new();
encoder
.encode(frame, &mut buf)
.expect("encode should succeed");
let decoded_frame = decoder
.decode(&mut buf)
.expect("decode should succeed")
.expect("expected a frame");
assert_eq!(
LengthDelimitedFrameCodec::frame_payload(&decoded_frame),
payload.as_ref()
);
}
#[test]
fn length_delimited_codec_rejects_oversized_payloads() {
let codec = LengthDelimitedFrameCodec::new(MIN_FRAME_LENGTH);
let mut encoder = codec.encoder();
let payload = Bytes::from(vec![0_u8; MIN_FRAME_LENGTH.saturating_add(1)]);
let frame = codec.wrap_payload(payload);
let mut buf = BytesMut::new();
let err = encoder
.encode(frame, &mut buf)
.expect_err("expected encode to fail for oversized frame");
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
#[test]
fn length_delimited_wrap_payload_reuses_bytes() {
let codec = LengthDelimitedFrameCodec::new(128);
let payload = Bytes::from(vec![9_u8; 4]);
let frame = codec.wrap_payload(payload.clone());
assert_eq!(payload.len(), frame.len());
assert_eq!(payload.as_ref().as_ptr(), frame.as_ref().as_ptr());
}
#[test]
fn decode_eof_with_empty_buffer_returns_none() {
let codec = LengthDelimitedFrameCodec::new(128);
let mut decoder = codec.decoder();
let mut buf = BytesMut::new();
let result = decoder.decode_eof(&mut buf);
assert!(
matches!(result, Ok(None)),
"clean close should return Ok(None), got {result:?}"
);
}
#[rstest]
#[case::partial_header(&[0x00, 0x10], io::ErrorKind::UnexpectedEof, "header")]
#[case::partial_payload(&[0x00, 0x00, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04], io::ErrorKind::UnexpectedEof, "16")]
fn decode_eof_error_cases(
#[case] initial_buffer: &[u8],
#[case] expected_kind: io::ErrorKind,
#[case] expected_substring: &str,
) {
let codec = LengthDelimitedFrameCodec::new(128);
let mut decoder = codec.decoder();
let mut buf = BytesMut::from(initial_buffer);
let err = decoder.decode_eof(&mut buf).expect_err("expected error");
assert_eq!(err.kind(), expected_kind, "unexpected error kind");
assert!(
err.to_string().contains(expected_substring),
"error message should contain '{expected_substring}', got: {err}"
);
}
#[test]
fn decode_eof_with_complete_frame_succeeds() {
let codec = LengthDelimitedFrameCodec::new(128);
let mut enc = codec.encoder();
let mut dec = codec.decoder();
let payload = Bytes::from(vec![1_u8, 2, 3, 4]);
let frame = codec.wrap_payload(payload.clone());
let mut buf = BytesMut::new();
enc.encode(frame, &mut buf).expect("encode should succeed");
let result = dec
.decode_eof(&mut buf)
.expect("decode should succeed")
.expect("expected a frame");
assert_eq!(result.as_ref(), payload.as_ref());
}
#[derive(Debug, Clone, Copy)]
enum TestCodec {
LengthDelimited,
Hotline,
Mysql,
}
impl TestCodec {
fn wrap_and_get_payload_ptr(self, payload: Bytes) -> *const u8 {
match self {
Self::LengthDelimited => {
let codec = LengthDelimitedFrameCodec::new(128);
let frame = codec.wrap_payload(payload);
frame.as_ptr()
}
Self::Hotline => {
let codec = HotlineFrameCodec::new(128);
let frame = codec.wrap_payload(payload);
frame.payload.as_ptr()
}
Self::Mysql => {
let codec = MysqlFrameCodec::new(128);
let frame = codec.wrap_payload(payload);
frame.payload.as_ptr()
}
}
}
fn wrap_and_extract_payload_bytes_ptrs(self, payload: Bytes) -> (*const u8, *const u8) {
match self {
Self::LengthDelimited => {
let codec = LengthDelimitedFrameCodec::new(128);
let frame = codec.wrap_payload(payload);
let extracted = LengthDelimitedFrameCodec::frame_payload_bytes(&frame);
(frame.as_ptr(), extracted.as_ptr())
}
Self::Hotline => {
let codec = HotlineFrameCodec::new(128);
let frame = codec.wrap_payload(payload);
let extracted = HotlineFrameCodec::frame_payload_bytes(&frame);
(frame.payload.as_ptr(), extracted.as_ptr())
}
Self::Mysql => {
let codec = MysqlFrameCodec::new(128);
let frame = codec.wrap_payload(payload);
let extracted = MysqlFrameCodec::frame_payload_bytes(&frame);
(frame.payload.as_ptr(), extracted.as_ptr())
}
}
}
}
#[rstest]
#[case::length_delimited(TestCodec::LengthDelimited, vec![9_u8; 4])]
#[case::hotline(TestCodec::Hotline, vec![5_u8; 8])]
#[case::mysql(TestCodec::Mysql, vec![3_u8; 10])]
fn wrap_payload_reuses_bytes(#[case] codec: TestCodec, #[case] payload_data: Vec<u8>) {
let payload = Bytes::from(payload_data);
let input_ptr = payload.as_ptr();
let frame_ptr = codec.wrap_and_get_payload_ptr(payload);
assert_eq!(
input_ptr, frame_ptr,
"wrap_payload should reuse the Bytes without copying"
);
}
#[rstest]
#[case::length_delimited(TestCodec::LengthDelimited, vec![1_u8, 2, 3, 4])]
#[case::hotline(TestCodec::Hotline, vec![7_u8; 6])]
#[case::mysql(TestCodec::Mysql, vec![9_u8; 5])]
fn frame_payload_bytes_reuses_memory(#[case] codec: TestCodec, #[case] payload_data: Vec<u8>) {
let payload = Bytes::from(payload_data);
let (frame_ptr, extracted_ptr) = codec.wrap_and_extract_payload_bytes_ptrs(payload);
assert_eq!(
frame_ptr, extracted_ptr,
"frame_payload_bytes should return the same memory region"
);
}
fn assert_decode_zero_copy(
payload_ptr_before_decode: *const u8,
decoded_payload_ptr: *const u8,
codec_name: &str,
) {
assert_eq!(
payload_ptr_before_decode, decoded_payload_ptr,
"{codec_name}: decoded payload should reuse buffer memory (zero-copy)"
);
}
#[test]
fn hotline_decode_produces_zero_copy_payload() {
use bytes::BufMut;
use tokio_util::codec::Decoder;
let payload_data: &[u8] = &[0xde, 0xad, 0xbe, 0xef];
let data_size: u32 = 4;
let total_size: u32 = 20 + 4;
let mut buf = BytesMut::with_capacity(total_size as usize);
buf.put_u32(data_size);
buf.put_u32(total_size);
buf.put_u32(42); buf.extend_from_slice(&[0_u8; 8]); buf.extend_from_slice(payload_data);
let payload_ptr = buf
.get(20..)
.expect("buffer should have at least 20 bytes")
.as_ptr();
let codec = HotlineFrameCodec::new(128);
let mut decoder = codec.decoder();
let frame = decoder
.decode(&mut buf)
.expect("decode should succeed")
.expect("expected a frame");
assert_decode_zero_copy(payload_ptr, frame.payload.as_ptr(), "hotline");
}
#[test]
fn mysql_decode_produces_zero_copy_payload() {
use bytes::BufMut;
use tokio_util::codec::Decoder;
let payload_data: &[u8] = &[0xca, 0xfe, 0xba, 0xbe];
let payload_len: u32 = 4;
let mut buf = BytesMut::with_capacity(8);
buf.put_u8((payload_len & 0xff) as u8);
buf.put_u8(((payload_len >> 8) & 0xff) as u8);
buf.put_u8(((payload_len >> 16) & 0xff) as u8);
buf.put_u8(1); buf.extend_from_slice(payload_data);
let payload_ptr = buf
.get(4..)
.expect("buffer should have at least 4 bytes")
.as_ptr();
let codec = MysqlFrameCodec::new(128);
let mut decoder = codec.decoder();
let frame = decoder
.decode(&mut buf)
.expect("decode should succeed")
.expect("expected a frame");
assert_decode_zero_copy(payload_ptr, frame.payload.as_ptr(), "mysql");
}
mod property;