1use crate::error::WireError;
2use std::sync::Arc;
3use std::time::Duration;
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5
6pub const MAX_PAYLOAD_SIZE: usize = 1_048_576;
7const MAGIC: u16 = 0x5652;
8const HEADER_SIZE: usize = 44;
9
10pub const FLAG_MAC_PRESENT: u16 = 0x0001;
12
13pub const FLAG_COMPRESSED: u16 = 0x0002;
16
17pub const FLAG_RAW_BINARY: u16 = 0x0010;
19
20pub const FLAG_FRAGMENTED: u16 = 0x0004;
23
24pub const COMPRESS_THRESHOLD: usize = 65_536;
26
27pub const FRAG_HEADER_SIZE: usize = 10;
33
34#[derive(Debug, Clone, Copy)]
36pub struct FragmentHeader {
37 pub fragment_id: u16,
40 pub sequence: u16,
42 pub total: u16,
44 pub stream_id: u32,
46}
47
48pub fn parse_frag_header(payload: &[u8]) -> Option<FragmentHeader> {
51 if payload.len() < FRAG_HEADER_SIZE {
52 return None;
53 }
54 Some(FragmentHeader {
55 fragment_id: u16::from_be_bytes([payload[0], payload[1]]),
56 sequence: u16::from_be_bytes([payload[2], payload[3]]),
57 total: u16::from_be_bytes([payload[4], payload[5]]),
58 stream_id: u32::from_be_bytes([payload[6], payload[7], payload[8], payload[9]]),
59 })
60}
61
62const FRAME_READ_TIMEOUT: Duration = Duration::from_secs(10);
67
68#[derive(Debug, Clone)]
69pub struct Frame {
70 pub magic: u16,
71 pub flags: u16,
72 pub length: u32,
73 pub target: [u8; 32],
74 pub crc32: u32,
75 pub payload: Arc<[u8]>,
79 pub mac: Option<[u8; 32]>,
81}
82
83pub fn serialize_header(frame: &Frame) -> [u8; HEADER_SIZE] {
86 let mut header = [0u8; HEADER_SIZE];
87 header[0..2].copy_from_slice(&frame.magic.to_be_bytes());
88 header[2..4].copy_from_slice(&frame.flags.to_be_bytes());
89 header[4..8].copy_from_slice(&frame.length.to_be_bytes());
90 header[8..40].copy_from_slice(&frame.target);
91 header[40..44].copy_from_slice(&frame.crc32.to_be_bytes());
92 header
93}
94
95pub async fn write_frame<W>(
96 stream: &mut W,
97 target: &str,
98 flags: u16,
99 payload: &[u8],
100) -> Result<(), WireError>
101where
102 W: AsyncWrite + Unpin,
103{
104 if payload.len() > MAX_PAYLOAD_SIZE {
105 return Err(WireError::PayloadTooLarge(payload.len()));
106 }
107
108 let mut header = [0u8; HEADER_SIZE];
109 header[0..2].copy_from_slice(&MAGIC.to_be_bytes());
110 header[2..4].copy_from_slice(&flags.to_be_bytes());
111 header[4..8].copy_from_slice(&(payload.len() as u32).to_be_bytes());
112
113 let target_bytes = target.as_bytes();
114 let copy_len = target_bytes.len().min(32);
115 header[8..8 + copy_len].copy_from_slice(&target_bytes[..copy_len]);
116
117 let checksum = crc32fast::hash(payload);
118 header[40..44].copy_from_slice(&checksum.to_be_bytes());
119
120 stream.write_all(&header).await?;
121 stream.write_all(payload).await?;
122 Ok(())
123}
124
125pub async fn write_frame_raw<W>(stream: &mut W, frame: &Frame) -> Result<(), WireError>
126where
127 W: AsyncWrite + Unpin,
128{
129 if frame.payload.len() > MAX_PAYLOAD_SIZE {
132 return Err(WireError::PayloadTooLarge(frame.payload.len()));
133 }
134
135 let (wire_payload, wire_flags): (Arc<[u8]>, u16) = if frame.payload.len() >= COMPRESS_THRESHOLD
138 && frame.flags & FLAG_COMPRESSED == 0
139 && frame.flags & FLAG_RAW_BINARY == 0
140 {
141 match zstd::bulk::compress(&frame.payload, 3) {
142 Ok(c) if c.len() < frame.payload.len() => (Arc::from(c), frame.flags | FLAG_COMPRESSED),
143 _ => (frame.payload.clone(), frame.flags),
146 }
147 } else {
148 (frame.payload.clone(), frame.flags)
149 };
150
151 let wire_crc = crc32fast::hash(&wire_payload);
153 let wire_frame = Frame {
154 magic: frame.magic,
155 flags: wire_flags,
156 length: wire_payload.len() as u32,
157 target: frame.target,
158 crc32: wire_crc,
159 payload: wire_payload,
160 mac: frame.mac,
161 };
162 let header = serialize_header(&wire_frame);
163
164 stream.write_all(&header).await?;
165 stream.write_all(&wire_frame.payload).await?;
166 if let Some(tag) = &wire_frame.mac {
167 stream.write_all(tag).await?;
168 }
169 Ok(())
170}
171
172pub async fn read_frame<R>(stream: &mut R) -> Result<Frame, WireError>
173where
174 R: AsyncRead + Unpin,
175{
176 read_frame_with_timeout(stream, FRAME_READ_TIMEOUT).await
177}
178
179pub async fn read_frame_with_timeout<R>(
180 stream: &mut R,
181 frame_timeout: Duration,
182) -> Result<Frame, WireError>
183where
184 R: AsyncRead + Unpin,
185{
186 let mut first = [0u8; 1];
190 stream.read_exact(&mut first).await?;
191
192 match tokio::time::timeout(frame_timeout, read_frame_body(stream, first[0])).await {
193 Ok(result) => result,
194 Err(_) => Err(WireError::FrameReadTimeout),
195 }
196}
197
198async fn read_frame_body<R>(stream: &mut R, first_byte: u8) -> Result<Frame, WireError>
199where
200 R: AsyncRead + Unpin,
201{
202 let mut header = [0u8; HEADER_SIZE];
203 header[0] = first_byte;
204 stream.read_exact(&mut header[1..]).await?;
205
206 let magic = u16::from_be_bytes([header[0], header[1]]);
207 if magic != MAGIC {
208 return Err(WireError::FrameMagicMismatch);
209 }
210
211 let flags = u16::from_be_bytes([header[2], header[3]]);
212 let length = u32::from_be_bytes([header[4], header[5], header[6], header[7]]);
213
214 if length as usize > MAX_PAYLOAD_SIZE {
215 return Err(WireError::PayloadTooLarge(length as usize));
216 }
217
218 let mut target = [0u8; 32];
219 target.copy_from_slice(&header[8..40]);
220
221 let crc32 = u32::from_be_bytes([header[40], header[41], header[42], header[43]]);
222
223 let mut payload = vec![0u8; length as usize];
224 if length > 0 {
225 stream.read_exact(&mut payload).await?;
226 }
227
228 let computed = crc32fast::hash(&payload);
230 if computed != crc32 {
231 return Err(WireError::FrameCrcMismatch);
232 }
233
234 let (payload, flags, length, crc32) = if flags & FLAG_COMPRESSED != 0 {
240 let decompressed = zstd::bulk::decompress(&payload, MAX_PAYLOAD_SIZE)
241 .map_err(|e| WireError::Internal(format!("decompress frame: {e}")))?;
242 let plain_len = decompressed.len() as u32;
243 let plain_crc = crc32fast::hash(&decompressed);
244 (decompressed, flags & !FLAG_COMPRESSED, plain_len, plain_crc)
245 } else {
246 (payload, flags, length, crc32)
247 };
248
249 let mac = if flags & FLAG_MAC_PRESENT != 0 {
250 let mut tag = [0u8; 32];
251 stream.read_exact(&mut tag).await?;
252 Some(tag)
253 } else {
254 None
255 };
256
257 Ok(Frame {
258 magic,
259 flags,
260 length,
261 target,
262 crc32,
263 payload: payload.into(),
264 mac,
265 })
266}
267
268pub fn target_as_str(frame: &Frame) -> Option<&str> {
271 let end = frame.target.iter().position(|&b| b == 0).unwrap_or(32);
272 std::str::from_utf8(&frame.target[..end]).ok()
273}