use bytes::{Buf, Bytes, BytesMut};
use std::io;
use tokio::io::{AsyncRead, AsyncReadExt};
const MAX_MESSAGE_SIZE: usize = 32 * 1024 * 1024; const INITIAL_BUFFER_SIZE: usize = 262144;
fn get_max_message_size() -> usize {
std::env::var("DYN_TCP_MAX_MESSAGE_SIZE")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(MAX_MESSAGE_SIZE)
}
pub struct ZeroCopyTcpDecoder {
read_buffer: BytesMut,
max_message_size: usize,
}
impl ZeroCopyTcpDecoder {
pub fn new() -> Self {
Self::with_capacity(INITIAL_BUFFER_SIZE)
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
read_buffer: BytesMut::with_capacity(capacity),
max_message_size: get_max_message_size(),
}
}
pub async fn read_message<R: AsyncRead + Unpin>(
&mut self,
reader: &mut R,
) -> io::Result<TcpRequestMessageZeroCopy> {
const MIN_HEADER_SIZE: usize = 2;
while self.read_buffer.len() < MIN_HEADER_SIZE {
let n = reader.read_buf(&mut self.read_buffer).await?;
if n == 0 {
if self.read_buffer.is_empty() {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"connection closed",
));
} else {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"incomplete message header",
));
}
}
}
let path_len = u16::from_be_bytes([self.read_buffer[0], self.read_buffer[1]]) as usize;
if path_len == 0 || path_len > 1024 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("invalid endpoint path length: {}", path_len),
));
}
let initial_header_size = 2 + path_len + 2; while self.read_buffer.len() < initial_header_size {
let n = reader.read_buf(&mut self.read_buffer).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"incomplete message header",
));
}
}
let headers_len_offset = 2 + path_len;
let headers_len = u16::from_be_bytes([
self.read_buffer[headers_len_offset],
self.read_buffer[headers_len_offset + 1],
]) as usize;
let full_header_size = 2 + path_len + 2 + headers_len + 4; while self.read_buffer.len() < full_header_size {
let n = reader.read_buf(&mut self.read_buffer).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"incomplete message header",
));
}
}
let payload_len_offset = 2 + path_len + 2 + headers_len;
let payload_len = u32::from_be_bytes([
self.read_buffer[payload_len_offset],
self.read_buffer[payload_len_offset + 1],
self.read_buffer[payload_len_offset + 2],
self.read_buffer[payload_len_offset + 3],
]) as usize;
let total_len = 2 + path_len + 2 + headers_len + 4 + payload_len;
if total_len > self.max_message_size {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"message too large: {} bytes (max: {} bytes)",
total_len, self.max_message_size
),
));
}
while self.read_buffer.len() < total_len {
let n = reader.read_buf(&mut self.read_buffer).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!(
"incomplete message: expected {} bytes, got {}",
total_len,
self.read_buffer.len()
),
));
}
}
let message_bytes = self.read_buffer.split_to(total_len).freeze();
Ok(TcpRequestMessageZeroCopy::new(message_bytes))
}
pub fn buffer_capacity(&self) -> usize {
self.read_buffer.capacity()
}
pub fn buffered_len(&self) -> usize {
self.read_buffer.len()
}
}
impl Default for ZeroCopyTcpDecoder {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct TcpRequestMessageZeroCopy {
raw: Bytes,
}
impl TcpRequestMessageZeroCopy {
fn new(raw: Bytes) -> Self {
Self { raw }
}
#[inline]
fn path_len(&self) -> usize {
u16::from_be_bytes([self.raw[0], self.raw[1]]) as usize
}
pub fn endpoint_path(&self) -> Result<&str, std::str::Utf8Error> {
let path_len = self.path_len();
std::str::from_utf8(&self.raw[2..2 + path_len])
}
pub fn endpoint_path_bytes(&self) -> &[u8] {
let path_len = self.path_len();
&self.raw[2..2 + path_len]
}
#[inline]
fn headers_len(&self) -> usize {
let path_len = self.path_len();
let offset = 2 + path_len;
u16::from_be_bytes([self.raw[offset], self.raw[offset + 1]]) as usize
}
pub fn headers_bytes(&self) -> &[u8] {
let path_len = self.path_len();
let headers_len = self.headers_len();
let headers_start = 2 + path_len + 2;
&self.raw[headers_start..headers_start + headers_len]
}
pub fn headers(&self) -> std::collections::HashMap<String, String> {
let headers_bytes = self.headers_bytes();
if headers_bytes.is_empty() {
return std::collections::HashMap::new();
}
serde_json::from_slice(headers_bytes).unwrap_or_default()
}
#[inline]
fn payload_len(&self) -> usize {
let path_len = self.path_len();
let headers_len = self.headers_len();
let offset = 2 + path_len + 2 + headers_len;
u32::from_be_bytes([
self.raw[offset],
self.raw[offset + 1],
self.raw[offset + 2],
self.raw[offset + 3],
]) as usize
}
pub fn payload(&self) -> Bytes {
let path_len = self.path_len();
let headers_len = self.headers_len();
let payload_start = 2 + path_len + 2 + headers_len + 4;
self.raw.slice(payload_start..) }
pub fn total_size(&self) -> usize {
self.raw.len()
}
pub fn raw_bytes(&self) -> &Bytes {
&self.raw
}
}
impl std::fmt::Debug for TcpRequestMessageZeroCopy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TcpRequestMessageZeroCopy")
.field("total_size", &self.total_size())
.field("endpoint_path", &self.endpoint_path().ok())
.field("payload_len", &self.payload_len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::AsyncWriteExt;
#[tokio::test]
async fn test_zero_copy_decoder_basic() {
let endpoint = "test/endpoint";
let payload = b"Hello, World!";
let headers: Vec<u8> = vec![];
let mut message = Vec::new();
message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message.extend_from_slice(endpoint.as_bytes());
message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
message.extend_from_slice(&headers);
message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message.extend_from_slice(payload);
let mut reader = &message[..];
let mut decoder = ZeroCopyTcpDecoder::new();
let msg = decoder.read_message(&mut reader).await.unwrap();
assert_eq!(msg.endpoint_path().unwrap(), endpoint);
assert_eq!(msg.payload().as_ref(), payload);
assert_eq!(msg.total_size(), message.len());
assert_eq!(msg.headers().len(), 0); }
#[tokio::test]
async fn test_zero_copy_decoder_large_payload() {
let endpoint = "large/endpoint";
let payload = vec![0x42u8; 200 * 1024];
let headers: Vec<u8> = vec![];
let mut message = Vec::new();
message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message.extend_from_slice(endpoint.as_bytes());
message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
message.extend_from_slice(&headers);
message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message.extend_from_slice(&payload);
let mut reader = &message[..];
let mut decoder = ZeroCopyTcpDecoder::new();
let msg = decoder.read_message(&mut reader).await.unwrap();
assert_eq!(msg.endpoint_path().unwrap(), endpoint);
assert_eq!(msg.payload().len(), payload.len());
}
#[tokio::test]
async fn test_zero_copy_decoder_total_size_limit() {
let max_size = 1024; let mut decoder = ZeroCopyTcpDecoder::with_capacity(256);
decoder.max_message_size = max_size;
let endpoint = "test/endpoint";
let payload = vec![0x42u8; max_size]; let headers: Vec<u8> = vec![];
let mut message = Vec::new();
message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message.extend_from_slice(endpoint.as_bytes());
message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
message.extend_from_slice(&headers);
message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message.extend_from_slice(&payload);
let mut reader = &message[..];
let result = decoder.read_message(&mut reader).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("message too large"));
assert!(err.to_string().contains("1045")); assert!(err.to_string().contains("1024")); }
#[tokio::test]
async fn test_zero_copy_decoder_with_headers() {
let endpoint = "api/v1/inference";
let payload = b"Request payload data";
let mut headers_map = std::collections::HashMap::new();
headers_map.insert("traceparent".to_string(), "00-abc123-def456-01".to_string());
headers_map.insert("user-agent".to_string(), "test-client/1.0".to_string());
headers_map.insert("request-id".to_string(), "req-12345".to_string());
let headers_json = serde_json::to_vec(&headers_map).unwrap();
let mut message = Vec::new();
message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message.extend_from_slice(endpoint.as_bytes());
message.extend_from_slice(&(headers_json.len() as u16).to_be_bytes());
message.extend_from_slice(&headers_json);
message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message.extend_from_slice(payload);
let mut reader = &message[..];
let mut decoder = ZeroCopyTcpDecoder::new();
let msg = decoder.read_message(&mut reader).await.unwrap();
assert_eq!(msg.endpoint_path().unwrap(), endpoint);
assert_eq!(msg.payload().as_ref(), payload);
assert_eq!(msg.total_size(), message.len());
let decoded_headers = msg.headers();
assert_eq!(decoded_headers.len(), 3);
assert_eq!(
decoded_headers.get("traceparent").unwrap(),
"00-abc123-def456-01"
);
assert_eq!(
decoded_headers.get("user-agent").unwrap(),
"test-client/1.0"
);
assert_eq!(decoded_headers.get("request-id").unwrap(), "req-12345");
let headers_bytes = msg.headers_bytes();
assert_eq!(headers_bytes, &headers_json[..]);
}
#[tokio::test]
async fn test_zero_copy_decoder_empty_vs_populated_headers() {
let endpoint = "test/endpoint";
let payload = b"test data";
let mut message_empty = Vec::new();
message_empty.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message_empty.extend_from_slice(endpoint.as_bytes());
message_empty.extend_from_slice(&(0u16).to_be_bytes()); message_empty.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message_empty.extend_from_slice(payload);
let mut reader = &message_empty[..];
let mut decoder = ZeroCopyTcpDecoder::new();
let msg = decoder.read_message(&mut reader).await.unwrap();
assert_eq!(msg.endpoint_path().unwrap(), endpoint);
assert_eq!(msg.payload().as_ref(), payload);
assert_eq!(msg.headers().len(), 0);
assert_eq!(msg.headers_bytes().len(), 0);
let mut headers_map = std::collections::HashMap::new();
headers_map.insert("x-test-header".to_string(), "test-value".to_string());
let headers_json = serde_json::to_vec(&headers_map).unwrap();
let mut message_with_headers = Vec::new();
message_with_headers.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message_with_headers.extend_from_slice(endpoint.as_bytes());
message_with_headers.extend_from_slice(&(headers_json.len() as u16).to_be_bytes());
message_with_headers.extend_from_slice(&headers_json);
message_with_headers.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message_with_headers.extend_from_slice(payload);
let mut reader = &message_with_headers[..];
let msg = decoder.read_message(&mut reader).await.unwrap();
assert_eq!(msg.endpoint_path().unwrap(), endpoint);
assert_eq!(msg.payload().as_ref(), payload);
assert_eq!(msg.headers().len(), 1);
assert_eq!(msg.headers().get("x-test-header").unwrap(), "test-value");
}
}