1use bytes::Bytes;
12use tokio_util::{
13 bytes::{Buf, BufMut, BytesMut},
14 codec::{Decoder, Encoder},
15};
16
17mod two_part;
18
19pub use two_part::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
20
21#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct TcpRequestMessage {
30 pub endpoint_path: String,
31 pub payload: Bytes,
32}
33
34impl TcpRequestMessage {
35 pub fn new(endpoint_path: String, payload: Bytes) -> Self {
36 Self {
37 endpoint_path,
38 payload,
39 }
40 }
41
42 pub fn encode(&self) -> Result<Bytes, std::io::Error> {
44 let endpoint_bytes = self.endpoint_path.as_bytes();
45 let endpoint_len = endpoint_bytes.len();
46
47 if endpoint_len > u16::MAX as usize {
48 return Err(std::io::Error::new(
49 std::io::ErrorKind::InvalidInput,
50 format!("Endpoint path too long: {} bytes", endpoint_len),
51 ));
52 }
53
54 if self.payload.len() > u32::MAX as usize {
55 return Err(std::io::Error::new(
56 std::io::ErrorKind::InvalidInput,
57 format!("Payload too large: {} bytes", self.payload.len()),
58 ));
59 }
60
61 let mut buf = BytesMut::with_capacity(2 + endpoint_len + 4 + self.payload.len());
63
64 buf.put_u16(endpoint_len as u16);
66
67 buf.put_slice(endpoint_bytes);
69
70 buf.put_u32(self.payload.len() as u32);
72
73 buf.put_slice(&self.payload);
75
76 Ok(buf.freeze())
78 }
79
80 pub fn decode(bytes: &Bytes) -> Result<Self, std::io::Error> {
82 if bytes.len() < 2 {
83 return Err(std::io::Error::new(
84 std::io::ErrorKind::UnexpectedEof,
85 "Not enough bytes for endpoint path length",
86 ));
87 }
88
89 let endpoint_len = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
91 let mut offset = 2;
92
93 if bytes.len() < offset + endpoint_len {
94 return Err(std::io::Error::new(
95 std::io::ErrorKind::UnexpectedEof,
96 "Not enough bytes for endpoint path",
97 ));
98 }
99
100 let endpoint_path = String::from_utf8(bytes[offset..offset + endpoint_len].to_vec())
102 .map_err(|e| {
103 std::io::Error::new(
104 std::io::ErrorKind::InvalidData,
105 format!("Invalid UTF-8: {}", e),
106 )
107 })?;
108 offset += endpoint_len;
109
110 if bytes.len() < offset + 4 {
111 return Err(std::io::Error::new(
112 std::io::ErrorKind::UnexpectedEof,
113 "Not enough bytes for payload length",
114 ));
115 }
116
117 let payload_len = u32::from_be_bytes([
119 bytes[offset],
120 bytes[offset + 1],
121 bytes[offset + 2],
122 bytes[offset + 3],
123 ]) as usize;
124 offset += 4;
125
126 if bytes.len() < offset + payload_len {
127 return Err(std::io::Error::new(
128 std::io::ErrorKind::UnexpectedEof,
129 format!(
130 "Not enough bytes for payload: expected {}, got {}",
131 payload_len,
132 bytes.len() - offset
133 ),
134 ));
135 }
136
137 let payload = bytes.slice(offset..offset + payload_len);
139
140 Ok(Self {
141 endpoint_path,
142 payload,
143 })
144 }
145}
146
147#[derive(Clone, Default)]
150pub struct TcpRequestCodec {
151 max_message_size: Option<usize>,
152}
153
154impl TcpRequestCodec {
155 pub fn new(max_message_size: Option<usize>) -> Self {
156 Self { max_message_size }
157 }
158}
159
160impl Decoder for TcpRequestCodec {
161 type Item = TcpRequestMessage;
162 type Error = std::io::Error;
163
164 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
165 if src.len() < 2 {
167 return Ok(None);
168 }
169
170 let endpoint_len = u16::from_be_bytes([src[0], src[1]]) as usize;
172 let header_size = 2 + endpoint_len + 4; if src.len() < header_size {
175 return Ok(None);
176 }
177
178 let payload_len_offset = 2 + endpoint_len;
180 let payload_len = u32::from_be_bytes([
181 src[payload_len_offset],
182 src[payload_len_offset + 1],
183 src[payload_len_offset + 2],
184 src[payload_len_offset + 3],
185 ]) as usize;
186
187 let total_len = header_size + payload_len;
188
189 if let Some(max_size) = self.max_message_size
191 && total_len > max_size
192 {
193 return Err(std::io::Error::new(
194 std::io::ErrorKind::InvalidData,
195 format!(
196 "Request too large: {} bytes (max: {} bytes)",
197 total_len, max_size
198 ),
199 ));
200 }
201
202 if src.len() < total_len {
204 return Ok(None);
205 }
206
207 src.advance(2);
209
210 let endpoint_bytes = src.split_to(endpoint_len);
212 let endpoint_path = String::from_utf8(endpoint_bytes.to_vec()).map_err(|e| {
213 std::io::Error::new(
214 std::io::ErrorKind::InvalidData,
215 format!("Invalid UTF-8 in endpoint path: {}", e),
216 )
217 })?;
218
219 src.advance(4);
221
222 let payload = src.split_to(payload_len).freeze();
224
225 Ok(Some(TcpRequestMessage {
226 endpoint_path,
227 payload,
228 }))
229 }
230}
231
232impl Encoder<TcpRequestMessage> for TcpRequestCodec {
233 type Error = std::io::Error;
234
235 fn encode(&mut self, item: TcpRequestMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
236 let endpoint_bytes = item.endpoint_path.as_bytes();
237 let endpoint_len = endpoint_bytes.len();
238
239 if endpoint_len > u16::MAX as usize {
240 return Err(std::io::Error::new(
241 std::io::ErrorKind::InvalidInput,
242 format!("Endpoint path too long: {} bytes", endpoint_len),
243 ));
244 }
245
246 if item.payload.len() > u32::MAX as usize {
247 return Err(std::io::Error::new(
248 std::io::ErrorKind::InvalidInput,
249 format!("Payload too large: {} bytes", item.payload.len()),
250 ));
251 }
252
253 let total_len = 2 + endpoint_len + 4 + item.payload.len();
254
255 if let Some(max_size) = self.max_message_size
257 && total_len > max_size
258 {
259 return Err(std::io::Error::new(
260 std::io::ErrorKind::InvalidInput,
261 format!(
262 "Request too large: {} bytes (max: {} bytes)",
263 total_len, max_size
264 ),
265 ));
266 }
267
268 dst.reserve(total_len);
270
271 dst.put_u16(endpoint_len as u16);
273
274 dst.put_slice(endpoint_bytes);
276
277 dst.put_u32(item.payload.len() as u32);
279
280 dst.put_slice(&item.payload);
282
283 Ok(())
284 }
285}
286
287#[derive(Debug, Clone, PartialEq, Eq)]
293pub struct TcpResponseMessage {
294 pub data: Bytes,
295}
296
297impl TcpResponseMessage {
298 pub fn new(data: Bytes) -> Self {
299 Self { data }
300 }
301
302 pub fn empty() -> Self {
303 Self { data: Bytes::new() }
304 }
305
306 pub fn encode(&self) -> Result<Bytes, std::io::Error> {
308 if self.data.len() > u32::MAX as usize {
309 return Err(std::io::Error::new(
310 std::io::ErrorKind::InvalidInput,
311 format!("Response too large: {} bytes", self.data.len()),
312 ));
313 }
314
315 let mut buf = BytesMut::with_capacity(4 + self.data.len());
317
318 buf.put_u32(self.data.len() as u32);
320
321 buf.put_slice(&self.data);
323
324 Ok(buf.freeze())
326 }
327
328 pub fn decode(bytes: &Bytes) -> Result<Self, std::io::Error> {
330 if bytes.len() < 4 {
331 return Err(std::io::Error::new(
332 std::io::ErrorKind::UnexpectedEof,
333 "Not enough bytes for response length",
334 ));
335 }
336
337 let len = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
339
340 if bytes.len() < 4 + len {
341 return Err(std::io::Error::new(
342 std::io::ErrorKind::UnexpectedEof,
343 format!(
344 "Not enough bytes for response: expected {}, got {}",
345 len,
346 bytes.len() - 4
347 ),
348 ));
349 }
350
351 let data = bytes.slice(4..4 + len);
353
354 Ok(Self { data })
355 }
356}
357
358#[derive(Clone, Default)]
361pub struct TcpResponseCodec {
362 max_message_size: Option<usize>,
363}
364
365impl TcpResponseCodec {
366 pub fn new(max_message_size: Option<usize>) -> Self {
367 Self { max_message_size }
368 }
369}
370
371impl Decoder for TcpResponseCodec {
372 type Item = TcpResponseMessage;
373 type Error = std::io::Error;
374
375 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
376 if src.len() < 4 {
378 return Ok(None);
379 }
380
381 let data_len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
383 let total_len = 4 + data_len;
384
385 if let Some(max_size) = self.max_message_size
387 && total_len > max_size
388 {
389 return Err(std::io::Error::new(
390 std::io::ErrorKind::InvalidData,
391 format!(
392 "Response too large: {} bytes (max: {} bytes)",
393 total_len, max_size
394 ),
395 ));
396 }
397
398 if src.len() < total_len {
400 return Ok(None);
401 }
402
403 src.advance(4);
405
406 let data = src.split_to(data_len).freeze();
408
409 Ok(Some(TcpResponseMessage { data }))
410 }
411}
412
413impl Encoder<TcpResponseMessage> for TcpResponseCodec {
414 type Error = std::io::Error;
415
416 fn encode(&mut self, item: TcpResponseMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
417 if item.data.len() > u32::MAX as usize {
418 return Err(std::io::Error::new(
419 std::io::ErrorKind::InvalidInput,
420 format!("Response too large: {} bytes", item.data.len()),
421 ));
422 }
423
424 let total_len = 4 + item.data.len();
425
426 if let Some(max_size) = self.max_message_size
428 && total_len > max_size
429 {
430 return Err(std::io::Error::new(
431 std::io::ErrorKind::InvalidInput,
432 format!(
433 "Response too large: {} bytes (max: {} bytes)",
434 total_len, max_size
435 ),
436 ));
437 }
438
439 dst.reserve(total_len);
441
442 dst.put_u32(item.data.len() as u32);
444
445 dst.put_slice(&item.data);
447
448 Ok(())
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455
456 #[test]
457 fn test_tcp_request_encode_decode() {
458 let msg = TcpRequestMessage::new(
459 "test.endpoint".to_string(),
460 Bytes::from(vec![1, 2, 3, 4, 5]),
461 );
462
463 let encoded = msg.encode().unwrap();
464 let decoded = TcpRequestMessage::decode(&encoded).unwrap();
465
466 assert_eq!(decoded, msg);
467 }
468
469 #[test]
470 fn test_tcp_request_empty_payload() {
471 let msg = TcpRequestMessage::new("test".to_string(), Bytes::new());
472
473 let encoded = msg.encode().unwrap();
474 let decoded = TcpRequestMessage::decode(&encoded).unwrap();
475
476 assert_eq!(decoded, msg);
477 }
478
479 #[test]
480 fn test_tcp_request_large_payload() {
481 let payload = Bytes::from(vec![42u8; 1024 * 1024]); let msg = TcpRequestMessage::new("large".to_string(), payload);
483
484 let encoded = msg.encode().unwrap();
485 let decoded = TcpRequestMessage::decode(&encoded).unwrap();
486
487 assert_eq!(decoded, msg);
488 }
489
490 #[test]
491 fn test_tcp_request_decode_truncated() {
492 let msg = TcpRequestMessage::new("test".to_string(), Bytes::from(vec![1, 2, 3, 4, 5]));
493 let encoded = msg.encode().unwrap();
494
495 let truncated = encoded.slice(..encoded.len() - 2);
497 let result = TcpRequestMessage::decode(&truncated);
498
499 assert!(result.is_err());
500 }
501
502 #[test]
503 fn test_tcp_response_encode_decode() {
504 let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
505
506 let encoded = msg.encode().unwrap();
507 let decoded = TcpResponseMessage::decode(&encoded).unwrap();
508
509 assert_eq!(decoded, msg);
510 }
511
512 #[test]
513 fn test_tcp_response_empty() {
514 let msg = TcpResponseMessage::empty();
515
516 let encoded = msg.encode().unwrap();
517 let decoded = TcpResponseMessage::decode(&encoded).unwrap();
518
519 assert_eq!(decoded, msg);
520 assert_eq!(decoded.data.len(), 0);
521 }
522
523 #[test]
524 fn test_tcp_response_decode_truncated() {
525 let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
526 let encoded = msg.encode().unwrap();
527
528 let truncated = encoded.slice(..3);
530 let result = TcpResponseMessage::decode(&truncated);
531
532 assert!(result.is_err());
533 }
534
535 #[test]
536 fn test_tcp_request_unicode_endpoint() {
537 let msg = TcpRequestMessage::new("тест.端点".to_string(), Bytes::from(vec![1, 2, 3]));
538
539 let encoded = msg.encode().unwrap();
540 let decoded = TcpRequestMessage::decode(&encoded).unwrap();
541
542 assert_eq!(decoded, msg);
543 }
544
545 #[test]
546 fn test_tcp_request_codec() {
547 use tokio_util::codec::{Decoder, Encoder};
548
549 let msg = TcpRequestMessage::new(
550 "test.endpoint".to_string(),
551 Bytes::from(vec![1, 2, 3, 4, 5]),
552 );
553
554 let mut codec = TcpRequestCodec::new(None);
555 let mut buf = BytesMut::new();
556
557 codec.encode(msg.clone(), &mut buf).unwrap();
559
560 let decoded = codec.decode(&mut buf).unwrap().unwrap();
562 assert_eq!(decoded, msg);
563 }
564
565 #[test]
566 fn test_tcp_request_codec_partial() {
567 use tokio_util::codec::Decoder;
568
569 let msg = TcpRequestMessage::new(
570 "test.endpoint".to_string(),
571 Bytes::from(vec![1, 2, 3, 4, 5]),
572 );
573
574 let encoded = msg.encode().unwrap();
575 let mut codec = TcpRequestCodec::new(None);
576
577 let mut buf = BytesMut::from(&encoded[..5]);
579 assert!(codec.decode(&mut buf).unwrap().is_none());
580
581 buf.extend_from_slice(&encoded[5..]);
583 let decoded = codec.decode(&mut buf).unwrap().unwrap();
584 assert_eq!(decoded, msg);
585 }
586
587 #[test]
588 fn test_tcp_request_codec_max_size() {
589 use tokio_util::codec::Encoder;
590
591 let msg = TcpRequestMessage::new("test".to_string(), Bytes::from(vec![1, 2, 3, 4, 5]));
592
593 let mut codec = TcpRequestCodec::new(Some(10)); let mut buf = BytesMut::new();
595
596 let result = codec.encode(msg, &mut buf);
597 assert!(result.is_err());
598 }
599
600 #[test]
601 fn test_tcp_response_codec() {
602 use tokio_util::codec::{Decoder, Encoder};
603
604 let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
605
606 let mut codec = TcpResponseCodec::new(None);
607 let mut buf = BytesMut::new();
608
609 codec.encode(msg.clone(), &mut buf).unwrap();
611
612 let decoded = codec.decode(&mut buf).unwrap().unwrap();
614 assert_eq!(decoded, msg);
615 }
616
617 #[test]
618 fn test_tcp_response_codec_partial() {
619 use tokio_util::codec::Decoder;
620
621 let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
622
623 let encoded = msg.encode().unwrap();
624 let mut codec = TcpResponseCodec::new(None);
625
626 let mut buf = BytesMut::from(&encoded[..3]);
628 assert!(codec.decode(&mut buf).unwrap().is_none());
629
630 buf.extend_from_slice(&encoded[3..]);
632 let decoded = codec.decode(&mut buf).unwrap().unwrap();
633 assert_eq!(decoded, msg);
634 }
635
636 #[test]
637 fn test_tcp_response_codec_max_size() {
638 use tokio_util::codec::Encoder;
639
640 let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
641
642 let mut codec = TcpResponseCodec::new(Some(5)); let mut buf = BytesMut::new();
644
645 let result = codec.encode(msg, &mut buf);
646 assert!(result.is_err());
647 }
648
649 #[tokio::test]
651 async fn test_framed_codec_integration() {
652 use futures::{SinkExt, StreamExt};
653 use std::io::Cursor;
654 use tokio_util::codec::{FramedRead, FramedWrite};
655
656 let mut buffer = Vec::new();
658
659 {
661 let cursor = Cursor::new(&mut buffer);
662 let mut writer = FramedWrite::new(cursor, TcpRequestCodec::new(None));
663
664 let msg1 = TcpRequestMessage::new("endpoint1".to_string(), Bytes::from("data1"));
665 let msg2 = TcpRequestMessage::new("endpoint2".to_string(), Bytes::from("data2"));
666
667 writer.send(msg1).await.unwrap();
668 writer.send(msg2).await.unwrap();
669 }
670
671 {
673 let cursor = Cursor::new(&buffer[..]);
674 let mut reader = FramedRead::new(cursor, TcpRequestCodec::new(None));
675
676 let decoded1 = reader.next().await.unwrap().unwrap();
677 assert_eq!(decoded1.endpoint_path, "endpoint1");
678 assert_eq!(decoded1.payload, Bytes::from("data1"));
679
680 let decoded2 = reader.next().await.unwrap().unwrap();
681 assert_eq!(decoded2.endpoint_path, "endpoint2");
682 assert_eq!(decoded2.payload, Bytes::from("data2"));
683 }
684 }
685
686 #[tokio::test]
688 async fn test_framed_codec_partial_messages() {
689 use futures::StreamExt;
690 use std::io::Cursor;
691 use tokio_util::codec::FramedRead;
692
693 let msg = TcpRequestMessage::new("test".to_string(), Bytes::from("hello"));
695 let encoded = msg.encode().unwrap();
696
697 let chunk1 = &encoded[..5];
699 let chunk2 = &encoded[5..];
700
701 let mut full_buffer = Vec::new();
703 full_buffer.extend_from_slice(chunk1);
704
705 {
707 let cursor = Cursor::new(&full_buffer[..]);
708 let _reader = FramedRead::new(cursor, TcpRequestCodec::new(None));
709 }
712
713 full_buffer.extend_from_slice(chunk2);
715
716 {
718 let cursor = Cursor::new(&full_buffer[..]);
719 let mut reader = FramedRead::new(cursor, TcpRequestCodec::new(None));
720
721 let decoded = reader.next().await.unwrap().unwrap();
722 assert_eq!(decoded.endpoint_path, "test");
723 assert_eq!(decoded.payload, Bytes::from("hello"));
724 }
725 }
726}