Skip to main content

dynamo_runtime/pipeline/network/
codec.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Codec Module
5//!
6//! Codec map structure into blobs of bytes and streams of bytes.
7//!
8//! In this module, we define three primary codec used to issue single, two-part or multi-part messages,
9//! on a byte stream.
10
11use bytes::Bytes;
12use tokio_util::{
13    bytes::{Buf, BufMut, BytesMut},
14    codec::{Decoder, Encoder},
15};
16
17mod two_part;
18pub mod zero_copy_decoder;
19
20pub use two_part::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
21pub use zero_copy_decoder::{TcpRequestMessageZeroCopy, ZeroCopyTcpDecoder};
22
23const TCP_REQUEST_ENDPOINT_LEN_WIDTH: usize = 2;
24const TCP_REQUEST_HEADERS_LEN_WIDTH: usize = 2;
25const TCP_REQUEST_PAYLOAD_LEN_WIDTH: usize = 4;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28struct TcpRequestWireHeader {
29    endpoint_len: usize,
30    headers_len: usize,
31    payload_len: usize,
32    header_size: usize,
33    total_len: usize,
34}
35
36impl TcpRequestWireHeader {
37    fn endpoint_start(&self) -> usize {
38        TCP_REQUEST_ENDPOINT_LEN_WIDTH
39    }
40
41    fn endpoint_end(&self) -> usize {
42        self.endpoint_start() + self.endpoint_len
43    }
44
45    fn headers_start(&self) -> usize {
46        self.endpoint_end() + TCP_REQUEST_HEADERS_LEN_WIDTH
47    }
48
49    fn headers_end(&self) -> usize {
50        self.headers_start() + self.headers_len
51    }
52
53    fn payload_start(&self) -> usize {
54        self.header_size
55    }
56}
57
58fn tcp_request_header_size(endpoint_len: usize, headers_len: usize) -> usize {
59    TCP_REQUEST_ENDPOINT_LEN_WIDTH
60        + endpoint_len
61        + TCP_REQUEST_HEADERS_LEN_WIDTH
62        + headers_len
63        + TCP_REQUEST_PAYLOAD_LEN_WIDTH
64}
65
66fn tcp_request_total_len(
67    endpoint_len: usize,
68    headers_len: usize,
69    payload_len: usize,
70) -> Result<TcpRequestWireHeader, std::io::Error> {
71    let header_size = tcp_request_header_size(endpoint_len, headers_len);
72    let total_len = header_size.checked_add(payload_len).ok_or_else(|| {
73        std::io::Error::new(
74            std::io::ErrorKind::InvalidData,
75            "TCP request message length overflow",
76        )
77    })?;
78
79    Ok(TcpRequestWireHeader {
80        endpoint_len,
81        headers_len,
82        payload_len,
83        header_size,
84        total_len,
85    })
86}
87
88fn validate_tcp_request_encode_lengths(
89    endpoint_len: usize,
90    headers_len: usize,
91    payload_len: usize,
92) -> Result<TcpRequestWireHeader, std::io::Error> {
93    if endpoint_len > u16::MAX as usize {
94        return Err(std::io::Error::new(
95            std::io::ErrorKind::InvalidInput,
96            format!("Endpoint path too long: {} bytes", endpoint_len),
97        ));
98    }
99
100    if headers_len > u16::MAX as usize {
101        return Err(std::io::Error::new(
102            std::io::ErrorKind::InvalidInput,
103            format!("Headers too large: {} bytes", headers_len),
104        ));
105    }
106
107    if payload_len > u32::MAX as usize {
108        return Err(std::io::Error::new(
109            std::io::ErrorKind::InvalidInput,
110            format!("Payload too large: {} bytes", payload_len),
111        ));
112    }
113
114    tcp_request_total_len(endpoint_len, headers_len, payload_len)
115}
116
117fn tcp_request_endpoint_len(bytes: &[u8]) -> Result<usize, std::io::Error> {
118    if bytes.len() < TCP_REQUEST_ENDPOINT_LEN_WIDTH {
119        return Err(std::io::Error::new(
120            std::io::ErrorKind::UnexpectedEof,
121            "Not enough bytes for endpoint path length",
122        ));
123    }
124
125    Ok(u16::from_be_bytes([bytes[0], bytes[1]]) as usize)
126}
127
128fn tcp_request_headers_len(bytes: &[u8], endpoint_len: usize) -> Result<usize, std::io::Error> {
129    let endpoint_end = TCP_REQUEST_ENDPOINT_LEN_WIDTH + endpoint_len;
130    if bytes.len() < endpoint_end {
131        return Err(std::io::Error::new(
132            std::io::ErrorKind::UnexpectedEof,
133            "Not enough bytes for endpoint path",
134        ));
135    }
136
137    if bytes.len() < endpoint_end + TCP_REQUEST_HEADERS_LEN_WIDTH {
138        return Err(std::io::Error::new(
139            std::io::ErrorKind::UnexpectedEof,
140            "Not enough bytes for headers length",
141        ));
142    }
143
144    Ok(u16::from_be_bytes([bytes[endpoint_end], bytes[endpoint_end + 1]]) as usize)
145}
146
147fn parse_tcp_request_frame_header(bytes: &[u8]) -> Result<TcpRequestWireHeader, std::io::Error> {
148    let endpoint_len = tcp_request_endpoint_len(bytes)?;
149    let headers_len = tcp_request_headers_len(bytes, endpoint_len)?;
150
151    let headers_end =
152        TCP_REQUEST_ENDPOINT_LEN_WIDTH + endpoint_len + TCP_REQUEST_HEADERS_LEN_WIDTH + headers_len;
153    if bytes.len() < headers_end {
154        return Err(std::io::Error::new(
155            std::io::ErrorKind::UnexpectedEof,
156            "Not enough bytes for headers",
157        ));
158    }
159
160    if bytes.len() < headers_end + TCP_REQUEST_PAYLOAD_LEN_WIDTH {
161        return Err(std::io::Error::new(
162            std::io::ErrorKind::UnexpectedEof,
163            "Not enough bytes for payload length",
164        ));
165    }
166
167    let payload_len = u32::from_be_bytes([
168        bytes[headers_end],
169        bytes[headers_end + 1],
170        bytes[headers_end + 2],
171        bytes[headers_end + 3],
172    ]) as usize;
173
174    tcp_request_total_len(endpoint_len, headers_len, payload_len)
175}
176
177fn parse_tcp_request_frame(bytes: &[u8]) -> Result<TcpRequestWireHeader, std::io::Error> {
178    let parsed = parse_tcp_request_frame_header(bytes)?;
179    if bytes.len() < parsed.total_len {
180        return Err(std::io::Error::new(
181            std::io::ErrorKind::UnexpectedEof,
182            format!(
183                "Not enough bytes for payload: expected {}, got {}",
184                parsed.payload_len,
185                bytes.len().saturating_sub(parsed.payload_start())
186            ),
187        ));
188    }
189
190    Ok(parsed)
191}
192
193fn check_tcp_request_max_message_size(
194    total_len: usize,
195    max_message_size: usize,
196) -> Result<(), std::io::Error> {
197    if total_len > max_message_size {
198        return Err(std::io::Error::new(
199            std::io::ErrorKind::InvalidData,
200            format!(
201                "message too large: {} bytes (max: {} bytes)",
202                total_len, max_message_size
203            ),
204        ));
205    }
206
207    Ok(())
208}
209
210/// TCP request plane protocol message with endpoint routing and trace headers
211///
212/// Wire format:
213/// - endpoint_path_len: u16 (big-endian)
214/// - endpoint_path: UTF-8 string
215/// - headers_len: u16 (big-endian)
216/// - headers: JSON-encoded HashMap<String, String>
217/// - payload_len: u32 (big-endian)
218/// - payload: bytes
219#[derive(Debug, Clone, PartialEq, Eq)]
220pub struct TcpRequestMessage {
221    pub endpoint_path: String,
222    pub headers: std::collections::HashMap<String, String>,
223    pub payload: Bytes,
224}
225
226impl TcpRequestMessage {
227    pub fn new(endpoint_path: String, payload: Bytes) -> Self {
228        Self {
229            endpoint_path,
230            headers: std::collections::HashMap::new(),
231            payload,
232        }
233    }
234
235    pub fn with_headers(
236        endpoint_path: String,
237        headers: std::collections::HashMap<String, String>,
238        payload: Bytes,
239    ) -> Self {
240        Self {
241            endpoint_path,
242            headers,
243            payload,
244        }
245    }
246
247    /// Encode message to bytes
248    pub fn encode(&self) -> Result<Bytes, std::io::Error> {
249        let endpoint_bytes = self.endpoint_path.as_bytes();
250        let endpoint_len = endpoint_bytes.len();
251
252        // Encode headers as JSON
253        let headers_json = serde_json::to_vec(&self.headers).map_err(|e| {
254            std::io::Error::new(
255                std::io::ErrorKind::InvalidInput,
256                format!("Failed to encode headers: {}", e),
257            )
258        })?;
259        let headers_len = headers_json.len();
260
261        let parsed =
262            validate_tcp_request_encode_lengths(endpoint_len, headers_len, self.payload.len())?;
263
264        // Use BytesMut for efficient buffer building
265        let mut buf = BytesMut::with_capacity(parsed.total_len);
266
267        // Write endpoint path length (2 bytes)
268        buf.put_u16(endpoint_len as u16);
269
270        // Write endpoint path
271        buf.put_slice(endpoint_bytes);
272
273        // Write headers length (2 bytes)
274        buf.put_u16(headers_len as u16);
275
276        // Write headers
277        buf.put_slice(&headers_json);
278
279        // Write payload length (4 bytes)
280        buf.put_u32(self.payload.len() as u32);
281
282        // Write payload
283        buf.put_slice(&self.payload);
284
285        // Zero-copy conversion to Bytes
286        Ok(buf.freeze())
287    }
288
289    /// Decode message from bytes (for backward compatibility, zero-copy when possible)
290    pub fn decode(bytes: &Bytes) -> Result<Self, std::io::Error> {
291        let parsed = parse_tcp_request_frame(bytes)?;
292
293        // Read endpoint path (requires copy for UTF-8 validation)
294        let endpoint_path =
295            String::from_utf8(bytes[parsed.endpoint_start()..parsed.endpoint_end()].to_vec())
296                .map_err(|e| {
297                    std::io::Error::new(
298                        std::io::ErrorKind::InvalidData,
299                        format!("Invalid UTF-8 in endpoint path: {}", e),
300                    )
301                })?;
302
303        // Read and parse headers
304        let headers: std::collections::HashMap<String, String> = serde_json::from_slice(
305            &bytes[parsed.headers_start()..parsed.headers_end()],
306        )
307        .map_err(|e| {
308            std::io::Error::new(
309                std::io::ErrorKind::InvalidData,
310                format!("Invalid JSON in headers: {}", e),
311            )
312        })?;
313
314        // Read payload (zero-copy slice)
315        let payload = bytes.slice(parsed.payload_start()..parsed.total_len);
316
317        Ok(Self {
318            endpoint_path,
319            headers,
320            payload,
321        })
322    }
323}
324
325/// TCP response message (acknowledgment or error)
326///
327/// Wire format:
328/// - length: u32 (big-endian)
329/// - data: bytes
330#[derive(Debug, Clone, PartialEq, Eq)]
331pub struct TcpResponseMessage {
332    pub data: Bytes,
333}
334
335impl TcpResponseMessage {
336    pub fn new(data: Bytes) -> Self {
337        Self { data }
338    }
339
340    pub fn empty() -> Self {
341        Self { data: Bytes::new() }
342    }
343
344    /// Encode response to bytes (for backward compatibility)
345    pub fn encode(&self) -> Result<Bytes, std::io::Error> {
346        if self.data.len() > u32::MAX as usize {
347            return Err(std::io::Error::new(
348                std::io::ErrorKind::InvalidInput,
349                format!("Response too large: {} bytes", self.data.len()),
350            ));
351        }
352
353        // Use BytesMut for efficient buffer building
354        let mut buf = BytesMut::with_capacity(4 + self.data.len());
355
356        // Write length (4 bytes)
357        buf.put_u32(self.data.len() as u32);
358
359        // Write data
360        buf.put_slice(&self.data);
361
362        // Zero-copy conversion to Bytes
363        Ok(buf.freeze())
364    }
365
366    /// Decode response from bytes (for backward compatibility, zero-copy when possible)
367    pub fn decode(bytes: &Bytes) -> Result<Self, std::io::Error> {
368        if bytes.len() < 4 {
369            return Err(std::io::Error::new(
370                std::io::ErrorKind::UnexpectedEof,
371                "Not enough bytes for response length",
372            ));
373        }
374
375        // Read length (4 bytes)
376        let len = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
377
378        if bytes.len() < 4 + len {
379            return Err(std::io::Error::new(
380                std::io::ErrorKind::UnexpectedEof,
381                format!(
382                    "Not enough bytes for response: expected {}, got {}",
383                    len,
384                    bytes.len() - 4
385                ),
386            ));
387        }
388
389        // Read data (zero-copy slice)
390        let data = bytes.slice(4..4 + len);
391
392        Ok(Self { data })
393    }
394}
395
396/// Codec for encoding/decoding TcpResponseMessage
397/// Supports max_message_size enforcement
398#[derive(Clone, Default)]
399pub struct TcpResponseCodec {
400    max_message_size: Option<usize>,
401}
402
403impl TcpResponseCodec {
404    pub fn new(max_message_size: Option<usize>) -> Self {
405        Self { max_message_size }
406    }
407}
408
409impl Decoder for TcpResponseCodec {
410    type Item = TcpResponseMessage;
411    type Error = std::io::Error;
412
413    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
414        // Need at least 4 bytes for length
415        if src.len() < 4 {
416            return Ok(None);
417        }
418
419        // Peek at message length without consuming
420        let data_len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
421        let total_len = 4 + data_len;
422
423        // Check max message size
424        if let Some(max_size) = self.max_message_size
425            && total_len > max_size
426        {
427            return Err(std::io::Error::new(
428                std::io::ErrorKind::InvalidData,
429                format!(
430                    "Response too large: {} bytes (max: {} bytes)",
431                    total_len, max_size
432                ),
433            ));
434        }
435
436        // Check if we have the full message
437        if src.len() < total_len {
438            return Ok(None);
439        }
440
441        // Advance past the length prefix
442        src.advance(4);
443
444        // Read data
445        let data = src.split_to(data_len).freeze();
446
447        Ok(Some(TcpResponseMessage { data }))
448    }
449}
450
451impl Encoder<TcpResponseMessage> for TcpResponseCodec {
452    type Error = std::io::Error;
453
454    fn encode(&mut self, item: TcpResponseMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
455        if item.data.len() > u32::MAX as usize {
456            return Err(std::io::Error::new(
457                std::io::ErrorKind::InvalidInput,
458                format!("Response too large: {} bytes", item.data.len()),
459            ));
460        }
461
462        let total_len = 4 + item.data.len();
463
464        // Check max message size
465        if let Some(max_size) = self.max_message_size
466            && total_len > max_size
467        {
468            return Err(std::io::Error::new(
469                std::io::ErrorKind::InvalidInput,
470                format!(
471                    "Response too large: {} bytes (max: {} bytes)",
472                    total_len, max_size
473                ),
474            ));
475        }
476
477        // Reserve space
478        dst.reserve(total_len);
479
480        // Write length
481        dst.put_u32(item.data.len() as u32);
482
483        // Write data
484        dst.put_slice(&item.data);
485
486        Ok(())
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493
494    #[test]
495    fn test_tcp_request_encode_decode() {
496        let msg = TcpRequestMessage::new(
497            "test.endpoint".to_string(),
498            Bytes::from(vec![1, 2, 3, 4, 5]),
499        );
500
501        let encoded = msg.encode().unwrap();
502        let decoded = TcpRequestMessage::decode(&encoded).unwrap();
503
504        assert_eq!(decoded, msg);
505    }
506
507    #[test]
508    fn test_tcp_request_empty_payload() {
509        let msg = TcpRequestMessage::new("test".to_string(), Bytes::new());
510
511        let encoded = msg.encode().unwrap();
512        let decoded = TcpRequestMessage::decode(&encoded).unwrap();
513
514        assert_eq!(decoded, msg);
515    }
516
517    #[test]
518    fn test_tcp_request_large_payload() {
519        let payload = Bytes::from(vec![42u8; 1024 * 1024]); // 1MB
520        let msg = TcpRequestMessage::new("large".to_string(), payload);
521
522        let encoded = msg.encode().unwrap();
523        let decoded = TcpRequestMessage::decode(&encoded).unwrap();
524
525        assert_eq!(decoded, msg);
526    }
527
528    #[test]
529    fn test_tcp_request_decode_truncated() {
530        let msg = TcpRequestMessage::new("test".to_string(), Bytes::from(vec![1, 2, 3, 4, 5]));
531        let encoded = msg.encode().unwrap();
532
533        // Truncate the encoded message
534        let truncated = encoded.slice(..encoded.len() - 2);
535        let result = TcpRequestMessage::decode(&truncated);
536
537        assert!(result.is_err());
538    }
539
540    #[test]
541    fn test_tcp_request_decode_invalid_endpoint_utf8() {
542        let mut encoded = BytesMut::new();
543        encoded.put_u16(2);
544        encoded.put_slice(&[0xff, 0xff]);
545        encoded.put_u16(2);
546        encoded.put_slice(b"{}");
547        encoded.put_u32(0);
548
549        let result = TcpRequestMessage::decode(&encoded.freeze());
550
551        assert!(result.is_err());
552        let err = result.unwrap_err();
553        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
554        assert!(err.to_string().contains("Invalid UTF-8"));
555    }
556
557    #[test]
558    fn test_tcp_request_decode_invalid_headers_json() {
559        let mut encoded = BytesMut::new();
560        encoded.put_u16(4);
561        encoded.put_slice(b"test");
562        encoded.put_u16(1);
563        encoded.put_slice(b"{");
564        encoded.put_u32(0);
565
566        let result = TcpRequestMessage::decode(&encoded.freeze());
567
568        assert!(result.is_err());
569        let err = result.unwrap_err();
570        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
571        assert!(err.to_string().contains("Invalid JSON"));
572    }
573
574    #[test]
575    fn test_tcp_request_empty_endpoint_path() {
576        let msg = TcpRequestMessage::new(String::new(), Bytes::from_static(b"payload"));
577
578        let encoded = msg.encode().unwrap();
579        let decoded = TcpRequestMessage::decode(&encoded).unwrap();
580
581        assert_eq!(decoded, msg);
582    }
583
584    #[test]
585    fn test_tcp_response_encode_decode() {
586        let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
587
588        let encoded = msg.encode().unwrap();
589        let decoded = TcpResponseMessage::decode(&encoded).unwrap();
590
591        assert_eq!(decoded, msg);
592    }
593
594    #[test]
595    fn test_tcp_response_empty() {
596        let msg = TcpResponseMessage::empty();
597
598        let encoded = msg.encode().unwrap();
599        let decoded = TcpResponseMessage::decode(&encoded).unwrap();
600
601        assert_eq!(decoded, msg);
602        assert_eq!(decoded.data.len(), 0);
603    }
604
605    #[test]
606    fn test_tcp_response_decode_truncated() {
607        let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
608        let encoded = msg.encode().unwrap();
609
610        // Truncate the encoded message
611        let truncated = encoded.slice(..3);
612        let result = TcpResponseMessage::decode(&truncated);
613
614        assert!(result.is_err());
615    }
616
617    #[test]
618    fn test_tcp_request_unicode_endpoint() {
619        let msg = TcpRequestMessage::new("тест.端点".to_string(), Bytes::from(vec![1, 2, 3]));
620
621        let encoded = msg.encode().unwrap();
622        let decoded = TcpRequestMessage::decode(&encoded).unwrap();
623
624        assert_eq!(decoded, msg);
625    }
626
627    #[test]
628    fn test_tcp_response_codec() {
629        use tokio_util::codec::{Decoder, Encoder};
630
631        let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
632
633        let mut codec = TcpResponseCodec::new(None);
634        let mut buf = BytesMut::new();
635
636        // Encode
637        codec.encode(msg.clone(), &mut buf).unwrap();
638
639        // Decode
640        let decoded = codec.decode(&mut buf).unwrap().unwrap();
641        assert_eq!(decoded, msg);
642    }
643
644    #[test]
645    fn test_tcp_response_codec_partial() {
646        use tokio_util::codec::Decoder;
647
648        let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
649
650        let encoded = msg.encode().unwrap();
651        let mut codec = TcpResponseCodec::new(None);
652
653        // Feed partial data
654        let mut buf = BytesMut::from(&encoded[..3]);
655        assert!(codec.decode(&mut buf).unwrap().is_none());
656
657        // Feed rest of data
658        buf.extend_from_slice(&encoded[3..]);
659        let decoded = codec.decode(&mut buf).unwrap().unwrap();
660        assert_eq!(decoded, msg);
661    }
662
663    #[test]
664    fn test_tcp_response_codec_max_size() {
665        use tokio_util::codec::Encoder;
666
667        let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
668
669        let mut codec = TcpResponseCodec::new(Some(5)); // Too small
670        let mut buf = BytesMut::new();
671
672        let result = codec.encode(msg, &mut buf);
673        assert!(result.is_err());
674    }
675}