1use bytes::Bytes;
12use tokio_util::{
13 bytes::{Buf, BufMut, BytesMut},
14 codec::{Decoder, Encoder},
15};
16
17mod two_part;
18pub mod zero_copy_decoder;
19
20pub use two_part::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
21pub use zero_copy_decoder::{TcpRequestMessageZeroCopy, ZeroCopyTcpDecoder};
22
23#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct TcpRequestMessage {
34 pub endpoint_path: String,
35 pub headers: std::collections::HashMap<String, String>,
36 pub payload: Bytes,
37}
38
39impl TcpRequestMessage {
40 pub fn new(endpoint_path: String, payload: Bytes) -> Self {
41 Self {
42 endpoint_path,
43 headers: std::collections::HashMap::new(),
44 payload,
45 }
46 }
47
48 pub fn with_headers(
49 endpoint_path: String,
50 headers: std::collections::HashMap<String, String>,
51 payload: Bytes,
52 ) -> Self {
53 Self {
54 endpoint_path,
55 headers,
56 payload,
57 }
58 }
59
60 pub fn encode(&self) -> Result<Bytes, std::io::Error> {
62 let endpoint_bytes = self.endpoint_path.as_bytes();
63 let endpoint_len = endpoint_bytes.len();
64
65 if endpoint_len > u16::MAX as usize {
66 return Err(std::io::Error::new(
67 std::io::ErrorKind::InvalidInput,
68 format!("Endpoint path too long: {} bytes", endpoint_len),
69 ));
70 }
71
72 let headers_json = serde_json::to_vec(&self.headers).map_err(|e| {
74 std::io::Error::new(
75 std::io::ErrorKind::InvalidInput,
76 format!("Failed to encode headers: {}", e),
77 )
78 })?;
79 let headers_len = headers_json.len();
80
81 if headers_len > u16::MAX as usize {
82 return Err(std::io::Error::new(
83 std::io::ErrorKind::InvalidInput,
84 format!("Headers too large: {} bytes", headers_len),
85 ));
86 }
87
88 if self.payload.len() > u32::MAX as usize {
89 return Err(std::io::Error::new(
90 std::io::ErrorKind::InvalidInput,
91 format!("Payload too large: {} bytes", self.payload.len()),
92 ));
93 }
94
95 let mut buf =
97 BytesMut::with_capacity(2 + endpoint_len + 2 + headers_len + 4 + self.payload.len());
98
99 buf.put_u16(endpoint_len as u16);
101
102 buf.put_slice(endpoint_bytes);
104
105 buf.put_u16(headers_len as u16);
107
108 buf.put_slice(&headers_json);
110
111 buf.put_u32(self.payload.len() as u32);
113
114 buf.put_slice(&self.payload);
116
117 Ok(buf.freeze())
119 }
120
121 pub fn decode(bytes: &Bytes) -> Result<Self, std::io::Error> {
123 if bytes.len() < 2 {
124 return Err(std::io::Error::new(
125 std::io::ErrorKind::UnexpectedEof,
126 "Not enough bytes for endpoint path length",
127 ));
128 }
129
130 let endpoint_len = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
132 let mut offset = 2;
133
134 if bytes.len() < offset + endpoint_len {
135 return Err(std::io::Error::new(
136 std::io::ErrorKind::UnexpectedEof,
137 "Not enough bytes for endpoint path",
138 ));
139 }
140
141 let endpoint_path = String::from_utf8(bytes[offset..offset + endpoint_len].to_vec())
143 .map_err(|e| {
144 std::io::Error::new(
145 std::io::ErrorKind::InvalidData,
146 format!("Invalid UTF-8 in endpoint path: {}", e),
147 )
148 })?;
149 offset += endpoint_len;
150
151 if bytes.len() < offset + 2 {
152 return Err(std::io::Error::new(
153 std::io::ErrorKind::UnexpectedEof,
154 "Not enough bytes for headers length",
155 ));
156 }
157
158 let headers_len = u16::from_be_bytes([bytes[offset], bytes[offset + 1]]) as usize;
160 offset += 2;
161
162 if bytes.len() < offset + headers_len {
163 return Err(std::io::Error::new(
164 std::io::ErrorKind::UnexpectedEof,
165 "Not enough bytes for headers",
166 ));
167 }
168
169 let headers: std::collections::HashMap<String, String> =
171 serde_json::from_slice(&bytes[offset..offset + headers_len]).map_err(|e| {
172 std::io::Error::new(
173 std::io::ErrorKind::InvalidData,
174 format!("Invalid JSON in headers: {}", e),
175 )
176 })?;
177 offset += headers_len;
178
179 if bytes.len() < offset + 4 {
180 return Err(std::io::Error::new(
181 std::io::ErrorKind::UnexpectedEof,
182 "Not enough bytes for payload length",
183 ));
184 }
185
186 let payload_len = u32::from_be_bytes([
188 bytes[offset],
189 bytes[offset + 1],
190 bytes[offset + 2],
191 bytes[offset + 3],
192 ]) as usize;
193 offset += 4;
194
195 if bytes.len() < offset + payload_len {
196 return Err(std::io::Error::new(
197 std::io::ErrorKind::UnexpectedEof,
198 format!(
199 "Not enough bytes for payload: expected {}, got {}",
200 payload_len,
201 bytes.len() - offset
202 ),
203 ));
204 }
205
206 let payload = bytes.slice(offset..offset + payload_len);
208
209 Ok(Self {
210 endpoint_path,
211 headers,
212 payload,
213 })
214 }
215}
216
217#[derive(Clone, Default)]
220pub struct TcpRequestCodec {
221 max_message_size: Option<usize>,
222}
223
224impl TcpRequestCodec {
225 pub fn new(max_message_size: Option<usize>) -> Self {
226 Self { max_message_size }
227 }
228}
229
230impl Decoder for TcpRequestCodec {
231 type Item = TcpRequestMessage;
232 type Error = std::io::Error;
233
234 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
235 if src.len() < 2 {
237 return Ok(None);
238 }
239
240 let endpoint_len = u16::from_be_bytes([src[0], src[1]]) as usize;
242
243 if src.len() < 2 + endpoint_len + 2 {
245 return Ok(None);
246 }
247
248 let headers_len_offset = 2 + endpoint_len;
250 let headers_len =
251 u16::from_be_bytes([src[headers_len_offset], src[headers_len_offset + 1]]) as usize;
252
253 let header_size = 2 + endpoint_len + 2 + headers_len + 4;
255 if src.len() < header_size {
256 return Ok(None);
257 }
258
259 let payload_len_offset = 2 + endpoint_len + 2 + headers_len;
261 let payload_len = u32::from_be_bytes([
262 src[payload_len_offset],
263 src[payload_len_offset + 1],
264 src[payload_len_offset + 2],
265 src[payload_len_offset + 3],
266 ]) as usize;
267
268 let total_len = header_size + payload_len;
269
270 if let Some(max_size) = self.max_message_size
272 && total_len > max_size
273 {
274 return Err(std::io::Error::new(
275 std::io::ErrorKind::InvalidData,
276 format!(
277 "Request too large: {} bytes (max: {} bytes)",
278 total_len, max_size
279 ),
280 ));
281 }
282
283 if src.len() < total_len {
285 return Ok(None);
286 }
287
288 src.advance(2);
290
291 let endpoint_bytes = src.split_to(endpoint_len);
293 let endpoint_path = String::from_utf8(endpoint_bytes.to_vec()).map_err(|e| {
294 std::io::Error::new(
295 std::io::ErrorKind::InvalidData,
296 format!("Invalid UTF-8 in endpoint path: {}", e),
297 )
298 })?;
299
300 src.advance(2);
302
303 let headers_bytes = src.split_to(headers_len);
305 let headers: std::collections::HashMap<String, String> =
306 serde_json::from_slice(&headers_bytes).map_err(|e| {
307 std::io::Error::new(
308 std::io::ErrorKind::InvalidData,
309 format!("Invalid JSON in headers: {}", e),
310 )
311 })?;
312
313 src.advance(4);
315
316 let payload = src.split_to(payload_len).freeze();
318
319 Ok(Some(TcpRequestMessage {
320 endpoint_path,
321 headers,
322 payload,
323 }))
324 }
325}
326
327impl Encoder<TcpRequestMessage> for TcpRequestCodec {
328 type Error = std::io::Error;
329
330 fn encode(&mut self, item: TcpRequestMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
331 let endpoint_bytes = item.endpoint_path.as_bytes();
332 let endpoint_len = endpoint_bytes.len();
333
334 if endpoint_len > u16::MAX as usize {
335 return Err(std::io::Error::new(
336 std::io::ErrorKind::InvalidInput,
337 format!("Endpoint path too long: {} bytes", endpoint_len),
338 ));
339 }
340
341 let headers_json = serde_json::to_vec(&item.headers).map_err(|e| {
343 std::io::Error::new(
344 std::io::ErrorKind::InvalidInput,
345 format!("Failed to encode headers: {}", e),
346 )
347 })?;
348 let headers_len = headers_json.len();
349
350 if headers_len > u16::MAX as usize {
351 return Err(std::io::Error::new(
352 std::io::ErrorKind::InvalidInput,
353 format!("Headers too large: {} bytes", headers_len),
354 ));
355 }
356
357 if item.payload.len() > u32::MAX as usize {
358 return Err(std::io::Error::new(
359 std::io::ErrorKind::InvalidInput,
360 format!("Payload too large: {} bytes", item.payload.len()),
361 ));
362 }
363
364 let total_len = 2 + endpoint_len + 2 + headers_len + 4 + item.payload.len();
365
366 if let Some(max_size) = self.max_message_size
368 && total_len > max_size
369 {
370 return Err(std::io::Error::new(
371 std::io::ErrorKind::InvalidInput,
372 format!(
373 "Request too large: {} bytes (max: {} bytes)",
374 total_len, max_size
375 ),
376 ));
377 }
378
379 dst.reserve(total_len);
381
382 dst.put_u16(endpoint_len as u16);
384
385 dst.put_slice(endpoint_bytes);
387
388 dst.put_u16(headers_len as u16);
390
391 dst.put_slice(&headers_json);
393
394 dst.put_u32(item.payload.len() as u32);
396
397 dst.put_slice(&item.payload);
399
400 Ok(())
401 }
402}
403
404#[derive(Debug, Clone, PartialEq, Eq)]
410pub struct TcpResponseMessage {
411 pub data: Bytes,
412}
413
414impl TcpResponseMessage {
415 pub fn new(data: Bytes) -> Self {
416 Self { data }
417 }
418
419 pub fn empty() -> Self {
420 Self { data: Bytes::new() }
421 }
422
423 pub fn encode(&self) -> Result<Bytes, std::io::Error> {
425 if self.data.len() > u32::MAX as usize {
426 return Err(std::io::Error::new(
427 std::io::ErrorKind::InvalidInput,
428 format!("Response too large: {} bytes", self.data.len()),
429 ));
430 }
431
432 let mut buf = BytesMut::with_capacity(4 + self.data.len());
434
435 buf.put_u32(self.data.len() as u32);
437
438 buf.put_slice(&self.data);
440
441 Ok(buf.freeze())
443 }
444
445 pub fn decode(bytes: &Bytes) -> Result<Self, std::io::Error> {
447 if bytes.len() < 4 {
448 return Err(std::io::Error::new(
449 std::io::ErrorKind::UnexpectedEof,
450 "Not enough bytes for response length",
451 ));
452 }
453
454 let len = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
456
457 if bytes.len() < 4 + len {
458 return Err(std::io::Error::new(
459 std::io::ErrorKind::UnexpectedEof,
460 format!(
461 "Not enough bytes for response: expected {}, got {}",
462 len,
463 bytes.len() - 4
464 ),
465 ));
466 }
467
468 let data = bytes.slice(4..4 + len);
470
471 Ok(Self { data })
472 }
473}
474
475#[derive(Clone, Default)]
478pub struct TcpResponseCodec {
479 max_message_size: Option<usize>,
480}
481
482impl TcpResponseCodec {
483 pub fn new(max_message_size: Option<usize>) -> Self {
484 Self { max_message_size }
485 }
486}
487
488impl Decoder for TcpResponseCodec {
489 type Item = TcpResponseMessage;
490 type Error = std::io::Error;
491
492 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
493 if src.len() < 4 {
495 return Ok(None);
496 }
497
498 let data_len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
500 let total_len = 4 + data_len;
501
502 if let Some(max_size) = self.max_message_size
504 && total_len > max_size
505 {
506 return Err(std::io::Error::new(
507 std::io::ErrorKind::InvalidData,
508 format!(
509 "Response too large: {} bytes (max: {} bytes)",
510 total_len, max_size
511 ),
512 ));
513 }
514
515 if src.len() < total_len {
517 return Ok(None);
518 }
519
520 src.advance(4);
522
523 let data = src.split_to(data_len).freeze();
525
526 Ok(Some(TcpResponseMessage { data }))
527 }
528}
529
530impl Encoder<TcpResponseMessage> for TcpResponseCodec {
531 type Error = std::io::Error;
532
533 fn encode(&mut self, item: TcpResponseMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
534 if item.data.len() > u32::MAX as usize {
535 return Err(std::io::Error::new(
536 std::io::ErrorKind::InvalidInput,
537 format!("Response too large: {} bytes", item.data.len()),
538 ));
539 }
540
541 let total_len = 4 + item.data.len();
542
543 if let Some(max_size) = self.max_message_size
545 && total_len > max_size
546 {
547 return Err(std::io::Error::new(
548 std::io::ErrorKind::InvalidInput,
549 format!(
550 "Response too large: {} bytes (max: {} bytes)",
551 total_len, max_size
552 ),
553 ));
554 }
555
556 dst.reserve(total_len);
558
559 dst.put_u32(item.data.len() as u32);
561
562 dst.put_slice(&item.data);
564
565 Ok(())
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572
573 #[test]
574 fn test_tcp_request_encode_decode() {
575 let msg = TcpRequestMessage::new(
576 "test.endpoint".to_string(),
577 Bytes::from(vec![1, 2, 3, 4, 5]),
578 );
579
580 let encoded = msg.encode().unwrap();
581 let decoded = TcpRequestMessage::decode(&encoded).unwrap();
582
583 assert_eq!(decoded, msg);
584 }
585
586 #[test]
587 fn test_tcp_request_empty_payload() {
588 let msg = TcpRequestMessage::new("test".to_string(), Bytes::new());
589
590 let encoded = msg.encode().unwrap();
591 let decoded = TcpRequestMessage::decode(&encoded).unwrap();
592
593 assert_eq!(decoded, msg);
594 }
595
596 #[test]
597 fn test_tcp_request_large_payload() {
598 let payload = Bytes::from(vec![42u8; 1024 * 1024]); let msg = TcpRequestMessage::new("large".to_string(), payload);
600
601 let encoded = msg.encode().unwrap();
602 let decoded = TcpRequestMessage::decode(&encoded).unwrap();
603
604 assert_eq!(decoded, msg);
605 }
606
607 #[test]
608 fn test_tcp_request_decode_truncated() {
609 let msg = TcpRequestMessage::new("test".to_string(), Bytes::from(vec![1, 2, 3, 4, 5]));
610 let encoded = msg.encode().unwrap();
611
612 let truncated = encoded.slice(..encoded.len() - 2);
614 let result = TcpRequestMessage::decode(&truncated);
615
616 assert!(result.is_err());
617 }
618
619 #[test]
620 fn test_tcp_response_encode_decode() {
621 let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
622
623 let encoded = msg.encode().unwrap();
624 let decoded = TcpResponseMessage::decode(&encoded).unwrap();
625
626 assert_eq!(decoded, msg);
627 }
628
629 #[test]
630 fn test_tcp_response_empty() {
631 let msg = TcpResponseMessage::empty();
632
633 let encoded = msg.encode().unwrap();
634 let decoded = TcpResponseMessage::decode(&encoded).unwrap();
635
636 assert_eq!(decoded, msg);
637 assert_eq!(decoded.data.len(), 0);
638 }
639
640 #[test]
641 fn test_tcp_response_decode_truncated() {
642 let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
643 let encoded = msg.encode().unwrap();
644
645 let truncated = encoded.slice(..3);
647 let result = TcpResponseMessage::decode(&truncated);
648
649 assert!(result.is_err());
650 }
651
652 #[test]
653 fn test_tcp_request_unicode_endpoint() {
654 let msg = TcpRequestMessage::new("тест.端点".to_string(), Bytes::from(vec![1, 2, 3]));
655
656 let encoded = msg.encode().unwrap();
657 let decoded = TcpRequestMessage::decode(&encoded).unwrap();
658
659 assert_eq!(decoded, msg);
660 }
661
662 #[test]
663 fn test_tcp_request_codec() {
664 use tokio_util::codec::{Decoder, Encoder};
665
666 let msg = TcpRequestMessage::new(
667 "test.endpoint".to_string(),
668 Bytes::from(vec![1, 2, 3, 4, 5]),
669 );
670
671 let mut codec = TcpRequestCodec::new(None);
672 let mut buf = BytesMut::new();
673
674 codec.encode(msg.clone(), &mut buf).unwrap();
676
677 let decoded = codec.decode(&mut buf).unwrap().unwrap();
679 assert_eq!(decoded, msg);
680 }
681
682 #[test]
683 fn test_tcp_request_codec_partial() {
684 use tokio_util::codec::Decoder;
685
686 let msg = TcpRequestMessage::new(
687 "test.endpoint".to_string(),
688 Bytes::from(vec![1, 2, 3, 4, 5]),
689 );
690
691 let encoded = msg.encode().unwrap();
692 let mut codec = TcpRequestCodec::new(None);
693
694 let mut buf = BytesMut::from(&encoded[..5]);
696 assert!(codec.decode(&mut buf).unwrap().is_none());
697
698 buf.extend_from_slice(&encoded[5..]);
700 let decoded = codec.decode(&mut buf).unwrap().unwrap();
701 assert_eq!(decoded, msg);
702 }
703
704 #[test]
705 fn test_tcp_request_codec_max_size() {
706 use tokio_util::codec::Encoder;
707
708 let msg = TcpRequestMessage::new("test".to_string(), Bytes::from(vec![1, 2, 3, 4, 5]));
709
710 let mut codec = TcpRequestCodec::new(Some(10)); let mut buf = BytesMut::new();
712
713 let result = codec.encode(msg, &mut buf);
714 assert!(result.is_err());
715 }
716
717 #[test]
718 fn test_tcp_response_codec() {
719 use tokio_util::codec::{Decoder, Encoder};
720
721 let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
722
723 let mut codec = TcpResponseCodec::new(None);
724 let mut buf = BytesMut::new();
725
726 codec.encode(msg.clone(), &mut buf).unwrap();
728
729 let decoded = codec.decode(&mut buf).unwrap().unwrap();
731 assert_eq!(decoded, msg);
732 }
733
734 #[test]
735 fn test_tcp_response_codec_partial() {
736 use tokio_util::codec::Decoder;
737
738 let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
739
740 let encoded = msg.encode().unwrap();
741 let mut codec = TcpResponseCodec::new(None);
742
743 let mut buf = BytesMut::from(&encoded[..3]);
745 assert!(codec.decode(&mut buf).unwrap().is_none());
746
747 buf.extend_from_slice(&encoded[3..]);
749 let decoded = codec.decode(&mut buf).unwrap().unwrap();
750 assert_eq!(decoded, msg);
751 }
752
753 #[test]
754 fn test_tcp_response_codec_max_size() {
755 use tokio_util::codec::Encoder;
756
757 let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
758
759 let mut codec = TcpResponseCodec::new(Some(5)); let mut buf = BytesMut::new();
761
762 let result = codec.encode(msg, &mut buf);
763 assert!(result.is_err());
764 }
765
766 #[tokio::test]
768 async fn test_framed_codec_integration() {
769 use futures::{SinkExt, StreamExt};
770 use std::io::Cursor;
771 use tokio_util::codec::{FramedRead, FramedWrite};
772
773 let mut buffer = Vec::new();
775
776 {
778 let cursor = Cursor::new(&mut buffer);
779 let mut writer = FramedWrite::new(cursor, TcpRequestCodec::new(None));
780
781 let msg1 = TcpRequestMessage::new("endpoint1".to_string(), Bytes::from("data1"));
782 let msg2 = TcpRequestMessage::new("endpoint2".to_string(), Bytes::from("data2"));
783
784 writer.send(msg1).await.unwrap();
785 writer.send(msg2).await.unwrap();
786 }
787
788 {
790 let cursor = Cursor::new(&buffer[..]);
791 let mut reader = FramedRead::new(cursor, TcpRequestCodec::new(None));
792
793 let decoded1 = reader.next().await.unwrap().unwrap();
794 assert_eq!(decoded1.endpoint_path, "endpoint1");
795 assert_eq!(decoded1.payload, Bytes::from("data1"));
796
797 let decoded2 = reader.next().await.unwrap().unwrap();
798 assert_eq!(decoded2.endpoint_path, "endpoint2");
799 assert_eq!(decoded2.payload, Bytes::from("data2"));
800 }
801 }
802
803 #[tokio::test]
805 async fn test_framed_codec_partial_messages() {
806 use futures::StreamExt;
807 use std::io::Cursor;
808 use tokio_util::codec::FramedRead;
809
810 let msg = TcpRequestMessage::new("test".to_string(), Bytes::from("hello"));
812 let encoded = msg.encode().unwrap();
813
814 let chunk1 = &encoded[..5];
816 let chunk2 = &encoded[5..];
817
818 let mut full_buffer = Vec::new();
820 full_buffer.extend_from_slice(chunk1);
821
822 {
824 let cursor = Cursor::new(&full_buffer[..]);
825 let _reader = FramedRead::new(cursor, TcpRequestCodec::new(None));
826 }
829
830 full_buffer.extend_from_slice(chunk2);
832
833 {
835 let cursor = Cursor::new(&full_buffer[..]);
836 let mut reader = FramedRead::new(cursor, TcpRequestCodec::new(None));
837
838 let decoded = reader.next().await.unwrap().unwrap();
839 assert_eq!(decoded.endpoint_path, "test");
840 assert_eq!(decoded.payload, Bytes::from("hello"));
841 }
842 }
843}