use super::{
check_tcp_request_max_message_size, parse_tcp_request_frame_header, tcp_request_endpoint_len,
tcp_request_header_size, tcp_request_headers_len,
};
use crate::pipeline::network::get_tcp_max_message_size;
use bytes::{Bytes, BytesMut};
use std::io;
use std::sync::OnceLock;
use tokio::io::{AsyncRead, AsyncReadExt};
const INITIAL_BUFFER_SIZE: usize = 262144; const DEFAULT_SHRINK_SIZE: usize = 8 * 1024 * 1024;
static SHRINK_MESSAGE_SIZE: OnceLock<usize> = OnceLock::new();
fn get_shrink_message_size() -> usize {
*SHRINK_MESSAGE_SIZE.get_or_init(|| {
let max_size = get_tcp_max_message_size();
let env_result = std::env::var("DYN_TCP_SHRINK_MESSAGE_SIZE");
let env_shrink_size = env_result.as_ref().ok().and_then(|s| {
s.parse::<usize>().ok().or_else(|| {
tracing::warn!(
env_var = "DYN_TCP_SHRINK_MESSAGE_SIZE",
value = %s,
"Invalid value for DYN_TCP_SHRINK_MESSAGE_SIZE, using default"
);
None
})
});
let resolved = resolve_shrink_message_size(max_size, env_shrink_size);
if let Some(configured) = env_shrink_size
&& configured != resolved
{
tracing::warn!(
configured_size = configured,
resolved_size = resolved,
max_size = max_size,
initial_buffer_size = INITIAL_BUFFER_SIZE,
"DYN_TCP_SHRINK_MESSAGE_SIZE was clamped to valid range. Note the size is in bytes."
);
}
resolved
})
}
fn resolve_shrink_message_size(max_size: usize, env_shrink_size: Option<usize>) -> usize {
let configured_size = env_shrink_size.unwrap_or(DEFAULT_SHRINK_SIZE);
configured_size
.min(max_size) .max(INITIAL_BUFFER_SIZE) }
pub struct ZeroCopyTcpDecoder {
read_buffer: BytesMut,
max_message_size: usize,
shrink_threshold: usize,
}
impl ZeroCopyTcpDecoder {
pub fn new() -> Self {
Self::with_capacity(INITIAL_BUFFER_SIZE)
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
read_buffer: BytesMut::with_capacity(capacity),
max_message_size: get_tcp_max_message_size(),
shrink_threshold: get_shrink_message_size(),
}
}
pub async fn read_message<R: AsyncRead + Unpin>(
&mut self,
reader: &mut R,
) -> io::Result<TcpRequestMessageZeroCopy> {
while self.read_buffer.len() < super::TCP_REQUEST_ENDPOINT_LEN_WIDTH {
let n = reader.read_buf(&mut self.read_buffer).await?;
if n == 0 {
if self.read_buffer.is_empty() {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"connection closed",
));
} else {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"incomplete message header",
));
}
}
}
let path_len = tcp_request_endpoint_len(&self.read_buffer)?;
let initial_header_size =
super::TCP_REQUEST_ENDPOINT_LEN_WIDTH + path_len + super::TCP_REQUEST_HEADERS_LEN_WIDTH;
while self.read_buffer.len() < initial_header_size {
let n = reader.read_buf(&mut self.read_buffer).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"incomplete message header",
));
}
}
let headers_len = tcp_request_headers_len(&self.read_buffer, path_len)?;
let full_header_size = tcp_request_header_size(path_len, headers_len);
while self.read_buffer.len() < full_header_size {
let n = reader.read_buf(&mut self.read_buffer).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"incomplete message header",
));
}
}
let parsed = parse_tcp_request_frame_header(&self.read_buffer)?;
check_tcp_request_max_message_size(parsed.total_len, self.max_message_size)?;
while self.read_buffer.len() < parsed.total_len {
let n = reader.read_buf(&mut self.read_buffer).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!(
"incomplete message: expected {} bytes, got {}",
parsed.total_len,
self.read_buffer.len()
),
));
}
}
let message_bytes = self.read_buffer.split_to(parsed.total_len).freeze();
if self.read_buffer.is_empty() && self.read_buffer.capacity() > self.shrink_threshold {
self.read_buffer = BytesMut::with_capacity(INITIAL_BUFFER_SIZE);
}
Ok(TcpRequestMessageZeroCopy::new(message_bytes, parsed))
}
pub fn buffer_capacity(&self) -> usize {
self.read_buffer.capacity()
}
pub fn buffered_len(&self) -> usize {
self.read_buffer.len()
}
}
impl Default for ZeroCopyTcpDecoder {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct TcpRequestMessageZeroCopy {
raw: Bytes,
parsed: super::TcpRequestWireHeader,
}
impl TcpRequestMessageZeroCopy {
fn new(raw: Bytes, parsed: super::TcpRequestWireHeader) -> Self {
Self { raw, parsed }
}
pub fn endpoint_path(&self) -> Result<&str, std::str::Utf8Error> {
std::str::from_utf8(&self.raw[self.parsed.endpoint_start()..self.parsed.endpoint_end()])
}
pub fn endpoint_path_bytes(&self) -> &[u8] {
&self.raw[self.parsed.endpoint_start()..self.parsed.endpoint_end()]
}
pub fn headers_bytes(&self) -> &[u8] {
&self.raw[self.parsed.headers_start()..self.parsed.headers_end()]
}
pub fn headers(&self) -> std::collections::HashMap<String, String> {
let headers_bytes = self.headers_bytes();
if headers_bytes.is_empty() {
return std::collections::HashMap::new();
}
serde_json::from_slice(headers_bytes).unwrap_or_default()
}
#[inline]
fn payload_len(&self) -> usize {
self.parsed.payload_len
}
pub fn payload(&self) -> Bytes {
self.raw.slice(self.parsed.payload_start()..) }
pub fn total_size(&self) -> usize {
self.raw.len()
}
pub fn raw_bytes(&self) -> &Bytes {
&self.raw
}
}
impl std::fmt::Debug for TcpRequestMessageZeroCopy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TcpRequestMessageZeroCopy")
.field("total_size", &self.total_size())
.field("endpoint_path", &self.endpoint_path().ok())
.field("payload_len", &self.payload_len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::AsyncWriteExt;
#[test]
fn test_resolve_shrink_message_size_edge_cases() {
let max_size_10mb = 10 * 1024 * 1024;
let result = resolve_shrink_message_size(max_size_10mb, None);
assert_eq!(
result, DEFAULT_SHRINK_SIZE,
"10MB max should return default 8MB"
);
let max_size_1mb = 1024 * 1024;
let result = resolve_shrink_message_size(max_size_1mb, None);
assert_eq!(result, max_size_1mb, "1MB max should be capped to 1MB");
let result = resolve_shrink_message_size(DEFAULT_SHRINK_SIZE, None);
assert_eq!(
result, DEFAULT_SHRINK_SIZE,
"exact match should return default"
);
let env_size = 2 * 1024 * 1024; let result = resolve_shrink_message_size(max_size_10mb, Some(env_size));
assert_eq!(
result, env_size,
"env var should be used when within bounds"
);
let env_size_large = 20 * 1024 * 1024; let result = resolve_shrink_message_size(max_size_10mb, Some(env_size_large));
assert_eq!(
result, max_size_10mb,
"env var should be capped to max_size"
);
let env_size_small = 100 * 1024; let result = resolve_shrink_message_size(max_size_10mb, Some(env_size_small));
assert_eq!(
result, INITIAL_BUFFER_SIZE,
"env var should be clamped to INITIAL_BUFFER_SIZE"
);
let max_size_small = 100 * 1024; let result = resolve_shrink_message_size(max_size_small, None);
assert_eq!(
result, INITIAL_BUFFER_SIZE,
"result should be clamped to INITIAL_BUFFER_SIZE"
);
}
#[tokio::test]
async fn test_zero_copy_decoder_basic() {
let endpoint = "test/endpoint";
let payload = b"Hello, World!";
let headers: Vec<u8> = vec![];
let mut message = Vec::new();
message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message.extend_from_slice(endpoint.as_bytes());
message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
message.extend_from_slice(&headers);
message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message.extend_from_slice(payload);
let mut reader = &message[..];
let mut decoder = ZeroCopyTcpDecoder::new();
let msg = decoder.read_message(&mut reader).await.unwrap();
assert_eq!(msg.endpoint_path().unwrap(), endpoint);
assert_eq!(msg.payload().as_ref(), payload);
assert_eq!(msg.total_size(), message.len());
assert_eq!(msg.headers().len(), 0); }
#[tokio::test]
async fn test_zero_copy_decoder_allows_empty_and_long_endpoint_paths() {
for endpoint in [String::new(), "x".repeat(2048)] {
let payload = b"payload";
let mut message = Vec::new();
message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message.extend_from_slice(endpoint.as_bytes());
message.extend_from_slice(&(0u16).to_be_bytes());
message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message.extend_from_slice(payload);
let mut reader = &message[..];
let mut decoder = ZeroCopyTcpDecoder::new();
let msg = decoder.read_message(&mut reader).await.unwrap();
assert_eq!(msg.endpoint_path().unwrap(), endpoint.as_str());
assert_eq!(msg.payload().as_ref(), payload);
}
}
#[tokio::test]
async fn test_zero_copy_decoder_large_payload() {
let endpoint = "large/endpoint";
let payload = vec![0x42u8; 200 * 1024];
let headers: Vec<u8> = vec![];
let mut message = Vec::new();
message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message.extend_from_slice(endpoint.as_bytes());
message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
message.extend_from_slice(&headers);
message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message.extend_from_slice(&payload);
let mut reader = &message[..];
let mut decoder = ZeroCopyTcpDecoder::new();
let msg = decoder.read_message(&mut reader).await.unwrap();
assert_eq!(msg.endpoint_path().unwrap(), endpoint);
assert_eq!(msg.payload().len(), payload.len());
}
#[tokio::test]
async fn test_zero_copy_decoder_total_size_limit() {
let max_size = 1024; let mut decoder = ZeroCopyTcpDecoder::with_capacity(256);
decoder.max_message_size = max_size;
let endpoint = "test/endpoint";
let payload = vec![0x42u8; max_size]; let headers: Vec<u8> = vec![];
let mut message = Vec::new();
message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message.extend_from_slice(endpoint.as_bytes());
message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
message.extend_from_slice(&headers);
message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message.extend_from_slice(&payload);
let mut reader = &message[..];
let result = decoder.read_message(&mut reader).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("message too large"));
assert!(err.to_string().contains("1045")); assert!(err.to_string().contains("1024")); }
#[tokio::test]
async fn test_zero_copy_decoder_with_headers() {
let endpoint = "api/v1/inference";
let payload = b"Request payload data";
let mut headers_map = std::collections::HashMap::new();
headers_map.insert("traceparent".to_string(), "00-abc123-def456-01".to_string());
headers_map.insert("user-agent".to_string(), "test-client/1.0".to_string());
headers_map.insert("request-id".to_string(), "req-12345".to_string());
let headers_json = serde_json::to_vec(&headers_map).unwrap();
let mut message = Vec::new();
message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message.extend_from_slice(endpoint.as_bytes());
message.extend_from_slice(&(headers_json.len() as u16).to_be_bytes());
message.extend_from_slice(&headers_json);
message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message.extend_from_slice(payload);
let mut reader = &message[..];
let mut decoder = ZeroCopyTcpDecoder::new();
let msg = decoder.read_message(&mut reader).await.unwrap();
assert_eq!(msg.endpoint_path().unwrap(), endpoint);
assert_eq!(msg.payload().as_ref(), payload);
assert_eq!(msg.total_size(), message.len());
let decoded_headers = msg.headers();
assert_eq!(decoded_headers.len(), 3);
assert_eq!(
decoded_headers.get("traceparent").unwrap(),
"00-abc123-def456-01"
);
assert_eq!(
decoded_headers.get("user-agent").unwrap(),
"test-client/1.0"
);
assert_eq!(decoded_headers.get("request-id").unwrap(), "req-12345");
let headers_bytes = msg.headers_bytes();
assert_eq!(headers_bytes, &headers_json[..]);
}
#[tokio::test]
async fn test_zero_copy_decoder_empty_vs_populated_headers() {
let endpoint = "test/endpoint";
let payload = b"test data";
let mut message_empty = Vec::new();
message_empty.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message_empty.extend_from_slice(endpoint.as_bytes());
message_empty.extend_from_slice(&(0u16).to_be_bytes()); message_empty.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message_empty.extend_from_slice(payload);
let mut reader = &message_empty[..];
let mut decoder = ZeroCopyTcpDecoder::new();
let msg = decoder.read_message(&mut reader).await.unwrap();
assert_eq!(msg.endpoint_path().unwrap(), endpoint);
assert_eq!(msg.payload().as_ref(), payload);
assert_eq!(msg.headers().len(), 0);
assert_eq!(msg.headers_bytes().len(), 0);
let mut headers_map = std::collections::HashMap::new();
headers_map.insert("x-test-header".to_string(), "test-value".to_string());
let headers_json = serde_json::to_vec(&headers_map).unwrap();
let mut message_with_headers = Vec::new();
message_with_headers.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message_with_headers.extend_from_slice(endpoint.as_bytes());
message_with_headers.extend_from_slice(&(headers_json.len() as u16).to_be_bytes());
message_with_headers.extend_from_slice(&headers_json);
message_with_headers.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message_with_headers.extend_from_slice(payload);
let mut reader = &message_with_headers[..];
let msg = decoder.read_message(&mut reader).await.unwrap();
assert_eq!(msg.endpoint_path().unwrap(), endpoint);
assert_eq!(msg.payload().as_ref(), payload);
assert_eq!(msg.headers().len(), 1);
assert_eq!(msg.headers().get("x-test-header").unwrap(), "test-value");
}
#[tokio::test]
async fn test_zero_copy_decoder_buffer_shrinking() {
let endpoint = "test/endpoint";
let small_payload = b"small";
let large_payload = vec![0x42u8; 1024 * 1024];
fn make_message(endpoint: &str, payload: &[u8]) -> Vec<u8> {
let mut message = Vec::new();
message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message.extend_from_slice(endpoint.as_bytes());
message.extend_from_slice(&(0u16).to_be_bytes()); message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message.extend_from_slice(payload);
message
}
let mut decoder = ZeroCopyTcpDecoder::with_capacity(INITIAL_BUFFER_SIZE);
decoder.max_message_size = 2 * 1024 * 1024; decoder.shrink_threshold = 512 * 1024;
assert!(decoder.buffer_capacity() <= INITIAL_BUFFER_SIZE);
let large_message = make_message(endpoint, &large_payload);
let mut reader = &large_message[..];
decoder.read_message(&mut reader).await.unwrap();
assert!(
decoder.buffer_capacity() <= INITIAL_BUFFER_SIZE,
"buffer should shrink after large message, got capacity {}",
decoder.buffer_capacity()
);
assert!(
decoder.buffered_len() == 0,
"buffer should be empty after read"
);
let small_message = make_message(endpoint, small_payload);
let mut reader = &small_message[..];
let msg = decoder.read_message(&mut reader).await.unwrap();
assert_eq!(msg.payload().as_ref(), small_payload);
}
}