dynamo_runtime/pipeline/network/codec/
zero_copy_decoder.rs1use super::{
13 check_tcp_request_max_message_size, parse_tcp_request_frame_header, tcp_request_endpoint_len,
14 tcp_request_header_size, tcp_request_headers_len,
15};
16use crate::pipeline::network::get_tcp_max_message_size;
17use bytes::{Bytes, BytesMut};
18use std::io;
19use std::sync::OnceLock;
20use tokio::io::{AsyncRead, AsyncReadExt};
21
22const INITIAL_BUFFER_SIZE: usize = 262144; const DEFAULT_SHRINK_SIZE: usize = 8 * 1024 * 1024; static SHRINK_MESSAGE_SIZE: OnceLock<usize> = OnceLock::new();
26
27fn get_shrink_message_size() -> usize {
29 *SHRINK_MESSAGE_SIZE.get_or_init(|| {
30 let max_size = get_tcp_max_message_size();
31 let env_result = std::env::var("DYN_TCP_SHRINK_MESSAGE_SIZE");
33 let env_shrink_size = env_result.as_ref().ok().and_then(|s| {
34 s.parse::<usize>().ok().or_else(|| {
35 tracing::warn!(
36 env_var = "DYN_TCP_SHRINK_MESSAGE_SIZE",
37 value = %s,
38 "Invalid value for DYN_TCP_SHRINK_MESSAGE_SIZE, using default"
39 );
40 None
41 })
42 });
43
44 let resolved = resolve_shrink_message_size(max_size, env_shrink_size);
45
46 if let Some(configured) = env_shrink_size
48 && configured != resolved
49 {
50 tracing::warn!(
51 configured_size = configured,
52 resolved_size = resolved,
53 max_size = max_size,
54 initial_buffer_size = INITIAL_BUFFER_SIZE,
55 "DYN_TCP_SHRINK_MESSAGE_SIZE was clamped to valid range. Note the size is in bytes."
56 );
57 }
58
59 resolved
60 })
61}
62
63fn resolve_shrink_message_size(max_size: usize, env_shrink_size: Option<usize>) -> usize {
66 let configured_size = env_shrink_size.unwrap_or(DEFAULT_SHRINK_SIZE);
67
68 configured_size
70 .min(max_size) .max(INITIAL_BUFFER_SIZE) }
73
74pub struct ZeroCopyTcpDecoder {
81 read_buffer: BytesMut,
83 max_message_size: usize,
85 shrink_threshold: usize,
87}
88
89impl ZeroCopyTcpDecoder {
90 pub fn new() -> Self {
92 Self::with_capacity(INITIAL_BUFFER_SIZE)
93 }
94
95 pub fn with_capacity(capacity: usize) -> Self {
97 Self {
98 read_buffer: BytesMut::with_capacity(capacity),
99 max_message_size: get_tcp_max_message_size(),
100 shrink_threshold: get_shrink_message_size(),
101 }
102 }
103
104 pub async fn read_message<R: AsyncRead + Unpin>(
113 &mut self,
114 reader: &mut R,
115 ) -> io::Result<TcpRequestMessageZeroCopy> {
116 while self.read_buffer.len() < super::TCP_REQUEST_ENDPOINT_LEN_WIDTH {
118 let n = reader.read_buf(&mut self.read_buffer).await?;
119 if n == 0 {
120 if self.read_buffer.is_empty() {
121 return Err(io::Error::new(
122 io::ErrorKind::UnexpectedEof,
123 "connection closed",
124 ));
125 } else {
126 return Err(io::Error::new(
127 io::ErrorKind::UnexpectedEof,
128 "incomplete message header",
129 ));
130 }
131 }
132 }
133
134 let path_len = tcp_request_endpoint_len(&self.read_buffer)?;
136
137 let initial_header_size =
139 super::TCP_REQUEST_ENDPOINT_LEN_WIDTH + path_len + super::TCP_REQUEST_HEADERS_LEN_WIDTH;
140 while self.read_buffer.len() < initial_header_size {
141 let n = reader.read_buf(&mut self.read_buffer).await?;
142 if n == 0 {
143 return Err(io::Error::new(
144 io::ErrorKind::UnexpectedEof,
145 "incomplete message header",
146 ));
147 }
148 }
149
150 let headers_len = tcp_request_headers_len(&self.read_buffer, path_len)?;
152
153 let full_header_size = tcp_request_header_size(path_len, headers_len);
155 while self.read_buffer.len() < full_header_size {
156 let n = reader.read_buf(&mut self.read_buffer).await?;
157 if n == 0 {
158 return Err(io::Error::new(
159 io::ErrorKind::UnexpectedEof,
160 "incomplete message header",
161 ));
162 }
163 }
164
165 let parsed = parse_tcp_request_frame_header(&self.read_buffer)?;
166
167 check_tcp_request_max_message_size(parsed.total_len, self.max_message_size)?;
169
170 while self.read_buffer.len() < parsed.total_len {
172 let n = reader.read_buf(&mut self.read_buffer).await?;
173 if n == 0 {
174 return Err(io::Error::new(
175 io::ErrorKind::UnexpectedEof,
176 format!(
177 "incomplete message: expected {} bytes, got {}",
178 parsed.total_len,
179 self.read_buffer.len()
180 ),
181 ));
182 }
183 }
184
185 let message_bytes = self.read_buffer.split_to(parsed.total_len).freeze();
188
189 if self.read_buffer.is_empty() && self.read_buffer.capacity() > self.shrink_threshold {
191 self.read_buffer = BytesMut::with_capacity(INITIAL_BUFFER_SIZE);
192 }
193
194 Ok(TcpRequestMessageZeroCopy::new(message_bytes, parsed))
196 }
197
198 pub fn buffer_capacity(&self) -> usize {
200 self.read_buffer.capacity()
201 }
202
203 pub fn buffered_len(&self) -> usize {
205 self.read_buffer.len()
206 }
207}
208
209impl Default for ZeroCopyTcpDecoder {
210 fn default() -> Self {
211 Self::new()
212 }
213}
214
215#[derive(Clone)]
220pub struct TcpRequestMessageZeroCopy {
221 raw: Bytes,
224 parsed: super::TcpRequestWireHeader,
225}
226
227impl TcpRequestMessageZeroCopy {
228 fn new(raw: Bytes, parsed: super::TcpRequestWireHeader) -> Self {
230 Self { raw, parsed }
231 }
232
233 pub fn endpoint_path(&self) -> Result<&str, std::str::Utf8Error> {
237 std::str::from_utf8(&self.raw[self.parsed.endpoint_start()..self.parsed.endpoint_end()])
238 }
239
240 pub fn endpoint_path_bytes(&self) -> &[u8] {
242 &self.raw[self.parsed.endpoint_start()..self.parsed.endpoint_end()]
243 }
244
245 pub fn headers_bytes(&self) -> &[u8] {
247 &self.raw[self.parsed.headers_start()..self.parsed.headers_end()]
248 }
249
250 pub fn headers(&self) -> std::collections::HashMap<String, String> {
252 let headers_bytes = self.headers_bytes();
253 if headers_bytes.is_empty() {
254 return std::collections::HashMap::new();
255 }
256
257 serde_json::from_slice(headers_bytes).unwrap_or_default()
259 }
260
261 #[inline]
263 fn payload_len(&self) -> usize {
264 self.parsed.payload_len
265 }
266
267 pub fn payload(&self) -> Bytes {
272 self.raw.slice(self.parsed.payload_start()..) }
274
275 pub fn total_size(&self) -> usize {
277 self.raw.len()
278 }
279
280 pub fn raw_bytes(&self) -> &Bytes {
282 &self.raw
283 }
284}
285
286impl std::fmt::Debug for TcpRequestMessageZeroCopy {
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 f.debug_struct("TcpRequestMessageZeroCopy")
289 .field("total_size", &self.total_size())
290 .field("endpoint_path", &self.endpoint_path().ok())
291 .field("payload_len", &self.payload_len())
292 .finish()
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use tokio::io::AsyncWriteExt;
300
301 #[test]
302 fn test_resolve_shrink_message_size_edge_cases() {
303 let max_size_10mb = 10 * 1024 * 1024;
306 let result = resolve_shrink_message_size(max_size_10mb, None);
307 assert_eq!(
308 result, DEFAULT_SHRINK_SIZE,
309 "10MB max should return default 8MB"
310 );
311
312 let max_size_1mb = 1024 * 1024;
315 let result = resolve_shrink_message_size(max_size_1mb, None);
316 assert_eq!(result, max_size_1mb, "1MB max should be capped to 1MB");
317
318 let result = resolve_shrink_message_size(DEFAULT_SHRINK_SIZE, None);
321 assert_eq!(
322 result, DEFAULT_SHRINK_SIZE,
323 "exact match should return default"
324 );
325
326 let env_size = 2 * 1024 * 1024; let result = resolve_shrink_message_size(max_size_10mb, Some(env_size));
329 assert_eq!(
330 result, env_size,
331 "env var should be used when within bounds"
332 );
333
334 let env_size_large = 20 * 1024 * 1024; let result = resolve_shrink_message_size(max_size_10mb, Some(env_size_large));
337 assert_eq!(
338 result, max_size_10mb,
339 "env var should be capped to max_size"
340 );
341
342 let env_size_small = 100 * 1024; let result = resolve_shrink_message_size(max_size_10mb, Some(env_size_small));
345 assert_eq!(
346 result, INITIAL_BUFFER_SIZE,
347 "env var should be clamped to INITIAL_BUFFER_SIZE"
348 );
349
350 let max_size_small = 100 * 1024; let result = resolve_shrink_message_size(max_size_small, None);
353 assert_eq!(
354 result, INITIAL_BUFFER_SIZE,
355 "result should be clamped to INITIAL_BUFFER_SIZE"
356 );
357 }
358
359 #[tokio::test]
360 async fn test_zero_copy_decoder_basic() {
361 let endpoint = "test/endpoint";
363 let payload = b"Hello, World!";
364 let headers: Vec<u8> = vec![]; let mut message = Vec::new();
367 message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
369 message.extend_from_slice(endpoint.as_bytes());
370 message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
372 message.extend_from_slice(&headers);
373 message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
375 message.extend_from_slice(payload);
376
377 let mut reader = &message[..];
379
380 let mut decoder = ZeroCopyTcpDecoder::new();
382 let msg = decoder.read_message(&mut reader).await.unwrap();
383
384 assert_eq!(msg.endpoint_path().unwrap(), endpoint);
386 assert_eq!(msg.payload().as_ref(), payload);
387 assert_eq!(msg.total_size(), message.len());
388 assert_eq!(msg.headers().len(), 0); }
390
391 #[tokio::test]
392 async fn test_zero_copy_decoder_allows_empty_and_long_endpoint_paths() {
393 for endpoint in [String::new(), "x".repeat(2048)] {
394 let payload = b"payload";
395
396 let mut message = Vec::new();
397 message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
398 message.extend_from_slice(endpoint.as_bytes());
399 message.extend_from_slice(&(0u16).to_be_bytes());
400 message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
401 message.extend_from_slice(payload);
402
403 let mut reader = &message[..];
404 let mut decoder = ZeroCopyTcpDecoder::new();
405 let msg = decoder.read_message(&mut reader).await.unwrap();
406
407 assert_eq!(msg.endpoint_path().unwrap(), endpoint.as_str());
408 assert_eq!(msg.payload().as_ref(), payload);
409 }
410 }
411
412 #[tokio::test]
413 async fn test_zero_copy_decoder_large_payload() {
414 let endpoint = "large/endpoint";
416 let payload = vec![0x42u8; 200 * 1024];
417 let headers: Vec<u8> = vec![]; let mut message = Vec::new();
420 message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
422 message.extend_from_slice(endpoint.as_bytes());
423 message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
425 message.extend_from_slice(&headers);
426 message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
428 message.extend_from_slice(&payload);
429
430 let mut reader = &message[..];
431 let mut decoder = ZeroCopyTcpDecoder::new();
432 let msg = decoder.read_message(&mut reader).await.unwrap();
433
434 assert_eq!(msg.endpoint_path().unwrap(), endpoint);
435 assert_eq!(msg.payload().len(), payload.len());
436 }
437
438 #[tokio::test]
439 async fn test_zero_copy_decoder_total_size_limit() {
440 let max_size = 1024; let mut decoder = ZeroCopyTcpDecoder::with_capacity(256);
444 decoder.max_message_size = max_size;
445
446 let endpoint = "test/endpoint";
448 let payload = vec![0x42u8; max_size]; let headers: Vec<u8> = vec![]; let mut message = Vec::new();
452 message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
454 message.extend_from_slice(endpoint.as_bytes());
455 message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
457 message.extend_from_slice(&headers);
458 message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
460 message.extend_from_slice(&payload);
461
462 let mut reader = &message[..];
464 let result = decoder.read_message(&mut reader).await;
465
466 assert!(result.is_err());
468 let err = result.unwrap_err();
469 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
470 assert!(err.to_string().contains("message too large"));
471 assert!(err.to_string().contains("1045")); assert!(err.to_string().contains("1024")); }
474
475 #[tokio::test]
476 async fn test_zero_copy_decoder_with_headers() {
477 let endpoint = "api/v1/inference";
479 let payload = b"Request payload data";
480
481 let mut headers_map = std::collections::HashMap::new();
483 headers_map.insert("traceparent".to_string(), "00-abc123-def456-01".to_string());
484 headers_map.insert("user-agent".to_string(), "test-client/1.0".to_string());
485 headers_map.insert("request-id".to_string(), "req-12345".to_string());
486
487 let headers_json = serde_json::to_vec(&headers_map).unwrap();
488
489 let mut message = Vec::new();
490 message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
492 message.extend_from_slice(endpoint.as_bytes());
493 message.extend_from_slice(&(headers_json.len() as u16).to_be_bytes());
495 message.extend_from_slice(&headers_json);
496 message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
498 message.extend_from_slice(payload);
499
500 let mut reader = &message[..];
502 let mut decoder = ZeroCopyTcpDecoder::new();
503 let msg = decoder.read_message(&mut reader).await.unwrap();
504
505 assert_eq!(msg.endpoint_path().unwrap(), endpoint);
507
508 assert_eq!(msg.payload().as_ref(), payload);
510
511 assert_eq!(msg.total_size(), message.len());
513
514 let decoded_headers = msg.headers();
516 assert_eq!(decoded_headers.len(), 3);
517 assert_eq!(
518 decoded_headers.get("traceparent").unwrap(),
519 "00-abc123-def456-01"
520 );
521 assert_eq!(
522 decoded_headers.get("user-agent").unwrap(),
523 "test-client/1.0"
524 );
525 assert_eq!(decoded_headers.get("request-id").unwrap(), "req-12345");
526
527 let headers_bytes = msg.headers_bytes();
529 assert_eq!(headers_bytes, &headers_json[..]);
530 }
531
532 #[tokio::test]
533 async fn test_zero_copy_decoder_empty_vs_populated_headers() {
534 let endpoint = "test/endpoint";
536 let payload = b"test data";
537
538 let mut message_empty = Vec::new();
540 message_empty.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
541 message_empty.extend_from_slice(endpoint.as_bytes());
542 message_empty.extend_from_slice(&(0u16).to_be_bytes()); message_empty.extend_from_slice(&(payload.len() as u32).to_be_bytes());
545 message_empty.extend_from_slice(payload);
546
547 let mut reader = &message_empty[..];
548 let mut decoder = ZeroCopyTcpDecoder::new();
549 let msg = decoder.read_message(&mut reader).await.unwrap();
550
551 assert_eq!(msg.endpoint_path().unwrap(), endpoint);
552 assert_eq!(msg.payload().as_ref(), payload);
553 assert_eq!(msg.headers().len(), 0);
554 assert_eq!(msg.headers_bytes().len(), 0);
555
556 let mut headers_map = std::collections::HashMap::new();
558 headers_map.insert("x-test-header".to_string(), "test-value".to_string());
559 let headers_json = serde_json::to_vec(&headers_map).unwrap();
560
561 let mut message_with_headers = Vec::new();
562 message_with_headers.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
563 message_with_headers.extend_from_slice(endpoint.as_bytes());
564 message_with_headers.extend_from_slice(&(headers_json.len() as u16).to_be_bytes());
565 message_with_headers.extend_from_slice(&headers_json);
566 message_with_headers.extend_from_slice(&(payload.len() as u32).to_be_bytes());
567 message_with_headers.extend_from_slice(payload);
568
569 let mut reader = &message_with_headers[..];
570 let msg = decoder.read_message(&mut reader).await.unwrap();
571
572 assert_eq!(msg.endpoint_path().unwrap(), endpoint);
573 assert_eq!(msg.payload().as_ref(), payload);
574 assert_eq!(msg.headers().len(), 1);
575 assert_eq!(msg.headers().get("x-test-header").unwrap(), "test-value");
576 }
577
578 #[tokio::test]
579 async fn test_zero_copy_decoder_buffer_shrinking() {
580 let endpoint = "test/endpoint";
583 let small_payload = b"small";
584 let large_payload = vec![0x42u8; 1024 * 1024]; fn make_message(endpoint: &str, payload: &[u8]) -> Vec<u8> {
588 let mut message = Vec::new();
589 message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
590 message.extend_from_slice(endpoint.as_bytes());
591 message.extend_from_slice(&(0u16).to_be_bytes()); message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
593 message.extend_from_slice(payload);
594 message
595 }
596
597 let mut decoder = ZeroCopyTcpDecoder::with_capacity(INITIAL_BUFFER_SIZE);
599 decoder.max_message_size = 2 * 1024 * 1024; decoder.shrink_threshold = 512 * 1024; assert!(decoder.buffer_capacity() <= INITIAL_BUFFER_SIZE);
603
604 let large_message = make_message(endpoint, &large_payload);
606 let mut reader = &large_message[..];
607 decoder.read_message(&mut reader).await.unwrap();
608
609 assert!(
613 decoder.buffer_capacity() <= INITIAL_BUFFER_SIZE,
614 "buffer should shrink after large message, got capacity {}",
615 decoder.buffer_capacity()
616 );
617 assert!(
618 decoder.buffered_len() == 0,
619 "buffer should be empty after read"
620 );
621
622 let small_message = make_message(endpoint, small_payload);
624 let mut reader = &small_message[..];
625 let msg = decoder.read_message(&mut reader).await.unwrap();
626 assert_eq!(msg.payload().as_ref(), small_payload);
627 }
628}