1use bytes::Bytes;
12use tokio_util::{
13 bytes::{Buf, BufMut, BytesMut},
14 codec::{Decoder, Encoder},
15};
16
17mod two_part;
18pub mod zero_copy_decoder;
19
20pub use two_part::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
21pub use zero_copy_decoder::{TcpRequestMessageZeroCopy, ZeroCopyTcpDecoder};
22
23const TCP_REQUEST_ENDPOINT_LEN_WIDTH: usize = 2;
24const TCP_REQUEST_HEADERS_LEN_WIDTH: usize = 2;
25const TCP_REQUEST_PAYLOAD_LEN_WIDTH: usize = 4;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28struct TcpRequestWireHeader {
29 endpoint_len: usize,
30 headers_len: usize,
31 payload_len: usize,
32 header_size: usize,
33 total_len: usize,
34}
35
36impl TcpRequestWireHeader {
37 fn endpoint_start(&self) -> usize {
38 TCP_REQUEST_ENDPOINT_LEN_WIDTH
39 }
40
41 fn endpoint_end(&self) -> usize {
42 self.endpoint_start() + self.endpoint_len
43 }
44
45 fn headers_start(&self) -> usize {
46 self.endpoint_end() + TCP_REQUEST_HEADERS_LEN_WIDTH
47 }
48
49 fn headers_end(&self) -> usize {
50 self.headers_start() + self.headers_len
51 }
52
53 fn payload_start(&self) -> usize {
54 self.header_size
55 }
56}
57
58fn tcp_request_header_size(endpoint_len: usize, headers_len: usize) -> usize {
59 TCP_REQUEST_ENDPOINT_LEN_WIDTH
60 + endpoint_len
61 + TCP_REQUEST_HEADERS_LEN_WIDTH
62 + headers_len
63 + TCP_REQUEST_PAYLOAD_LEN_WIDTH
64}
65
66fn tcp_request_total_len(
67 endpoint_len: usize,
68 headers_len: usize,
69 payload_len: usize,
70) -> Result<TcpRequestWireHeader, std::io::Error> {
71 let header_size = tcp_request_header_size(endpoint_len, headers_len);
72 let total_len = header_size.checked_add(payload_len).ok_or_else(|| {
73 std::io::Error::new(
74 std::io::ErrorKind::InvalidData,
75 "TCP request message length overflow",
76 )
77 })?;
78
79 Ok(TcpRequestWireHeader {
80 endpoint_len,
81 headers_len,
82 payload_len,
83 header_size,
84 total_len,
85 })
86}
87
88fn validate_tcp_request_encode_lengths(
89 endpoint_len: usize,
90 headers_len: usize,
91 payload_len: usize,
92) -> Result<TcpRequestWireHeader, std::io::Error> {
93 if endpoint_len > u16::MAX as usize {
94 return Err(std::io::Error::new(
95 std::io::ErrorKind::InvalidInput,
96 format!("Endpoint path too long: {} bytes", endpoint_len),
97 ));
98 }
99
100 if headers_len > u16::MAX as usize {
101 return Err(std::io::Error::new(
102 std::io::ErrorKind::InvalidInput,
103 format!("Headers too large: {} bytes", headers_len),
104 ));
105 }
106
107 if payload_len > u32::MAX as usize {
108 return Err(std::io::Error::new(
109 std::io::ErrorKind::InvalidInput,
110 format!("Payload too large: {} bytes", payload_len),
111 ));
112 }
113
114 tcp_request_total_len(endpoint_len, headers_len, payload_len)
115}
116
117fn tcp_request_endpoint_len(bytes: &[u8]) -> Result<usize, std::io::Error> {
118 if bytes.len() < TCP_REQUEST_ENDPOINT_LEN_WIDTH {
119 return Err(std::io::Error::new(
120 std::io::ErrorKind::UnexpectedEof,
121 "Not enough bytes for endpoint path length",
122 ));
123 }
124
125 Ok(u16::from_be_bytes([bytes[0], bytes[1]]) as usize)
126}
127
128fn tcp_request_headers_len(bytes: &[u8], endpoint_len: usize) -> Result<usize, std::io::Error> {
129 let endpoint_end = TCP_REQUEST_ENDPOINT_LEN_WIDTH + endpoint_len;
130 if bytes.len() < endpoint_end {
131 return Err(std::io::Error::new(
132 std::io::ErrorKind::UnexpectedEof,
133 "Not enough bytes for endpoint path",
134 ));
135 }
136
137 if bytes.len() < endpoint_end + TCP_REQUEST_HEADERS_LEN_WIDTH {
138 return Err(std::io::Error::new(
139 std::io::ErrorKind::UnexpectedEof,
140 "Not enough bytes for headers length",
141 ));
142 }
143
144 Ok(u16::from_be_bytes([bytes[endpoint_end], bytes[endpoint_end + 1]]) as usize)
145}
146
147fn parse_tcp_request_frame_header(bytes: &[u8]) -> Result<TcpRequestWireHeader, std::io::Error> {
148 let endpoint_len = tcp_request_endpoint_len(bytes)?;
149 let headers_len = tcp_request_headers_len(bytes, endpoint_len)?;
150
151 let headers_end =
152 TCP_REQUEST_ENDPOINT_LEN_WIDTH + endpoint_len + TCP_REQUEST_HEADERS_LEN_WIDTH + headers_len;
153 if bytes.len() < headers_end {
154 return Err(std::io::Error::new(
155 std::io::ErrorKind::UnexpectedEof,
156 "Not enough bytes for headers",
157 ));
158 }
159
160 if bytes.len() < headers_end + TCP_REQUEST_PAYLOAD_LEN_WIDTH {
161 return Err(std::io::Error::new(
162 std::io::ErrorKind::UnexpectedEof,
163 "Not enough bytes for payload length",
164 ));
165 }
166
167 let payload_len = u32::from_be_bytes([
168 bytes[headers_end],
169 bytes[headers_end + 1],
170 bytes[headers_end + 2],
171 bytes[headers_end + 3],
172 ]) as usize;
173
174 tcp_request_total_len(endpoint_len, headers_len, payload_len)
175}
176
177fn parse_tcp_request_frame(bytes: &[u8]) -> Result<TcpRequestWireHeader, std::io::Error> {
178 let parsed = parse_tcp_request_frame_header(bytes)?;
179 if bytes.len() < parsed.total_len {
180 return Err(std::io::Error::new(
181 std::io::ErrorKind::UnexpectedEof,
182 format!(
183 "Not enough bytes for payload: expected {}, got {}",
184 parsed.payload_len,
185 bytes.len().saturating_sub(parsed.payload_start())
186 ),
187 ));
188 }
189
190 Ok(parsed)
191}
192
193fn check_tcp_request_max_message_size(
194 total_len: usize,
195 max_message_size: usize,
196) -> Result<(), std::io::Error> {
197 if total_len > max_message_size {
198 return Err(std::io::Error::new(
199 std::io::ErrorKind::InvalidData,
200 format!(
201 "message too large: {} bytes (max: {} bytes)",
202 total_len, max_message_size
203 ),
204 ));
205 }
206
207 Ok(())
208}
209
210#[derive(Debug, Clone, PartialEq, Eq)]
220pub struct TcpRequestMessage {
221 pub endpoint_path: String,
222 pub headers: std::collections::HashMap<String, String>,
223 pub payload: Bytes,
224}
225
226impl TcpRequestMessage {
227 pub fn new(endpoint_path: String, payload: Bytes) -> Self {
228 Self {
229 endpoint_path,
230 headers: std::collections::HashMap::new(),
231 payload,
232 }
233 }
234
235 pub fn with_headers(
236 endpoint_path: String,
237 headers: std::collections::HashMap<String, String>,
238 payload: Bytes,
239 ) -> Self {
240 Self {
241 endpoint_path,
242 headers,
243 payload,
244 }
245 }
246
247 pub fn encode(&self) -> Result<Bytes, std::io::Error> {
249 let endpoint_bytes = self.endpoint_path.as_bytes();
250 let endpoint_len = endpoint_bytes.len();
251
252 let headers_json = serde_json::to_vec(&self.headers).map_err(|e| {
254 std::io::Error::new(
255 std::io::ErrorKind::InvalidInput,
256 format!("Failed to encode headers: {}", e),
257 )
258 })?;
259 let headers_len = headers_json.len();
260
261 let parsed =
262 validate_tcp_request_encode_lengths(endpoint_len, headers_len, self.payload.len())?;
263
264 let mut buf = BytesMut::with_capacity(parsed.total_len);
266
267 buf.put_u16(endpoint_len as u16);
269
270 buf.put_slice(endpoint_bytes);
272
273 buf.put_u16(headers_len as u16);
275
276 buf.put_slice(&headers_json);
278
279 buf.put_u32(self.payload.len() as u32);
281
282 buf.put_slice(&self.payload);
284
285 Ok(buf.freeze())
287 }
288
289 pub fn decode(bytes: &Bytes) -> Result<Self, std::io::Error> {
291 let parsed = parse_tcp_request_frame(bytes)?;
292
293 let endpoint_path =
295 String::from_utf8(bytes[parsed.endpoint_start()..parsed.endpoint_end()].to_vec())
296 .map_err(|e| {
297 std::io::Error::new(
298 std::io::ErrorKind::InvalidData,
299 format!("Invalid UTF-8 in endpoint path: {}", e),
300 )
301 })?;
302
303 let headers: std::collections::HashMap<String, String> = serde_json::from_slice(
305 &bytes[parsed.headers_start()..parsed.headers_end()],
306 )
307 .map_err(|e| {
308 std::io::Error::new(
309 std::io::ErrorKind::InvalidData,
310 format!("Invalid JSON in headers: {}", e),
311 )
312 })?;
313
314 let payload = bytes.slice(parsed.payload_start()..parsed.total_len);
316
317 Ok(Self {
318 endpoint_path,
319 headers,
320 payload,
321 })
322 }
323}
324
325#[derive(Debug, Clone, PartialEq, Eq)]
331pub struct TcpResponseMessage {
332 pub data: Bytes,
333}
334
335impl TcpResponseMessage {
336 pub fn new(data: Bytes) -> Self {
337 Self { data }
338 }
339
340 pub fn empty() -> Self {
341 Self { data: Bytes::new() }
342 }
343
344 pub fn encode(&self) -> Result<Bytes, std::io::Error> {
346 if self.data.len() > u32::MAX as usize {
347 return Err(std::io::Error::new(
348 std::io::ErrorKind::InvalidInput,
349 format!("Response too large: {} bytes", self.data.len()),
350 ));
351 }
352
353 let mut buf = BytesMut::with_capacity(4 + self.data.len());
355
356 buf.put_u32(self.data.len() as u32);
358
359 buf.put_slice(&self.data);
361
362 Ok(buf.freeze())
364 }
365
366 pub fn decode(bytes: &Bytes) -> Result<Self, std::io::Error> {
368 if bytes.len() < 4 {
369 return Err(std::io::Error::new(
370 std::io::ErrorKind::UnexpectedEof,
371 "Not enough bytes for response length",
372 ));
373 }
374
375 let len = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
377
378 if bytes.len() < 4 + len {
379 return Err(std::io::Error::new(
380 std::io::ErrorKind::UnexpectedEof,
381 format!(
382 "Not enough bytes for response: expected {}, got {}",
383 len,
384 bytes.len() - 4
385 ),
386 ));
387 }
388
389 let data = bytes.slice(4..4 + len);
391
392 Ok(Self { data })
393 }
394}
395
396#[derive(Clone, Default)]
399pub struct TcpResponseCodec {
400 max_message_size: Option<usize>,
401}
402
403impl TcpResponseCodec {
404 pub fn new(max_message_size: Option<usize>) -> Self {
405 Self { max_message_size }
406 }
407}
408
409impl Decoder for TcpResponseCodec {
410 type Item = TcpResponseMessage;
411 type Error = std::io::Error;
412
413 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
414 if src.len() < 4 {
416 return Ok(None);
417 }
418
419 let data_len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
421 let total_len = 4 + data_len;
422
423 if let Some(max_size) = self.max_message_size
425 && total_len > max_size
426 {
427 return Err(std::io::Error::new(
428 std::io::ErrorKind::InvalidData,
429 format!(
430 "Response too large: {} bytes (max: {} bytes)",
431 total_len, max_size
432 ),
433 ));
434 }
435
436 if src.len() < total_len {
438 return Ok(None);
439 }
440
441 src.advance(4);
443
444 let data = src.split_to(data_len).freeze();
446
447 Ok(Some(TcpResponseMessage { data }))
448 }
449}
450
451impl Encoder<TcpResponseMessage> for TcpResponseCodec {
452 type Error = std::io::Error;
453
454 fn encode(&mut self, item: TcpResponseMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
455 if item.data.len() > u32::MAX as usize {
456 return Err(std::io::Error::new(
457 std::io::ErrorKind::InvalidInput,
458 format!("Response too large: {} bytes", item.data.len()),
459 ));
460 }
461
462 let total_len = 4 + item.data.len();
463
464 if let Some(max_size) = self.max_message_size
466 && total_len > max_size
467 {
468 return Err(std::io::Error::new(
469 std::io::ErrorKind::InvalidInput,
470 format!(
471 "Response too large: {} bytes (max: {} bytes)",
472 total_len, max_size
473 ),
474 ));
475 }
476
477 dst.reserve(total_len);
479
480 dst.put_u32(item.data.len() as u32);
482
483 dst.put_slice(&item.data);
485
486 Ok(())
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493
494 #[test]
495 fn test_tcp_request_encode_decode() {
496 let msg = TcpRequestMessage::new(
497 "test.endpoint".to_string(),
498 Bytes::from(vec![1, 2, 3, 4, 5]),
499 );
500
501 let encoded = msg.encode().unwrap();
502 let decoded = TcpRequestMessage::decode(&encoded).unwrap();
503
504 assert_eq!(decoded, msg);
505 }
506
507 #[test]
508 fn test_tcp_request_empty_payload() {
509 let msg = TcpRequestMessage::new("test".to_string(), Bytes::new());
510
511 let encoded = msg.encode().unwrap();
512 let decoded = TcpRequestMessage::decode(&encoded).unwrap();
513
514 assert_eq!(decoded, msg);
515 }
516
517 #[test]
518 fn test_tcp_request_large_payload() {
519 let payload = Bytes::from(vec![42u8; 1024 * 1024]); let msg = TcpRequestMessage::new("large".to_string(), payload);
521
522 let encoded = msg.encode().unwrap();
523 let decoded = TcpRequestMessage::decode(&encoded).unwrap();
524
525 assert_eq!(decoded, msg);
526 }
527
528 #[test]
529 fn test_tcp_request_decode_truncated() {
530 let msg = TcpRequestMessage::new("test".to_string(), Bytes::from(vec![1, 2, 3, 4, 5]));
531 let encoded = msg.encode().unwrap();
532
533 let truncated = encoded.slice(..encoded.len() - 2);
535 let result = TcpRequestMessage::decode(&truncated);
536
537 assert!(result.is_err());
538 }
539
540 #[test]
541 fn test_tcp_request_decode_invalid_endpoint_utf8() {
542 let mut encoded = BytesMut::new();
543 encoded.put_u16(2);
544 encoded.put_slice(&[0xff, 0xff]);
545 encoded.put_u16(2);
546 encoded.put_slice(b"{}");
547 encoded.put_u32(0);
548
549 let result = TcpRequestMessage::decode(&encoded.freeze());
550
551 assert!(result.is_err());
552 let err = result.unwrap_err();
553 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
554 assert!(err.to_string().contains("Invalid UTF-8"));
555 }
556
557 #[test]
558 fn test_tcp_request_decode_invalid_headers_json() {
559 let mut encoded = BytesMut::new();
560 encoded.put_u16(4);
561 encoded.put_slice(b"test");
562 encoded.put_u16(1);
563 encoded.put_slice(b"{");
564 encoded.put_u32(0);
565
566 let result = TcpRequestMessage::decode(&encoded.freeze());
567
568 assert!(result.is_err());
569 let err = result.unwrap_err();
570 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
571 assert!(err.to_string().contains("Invalid JSON"));
572 }
573
574 #[test]
575 fn test_tcp_request_empty_endpoint_path() {
576 let msg = TcpRequestMessage::new(String::new(), Bytes::from_static(b"payload"));
577
578 let encoded = msg.encode().unwrap();
579 let decoded = TcpRequestMessage::decode(&encoded).unwrap();
580
581 assert_eq!(decoded, msg);
582 }
583
584 #[test]
585 fn test_tcp_response_encode_decode() {
586 let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
587
588 let encoded = msg.encode().unwrap();
589 let decoded = TcpResponseMessage::decode(&encoded).unwrap();
590
591 assert_eq!(decoded, msg);
592 }
593
594 #[test]
595 fn test_tcp_response_empty() {
596 let msg = TcpResponseMessage::empty();
597
598 let encoded = msg.encode().unwrap();
599 let decoded = TcpResponseMessage::decode(&encoded).unwrap();
600
601 assert_eq!(decoded, msg);
602 assert_eq!(decoded.data.len(), 0);
603 }
604
605 #[test]
606 fn test_tcp_response_decode_truncated() {
607 let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
608 let encoded = msg.encode().unwrap();
609
610 let truncated = encoded.slice(..3);
612 let result = TcpResponseMessage::decode(&truncated);
613
614 assert!(result.is_err());
615 }
616
617 #[test]
618 fn test_tcp_request_unicode_endpoint() {
619 let msg = TcpRequestMessage::new("тест.端点".to_string(), Bytes::from(vec![1, 2, 3]));
620
621 let encoded = msg.encode().unwrap();
622 let decoded = TcpRequestMessage::decode(&encoded).unwrap();
623
624 assert_eq!(decoded, msg);
625 }
626
627 #[test]
628 fn test_tcp_response_codec() {
629 use tokio_util::codec::{Decoder, Encoder};
630
631 let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
632
633 let mut codec = TcpResponseCodec::new(None);
634 let mut buf = BytesMut::new();
635
636 codec.encode(msg.clone(), &mut buf).unwrap();
638
639 let decoded = codec.decode(&mut buf).unwrap().unwrap();
641 assert_eq!(decoded, msg);
642 }
643
644 #[test]
645 fn test_tcp_response_codec_partial() {
646 use tokio_util::codec::Decoder;
647
648 let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
649
650 let encoded = msg.encode().unwrap();
651 let mut codec = TcpResponseCodec::new(None);
652
653 let mut buf = BytesMut::from(&encoded[..3]);
655 assert!(codec.decode(&mut buf).unwrap().is_none());
656
657 buf.extend_from_slice(&encoded[3..]);
659 let decoded = codec.decode(&mut buf).unwrap().unwrap();
660 assert_eq!(decoded, msg);
661 }
662
663 #[test]
664 fn test_tcp_response_codec_max_size() {
665 use tokio_util::codec::Encoder;
666
667 let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
668
669 let mut codec = TcpResponseCodec::new(Some(5)); let mut buf = BytesMut::new();
671
672 let result = codec.encode(msg, &mut buf);
673 assert!(result.is_err());
674 }
675}