1use bytes::{BufMut, BytesMut};
16use std::io;
17use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
18use tracing::trace;
19
20const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum FrameMode {
24 Handshake,
25 Distribution,
26}
27
28impl FrameMode {
29 pub fn length_prefix_size(&self) -> usize {
30 match self {
31 FrameMode::Handshake => 2,
32 FrameMode::Distribution => 4,
33 }
34 }
35}
36
37pub struct MessageFramer {
38 mode: FrameMode,
39}
40
41impl MessageFramer {
42 pub fn new(mode: FrameMode) -> Self {
43 Self { mode }
44 }
45
46 pub fn set_mode(&mut self, mode: FrameMode) {
47 self.mode = mode;
48 }
49
50 pub fn frame_message(&self, data: &[u8]) -> Vec<u8> {
51 let mut buf = match self.mode {
52 FrameMode::Handshake => {
53 let len = data.len() as u16;
54 let mut b = BytesMut::with_capacity(2 + data.len());
55 b.put_u16(len);
56 b
57 }
58 FrameMode::Distribution => {
59 let len = data.len() as u32;
60 let mut b = BytesMut::with_capacity(4 + data.len());
61 b.put_u32(len);
62 b
63 }
64 };
65 buf.put_slice(data);
66 buf.to_vec()
67 }
68
69 pub async fn write_framed<W: AsyncWrite + Unpin>(
70 &self,
71 writer: &mut W,
72 data: &[u8],
73 ) -> io::Result<()> {
74 trace!(
75 "Writing {} bytes in {:?} mode: {:02x?}",
76 data.len(),
77 self.mode,
78 data
79 );
80
81 match self.mode {
82 FrameMode::Handshake => {
83 let len = data.len() as u16;
84 writer.write_u16(len).await?;
85 }
86 FrameMode::Distribution => {
87 let len = data.len() as u32;
88 writer.write_u32(len).await?;
89 }
90 }
91 writer.write_all(data).await?;
92 writer.flush().await?;
93 Ok(())
94 }
95}
96
97pub struct MessageDeframer {
98 mode: FrameMode,
99}
100
101impl MessageDeframer {
102 pub fn new(mode: FrameMode) -> Self {
103 Self { mode }
104 }
105
106 pub fn set_mode(&mut self, mode: FrameMode) {
107 self.mode = mode;
108 }
109
110 pub async fn read_framed<R: AsyncRead + Unpin>(&self, reader: &mut R) -> io::Result<Vec<u8>> {
111 let len = match self.mode {
112 FrameMode::Handshake => {
113 trace!("Reading message length (2 bytes, handshake mode)");
114 let len = reader.read_u16().await?;
115 trace!("Read length: {} bytes", len);
116 len as usize
117 }
118 FrameMode::Distribution => {
119 trace!("Reading message length (4 bytes, distribution mode)");
120 let mut len_bytes = [0u8; 4];
121 reader.read_exact(&mut len_bytes).await?;
122 let len = u32::from_be_bytes(len_bytes);
123 trace!("Read length: {} bytes (raw: {:02x?})", len, len_bytes);
124 len as usize
125 }
126 };
127
128 if len == 0 {
129 trace!("Received 0-byte message (heartbeat/tick)");
130 return Ok(Vec::new());
131 }
132
133 if len > MAX_MESSAGE_SIZE {
134 return Err(io::Error::new(
135 io::ErrorKind::InvalidData,
136 format!(
137 "Message too large: {} bytes (max: {})",
138 len, MAX_MESSAGE_SIZE
139 ),
140 ));
141 }
142
143 let mut buf = vec![0u8; len];
144 trace!("Reading {} bytes of message data", len);
145 reader.read_exact(&mut buf).await?;
146 trace!("Read message data (hex): {:02x?}", buf);
147
148 Ok(buf)
149 }
150}