Skip to main content

dynamo_runtime/pipeline/network/codec/
zero_copy_decoder.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Zero-copy TCP message decoder for high-concurrency scenarios
5//!
6//! This decoder eliminates message reconstruction copies by:
7//! 1. Reading into a reusable buffer
8//! 2. Parsing headers in-place
9//! 3. Splitting off exact message sizes (zero-copy via Bytes::split_to)
10//! 4. Returning Arc-counted Bytes that can be cloned cheaply
11
12use bytes::{Buf, Bytes, BytesMut};
13use std::io;
14use tokio::io::{AsyncRead, AsyncReadExt};
15
16/// Maximum message size (32MB default, configurable via env)
17const MAX_MESSAGE_SIZE: usize = 32 * 1024 * 1024; // 32MB
18const INITIAL_BUFFER_SIZE: usize = 262144; // 256KB
19
20fn get_max_message_size() -> usize {
21    std::env::var("DYN_TCP_MAX_MESSAGE_SIZE")
22        .ok()
23        .and_then(|s| s.parse::<usize>().ok())
24        .unwrap_or(MAX_MESSAGE_SIZE)
25}
26
27/// Zero-copy streaming decoder that reuses buffers
28///
29/// This decoder maintains an internal buffer and only allocates when necessary.
30/// Messages are returned as Arc-counted Bytes slices, making cloning extremely cheap.
31pub struct ZeroCopyTcpDecoder {
32    /// Reusable read buffer - grows as needed but never shrinks
33    read_buffer: BytesMut,
34    /// Maximum allowed message size
35    max_message_size: usize,
36}
37
38impl ZeroCopyTcpDecoder {
39    /// Create a new decoder with default buffer size
40    pub fn new() -> Self {
41        Self::with_capacity(INITIAL_BUFFER_SIZE)
42    }
43
44    /// Create a new decoder with specific initial capacity
45    pub fn with_capacity(capacity: usize) -> Self {
46        Self {
47            read_buffer: BytesMut::with_capacity(capacity),
48            max_message_size: get_max_message_size(),
49        }
50    }
51
52    /// Read one complete message with ZERO copies
53    ///
54    /// This method:
55    /// 1. Ensures headers are buffered
56    /// 2. Parses headers in-place (no allocation)
57    /// 3. Ensures entire message is buffered
58    /// 4. Splits off exact message size (zero-copy pointer arithmetic)
59    /// 5. Returns Arc-counted Bytes (cheap to clone)
60    pub async fn read_message<R: AsyncRead + Unpin>(
61        &mut self,
62        reader: &mut R,
63    ) -> io::Result<TcpRequestMessageZeroCopy> {
64        // Ensure we have at least enough bytes to start parsing
65        // Wire format: [path_len(2)][path][headers_len(2)][headers][payload_len(4)][payload]
66        const MIN_HEADER_SIZE: usize = 2;
67
68        // Fill buffer if needed
69        while self.read_buffer.len() < MIN_HEADER_SIZE {
70            let n = reader.read_buf(&mut self.read_buffer).await?;
71            if n == 0 {
72                if self.read_buffer.is_empty() {
73                    return Err(io::Error::new(
74                        io::ErrorKind::UnexpectedEof,
75                        "connection closed",
76                    ));
77                } else {
78                    return Err(io::Error::new(
79                        io::ErrorKind::UnexpectedEof,
80                        "incomplete message header",
81                    ));
82                }
83            }
84        }
85
86        // Parse endpoint path length (first 2 bytes) - NO COPY
87        let path_len = u16::from_be_bytes([self.read_buffer[0], self.read_buffer[1]]) as usize;
88
89        // Sanity check path length
90        if path_len == 0 || path_len > 1024 {
91            return Err(io::Error::new(
92                io::ErrorKind::InvalidData,
93                format!("invalid endpoint path length: {}", path_len),
94            ));
95        }
96
97        // Ensure we have path + headers_len
98        let initial_header_size = 2 + path_len + 2; // path_len(2) + path + headers_len(2)
99        while self.read_buffer.len() < initial_header_size {
100            let n = reader.read_buf(&mut self.read_buffer).await?;
101            if n == 0 {
102                return Err(io::Error::new(
103                    io::ErrorKind::UnexpectedEof,
104                    "incomplete message header",
105                ));
106            }
107        }
108
109        // Parse headers length (2 bytes after path) - NO COPY
110        let headers_len_offset = 2 + path_len;
111        let headers_len = u16::from_be_bytes([
112            self.read_buffer[headers_len_offset],
113            self.read_buffer[headers_len_offset + 1],
114        ]) as usize;
115
116        // Ensure we have headers + payload length
117        let full_header_size = 2 + path_len + 2 + headers_len + 4; // path_len(2) + path + headers_len(2) + headers + payload_len(4)
118        while self.read_buffer.len() < full_header_size {
119            let n = reader.read_buf(&mut self.read_buffer).await?;
120            if n == 0 {
121                return Err(io::Error::new(
122                    io::ErrorKind::UnexpectedEof,
123                    "incomplete message header",
124                ));
125            }
126        }
127
128        // Parse payload length (4 bytes after headers) - NO COPY
129        let payload_len_offset = 2 + path_len + 2 + headers_len;
130        let payload_len = u32::from_be_bytes([
131            self.read_buffer[payload_len_offset],
132            self.read_buffer[payload_len_offset + 1],
133            self.read_buffer[payload_len_offset + 2],
134            self.read_buffer[payload_len_offset + 3],
135        ]) as usize;
136
137        // Calculate total message size
138        let total_len = 2 + path_len + 2 + headers_len + 4 + payload_len;
139
140        // Sanity check total message length (including all overhead)
141        if total_len > self.max_message_size {
142            return Err(io::Error::new(
143                io::ErrorKind::InvalidData,
144                format!(
145                    "message too large: {} bytes (max: {} bytes)",
146                    total_len, self.max_message_size
147                ),
148            ));
149        }
150
151        // Ensure entire message is buffered
152        while self.read_buffer.len() < total_len {
153            let n = reader.read_buf(&mut self.read_buffer).await?;
154            if n == 0 {
155                return Err(io::Error::new(
156                    io::ErrorKind::UnexpectedEof,
157                    format!(
158                        "incomplete message: expected {} bytes, got {}",
159                        total_len,
160                        self.read_buffer.len()
161                    ),
162                ));
163            }
164        }
165
166        // Split off exactly what we need - ZERO COPY!
167        // split_to() just advances the internal pointer, doesn't allocate or copy
168        let message_bytes = self.read_buffer.split_to(total_len).freeze();
169
170        // Return zero-copy message wrapper
171        Ok(TcpRequestMessageZeroCopy::new(message_bytes))
172    }
173
174    /// Get the current buffer capacity
175    pub fn buffer_capacity(&self) -> usize {
176        self.read_buffer.capacity()
177    }
178
179    /// Get the current buffered data size
180    pub fn buffered_len(&self) -> usize {
181        self.read_buffer.len()
182    }
183}
184
185impl Default for ZeroCopyTcpDecoder {
186    fn default() -> Self {
187        Self::new()
188    }
189}
190
191/// Zero-copy message representation
192///
193/// This struct holds an Arc-counted Bytes buffer containing the entire message.
194/// All accessors return zero-copy slices or references into this buffer.
195#[derive(Clone)]
196pub struct TcpRequestMessageZeroCopy {
197    /// Entire message as Arc-counted buffer
198    /// Format: [path_len(2)][path(var)][headers_len(2)][headers(var)][payload_len(4)][payload(var)]
199    raw: Bytes,
200}
201
202impl TcpRequestMessageZeroCopy {
203    /// Create a new zero-copy message from raw bytes
204    fn new(raw: Bytes) -> Self {
205        Self { raw }
206    }
207
208    /// Get the endpoint path length
209    #[inline]
210    fn path_len(&self) -> usize {
211        u16::from_be_bytes([self.raw[0], self.raw[1]]) as usize
212    }
213
214    /// Get endpoint path as a string slice (zero-copy)
215    ///
216    /// This returns a reference into the message buffer, no allocation.
217    pub fn endpoint_path(&self) -> Result<&str, std::str::Utf8Error> {
218        let path_len = self.path_len();
219        std::str::from_utf8(&self.raw[2..2 + path_len])
220    }
221
222    /// Get endpoint path as bytes (zero-copy)
223    pub fn endpoint_path_bytes(&self) -> &[u8] {
224        let path_len = self.path_len();
225        &self.raw[2..2 + path_len]
226    }
227
228    /// Get the headers length
229    #[inline]
230    fn headers_len(&self) -> usize {
231        let path_len = self.path_len();
232        let offset = 2 + path_len;
233        u16::from_be_bytes([self.raw[offset], self.raw[offset + 1]]) as usize
234    }
235
236    /// Get headers as bytes (zero-copy)
237    pub fn headers_bytes(&self) -> &[u8] {
238        let path_len = self.path_len();
239        let headers_len = self.headers_len();
240        let headers_start = 2 + path_len + 2;
241        &self.raw[headers_start..headers_start + headers_len]
242    }
243
244    /// Get headers as a HashMap (requires parsing)
245    pub fn headers(&self) -> std::collections::HashMap<String, String> {
246        let headers_bytes = self.headers_bytes();
247        if headers_bytes.is_empty() {
248            return std::collections::HashMap::new();
249        }
250
251        // Parse headers from JSON format
252        serde_json::from_slice(headers_bytes).unwrap_or_default()
253    }
254
255    /// Get the payload length
256    #[inline]
257    fn payload_len(&self) -> usize {
258        let path_len = self.path_len();
259        let headers_len = self.headers_len();
260        let offset = 2 + path_len + 2 + headers_len;
261        u32::from_be_bytes([
262            self.raw[offset],
263            self.raw[offset + 1],
264            self.raw[offset + 2],
265            self.raw[offset + 3],
266        ]) as usize
267    }
268
269    /// Get payload as zero-copy Bytes
270    ///
271    /// This returns an Arc-counted slice of the message buffer.
272    /// Cloning the returned Bytes is extremely cheap (just Arc clone).
273    pub fn payload(&self) -> Bytes {
274        let path_len = self.path_len();
275        let headers_len = self.headers_len();
276        let payload_start = 2 + path_len + 2 + headers_len + 4;
277        self.raw.slice(payload_start..) // ZERO COPY! Just Arc clone + offset
278    }
279
280    /// Get total message size in bytes
281    pub fn total_size(&self) -> usize {
282        self.raw.len()
283    }
284
285    /// Get the raw message bytes (for debugging)
286    pub fn raw_bytes(&self) -> &Bytes {
287        &self.raw
288    }
289}
290
291impl std::fmt::Debug for TcpRequestMessageZeroCopy {
292    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293        f.debug_struct("TcpRequestMessageZeroCopy")
294            .field("total_size", &self.total_size())
295            .field("endpoint_path", &self.endpoint_path().ok())
296            .field("payload_len", &self.payload_len())
297            .finish()
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use tokio::io::AsyncWriteExt;
305
306    #[tokio::test]
307    async fn test_zero_copy_decoder_basic() {
308        // Create a test message with headers
309        let endpoint = "test/endpoint";
310        let payload = b"Hello, World!";
311        let headers: Vec<u8> = vec![]; // Empty headers
312
313        let mut message = Vec::new();
314        // path_len + path
315        message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
316        message.extend_from_slice(endpoint.as_bytes());
317        // headers_len + headers
318        message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
319        message.extend_from_slice(&headers);
320        // payload_len + payload
321        message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
322        message.extend_from_slice(payload);
323
324        // Create a mock reader
325        let mut reader = &message[..];
326
327        // Decode
328        let mut decoder = ZeroCopyTcpDecoder::new();
329        let msg = decoder.read_message(&mut reader).await.unwrap();
330
331        // Verify
332        assert_eq!(msg.endpoint_path().unwrap(), endpoint);
333        assert_eq!(msg.payload().as_ref(), payload);
334        assert_eq!(msg.total_size(), message.len());
335        assert_eq!(msg.headers().len(), 0); // Empty headers
336    }
337
338    #[tokio::test]
339    async fn test_zero_copy_decoder_large_payload() {
340        // Create a large payload (200KB)
341        let endpoint = "large/endpoint";
342        let payload = vec![0x42u8; 200 * 1024];
343        let headers: Vec<u8> = vec![]; // Empty headers
344
345        let mut message = Vec::new();
346        // path_len + path
347        message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
348        message.extend_from_slice(endpoint.as_bytes());
349        // headers_len + headers
350        message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
351        message.extend_from_slice(&headers);
352        // payload_len + payload
353        message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
354        message.extend_from_slice(&payload);
355
356        let mut reader = &message[..];
357        let mut decoder = ZeroCopyTcpDecoder::new();
358        let msg = decoder.read_message(&mut reader).await.unwrap();
359
360        assert_eq!(msg.endpoint_path().unwrap(), endpoint);
361        assert_eq!(msg.payload().len(), payload.len());
362    }
363
364    #[tokio::test]
365    async fn test_zero_copy_decoder_total_size_limit() {
366        // Test that the decoder validates total message size, not just payload size
367        // Create a message where total_len exceeds max but payload alone might not
368        let max_size = 1024; // 1KB limit
369        let mut decoder = ZeroCopyTcpDecoder::with_capacity(256);
370        decoder.max_message_size = max_size;
371
372        // Create a message that exceeds the limit with overhead included
373        let endpoint = "test/endpoint";
374        let payload = vec![0x42u8; max_size]; // Payload equals max
375        let headers: Vec<u8> = vec![]; // Empty headers
376
377        let mut message = Vec::new();
378        // path_len + path
379        message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
380        message.extend_from_slice(endpoint.as_bytes());
381        // headers_len + headers
382        message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
383        message.extend_from_slice(&headers);
384        // payload_len + payload
385        message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
386        message.extend_from_slice(&payload);
387
388        // total_len = 2 + 13 + 2 + 0 + 4 + 1024 = 1045 bytes > 1024 max
389        let mut reader = &message[..];
390        let result = decoder.read_message(&mut reader).await;
391
392        // Should fail with InvalidData error
393        assert!(result.is_err());
394        let err = result.unwrap_err();
395        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
396        assert!(err.to_string().contains("message too large"));
397        assert!(err.to_string().contains("1045")); // total_len
398        assert!(err.to_string().contains("1024")); // max_message_size
399    }
400
401    #[tokio::test]
402    async fn test_zero_copy_decoder_with_headers() {
403        // Test header parsing with actual header data
404        let endpoint = "api/v1/inference";
405        let payload = b"Request payload data";
406
407        // Create mock headers as JSON
408        let mut headers_map = std::collections::HashMap::new();
409        headers_map.insert("traceparent".to_string(), "00-abc123-def456-01".to_string());
410        headers_map.insert("user-agent".to_string(), "test-client/1.0".to_string());
411        headers_map.insert("request-id".to_string(), "req-12345".to_string());
412
413        let headers_json = serde_json::to_vec(&headers_map).unwrap();
414
415        let mut message = Vec::new();
416        // path_len + path
417        message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
418        message.extend_from_slice(endpoint.as_bytes());
419        // headers_len + headers (non-empty this time)
420        message.extend_from_slice(&(headers_json.len() as u16).to_be_bytes());
421        message.extend_from_slice(&headers_json);
422        // payload_len + payload
423        message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
424        message.extend_from_slice(payload);
425
426        // Decode the message
427        let mut reader = &message[..];
428        let mut decoder = ZeroCopyTcpDecoder::new();
429        let msg = decoder.read_message(&mut reader).await.unwrap();
430
431        // Verify endpoint
432        assert_eq!(msg.endpoint_path().unwrap(), endpoint);
433
434        // Verify payload
435        assert_eq!(msg.payload().as_ref(), payload);
436
437        // Verify total size includes all components
438        assert_eq!(msg.total_size(), message.len());
439
440        // Verify headers are correctly parsed
441        let decoded_headers = msg.headers();
442        assert_eq!(decoded_headers.len(), 3);
443        assert_eq!(
444            decoded_headers.get("traceparent").unwrap(),
445            "00-abc123-def456-01"
446        );
447        assert_eq!(
448            decoded_headers.get("user-agent").unwrap(),
449            "test-client/1.0"
450        );
451        assert_eq!(decoded_headers.get("request-id").unwrap(), "req-12345");
452
453        // Verify headers_bytes returns the raw JSON
454        let headers_bytes = msg.headers_bytes();
455        assert_eq!(headers_bytes, &headers_json[..]);
456    }
457
458    #[tokio::test]
459    async fn test_zero_copy_decoder_empty_vs_populated_headers() {
460        // Test both empty and populated headers in sequence to ensure proper parsing
461        let endpoint = "test/endpoint";
462        let payload = b"test data";
463
464        // Test 1: Empty headers
465        let mut message_empty = Vec::new();
466        message_empty.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
467        message_empty.extend_from_slice(endpoint.as_bytes());
468        message_empty.extend_from_slice(&(0u16).to_be_bytes()); // headers_len = 0
469        // No headers bytes
470        message_empty.extend_from_slice(&(payload.len() as u32).to_be_bytes());
471        message_empty.extend_from_slice(payload);
472
473        let mut reader = &message_empty[..];
474        let mut decoder = ZeroCopyTcpDecoder::new();
475        let msg = decoder.read_message(&mut reader).await.unwrap();
476
477        assert_eq!(msg.endpoint_path().unwrap(), endpoint);
478        assert_eq!(msg.payload().as_ref(), payload);
479        assert_eq!(msg.headers().len(), 0);
480        assert_eq!(msg.headers_bytes().len(), 0);
481
482        // Test 2: Populated headers with same decoder
483        let mut headers_map = std::collections::HashMap::new();
484        headers_map.insert("x-test-header".to_string(), "test-value".to_string());
485        let headers_json = serde_json::to_vec(&headers_map).unwrap();
486
487        let mut message_with_headers = Vec::new();
488        message_with_headers.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
489        message_with_headers.extend_from_slice(endpoint.as_bytes());
490        message_with_headers.extend_from_slice(&(headers_json.len() as u16).to_be_bytes());
491        message_with_headers.extend_from_slice(&headers_json);
492        message_with_headers.extend_from_slice(&(payload.len() as u32).to_be_bytes());
493        message_with_headers.extend_from_slice(payload);
494
495        let mut reader = &message_with_headers[..];
496        let msg = decoder.read_message(&mut reader).await.unwrap();
497
498        assert_eq!(msg.endpoint_path().unwrap(), endpoint);
499        assert_eq!(msg.payload().as_ref(), payload);
500        assert_eq!(msg.headers().len(), 1);
501        assert_eq!(msg.headers().get("x-test-header").unwrap(), "test-value");
502    }
503}