dynamo_runtime/pipeline/network/codec/
zero_copy_decoder.rs1use bytes::{Buf, Bytes, BytesMut};
13use std::io;
14use tokio::io::{AsyncRead, AsyncReadExt};
15
16const MAX_MESSAGE_SIZE: usize = 32 * 1024 * 1024; const INITIAL_BUFFER_SIZE: usize = 262144; fn get_max_message_size() -> usize {
21 std::env::var("DYN_TCP_MAX_MESSAGE_SIZE")
22 .ok()
23 .and_then(|s| s.parse::<usize>().ok())
24 .unwrap_or(MAX_MESSAGE_SIZE)
25}
26
27pub struct ZeroCopyTcpDecoder {
32 read_buffer: BytesMut,
34 max_message_size: usize,
36}
37
38impl ZeroCopyTcpDecoder {
39 pub fn new() -> Self {
41 Self::with_capacity(INITIAL_BUFFER_SIZE)
42 }
43
44 pub fn with_capacity(capacity: usize) -> Self {
46 Self {
47 read_buffer: BytesMut::with_capacity(capacity),
48 max_message_size: get_max_message_size(),
49 }
50 }
51
52 pub async fn read_message<R: AsyncRead + Unpin>(
61 &mut self,
62 reader: &mut R,
63 ) -> io::Result<TcpRequestMessageZeroCopy> {
64 const MIN_HEADER_SIZE: usize = 2;
67
68 while self.read_buffer.len() < MIN_HEADER_SIZE {
70 let n = reader.read_buf(&mut self.read_buffer).await?;
71 if n == 0 {
72 if self.read_buffer.is_empty() {
73 return Err(io::Error::new(
74 io::ErrorKind::UnexpectedEof,
75 "connection closed",
76 ));
77 } else {
78 return Err(io::Error::new(
79 io::ErrorKind::UnexpectedEof,
80 "incomplete message header",
81 ));
82 }
83 }
84 }
85
86 let path_len = u16::from_be_bytes([self.read_buffer[0], self.read_buffer[1]]) as usize;
88
89 if path_len == 0 || path_len > 1024 {
91 return Err(io::Error::new(
92 io::ErrorKind::InvalidData,
93 format!("invalid endpoint path length: {}", path_len),
94 ));
95 }
96
97 let initial_header_size = 2 + path_len + 2; while self.read_buffer.len() < initial_header_size {
100 let n = reader.read_buf(&mut self.read_buffer).await?;
101 if n == 0 {
102 return Err(io::Error::new(
103 io::ErrorKind::UnexpectedEof,
104 "incomplete message header",
105 ));
106 }
107 }
108
109 let headers_len_offset = 2 + path_len;
111 let headers_len = u16::from_be_bytes([
112 self.read_buffer[headers_len_offset],
113 self.read_buffer[headers_len_offset + 1],
114 ]) as usize;
115
116 let full_header_size = 2 + path_len + 2 + headers_len + 4; while self.read_buffer.len() < full_header_size {
119 let n = reader.read_buf(&mut self.read_buffer).await?;
120 if n == 0 {
121 return Err(io::Error::new(
122 io::ErrorKind::UnexpectedEof,
123 "incomplete message header",
124 ));
125 }
126 }
127
128 let payload_len_offset = 2 + path_len + 2 + headers_len;
130 let payload_len = u32::from_be_bytes([
131 self.read_buffer[payload_len_offset],
132 self.read_buffer[payload_len_offset + 1],
133 self.read_buffer[payload_len_offset + 2],
134 self.read_buffer[payload_len_offset + 3],
135 ]) as usize;
136
137 let total_len = 2 + path_len + 2 + headers_len + 4 + payload_len;
139
140 if total_len > self.max_message_size {
142 return Err(io::Error::new(
143 io::ErrorKind::InvalidData,
144 format!(
145 "message too large: {} bytes (max: {} bytes)",
146 total_len, self.max_message_size
147 ),
148 ));
149 }
150
151 while self.read_buffer.len() < total_len {
153 let n = reader.read_buf(&mut self.read_buffer).await?;
154 if n == 0 {
155 return Err(io::Error::new(
156 io::ErrorKind::UnexpectedEof,
157 format!(
158 "incomplete message: expected {} bytes, got {}",
159 total_len,
160 self.read_buffer.len()
161 ),
162 ));
163 }
164 }
165
166 let message_bytes = self.read_buffer.split_to(total_len).freeze();
169
170 Ok(TcpRequestMessageZeroCopy::new(message_bytes))
172 }
173
174 pub fn buffer_capacity(&self) -> usize {
176 self.read_buffer.capacity()
177 }
178
179 pub fn buffered_len(&self) -> usize {
181 self.read_buffer.len()
182 }
183}
184
185impl Default for ZeroCopyTcpDecoder {
186 fn default() -> Self {
187 Self::new()
188 }
189}
190
191#[derive(Clone)]
196pub struct TcpRequestMessageZeroCopy {
197 raw: Bytes,
200}
201
202impl TcpRequestMessageZeroCopy {
203 fn new(raw: Bytes) -> Self {
205 Self { raw }
206 }
207
208 #[inline]
210 fn path_len(&self) -> usize {
211 u16::from_be_bytes([self.raw[0], self.raw[1]]) as usize
212 }
213
214 pub fn endpoint_path(&self) -> Result<&str, std::str::Utf8Error> {
218 let path_len = self.path_len();
219 std::str::from_utf8(&self.raw[2..2 + path_len])
220 }
221
222 pub fn endpoint_path_bytes(&self) -> &[u8] {
224 let path_len = self.path_len();
225 &self.raw[2..2 + path_len]
226 }
227
228 #[inline]
230 fn headers_len(&self) -> usize {
231 let path_len = self.path_len();
232 let offset = 2 + path_len;
233 u16::from_be_bytes([self.raw[offset], self.raw[offset + 1]]) as usize
234 }
235
236 pub fn headers_bytes(&self) -> &[u8] {
238 let path_len = self.path_len();
239 let headers_len = self.headers_len();
240 let headers_start = 2 + path_len + 2;
241 &self.raw[headers_start..headers_start + headers_len]
242 }
243
244 pub fn headers(&self) -> std::collections::HashMap<String, String> {
246 let headers_bytes = self.headers_bytes();
247 if headers_bytes.is_empty() {
248 return std::collections::HashMap::new();
249 }
250
251 serde_json::from_slice(headers_bytes).unwrap_or_default()
253 }
254
255 #[inline]
257 fn payload_len(&self) -> usize {
258 let path_len = self.path_len();
259 let headers_len = self.headers_len();
260 let offset = 2 + path_len + 2 + headers_len;
261 u32::from_be_bytes([
262 self.raw[offset],
263 self.raw[offset + 1],
264 self.raw[offset + 2],
265 self.raw[offset + 3],
266 ]) as usize
267 }
268
269 pub fn payload(&self) -> Bytes {
274 let path_len = self.path_len();
275 let headers_len = self.headers_len();
276 let payload_start = 2 + path_len + 2 + headers_len + 4;
277 self.raw.slice(payload_start..) }
279
280 pub fn total_size(&self) -> usize {
282 self.raw.len()
283 }
284
285 pub fn raw_bytes(&self) -> &Bytes {
287 &self.raw
288 }
289}
290
291impl std::fmt::Debug for TcpRequestMessageZeroCopy {
292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293 f.debug_struct("TcpRequestMessageZeroCopy")
294 .field("total_size", &self.total_size())
295 .field("endpoint_path", &self.endpoint_path().ok())
296 .field("payload_len", &self.payload_len())
297 .finish()
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use tokio::io::AsyncWriteExt;
305
306 #[tokio::test]
307 async fn test_zero_copy_decoder_basic() {
308 let endpoint = "test/endpoint";
310 let payload = b"Hello, World!";
311 let headers: Vec<u8> = vec![]; let mut message = Vec::new();
314 message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
316 message.extend_from_slice(endpoint.as_bytes());
317 message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
319 message.extend_from_slice(&headers);
320 message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
322 message.extend_from_slice(payload);
323
324 let mut reader = &message[..];
326
327 let mut decoder = ZeroCopyTcpDecoder::new();
329 let msg = decoder.read_message(&mut reader).await.unwrap();
330
331 assert_eq!(msg.endpoint_path().unwrap(), endpoint);
333 assert_eq!(msg.payload().as_ref(), payload);
334 assert_eq!(msg.total_size(), message.len());
335 assert_eq!(msg.headers().len(), 0); }
337
338 #[tokio::test]
339 async fn test_zero_copy_decoder_large_payload() {
340 let endpoint = "large/endpoint";
342 let payload = vec![0x42u8; 200 * 1024];
343 let headers: Vec<u8> = vec![]; let mut message = Vec::new();
346 message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
348 message.extend_from_slice(endpoint.as_bytes());
349 message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
351 message.extend_from_slice(&headers);
352 message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
354 message.extend_from_slice(&payload);
355
356 let mut reader = &message[..];
357 let mut decoder = ZeroCopyTcpDecoder::new();
358 let msg = decoder.read_message(&mut reader).await.unwrap();
359
360 assert_eq!(msg.endpoint_path().unwrap(), endpoint);
361 assert_eq!(msg.payload().len(), payload.len());
362 }
363
364 #[tokio::test]
365 async fn test_zero_copy_decoder_total_size_limit() {
366 let max_size = 1024; let mut decoder = ZeroCopyTcpDecoder::with_capacity(256);
370 decoder.max_message_size = max_size;
371
372 let endpoint = "test/endpoint";
374 let payload = vec![0x42u8; max_size]; let headers: Vec<u8> = vec![]; let mut message = Vec::new();
378 message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
380 message.extend_from_slice(endpoint.as_bytes());
381 message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
383 message.extend_from_slice(&headers);
384 message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
386 message.extend_from_slice(&payload);
387
388 let mut reader = &message[..];
390 let result = decoder.read_message(&mut reader).await;
391
392 assert!(result.is_err());
394 let err = result.unwrap_err();
395 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
396 assert!(err.to_string().contains("message too large"));
397 assert!(err.to_string().contains("1045")); assert!(err.to_string().contains("1024")); }
400
401 #[tokio::test]
402 async fn test_zero_copy_decoder_with_headers() {
403 let endpoint = "api/v1/inference";
405 let payload = b"Request payload data";
406
407 let mut headers_map = std::collections::HashMap::new();
409 headers_map.insert("traceparent".to_string(), "00-abc123-def456-01".to_string());
410 headers_map.insert("user-agent".to_string(), "test-client/1.0".to_string());
411 headers_map.insert("request-id".to_string(), "req-12345".to_string());
412
413 let headers_json = serde_json::to_vec(&headers_map).unwrap();
414
415 let mut message = Vec::new();
416 message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
418 message.extend_from_slice(endpoint.as_bytes());
419 message.extend_from_slice(&(headers_json.len() as u16).to_be_bytes());
421 message.extend_from_slice(&headers_json);
422 message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
424 message.extend_from_slice(payload);
425
426 let mut reader = &message[..];
428 let mut decoder = ZeroCopyTcpDecoder::new();
429 let msg = decoder.read_message(&mut reader).await.unwrap();
430
431 assert_eq!(msg.endpoint_path().unwrap(), endpoint);
433
434 assert_eq!(msg.payload().as_ref(), payload);
436
437 assert_eq!(msg.total_size(), message.len());
439
440 let decoded_headers = msg.headers();
442 assert_eq!(decoded_headers.len(), 3);
443 assert_eq!(
444 decoded_headers.get("traceparent").unwrap(),
445 "00-abc123-def456-01"
446 );
447 assert_eq!(
448 decoded_headers.get("user-agent").unwrap(),
449 "test-client/1.0"
450 );
451 assert_eq!(decoded_headers.get("request-id").unwrap(), "req-12345");
452
453 let headers_bytes = msg.headers_bytes();
455 assert_eq!(headers_bytes, &headers_json[..]);
456 }
457
458 #[tokio::test]
459 async fn test_zero_copy_decoder_empty_vs_populated_headers() {
460 let endpoint = "test/endpoint";
462 let payload = b"test data";
463
464 let mut message_empty = Vec::new();
466 message_empty.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
467 message_empty.extend_from_slice(endpoint.as_bytes());
468 message_empty.extend_from_slice(&(0u16).to_be_bytes()); message_empty.extend_from_slice(&(payload.len() as u32).to_be_bytes());
471 message_empty.extend_from_slice(payload);
472
473 let mut reader = &message_empty[..];
474 let mut decoder = ZeroCopyTcpDecoder::new();
475 let msg = decoder.read_message(&mut reader).await.unwrap();
476
477 assert_eq!(msg.endpoint_path().unwrap(), endpoint);
478 assert_eq!(msg.payload().as_ref(), payload);
479 assert_eq!(msg.headers().len(), 0);
480 assert_eq!(msg.headers_bytes().len(), 0);
481
482 let mut headers_map = std::collections::HashMap::new();
484 headers_map.insert("x-test-header".to_string(), "test-value".to_string());
485 let headers_json = serde_json::to_vec(&headers_map).unwrap();
486
487 let mut message_with_headers = Vec::new();
488 message_with_headers.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
489 message_with_headers.extend_from_slice(endpoint.as_bytes());
490 message_with_headers.extend_from_slice(&(headers_json.len() as u16).to_be_bytes());
491 message_with_headers.extend_from_slice(&headers_json);
492 message_with_headers.extend_from_slice(&(payload.len() as u32).to_be_bytes());
493 message_with_headers.extend_from_slice(payload);
494
495 let mut reader = &message_with_headers[..];
496 let msg = decoder.read_message(&mut reader).await.unwrap();
497
498 assert_eq!(msg.endpoint_path().unwrap(), endpoint);
499 assert_eq!(msg.payload().as_ref(), payload);
500 assert_eq!(msg.headers().len(), 1);
501 assert_eq!(msg.headers().get("x-test-header").unwrap(), "test-value");
502 }
503}