dynamo_runtime/pipeline/network/
codec.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Codec Module
5//!
6//! Codec map structure into blobs of bytes and streams of bytes.
7//!
8//! In this module, we define three primary codec used to issue single, two-part or multi-part messages,
9//! on a byte stream.
10
11use bytes::Bytes;
12use tokio_util::{
13    bytes::{Buf, BufMut, BytesMut},
14    codec::{Decoder, Encoder},
15};
16
17mod two_part;
18
19pub use two_part::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
20
21/// TCP request plane protocol message with endpoint routing
22///
23/// Wire format:
24/// - endpoint_path_len: u16 (big-endian)
25/// - endpoint_path: UTF-8 string
26/// - payload_len: u32 (big-endian)
27/// - payload: bytes
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct TcpRequestMessage {
30    pub endpoint_path: String,
31    pub payload: Bytes,
32}
33
34impl TcpRequestMessage {
35    pub fn new(endpoint_path: String, payload: Bytes) -> Self {
36        Self {
37            endpoint_path,
38            payload,
39        }
40    }
41
42    /// Encode message to bytes
43    pub fn encode(&self) -> Result<Bytes, std::io::Error> {
44        let endpoint_bytes = self.endpoint_path.as_bytes();
45        let endpoint_len = endpoint_bytes.len();
46
47        if endpoint_len > u16::MAX as usize {
48            return Err(std::io::Error::new(
49                std::io::ErrorKind::InvalidInput,
50                format!("Endpoint path too long: {} bytes", endpoint_len),
51            ));
52        }
53
54        if self.payload.len() > u32::MAX as usize {
55            return Err(std::io::Error::new(
56                std::io::ErrorKind::InvalidInput,
57                format!("Payload too large: {} bytes", self.payload.len()),
58            ));
59        }
60
61        // Use BytesMut for efficient buffer building
62        let mut buf = BytesMut::with_capacity(2 + endpoint_len + 4 + self.payload.len());
63
64        // Write endpoint path length (2 bytes)
65        buf.put_u16(endpoint_len as u16);
66
67        // Write endpoint path
68        buf.put_slice(endpoint_bytes);
69
70        // Write payload length (4 bytes)
71        buf.put_u32(self.payload.len() as u32);
72
73        // Write payload
74        buf.put_slice(&self.payload);
75
76        // Zero-copy conversion to Bytes
77        Ok(buf.freeze())
78    }
79
80    /// Decode message from bytes (for backward compatibility, zero-copy when possible)
81    pub fn decode(bytes: &Bytes) -> Result<Self, std::io::Error> {
82        if bytes.len() < 2 {
83            return Err(std::io::Error::new(
84                std::io::ErrorKind::UnexpectedEof,
85                "Not enough bytes for endpoint path length",
86            ));
87        }
88
89        // Read endpoint path length (2 bytes)
90        let endpoint_len = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
91        let mut offset = 2;
92
93        if bytes.len() < offset + endpoint_len {
94            return Err(std::io::Error::new(
95                std::io::ErrorKind::UnexpectedEof,
96                "Not enough bytes for endpoint path",
97            ));
98        }
99
100        // Read endpoint path (requires copy for UTF-8 validation)
101        let endpoint_path = String::from_utf8(bytes[offset..offset + endpoint_len].to_vec())
102            .map_err(|e| {
103                std::io::Error::new(
104                    std::io::ErrorKind::InvalidData,
105                    format!("Invalid UTF-8: {}", e),
106                )
107            })?;
108        offset += endpoint_len;
109
110        if bytes.len() < offset + 4 {
111            return Err(std::io::Error::new(
112                std::io::ErrorKind::UnexpectedEof,
113                "Not enough bytes for payload length",
114            ));
115        }
116
117        // Read payload length (4 bytes)
118        let payload_len = u32::from_be_bytes([
119            bytes[offset],
120            bytes[offset + 1],
121            bytes[offset + 2],
122            bytes[offset + 3],
123        ]) as usize;
124        offset += 4;
125
126        if bytes.len() < offset + payload_len {
127            return Err(std::io::Error::new(
128                std::io::ErrorKind::UnexpectedEof,
129                format!(
130                    "Not enough bytes for payload: expected {}, got {}",
131                    payload_len,
132                    bytes.len() - offset
133                ),
134            ));
135        }
136
137        // Read payload (zero-copy slice)
138        let payload = bytes.slice(offset..offset + payload_len);
139
140        Ok(Self {
141            endpoint_path,
142            payload,
143        })
144    }
145}
146
147/// Codec for encoding/decoding TcpRequestMessage
148/// Supports max_message_size enforcement
149#[derive(Clone, Default)]
150pub struct TcpRequestCodec {
151    max_message_size: Option<usize>,
152}
153
154impl TcpRequestCodec {
155    pub fn new(max_message_size: Option<usize>) -> Self {
156        Self { max_message_size }
157    }
158}
159
160impl Decoder for TcpRequestCodec {
161    type Item = TcpRequestMessage;
162    type Error = std::io::Error;
163
164    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
165        // Need at least 2 bytes for endpoint_path_len
166        if src.len() < 2 {
167            return Ok(None);
168        }
169
170        // Peek at endpoint path length without consuming
171        let endpoint_len = u16::from_be_bytes([src[0], src[1]]) as usize;
172        let header_size = 2 + endpoint_len + 4; // path_len + path + payload_len
173
174        if src.len() < header_size {
175            return Ok(None);
176        }
177
178        // Peek at payload length
179        let payload_len_offset = 2 + endpoint_len;
180        let payload_len = u32::from_be_bytes([
181            src[payload_len_offset],
182            src[payload_len_offset + 1],
183            src[payload_len_offset + 2],
184            src[payload_len_offset + 3],
185        ]) as usize;
186
187        let total_len = header_size + payload_len;
188
189        // Check max message size
190        if let Some(max_size) = self.max_message_size
191            && total_len > max_size
192        {
193            return Err(std::io::Error::new(
194                std::io::ErrorKind::InvalidData,
195                format!(
196                    "Request too large: {} bytes (max: {} bytes)",
197                    total_len, max_size
198                ),
199            ));
200        }
201
202        // Check if we have the full message
203        if src.len() < total_len {
204            return Ok(None);
205        }
206
207        // We have a complete message, advance past length prefix
208        src.advance(2);
209
210        // Read endpoint path
211        let endpoint_bytes = src.split_to(endpoint_len);
212        let endpoint_path = String::from_utf8(endpoint_bytes.to_vec()).map_err(|e| {
213            std::io::Error::new(
214                std::io::ErrorKind::InvalidData,
215                format!("Invalid UTF-8 in endpoint path: {}", e),
216            )
217        })?;
218
219        // Advance past payload length
220        src.advance(4);
221
222        // Read payload
223        let payload = src.split_to(payload_len).freeze();
224
225        Ok(Some(TcpRequestMessage {
226            endpoint_path,
227            payload,
228        }))
229    }
230}
231
232impl Encoder<TcpRequestMessage> for TcpRequestCodec {
233    type Error = std::io::Error;
234
235    fn encode(&mut self, item: TcpRequestMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
236        let endpoint_bytes = item.endpoint_path.as_bytes();
237        let endpoint_len = endpoint_bytes.len();
238
239        if endpoint_len > u16::MAX as usize {
240            return Err(std::io::Error::new(
241                std::io::ErrorKind::InvalidInput,
242                format!("Endpoint path too long: {} bytes", endpoint_len),
243            ));
244        }
245
246        if item.payload.len() > u32::MAX as usize {
247            return Err(std::io::Error::new(
248                std::io::ErrorKind::InvalidInput,
249                format!("Payload too large: {} bytes", item.payload.len()),
250            ));
251        }
252
253        let total_len = 2 + endpoint_len + 4 + item.payload.len();
254
255        // Check max message size
256        if let Some(max_size) = self.max_message_size
257            && total_len > max_size
258        {
259            return Err(std::io::Error::new(
260                std::io::ErrorKind::InvalidInput,
261                format!(
262                    "Request too large: {} bytes (max: {} bytes)",
263                    total_len, max_size
264                ),
265            ));
266        }
267
268        // Reserve space
269        dst.reserve(total_len);
270
271        // Write endpoint path length
272        dst.put_u16(endpoint_len as u16);
273
274        // Write endpoint path
275        dst.put_slice(endpoint_bytes);
276
277        // Write payload length
278        dst.put_u32(item.payload.len() as u32);
279
280        // Write payload
281        dst.put_slice(&item.payload);
282
283        Ok(())
284    }
285}
286
287/// TCP response message (acknowledgment or error)
288///
289/// Wire format:
290/// - length: u32 (big-endian)
291/// - data: bytes
292#[derive(Debug, Clone, PartialEq, Eq)]
293pub struct TcpResponseMessage {
294    pub data: Bytes,
295}
296
297impl TcpResponseMessage {
298    pub fn new(data: Bytes) -> Self {
299        Self { data }
300    }
301
302    pub fn empty() -> Self {
303        Self { data: Bytes::new() }
304    }
305
306    /// Encode response to bytes (for backward compatibility)
307    pub fn encode(&self) -> Result<Bytes, std::io::Error> {
308        if self.data.len() > u32::MAX as usize {
309            return Err(std::io::Error::new(
310                std::io::ErrorKind::InvalidInput,
311                format!("Response too large: {} bytes", self.data.len()),
312            ));
313        }
314
315        // Use BytesMut for efficient buffer building
316        let mut buf = BytesMut::with_capacity(4 + self.data.len());
317
318        // Write length (4 bytes)
319        buf.put_u32(self.data.len() as u32);
320
321        // Write data
322        buf.put_slice(&self.data);
323
324        // Zero-copy conversion to Bytes
325        Ok(buf.freeze())
326    }
327
328    /// Decode response from bytes (for backward compatibility, zero-copy when possible)
329    pub fn decode(bytes: &Bytes) -> Result<Self, std::io::Error> {
330        if bytes.len() < 4 {
331            return Err(std::io::Error::new(
332                std::io::ErrorKind::UnexpectedEof,
333                "Not enough bytes for response length",
334            ));
335        }
336
337        // Read length (4 bytes)
338        let len = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
339
340        if bytes.len() < 4 + len {
341            return Err(std::io::Error::new(
342                std::io::ErrorKind::UnexpectedEof,
343                format!(
344                    "Not enough bytes for response: expected {}, got {}",
345                    len,
346                    bytes.len() - 4
347                ),
348            ));
349        }
350
351        // Read data (zero-copy slice)
352        let data = bytes.slice(4..4 + len);
353
354        Ok(Self { data })
355    }
356}
357
358/// Codec for encoding/decoding TcpResponseMessage
359/// Supports max_message_size enforcement
360#[derive(Clone, Default)]
361pub struct TcpResponseCodec {
362    max_message_size: Option<usize>,
363}
364
365impl TcpResponseCodec {
366    pub fn new(max_message_size: Option<usize>) -> Self {
367        Self { max_message_size }
368    }
369}
370
371impl Decoder for TcpResponseCodec {
372    type Item = TcpResponseMessage;
373    type Error = std::io::Error;
374
375    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
376        // Need at least 4 bytes for length
377        if src.len() < 4 {
378            return Ok(None);
379        }
380
381        // Peek at message length without consuming
382        let data_len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
383        let total_len = 4 + data_len;
384
385        // Check max message size
386        if let Some(max_size) = self.max_message_size
387            && total_len > max_size
388        {
389            return Err(std::io::Error::new(
390                std::io::ErrorKind::InvalidData,
391                format!(
392                    "Response too large: {} bytes (max: {} bytes)",
393                    total_len, max_size
394                ),
395            ));
396        }
397
398        // Check if we have the full message
399        if src.len() < total_len {
400            return Ok(None);
401        }
402
403        // Advance past the length prefix
404        src.advance(4);
405
406        // Read data
407        let data = src.split_to(data_len).freeze();
408
409        Ok(Some(TcpResponseMessage { data }))
410    }
411}
412
413impl Encoder<TcpResponseMessage> for TcpResponseCodec {
414    type Error = std::io::Error;
415
416    fn encode(&mut self, item: TcpResponseMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
417        if item.data.len() > u32::MAX as usize {
418            return Err(std::io::Error::new(
419                std::io::ErrorKind::InvalidInput,
420                format!("Response too large: {} bytes", item.data.len()),
421            ));
422        }
423
424        let total_len = 4 + item.data.len();
425
426        // Check max message size
427        if let Some(max_size) = self.max_message_size
428            && total_len > max_size
429        {
430            return Err(std::io::Error::new(
431                std::io::ErrorKind::InvalidInput,
432                format!(
433                    "Response too large: {} bytes (max: {} bytes)",
434                    total_len, max_size
435                ),
436            ));
437        }
438
439        // Reserve space
440        dst.reserve(total_len);
441
442        // Write length
443        dst.put_u32(item.data.len() as u32);
444
445        // Write data
446        dst.put_slice(&item.data);
447
448        Ok(())
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455
456    #[test]
457    fn test_tcp_request_encode_decode() {
458        let msg = TcpRequestMessage::new(
459            "test.endpoint".to_string(),
460            Bytes::from(vec![1, 2, 3, 4, 5]),
461        );
462
463        let encoded = msg.encode().unwrap();
464        let decoded = TcpRequestMessage::decode(&encoded).unwrap();
465
466        assert_eq!(decoded, msg);
467    }
468
469    #[test]
470    fn test_tcp_request_empty_payload() {
471        let msg = TcpRequestMessage::new("test".to_string(), Bytes::new());
472
473        let encoded = msg.encode().unwrap();
474        let decoded = TcpRequestMessage::decode(&encoded).unwrap();
475
476        assert_eq!(decoded, msg);
477    }
478
479    #[test]
480    fn test_tcp_request_large_payload() {
481        let payload = Bytes::from(vec![42u8; 1024 * 1024]); // 1MB
482        let msg = TcpRequestMessage::new("large".to_string(), payload);
483
484        let encoded = msg.encode().unwrap();
485        let decoded = TcpRequestMessage::decode(&encoded).unwrap();
486
487        assert_eq!(decoded, msg);
488    }
489
490    #[test]
491    fn test_tcp_request_decode_truncated() {
492        let msg = TcpRequestMessage::new("test".to_string(), Bytes::from(vec![1, 2, 3, 4, 5]));
493        let encoded = msg.encode().unwrap();
494
495        // Truncate the encoded message
496        let truncated = encoded.slice(..encoded.len() - 2);
497        let result = TcpRequestMessage::decode(&truncated);
498
499        assert!(result.is_err());
500    }
501
502    #[test]
503    fn test_tcp_response_encode_decode() {
504        let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
505
506        let encoded = msg.encode().unwrap();
507        let decoded = TcpResponseMessage::decode(&encoded).unwrap();
508
509        assert_eq!(decoded, msg);
510    }
511
512    #[test]
513    fn test_tcp_response_empty() {
514        let msg = TcpResponseMessage::empty();
515
516        let encoded = msg.encode().unwrap();
517        let decoded = TcpResponseMessage::decode(&encoded).unwrap();
518
519        assert_eq!(decoded, msg);
520        assert_eq!(decoded.data.len(), 0);
521    }
522
523    #[test]
524    fn test_tcp_response_decode_truncated() {
525        let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
526        let encoded = msg.encode().unwrap();
527
528        // Truncate the encoded message
529        let truncated = encoded.slice(..3);
530        let result = TcpResponseMessage::decode(&truncated);
531
532        assert!(result.is_err());
533    }
534
535    #[test]
536    fn test_tcp_request_unicode_endpoint() {
537        let msg = TcpRequestMessage::new("тест.端点".to_string(), Bytes::from(vec![1, 2, 3]));
538
539        let encoded = msg.encode().unwrap();
540        let decoded = TcpRequestMessage::decode(&encoded).unwrap();
541
542        assert_eq!(decoded, msg);
543    }
544
545    #[test]
546    fn test_tcp_request_codec() {
547        use tokio_util::codec::{Decoder, Encoder};
548
549        let msg = TcpRequestMessage::new(
550            "test.endpoint".to_string(),
551            Bytes::from(vec![1, 2, 3, 4, 5]),
552        );
553
554        let mut codec = TcpRequestCodec::new(None);
555        let mut buf = BytesMut::new();
556
557        // Encode
558        codec.encode(msg.clone(), &mut buf).unwrap();
559
560        // Decode
561        let decoded = codec.decode(&mut buf).unwrap().unwrap();
562        assert_eq!(decoded, msg);
563    }
564
565    #[test]
566    fn test_tcp_request_codec_partial() {
567        use tokio_util::codec::Decoder;
568
569        let msg = TcpRequestMessage::new(
570            "test.endpoint".to_string(),
571            Bytes::from(vec![1, 2, 3, 4, 5]),
572        );
573
574        let encoded = msg.encode().unwrap();
575        let mut codec = TcpRequestCodec::new(None);
576
577        // Feed partial data
578        let mut buf = BytesMut::from(&encoded[..5]);
579        assert!(codec.decode(&mut buf).unwrap().is_none());
580
581        // Feed rest of data
582        buf.extend_from_slice(&encoded[5..]);
583        let decoded = codec.decode(&mut buf).unwrap().unwrap();
584        assert_eq!(decoded, msg);
585    }
586
587    #[test]
588    fn test_tcp_request_codec_max_size() {
589        use tokio_util::codec::Encoder;
590
591        let msg = TcpRequestMessage::new("test".to_string(), Bytes::from(vec![1, 2, 3, 4, 5]));
592
593        let mut codec = TcpRequestCodec::new(Some(10)); // Too small
594        let mut buf = BytesMut::new();
595
596        let result = codec.encode(msg, &mut buf);
597        assert!(result.is_err());
598    }
599
600    #[test]
601    fn test_tcp_response_codec() {
602        use tokio_util::codec::{Decoder, Encoder};
603
604        let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
605
606        let mut codec = TcpResponseCodec::new(None);
607        let mut buf = BytesMut::new();
608
609        // Encode
610        codec.encode(msg.clone(), &mut buf).unwrap();
611
612        // Decode
613        let decoded = codec.decode(&mut buf).unwrap().unwrap();
614        assert_eq!(decoded, msg);
615    }
616
617    #[test]
618    fn test_tcp_response_codec_partial() {
619        use tokio_util::codec::Decoder;
620
621        let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
622
623        let encoded = msg.encode().unwrap();
624        let mut codec = TcpResponseCodec::new(None);
625
626        // Feed partial data
627        let mut buf = BytesMut::from(&encoded[..3]);
628        assert!(codec.decode(&mut buf).unwrap().is_none());
629
630        // Feed rest of data
631        buf.extend_from_slice(&encoded[3..]);
632        let decoded = codec.decode(&mut buf).unwrap().unwrap();
633        assert_eq!(decoded, msg);
634    }
635
636    #[test]
637    fn test_tcp_response_codec_max_size() {
638        use tokio_util::codec::Encoder;
639
640        let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
641
642        let mut codec = TcpResponseCodec::new(Some(5)); // Too small
643        let mut buf = BytesMut::new();
644
645        let result = codec.encode(msg, &mut buf);
646        assert!(result.is_err());
647    }
648
649    /// Demonstrates how framed codec enables testability without actual TCP connections
650    #[tokio::test]
651    async fn test_framed_codec_integration() {
652        use futures::{SinkExt, StreamExt};
653        use std::io::Cursor;
654        use tokio_util::codec::{FramedRead, FramedWrite};
655
656        // Simulate a duplex connection using in-memory buffer
657        let mut buffer = Vec::new();
658
659        // Writer side: encode requests
660        {
661            let cursor = Cursor::new(&mut buffer);
662            let mut writer = FramedWrite::new(cursor, TcpRequestCodec::new(None));
663
664            let msg1 = TcpRequestMessage::new("endpoint1".to_string(), Bytes::from("data1"));
665            let msg2 = TcpRequestMessage::new("endpoint2".to_string(), Bytes::from("data2"));
666
667            writer.send(msg1).await.unwrap();
668            writer.send(msg2).await.unwrap();
669        }
670
671        // Reader side: decode requests
672        {
673            let cursor = Cursor::new(&buffer[..]);
674            let mut reader = FramedRead::new(cursor, TcpRequestCodec::new(None));
675
676            let decoded1 = reader.next().await.unwrap().unwrap();
677            assert_eq!(decoded1.endpoint_path, "endpoint1");
678            assert_eq!(decoded1.payload, Bytes::from("data1"));
679
680            let decoded2 = reader.next().await.unwrap().unwrap();
681            assert_eq!(decoded2.endpoint_path, "endpoint2");
682            assert_eq!(decoded2.payload, Bytes::from("data2"));
683        }
684    }
685
686    /// Demonstrates testing partial message handling
687    #[tokio::test]
688    async fn test_framed_codec_partial_messages() {
689        use futures::StreamExt;
690        use std::io::Cursor;
691        use tokio_util::codec::FramedRead;
692
693        // Create a message and encode it
694        let msg = TcpRequestMessage::new("test".to_string(), Bytes::from("hello"));
695        let encoded = msg.encode().unwrap();
696
697        // Split the encoded message into chunks
698        let chunk1 = &encoded[..5];
699        let chunk2 = &encoded[5..];
700
701        // Create a buffer that simulates receiving data in chunks
702        let mut full_buffer = Vec::new();
703        full_buffer.extend_from_slice(chunk1);
704
705        // Reader can't decode yet (partial data)
706        {
707            let cursor = Cursor::new(&full_buffer[..]);
708            let _reader = FramedRead::new(cursor, TcpRequestCodec::new(None));
709            // In real async, this would return Ok(None) and wait for more data
710            // For Cursor, it returns None at EOF
711        }
712
713        // Add the rest of the data
714        full_buffer.extend_from_slice(chunk2);
715
716        // Now decoding succeeds
717        {
718            let cursor = Cursor::new(&full_buffer[..]);
719            let mut reader = FramedRead::new(cursor, TcpRequestCodec::new(None));
720
721            let decoded = reader.next().await.unwrap().unwrap();
722            assert_eq!(decoded.endpoint_path, "test");
723            assert_eq!(decoded.payload, Bytes::from("hello"));
724        }
725    }
726}