use bytes::Bytes;
use tokio_util::{
bytes::{Buf, BufMut, BytesMut},
codec::{Decoder, Encoder},
};
mod two_part;
pub use two_part::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TcpRequestMessage {
pub endpoint_path: String,
pub payload: Bytes,
}
impl TcpRequestMessage {
pub fn new(endpoint_path: String, payload: Bytes) -> Self {
Self {
endpoint_path,
payload,
}
}
pub fn encode(&self) -> Result<Bytes, std::io::Error> {
let endpoint_bytes = self.endpoint_path.as_bytes();
let endpoint_len = endpoint_bytes.len();
if endpoint_len > u16::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Endpoint path too long: {} bytes", endpoint_len),
));
}
if self.payload.len() > u32::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Payload too large: {} bytes", self.payload.len()),
));
}
let mut buf = BytesMut::with_capacity(2 + endpoint_len + 4 + self.payload.len());
buf.put_u16(endpoint_len as u16);
buf.put_slice(endpoint_bytes);
buf.put_u32(self.payload.len() as u32);
buf.put_slice(&self.payload);
Ok(buf.freeze())
}
pub fn decode(bytes: &Bytes) -> Result<Self, std::io::Error> {
if bytes.len() < 2 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Not enough bytes for endpoint path length",
));
}
let endpoint_len = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
let mut offset = 2;
if bytes.len() < offset + endpoint_len {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Not enough bytes for endpoint path",
));
}
let endpoint_path = String::from_utf8(bytes[offset..offset + endpoint_len].to_vec())
.map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Invalid UTF-8: {}", e),
)
})?;
offset += endpoint_len;
if bytes.len() < offset + 4 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Not enough bytes for payload length",
));
}
let payload_len = u32::from_be_bytes([
bytes[offset],
bytes[offset + 1],
bytes[offset + 2],
bytes[offset + 3],
]) as usize;
offset += 4;
if bytes.len() < offset + payload_len {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!(
"Not enough bytes for payload: expected {}, got {}",
payload_len,
bytes.len() - offset
),
));
}
let payload = bytes.slice(offset..offset + payload_len);
Ok(Self {
endpoint_path,
payload,
})
}
}
#[derive(Clone, Default)]
pub struct TcpRequestCodec {
max_message_size: Option<usize>,
}
impl TcpRequestCodec {
pub fn new(max_message_size: Option<usize>) -> Self {
Self { max_message_size }
}
}
impl Decoder for TcpRequestCodec {
type Item = TcpRequestMessage;
type Error = std::io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < 2 {
return Ok(None);
}
let endpoint_len = u16::from_be_bytes([src[0], src[1]]) as usize;
let header_size = 2 + endpoint_len + 4;
if src.len() < header_size {
return Ok(None);
}
let payload_len_offset = 2 + endpoint_len;
let payload_len = u32::from_be_bytes([
src[payload_len_offset],
src[payload_len_offset + 1],
src[payload_len_offset + 2],
src[payload_len_offset + 3],
]) as usize;
let total_len = header_size + payload_len;
if let Some(max_size) = self.max_message_size
&& total_len > max_size
{
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"Request too large: {} bytes (max: {} bytes)",
total_len, max_size
),
));
}
if src.len() < total_len {
return Ok(None);
}
src.advance(2);
let endpoint_bytes = src.split_to(endpoint_len);
let endpoint_path = String::from_utf8(endpoint_bytes.to_vec()).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Invalid UTF-8 in endpoint path: {}", e),
)
})?;
src.advance(4);
let payload = src.split_to(payload_len).freeze();
Ok(Some(TcpRequestMessage {
endpoint_path,
payload,
}))
}
}
impl Encoder<TcpRequestMessage> for TcpRequestCodec {
type Error = std::io::Error;
fn encode(&mut self, item: TcpRequestMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
let endpoint_bytes = item.endpoint_path.as_bytes();
let endpoint_len = endpoint_bytes.len();
if endpoint_len > u16::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Endpoint path too long: {} bytes", endpoint_len),
));
}
if item.payload.len() > u32::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Payload too large: {} bytes", item.payload.len()),
));
}
let total_len = 2 + endpoint_len + 4 + item.payload.len();
if let Some(max_size) = self.max_message_size
&& total_len > max_size
{
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!(
"Request too large: {} bytes (max: {} bytes)",
total_len, max_size
),
));
}
dst.reserve(total_len);
dst.put_u16(endpoint_len as u16);
dst.put_slice(endpoint_bytes);
dst.put_u32(item.payload.len() as u32);
dst.put_slice(&item.payload);
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TcpResponseMessage {
pub data: Bytes,
}
impl TcpResponseMessage {
pub fn new(data: Bytes) -> Self {
Self { data }
}
pub fn empty() -> Self {
Self { data: Bytes::new() }
}
pub fn encode(&self) -> Result<Bytes, std::io::Error> {
if self.data.len() > u32::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Response too large: {} bytes", self.data.len()),
));
}
let mut buf = BytesMut::with_capacity(4 + self.data.len());
buf.put_u32(self.data.len() as u32);
buf.put_slice(&self.data);
Ok(buf.freeze())
}
pub fn decode(bytes: &Bytes) -> Result<Self, std::io::Error> {
if bytes.len() < 4 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Not enough bytes for response length",
));
}
let len = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
if bytes.len() < 4 + len {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!(
"Not enough bytes for response: expected {}, got {}",
len,
bytes.len() - 4
),
));
}
let data = bytes.slice(4..4 + len);
Ok(Self { data })
}
}
#[derive(Clone, Default)]
pub struct TcpResponseCodec {
max_message_size: Option<usize>,
}
impl TcpResponseCodec {
pub fn new(max_message_size: Option<usize>) -> Self {
Self { max_message_size }
}
}
impl Decoder for TcpResponseCodec {
type Item = TcpResponseMessage;
type Error = std::io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < 4 {
return Ok(None);
}
let data_len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
let total_len = 4 + data_len;
if let Some(max_size) = self.max_message_size
&& total_len > max_size
{
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"Response too large: {} bytes (max: {} bytes)",
total_len, max_size
),
));
}
if src.len() < total_len {
return Ok(None);
}
src.advance(4);
let data = src.split_to(data_len).freeze();
Ok(Some(TcpResponseMessage { data }))
}
}
impl Encoder<TcpResponseMessage> for TcpResponseCodec {
type Error = std::io::Error;
fn encode(&mut self, item: TcpResponseMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
if item.data.len() > u32::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Response too large: {} bytes", item.data.len()),
));
}
let total_len = 4 + item.data.len();
if let Some(max_size) = self.max_message_size
&& total_len > max_size
{
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!(
"Response too large: {} bytes (max: {} bytes)",
total_len, max_size
),
));
}
dst.reserve(total_len);
dst.put_u32(item.data.len() as u32);
dst.put_slice(&item.data);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tcp_request_encode_decode() {
let msg = TcpRequestMessage::new(
"test.endpoint".to_string(),
Bytes::from(vec![1, 2, 3, 4, 5]),
);
let encoded = msg.encode().unwrap();
let decoded = TcpRequestMessage::decode(&encoded).unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_request_empty_payload() {
let msg = TcpRequestMessage::new("test".to_string(), Bytes::new());
let encoded = msg.encode().unwrap();
let decoded = TcpRequestMessage::decode(&encoded).unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_request_large_payload() {
let payload = Bytes::from(vec![42u8; 1024 * 1024]); let msg = TcpRequestMessage::new("large".to_string(), payload);
let encoded = msg.encode().unwrap();
let decoded = TcpRequestMessage::decode(&encoded).unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_request_decode_truncated() {
let msg = TcpRequestMessage::new("test".to_string(), Bytes::from(vec![1, 2, 3, 4, 5]));
let encoded = msg.encode().unwrap();
let truncated = encoded.slice(..encoded.len() - 2);
let result = TcpRequestMessage::decode(&truncated);
assert!(result.is_err());
}
#[test]
fn test_tcp_response_encode_decode() {
let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
let encoded = msg.encode().unwrap();
let decoded = TcpResponseMessage::decode(&encoded).unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_response_empty() {
let msg = TcpResponseMessage::empty();
let encoded = msg.encode().unwrap();
let decoded = TcpResponseMessage::decode(&encoded).unwrap();
assert_eq!(decoded, msg);
assert_eq!(decoded.data.len(), 0);
}
#[test]
fn test_tcp_response_decode_truncated() {
let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
let encoded = msg.encode().unwrap();
let truncated = encoded.slice(..3);
let result = TcpResponseMessage::decode(&truncated);
assert!(result.is_err());
}
#[test]
fn test_tcp_request_unicode_endpoint() {
let msg = TcpRequestMessage::new("тест.端点".to_string(), Bytes::from(vec![1, 2, 3]));
let encoded = msg.encode().unwrap();
let decoded = TcpRequestMessage::decode(&encoded).unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_request_codec() {
use tokio_util::codec::{Decoder, Encoder};
let msg = TcpRequestMessage::new(
"test.endpoint".to_string(),
Bytes::from(vec![1, 2, 3, 4, 5]),
);
let mut codec = TcpRequestCodec::new(None);
let mut buf = BytesMut::new();
codec.encode(msg.clone(), &mut buf).unwrap();
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_request_codec_partial() {
use tokio_util::codec::Decoder;
let msg = TcpRequestMessage::new(
"test.endpoint".to_string(),
Bytes::from(vec![1, 2, 3, 4, 5]),
);
let encoded = msg.encode().unwrap();
let mut codec = TcpRequestCodec::new(None);
let mut buf = BytesMut::from(&encoded[..5]);
assert!(codec.decode(&mut buf).unwrap().is_none());
buf.extend_from_slice(&encoded[5..]);
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_request_codec_max_size() {
use tokio_util::codec::Encoder;
let msg = TcpRequestMessage::new("test".to_string(), Bytes::from(vec![1, 2, 3, 4, 5]));
let mut codec = TcpRequestCodec::new(Some(10)); let mut buf = BytesMut::new();
let result = codec.encode(msg, &mut buf);
assert!(result.is_err());
}
#[test]
fn test_tcp_response_codec() {
use tokio_util::codec::{Decoder, Encoder};
let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
let mut codec = TcpResponseCodec::new(None);
let mut buf = BytesMut::new();
codec.encode(msg.clone(), &mut buf).unwrap();
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_response_codec_partial() {
use tokio_util::codec::Decoder;
let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
let encoded = msg.encode().unwrap();
let mut codec = TcpResponseCodec::new(None);
let mut buf = BytesMut::from(&encoded[..3]);
assert!(codec.decode(&mut buf).unwrap().is_none());
buf.extend_from_slice(&encoded[3..]);
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_response_codec_max_size() {
use tokio_util::codec::Encoder;
let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
let mut codec = TcpResponseCodec::new(Some(5)); let mut buf = BytesMut::new();
let result = codec.encode(msg, &mut buf);
assert!(result.is_err());
}
#[tokio::test]
async fn test_framed_codec_integration() {
use futures::{SinkExt, StreamExt};
use std::io::Cursor;
use tokio_util::codec::{FramedRead, FramedWrite};
let mut buffer = Vec::new();
{
let cursor = Cursor::new(&mut buffer);
let mut writer = FramedWrite::new(cursor, TcpRequestCodec::new(None));
let msg1 = TcpRequestMessage::new("endpoint1".to_string(), Bytes::from("data1"));
let msg2 = TcpRequestMessage::new("endpoint2".to_string(), Bytes::from("data2"));
writer.send(msg1).await.unwrap();
writer.send(msg2).await.unwrap();
}
{
let cursor = Cursor::new(&buffer[..]);
let mut reader = FramedRead::new(cursor, TcpRequestCodec::new(None));
let decoded1 = reader.next().await.unwrap().unwrap();
assert_eq!(decoded1.endpoint_path, "endpoint1");
assert_eq!(decoded1.payload, Bytes::from("data1"));
let decoded2 = reader.next().await.unwrap().unwrap();
assert_eq!(decoded2.endpoint_path, "endpoint2");
assert_eq!(decoded2.payload, Bytes::from("data2"));
}
}
#[tokio::test]
async fn test_framed_codec_partial_messages() {
use futures::StreamExt;
use std::io::Cursor;
use tokio_util::codec::FramedRead;
let msg = TcpRequestMessage::new("test".to_string(), Bytes::from("hello"));
let encoded = msg.encode().unwrap();
let chunk1 = &encoded[..5];
let chunk2 = &encoded[5..];
let mut full_buffer = Vec::new();
full_buffer.extend_from_slice(chunk1);
{
let cursor = Cursor::new(&full_buffer[..]);
let _reader = FramedRead::new(cursor, TcpRequestCodec::new(None));
}
full_buffer.extend_from_slice(chunk2);
{
let cursor = Cursor::new(&full_buffer[..]);
let mut reader = FramedRead::new(cursor, TcpRequestCodec::new(None));
let decoded = reader.next().await.unwrap().unwrap();
assert_eq!(decoded.endpoint_path, "test");
assert_eq!(decoded.payload, Bytes::from("hello"));
}
}
}