dynamo_runtime/pipeline/network/codec/
two_part.rs1use bytes::{Buf, BufMut, Bytes, BytesMut};
5use tokio_util::codec::{Decoder, Encoder};
6use xxhash_rust::xxh3::xxh3_64;
7
8use crate::pipeline::error::TwoPartCodecError;
9
10#[derive(Clone, Default)]
11pub struct TwoPartCodec {
12 max_message_size: Option<usize>,
13}
14
15impl TwoPartCodec {
16 pub fn new(max_message_size: Option<usize>) -> Self {
17 TwoPartCodec { max_message_size }
18 }
19
20 pub fn encode_message(&self, msg: TwoPartMessage) -> Result<Bytes, TwoPartCodecError> {
22 let mut buf = BytesMut::new();
23 let mut codec = self.clone();
24 codec.encode(msg, &mut buf)?;
25 Ok(buf.freeze())
26 }
27
28 pub fn decode_message(&self, data: Bytes) -> Result<TwoPartMessage, TwoPartCodecError> {
30 let mut buf = BytesMut::from(&data[..]);
31 let mut codec = self.clone();
32 match codec.decode(&mut buf)? {
33 Some(msg) => Ok(msg),
34 None => Err(TwoPartCodecError::InvalidMessage(
35 "No message decoded".to_string(),
36 )),
37 }
38 }
39}
40
41impl Decoder for TwoPartCodec {
42 type Item = TwoPartMessage;
43 type Error = TwoPartCodecError;
44
45 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
46 if src.len() < 24 {
48 return Ok(None);
49 }
50
51 let mut cursor = &src[..];
53
54 let header_len = cursor.get_u64() as usize;
55 let body_len = cursor.get_u64() as usize;
56 let _checksum = cursor.get_u64();
57
58 let total_len = 24usize
59 .checked_add(header_len)
60 .and_then(|n| n.checked_add(body_len))
61 .ok_or(TwoPartCodecError::MessageTooLarge(
62 usize::MAX,
63 self.max_message_size.unwrap_or(usize::MAX),
64 ))?;
65
66 if let Some(max_size) = self.max_message_size
68 && total_len > max_size
69 {
70 return Err(TwoPartCodecError::MessageTooLarge(total_len, max_size));
71 }
72
73 if src.len() < total_len {
75 return Ok(None);
76 }
77
78 src.advance(24);
80
81 #[cfg(debug_assertions)]
82 {
83 if _checksum != 0 {
85 let bytes_to_hash =
86 header_len
87 .checked_add(body_len)
88 .ok_or(TwoPartCodecError::InvalidMessage(
89 "Message exceeds max allowed length.".to_string(),
90 ))?;
91
92 let data_to_hash = &src[..bytes_to_hash];
93
94 let computed_checksum = xxh3_64(data_to_hash);
95
96 if _checksum != computed_checksum {
98 return Err(TwoPartCodecError::ChecksumMismatch);
99 }
100 }
101 }
102
103 let header = src.split_to(header_len).freeze();
105 let data = src.split_to(body_len).freeze();
106
107 Ok(Some(TwoPartMessage { header, data }))
108 }
109}
110
111impl Encoder<TwoPartMessage> for TwoPartCodec {
112 type Error = TwoPartCodecError;
113
114 fn encode(&mut self, item: TwoPartMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
115 let header_len = item.header.len();
116 let body_len = item.data.len();
117
118 let total_len = 24usize
119 .checked_add(header_len)
120 .and_then(|n| n.checked_add(body_len))
121 .ok_or(TwoPartCodecError::MessageTooLarge(
122 usize::MAX,
123 self.max_message_size.unwrap_or(usize::MAX),
124 ))?;
125
126 if let Some(max_size) = self.max_message_size
128 && total_len > max_size
129 {
130 return Err(TwoPartCodecError::MessageTooLarge(total_len, max_size));
131 }
132
133 dst.put_u64(header_len as u64);
134 dst.put_u64(body_len as u64);
135
136 #[cfg(debug_assertions)]
139 {
140 let mut data_to_hash = BytesMut::with_capacity(header_len + body_len);
142 data_to_hash.extend_from_slice(&item.header);
143 data_to_hash.extend_from_slice(&item.data);
144 let checksum = xxh3_64(&data_to_hash);
145
146 dst.put_u64(checksum);
147 }
148 #[cfg(not(debug_assertions))]
149 {
150 dst.put_u64(0);
151 }
152
153 dst.put_slice(&item.header);
155 dst.put_slice(&item.data);
156
157 Ok(())
158 }
159}
160
161pub enum TwoPartMessageType {
162 HeaderOnly(Bytes),
163 DataOnly(Bytes),
164 HeaderAndData(Bytes, Bytes),
165 Empty,
166}
167
168#[derive(Clone, Debug)]
169pub struct TwoPartMessage {
170 pub header: Bytes,
171 pub data: Bytes,
172}
173
174impl TwoPartMessage {
175 pub fn new(header: Bytes, data: Bytes) -> Self {
176 TwoPartMessage { header, data }
177 }
178
179 pub fn from_header(header: Bytes) -> Self {
180 TwoPartMessage {
181 header,
182 data: Bytes::new(),
183 }
184 }
185
186 pub fn from_data(data: Bytes) -> Self {
187 TwoPartMessage {
188 header: Bytes::new(),
189 data,
190 }
191 }
192
193 pub fn from_parts(header: Bytes, data: Bytes) -> Self {
194 TwoPartMessage { header, data }
195 }
196
197 pub fn parts(&self) -> (&Bytes, &Bytes) {
198 (&self.header, &self.data)
199 }
200
201 pub fn optional_parts(&self) -> (Option<&Bytes>, Option<&Bytes>) {
202 (self.header(), self.data())
203 }
204
205 pub fn into_parts(self) -> (Bytes, Bytes) {
206 (self.header, self.data)
207 }
208
209 pub fn header(&self) -> Option<&Bytes> {
210 if self.header.is_empty() {
211 None
212 } else {
213 Some(&self.header)
214 }
215 }
216
217 pub fn data(&self) -> Option<&Bytes> {
218 if self.data.is_empty() {
219 None
220 } else {
221 Some(&self.data)
222 }
223 }
224
225 pub fn into_message_type(self) -> TwoPartMessageType {
226 if self.header.is_empty() && self.data.is_empty() {
227 TwoPartMessageType::Empty
228 } else if self.header.is_empty() {
229 TwoPartMessageType::DataOnly(self.data)
230 } else if self.data.is_empty() {
231 TwoPartMessageType::HeaderOnly(self.header)
232 } else {
233 TwoPartMessageType::HeaderAndData(self.header, self.data)
234 }
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use std::io::Cursor;
241 use std::pin::Pin;
242 use std::task::{Context, Poll};
243
244 use bytes::{Bytes, BytesMut};
245 use futures::StreamExt;
246 use tokio::io::AsyncRead;
247 use tokio::io::ReadBuf;
248 use tokio_util::codec::{Decoder, FramedRead};
249
250 use super::*;
251
252 #[test]
254 fn test_message_with_header_and_data() {
255 let header_data = Bytes::from("header data");
257 let data = Bytes::from("body data");
258 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
259
260 let codec = TwoPartCodec::new(None);
261
262 let encoded = codec.encode_message(message).unwrap();
264
265 let decoded = codec.decode_message(encoded).unwrap();
267
268 assert_eq!(decoded.header, header_data);
270 assert_eq!(decoded.data, data);
271 }
272
273 #[test]
275 fn test_message_with_only_header() {
276 let header_data = Bytes::from("header only");
277 let message = TwoPartMessage::from_header(header_data.clone());
278
279 let codec = TwoPartCodec::new(None);
280
281 let encoded = codec.encode_message(message).unwrap();
283
284 let decoded = codec.decode_message(encoded).unwrap();
286
287 assert_eq!(decoded.header, header_data);
289 assert!(decoded.data.is_empty());
290 }
291
292 #[test]
294 fn test_message_with_only_data() {
295 let data = Bytes::from("data only");
296 let message = TwoPartMessage::from_data(data.clone());
297
298 let codec = TwoPartCodec::new(None);
299
300 let encoded = codec.encode_message(message).unwrap();
302
303 let decoded = codec.decode_message(encoded).unwrap();
305
306 assert!(decoded.header.is_empty());
308 assert_eq!(decoded.data, data);
309 }
310
311 #[test]
313 fn test_empty_message() {
314 let message = TwoPartMessage::from_parts(Bytes::new(), Bytes::new());
315
316 let codec = TwoPartCodec::new(None);
317
318 let encoded = codec.encode_message(message).unwrap();
320
321 let decoded = codec.decode_message(encoded).unwrap();
323
324 assert!(decoded.header.is_empty());
326 assert!(decoded.data.is_empty());
327 }
328
329 #[test]
331 fn test_message_under_max_size() {
332 let max_size = 1024; let header_data = Bytes::from(vec![b'h'; 100]);
336 let body_data = Bytes::from(vec![b'd'; 200]);
337 let message = TwoPartMessage::from_parts(header_data.clone(), body_data.clone());
338
339 let codec = TwoPartCodec::new(Some(max_size));
340
341 let encoded = codec.encode_message(message.clone()).unwrap();
343
344 let decoded = codec.decode_message(encoded).unwrap();
346
347 assert_eq!(decoded.header, header_data);
349 assert_eq!(decoded.data, body_data);
350 }
351
352 #[test]
354 fn test_message_exactly_at_max_size() {
355 let max_size = 1024; let lengths_size = 24; let data_size = max_size - lengths_size; let header_size = data_size / 2;
363 let body_size = data_size - header_size;
364
365 let header_data = Bytes::from(vec![b'h'; header_size]);
367 let body_data = Bytes::from(vec![b'd'; body_size]);
368
369 let message = TwoPartMessage::from_parts(header_data.clone(), body_data.clone());
370
371 let codec = TwoPartCodec::new(Some(max_size));
372
373 let encoded = codec.encode_message(message.clone()).unwrap();
375
376 assert_eq!(encoded.len(), max_size);
378
379 let decoded = codec.decode_message(encoded).unwrap();
381
382 assert_eq!(decoded.header, header_data);
384 assert_eq!(decoded.data, body_data);
385 }
386
387 #[test]
389 fn test_message_over_max_size() {
390 let max_size = 1024; let data_size = max_size - 24 + 1; let header_size = data_size / 2;
395 let body_size = data_size - header_size;
396
397 let header_data = Bytes::from(vec![b'h'; header_size]);
398 let body_data = Bytes::from(vec![b'd'; body_size]);
399
400 let message = TwoPartMessage::from_parts(header_data, body_data);
401
402 let codec = TwoPartCodec::new(Some(max_size));
403
404 let result = codec.encode_message(message);
406
407 assert!(result.is_err());
409
410 if let Err(TwoPartCodecError::MessageTooLarge(size, max)) = result {
412 assert_eq!(size, data_size + 24); assert_eq!(max, max_size);
414 } else {
415 panic!("Expected MessageTooLarge error");
416 }
417 }
418
419 #[test]
421 fn test_decoding_message_over_max_size() {
422 let max_size = 1024; let data_size = max_size - 24 + 1; let header_size = data_size / 2;
427 let body_size = data_size - header_size;
428
429 let header_data = Bytes::from(vec![b'h'; header_size]);
430 let body_data = Bytes::from(vec![b'd'; body_size]);
431
432 let message = TwoPartMessage::from_parts(header_data.clone(), body_data.clone());
433
434 let codec = TwoPartCodec::new(None); let encoded = codec.encode_message(message).unwrap();
438
439 let codec_with_limit = TwoPartCodec::new(Some(max_size));
440
441 let result = codec_with_limit.decode_message(encoded);
443
444 assert!(result.is_err());
446
447 if let Err(TwoPartCodecError::MessageTooLarge(size, max)) = result {
449 assert_eq!(size, data_size + 24); assert_eq!(max, max_size);
451 } else {
452 panic!("Expected MessageTooLarge error");
453 }
454 }
455
456 #[test]
458 #[cfg(debug_assertions)]
460 fn test_checksum_mismatch() {
461 let header_data = Bytes::from("header data");
463 let data = Bytes::from("body data");
464 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
465
466 let codec = TwoPartCodec::new(None);
467
468 let encoded = codec.encode_message(message).unwrap();
470
471 let mut encoded = BytesMut::from(encoded);
473 let len = encoded.len();
474 encoded[len - 1] ^= 0xFF; let result = codec.decode_message(encoded.into());
478
479 assert!(result.is_err());
481
482 if let Err(TwoPartCodecError::ChecksumMismatch) = result {
484 } else {
486 panic!("Expected ChecksumMismatch error");
487 }
488 }
489
490 #[test]
492 fn test_partial_data() {
493 let header_data = Bytes::from("header data");
494 let data = Bytes::from("body data");
495 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
496
497 let codec = TwoPartCodec::new(None);
498
499 let encoded = codec.encode_message(message).unwrap();
501
502 let partial_len = encoded.len() - 5;
504 let partial_encoded = encoded.slice(0..partial_len);
505
506 let result = codec.decode_message(partial_encoded);
508
509 assert!(result.is_err());
511
512 if let Err(TwoPartCodecError::InvalidMessage(_)) = result {
513 } else {
515 panic!("Expected InvalidMessage error");
516 }
517 }
518
519 #[test]
521 fn test_multiple_messages_in_buffer() {
522 let header_data1 = Bytes::from("header1");
523 let data1 = Bytes::from("data1");
524 let message1 = TwoPartMessage::from_parts(header_data1.clone(), data1.clone());
525
526 let header_data2 = Bytes::from("header2");
527 let data2 = Bytes::from("data2");
528 let message2 = TwoPartMessage::from_parts(header_data2.clone(), data2.clone());
529
530 let codec = TwoPartCodec::new(None);
531
532 let encoded1 = codec.encode_message(message1).unwrap();
534 let encoded2 = codec.encode_message(message2).unwrap();
535
536 let mut combined = BytesMut::new();
538 combined.extend_from_slice(&encoded1);
539 combined.extend_from_slice(&encoded2);
540
541 let mut decode_buf = combined;
543 let mut codec = codec.clone();
544
545 let decoded_msg1 = codec.decode(&mut decode_buf).unwrap().unwrap();
546 let decoded_msg2 = codec.decode(&mut decode_buf).unwrap().unwrap();
547
548 assert_eq!(decoded_msg1.header, header_data1);
550 assert_eq!(decoded_msg1.data, data1);
551
552 assert_eq!(decoded_msg2.header, header_data2);
553 assert_eq!(decoded_msg2.data, data2);
554 }
555
556 #[tokio::test]
558 async fn test_streaming_read() {
559 let header_data = Bytes::from("header data");
561 let data = Bytes::from("body data");
562 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
563
564 let codec = TwoPartCodec::new(None);
565
566 let encoded = codec.encode_message(message.clone()).unwrap();
568
569 let reader = Cursor::new(encoded.clone());
572
573 let mut framed_read = FramedRead::new(reader, codec.clone());
575
576 if let Some(Ok(decoded_message)) = framed_read.next().await {
578 assert_eq!(decoded_message.header, header_data);
580 assert_eq!(decoded_message.data, data);
581 } else {
582 panic!("Failed to decode message from stream");
583 }
584 }
585
586 #[tokio::test]
588 async fn test_streaming_partial_reads() {
589 let header_data = Bytes::from("header data");
591 let data = Bytes::from("body data");
592 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
593
594 let codec = TwoPartCodec::new(None);
595
596 let encoded = codec.encode_message(message.clone()).unwrap();
598
599 struct ChunkedReader {
602 data: Bytes,
603 pos: usize,
604 chunk_size: usize,
605 }
606
607 impl AsyncRead for ChunkedReader {
608 fn poll_read(
609 mut self: Pin<&mut Self>,
610 _cx: &mut Context<'_>,
611 buf: &mut ReadBuf<'_>,
612 ) -> Poll<std::io::Result<()>> {
613 if self.pos >= self.data.len() {
614 return Poll::Ready(Ok(()));
615 }
616
617 let end = std::cmp::min(self.pos + self.chunk_size, self.data.len());
618 let bytes_to_read = &self.data[self.pos..end];
619 buf.put_slice(bytes_to_read);
620 self.pos = end;
621
622 Poll::Ready(Ok(()))
628 }
629 }
630
631 let reader = ChunkedReader {
632 data: encoded.clone(),
633 pos: 0,
634 chunk_size: 5, };
636
637 let mut framed_read = FramedRead::new(reader, codec.clone());
638
639 if let Some(Ok(decoded_message)) = framed_read.next().await {
641 assert_eq!(decoded_message.header, header_data);
643 assert_eq!(decoded_message.data, data);
644 } else {
645 panic!("Failed to decode message from stream");
646 }
647 }
648
649 #[tokio::test]
651 #[cfg(debug_assertions)]
653 async fn test_streaming_corrupted_data() {
654 let header_data = Bytes::from("header data");
656 let data = Bytes::from("body data");
657 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
658
659 let codec = TwoPartCodec::new(None);
660
661 let encoded = codec.encode_message(message.clone()).unwrap();
663
664 let mut encoded = BytesMut::from(encoded);
666 encoded[30] ^= 0xFF; let reader = Cursor::new(encoded.clone());
670
671 let mut framed_read = FramedRead::new(reader, codec.clone());
672
673 if let Some(result) = framed_read.next().await {
675 assert!(result.is_err());
676
677 if let Err(TwoPartCodecError::ChecksumMismatch) = result {
679 } else {
681 panic!("Expected ChecksumMismatch error");
682 }
683 } else {
684 panic!("Failed to read message from stream");
685 }
686 }
687
688 #[tokio::test]
690 async fn test_empty_stream() {
691 let codec = TwoPartCodec::new(None);
692
693 let reader = Cursor::new(Vec::new());
695
696 let mut framed_read = FramedRead::new(reader, codec.clone());
697
698 if let Some(result) = framed_read.next().await {
700 panic!("Expected no messages, but got {:?}", result);
701 } else {
702 }
704 }
705
706 #[tokio::test]
708 async fn test_streaming_multiple_messages() {
709 let header_data1 = Bytes::from("header1");
710 let data1 = Bytes::from("data1");
711 let message1 = TwoPartMessage::from_parts(header_data1.clone(), data1.clone());
712
713 let header_data2 = Bytes::from("header2");
714 let data2 = Bytes::from("data2");
715 let message2 = TwoPartMessage::from_parts(header_data2.clone(), data2.clone());
716
717 let codec = TwoPartCodec::new(None);
718
719 let encoded1 = codec.encode_message(message1.clone()).unwrap();
721 let encoded2 = codec.encode_message(message2.clone()).unwrap();
722
723 let mut combined = BytesMut::new();
725 combined.extend_from_slice(&encoded1);
726 combined.extend_from_slice(&encoded2);
727
728 let reader = Cursor::new(combined.freeze());
730
731 let mut framed_read = FramedRead::new(reader, codec.clone());
732
733 if let Some(Ok(decoded_message)) = framed_read.next().await {
735 assert_eq!(decoded_message.header, header_data1);
736 assert_eq!(decoded_message.data, data1);
737 } else {
738 panic!("Failed to decode first message from stream");
739 }
740
741 if let Some(Ok(decoded_message)) = framed_read.next().await {
743 assert_eq!(decoded_message.header, header_data2);
744 assert_eq!(decoded_message.data, data2);
745 } else {
746 panic!("Failed to decode second message from stream");
747 }
748
749 if let Some(result) = framed_read.next().await {
751 panic!("Expected no more messages, but got {:?}", result);
752 }
753 }
754
755 #[test]
757 fn test_message_without_max_size() {
758 let header_data = Bytes::from(vec![b'h'; 1024 * 1024]); let body_data = Bytes::from(vec![b'd'; 1024 * 1024]); let message = TwoPartMessage::from_parts(header_data.clone(), body_data.clone());
763
764 let codec = TwoPartCodec::new(None);
765
766 let encoded = codec.encode_message(message).unwrap();
768
769 let decoded = codec.decode_message(encoded).unwrap();
771
772 assert_eq!(decoded.header, header_data);
774 assert_eq!(decoded.data, body_data);
775 }
776}