dynamo_runtime/pipeline/network/codec/
two_part.rs1use bytes::{Buf, BufMut, Bytes, BytesMut};
5use tokio_util::codec::{Decoder, Encoder};
6use xxhash_rust::xxh3::xxh3_64;
7
8use crate::pipeline::error::TwoPartCodecError;
9
10#[derive(Clone, Default)]
11pub struct TwoPartCodec {
12 max_message_size: Option<usize>,
13}
14
15impl TwoPartCodec {
16 pub fn new(max_message_size: Option<usize>) -> Self {
17 TwoPartCodec { max_message_size }
18 }
19
20 pub fn encode_message(&self, msg: TwoPartMessage) -> Result<Bytes, TwoPartCodecError> {
22 let mut buf = BytesMut::new();
23 let mut codec = self.clone();
24 codec.encode(msg, &mut buf)?;
25 Ok(buf.freeze())
26 }
27
28 pub fn decode_message(&self, data: Bytes) -> Result<TwoPartMessage, TwoPartCodecError> {
30 let mut buf = BytesMut::from(&data[..]);
31 let mut codec = self.clone();
32 match codec.decode(&mut buf)? {
33 Some(msg) => Ok(msg),
34 None => Err(TwoPartCodecError::InvalidMessage(
35 "No message decoded".to_string(),
36 )),
37 }
38 }
39}
40
41impl Decoder for TwoPartCodec {
42 type Item = TwoPartMessage;
43 type Error = TwoPartCodecError;
44
45 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
46 if src.len() < 24 {
48 return Ok(None);
49 }
50
51 let mut cursor = &src[..];
53
54 let header_len = cursor.get_u64() as usize;
55 let body_len = cursor.get_u64() as usize;
56 let _checksum = cursor.get_u64();
57
58 let total_len = 24 + header_len + body_len;
59
60 if let Some(max_size) = self.max_message_size
62 && total_len > max_size
63 {
64 return Err(TwoPartCodecError::MessageTooLarge(total_len, max_size));
65 }
66
67 if src.len() < total_len {
69 return Ok(None);
70 }
71
72 src.advance(24);
74
75 #[cfg(debug_assertions)]
76 {
77 if _checksum != 0 {
79 let bytes_to_hash =
80 header_len
81 .checked_add(body_len)
82 .ok_or(TwoPartCodecError::InvalidMessage(
83 "Message exceeds max allowed length.".to_string(),
84 ))?;
85
86 let data_to_hash = &src[..bytes_to_hash];
87
88 let computed_checksum = xxh3_64(data_to_hash);
89
90 if _checksum != computed_checksum {
92 return Err(TwoPartCodecError::ChecksumMismatch);
93 }
94 }
95 }
96
97 let header = src.split_to(header_len).freeze();
99 let data = src.split_to(body_len).freeze();
100
101 Ok(Some(TwoPartMessage { header, data }))
102 }
103}
104
105impl Encoder<TwoPartMessage> for TwoPartCodec {
106 type Error = TwoPartCodecError;
107
108 fn encode(&mut self, item: TwoPartMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
109 let header_len = item.header.len();
110 let body_len = item.data.len();
111
112 let total_len = 24 + header_len + body_len; if let Some(max_size) = self.max_message_size
116 && total_len > max_size
117 {
118 return Err(TwoPartCodecError::MessageTooLarge(total_len, max_size));
119 }
120
121 dst.put_u64(header_len as u64);
122 dst.put_u64(body_len as u64);
123
124 #[cfg(debug_assertions)]
127 {
128 let mut data_to_hash = BytesMut::with_capacity(header_len + body_len);
130 data_to_hash.extend_from_slice(&item.header);
131 data_to_hash.extend_from_slice(&item.data);
132 let checksum = xxh3_64(&data_to_hash);
133
134 dst.put_u64(checksum);
135 }
136 #[cfg(not(debug_assertions))]
137 {
138 dst.put_u64(0);
139 }
140
141 dst.put_slice(&item.header);
143 dst.put_slice(&item.data);
144
145 Ok(())
146 }
147}
148
149pub enum TwoPartMessageType {
150 HeaderOnly(Bytes),
151 DataOnly(Bytes),
152 HeaderAndData(Bytes, Bytes),
153 Empty,
154}
155
156#[derive(Clone, Debug)]
157pub struct TwoPartMessage {
158 pub header: Bytes,
159 pub data: Bytes,
160}
161
162impl TwoPartMessage {
163 pub fn new(header: Bytes, data: Bytes) -> Self {
164 TwoPartMessage { header, data }
165 }
166
167 pub fn from_header(header: Bytes) -> Self {
168 TwoPartMessage {
169 header,
170 data: Bytes::new(),
171 }
172 }
173
174 pub fn from_data(data: Bytes) -> Self {
175 TwoPartMessage {
176 header: Bytes::new(),
177 data,
178 }
179 }
180
181 pub fn from_parts(header: Bytes, data: Bytes) -> Self {
182 TwoPartMessage { header, data }
183 }
184
185 pub fn parts(&self) -> (&Bytes, &Bytes) {
186 (&self.header, &self.data)
187 }
188
189 pub fn optional_parts(&self) -> (Option<&Bytes>, Option<&Bytes>) {
190 (self.header(), self.data())
191 }
192
193 pub fn into_parts(self) -> (Bytes, Bytes) {
194 (self.header, self.data)
195 }
196
197 pub fn header(&self) -> Option<&Bytes> {
198 if self.header.is_empty() {
199 None
200 } else {
201 Some(&self.header)
202 }
203 }
204
205 pub fn data(&self) -> Option<&Bytes> {
206 if self.data.is_empty() {
207 None
208 } else {
209 Some(&self.data)
210 }
211 }
212
213 pub fn into_message_type(self) -> TwoPartMessageType {
214 if self.header.is_empty() && self.data.is_empty() {
215 TwoPartMessageType::Empty
216 } else if self.header.is_empty() {
217 TwoPartMessageType::DataOnly(self.data)
218 } else if self.data.is_empty() {
219 TwoPartMessageType::HeaderOnly(self.header)
220 } else {
221 TwoPartMessageType::HeaderAndData(self.header, self.data)
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use std::io::Cursor;
229 use std::pin::Pin;
230 use std::task::{Context, Poll};
231
232 use bytes::{Bytes, BytesMut};
233 use futures::StreamExt;
234 use tokio::io::AsyncRead;
235 use tokio::io::ReadBuf;
236 use tokio_util::codec::{Decoder, FramedRead};
237
238 use super::*;
239
240 #[test]
242 fn test_message_with_header_and_data() {
243 let header_data = Bytes::from("header data");
245 let data = Bytes::from("body data");
246 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
247
248 let codec = TwoPartCodec::new(None);
249
250 let encoded = codec.encode_message(message).unwrap();
252
253 let decoded = codec.decode_message(encoded).unwrap();
255
256 assert_eq!(decoded.header, header_data);
258 assert_eq!(decoded.data, data);
259 }
260
261 #[test]
263 fn test_message_with_only_header() {
264 let header_data = Bytes::from("header only");
265 let message = TwoPartMessage::from_header(header_data.clone());
266
267 let codec = TwoPartCodec::new(None);
268
269 let encoded = codec.encode_message(message).unwrap();
271
272 let decoded = codec.decode_message(encoded).unwrap();
274
275 assert_eq!(decoded.header, header_data);
277 assert!(decoded.data.is_empty());
278 }
279
280 #[test]
282 fn test_message_with_only_data() {
283 let data = Bytes::from("data only");
284 let message = TwoPartMessage::from_data(data.clone());
285
286 let codec = TwoPartCodec::new(None);
287
288 let encoded = codec.encode_message(message).unwrap();
290
291 let decoded = codec.decode_message(encoded).unwrap();
293
294 assert!(decoded.header.is_empty());
296 assert_eq!(decoded.data, data);
297 }
298
299 #[test]
301 fn test_empty_message() {
302 let message = TwoPartMessage::from_parts(Bytes::new(), Bytes::new());
303
304 let codec = TwoPartCodec::new(None);
305
306 let encoded = codec.encode_message(message).unwrap();
308
309 let decoded = codec.decode_message(encoded).unwrap();
311
312 assert!(decoded.header.is_empty());
314 assert!(decoded.data.is_empty());
315 }
316
317 #[test]
319 fn test_message_under_max_size() {
320 let max_size = 1024; let header_data = Bytes::from(vec![b'h'; 100]);
324 let body_data = Bytes::from(vec![b'd'; 200]);
325 let message = TwoPartMessage::from_parts(header_data.clone(), body_data.clone());
326
327 let codec = TwoPartCodec::new(Some(max_size));
328
329 let encoded = codec.encode_message(message.clone()).unwrap();
331
332 let decoded = codec.decode_message(encoded).unwrap();
334
335 assert_eq!(decoded.header, header_data);
337 assert_eq!(decoded.data, body_data);
338 }
339
340 #[test]
342 fn test_message_exactly_at_max_size() {
343 let max_size = 1024; let lengths_size = 24; let data_size = max_size - lengths_size; let header_size = data_size / 2;
351 let body_size = data_size - header_size;
352
353 let header_data = Bytes::from(vec![b'h'; header_size]);
355 let body_data = Bytes::from(vec![b'd'; body_size]);
356
357 let message = TwoPartMessage::from_parts(header_data.clone(), body_data.clone());
358
359 let codec = TwoPartCodec::new(Some(max_size));
360
361 let encoded = codec.encode_message(message.clone()).unwrap();
363
364 assert_eq!(encoded.len(), max_size);
366
367 let decoded = codec.decode_message(encoded).unwrap();
369
370 assert_eq!(decoded.header, header_data);
372 assert_eq!(decoded.data, body_data);
373 }
374
375 #[test]
377 fn test_message_over_max_size() {
378 let max_size = 1024; let data_size = max_size - 24 + 1; let header_size = data_size / 2;
383 let body_size = data_size - header_size;
384
385 let header_data = Bytes::from(vec![b'h'; header_size]);
386 let body_data = Bytes::from(vec![b'd'; body_size]);
387
388 let message = TwoPartMessage::from_parts(header_data, body_data);
389
390 let codec = TwoPartCodec::new(Some(max_size));
391
392 let result = codec.encode_message(message);
394
395 assert!(result.is_err());
397
398 if let Err(TwoPartCodecError::MessageTooLarge(size, max)) = result {
400 assert_eq!(size, data_size + 24); assert_eq!(max, max_size);
402 } else {
403 panic!("Expected MessageTooLarge error");
404 }
405 }
406
407 #[test]
409 fn test_decoding_message_over_max_size() {
410 let max_size = 1024; let data_size = max_size - 24 + 1; let header_size = data_size / 2;
415 let body_size = data_size - header_size;
416
417 let header_data = Bytes::from(vec![b'h'; header_size]);
418 let body_data = Bytes::from(vec![b'd'; body_size]);
419
420 let message = TwoPartMessage::from_parts(header_data.clone(), body_data.clone());
421
422 let codec = TwoPartCodec::new(None); let encoded = codec.encode_message(message).unwrap();
426
427 let codec_with_limit = TwoPartCodec::new(Some(max_size));
428
429 let result = codec_with_limit.decode_message(encoded);
431
432 assert!(result.is_err());
434
435 if let Err(TwoPartCodecError::MessageTooLarge(size, max)) = result {
437 assert_eq!(size, data_size + 24); assert_eq!(max, max_size);
439 } else {
440 panic!("Expected MessageTooLarge error");
441 }
442 }
443
444 #[test]
446 #[cfg(debug_assertions)]
448 fn test_checksum_mismatch() {
449 let header_data = Bytes::from("header data");
451 let data = Bytes::from("body data");
452 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
453
454 let codec = TwoPartCodec::new(None);
455
456 let encoded = codec.encode_message(message).unwrap();
458
459 let mut encoded = BytesMut::from(encoded);
461 let len = encoded.len();
462 encoded[len - 1] ^= 0xFF; let result = codec.decode_message(encoded.into());
466
467 assert!(result.is_err());
469
470 if let Err(TwoPartCodecError::ChecksumMismatch) = result {
472 } else {
474 panic!("Expected ChecksumMismatch error");
475 }
476 }
477
478 #[test]
480 fn test_partial_data() {
481 let header_data = Bytes::from("header data");
482 let data = Bytes::from("body data");
483 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
484
485 let codec = TwoPartCodec::new(None);
486
487 let encoded = codec.encode_message(message).unwrap();
489
490 let partial_len = encoded.len() - 5;
492 let partial_encoded = encoded.slice(0..partial_len);
493
494 let result = codec.decode_message(partial_encoded);
496
497 assert!(result.is_err());
499
500 if let Err(TwoPartCodecError::InvalidMessage(_)) = result {
501 } else {
503 panic!("Expected InvalidMessage error");
504 }
505 }
506
507 #[test]
509 fn test_multiple_messages_in_buffer() {
510 let header_data1 = Bytes::from("header1");
511 let data1 = Bytes::from("data1");
512 let message1 = TwoPartMessage::from_parts(header_data1.clone(), data1.clone());
513
514 let header_data2 = Bytes::from("header2");
515 let data2 = Bytes::from("data2");
516 let message2 = TwoPartMessage::from_parts(header_data2.clone(), data2.clone());
517
518 let codec = TwoPartCodec::new(None);
519
520 let encoded1 = codec.encode_message(message1).unwrap();
522 let encoded2 = codec.encode_message(message2).unwrap();
523
524 let mut combined = BytesMut::new();
526 combined.extend_from_slice(&encoded1);
527 combined.extend_from_slice(&encoded2);
528
529 let mut decode_buf = combined;
531 let mut codec = codec.clone();
532
533 let decoded_msg1 = codec.decode(&mut decode_buf).unwrap().unwrap();
534 let decoded_msg2 = codec.decode(&mut decode_buf).unwrap().unwrap();
535
536 assert_eq!(decoded_msg1.header, header_data1);
538 assert_eq!(decoded_msg1.data, data1);
539
540 assert_eq!(decoded_msg2.header, header_data2);
541 assert_eq!(decoded_msg2.data, data2);
542 }
543
544 #[tokio::test]
546 async fn test_streaming_read() {
547 let header_data = Bytes::from("header data");
549 let data = Bytes::from("body data");
550 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
551
552 let codec = TwoPartCodec::new(None);
553
554 let encoded = codec.encode_message(message.clone()).unwrap();
556
557 let reader = Cursor::new(encoded.clone());
560
561 let mut framed_read = FramedRead::new(reader, codec.clone());
563
564 if let Some(Ok(decoded_message)) = framed_read.next().await {
566 assert_eq!(decoded_message.header, header_data);
568 assert_eq!(decoded_message.data, data);
569 } else {
570 panic!("Failed to decode message from stream");
571 }
572 }
573
574 #[tokio::test]
576 async fn test_streaming_partial_reads() {
577 let header_data = Bytes::from("header data");
579 let data = Bytes::from("body data");
580 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
581
582 let codec = TwoPartCodec::new(None);
583
584 let encoded = codec.encode_message(message.clone()).unwrap();
586
587 struct ChunkedReader {
590 data: Bytes,
591 pos: usize,
592 chunk_size: usize,
593 }
594
595 impl AsyncRead for ChunkedReader {
596 fn poll_read(
597 mut self: Pin<&mut Self>,
598 _cx: &mut Context<'_>,
599 buf: &mut ReadBuf<'_>,
600 ) -> Poll<std::io::Result<()>> {
601 if self.pos >= self.data.len() {
602 return Poll::Ready(Ok(()));
603 }
604
605 let end = std::cmp::min(self.pos + self.chunk_size, self.data.len());
606 let bytes_to_read = &self.data[self.pos..end];
607 buf.put_slice(bytes_to_read);
608 self.pos = end;
609
610 Poll::Ready(Ok(()))
616 }
617 }
618
619 let reader = ChunkedReader {
620 data: encoded.clone(),
621 pos: 0,
622 chunk_size: 5, };
624
625 let mut framed_read = FramedRead::new(reader, codec.clone());
626
627 if let Some(Ok(decoded_message)) = framed_read.next().await {
629 assert_eq!(decoded_message.header, header_data);
631 assert_eq!(decoded_message.data, data);
632 } else {
633 panic!("Failed to decode message from stream");
634 }
635 }
636
637 #[tokio::test]
639 #[cfg(debug_assertions)]
641 async fn test_streaming_corrupted_data() {
642 let header_data = Bytes::from("header data");
644 let data = Bytes::from("body data");
645 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
646
647 let codec = TwoPartCodec::new(None);
648
649 let encoded = codec.encode_message(message.clone()).unwrap();
651
652 let mut encoded = BytesMut::from(encoded);
654 encoded[30] ^= 0xFF; let reader = Cursor::new(encoded.clone());
658
659 let mut framed_read = FramedRead::new(reader, codec.clone());
660
661 if let Some(result) = framed_read.next().await {
663 assert!(result.is_err());
664
665 if let Err(TwoPartCodecError::ChecksumMismatch) = result {
667 } else {
669 panic!("Expected ChecksumMismatch error");
670 }
671 } else {
672 panic!("Failed to read message from stream");
673 }
674 }
675
676 #[tokio::test]
678 async fn test_empty_stream() {
679 let codec = TwoPartCodec::new(None);
680
681 let reader = Cursor::new(Vec::new());
683
684 let mut framed_read = FramedRead::new(reader, codec.clone());
685
686 if let Some(result) = framed_read.next().await {
688 panic!("Expected no messages, but got {:?}", result);
689 } else {
690 }
692 }
693
694 #[tokio::test]
696 async fn test_streaming_multiple_messages() {
697 let header_data1 = Bytes::from("header1");
698 let data1 = Bytes::from("data1");
699 let message1 = TwoPartMessage::from_parts(header_data1.clone(), data1.clone());
700
701 let header_data2 = Bytes::from("header2");
702 let data2 = Bytes::from("data2");
703 let message2 = TwoPartMessage::from_parts(header_data2.clone(), data2.clone());
704
705 let codec = TwoPartCodec::new(None);
706
707 let encoded1 = codec.encode_message(message1.clone()).unwrap();
709 let encoded2 = codec.encode_message(message2.clone()).unwrap();
710
711 let mut combined = BytesMut::new();
713 combined.extend_from_slice(&encoded1);
714 combined.extend_from_slice(&encoded2);
715
716 let reader = Cursor::new(combined.freeze());
718
719 let mut framed_read = FramedRead::new(reader, codec.clone());
720
721 if let Some(Ok(decoded_message)) = framed_read.next().await {
723 assert_eq!(decoded_message.header, header_data1);
724 assert_eq!(decoded_message.data, data1);
725 } else {
726 panic!("Failed to decode first message from stream");
727 }
728
729 if let Some(Ok(decoded_message)) = framed_read.next().await {
731 assert_eq!(decoded_message.header, header_data2);
732 assert_eq!(decoded_message.data, data2);
733 } else {
734 panic!("Failed to decode second message from stream");
735 }
736
737 if let Some(result) = framed_read.next().await {
739 panic!("Expected no more messages, but got {:?}", result);
740 }
741 }
742
743 #[test]
745 fn test_message_without_max_size() {
746 let header_data = Bytes::from(vec![b'h'; 1024 * 1024]); let body_data = Bytes::from(vec![b'd'; 1024 * 1024]); let message = TwoPartMessage::from_parts(header_data.clone(), body_data.clone());
751
752 let codec = TwoPartCodec::new(None);
753
754 let encoded = codec.encode_message(message).unwrap();
756
757 let decoded = codec.decode_message(encoded).unwrap();
759
760 assert_eq!(decoded.header, header_data);
762 assert_eq!(decoded.data, body_data);
763 }
764}