1use crate::error::ProtocolError;
16use crate::MAX_PAYLOAD_SIZE;
17use bytes::{Buf, BufMut, Bytes, BytesMut};
18
19pub const MAGIC: [u8; 4] = *b"RCPX";
21
22pub const FRAME_HEADER_SIZE: usize = 18;
24
25#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
27pub struct FrameFlags(u16);
28
29impl FrameFlags {
30 pub const CRC_PRESENT: u16 = 1 << 0;
32 pub const COMPRESSED: u16 = 1 << 1;
34 pub const STREAM: u16 = 1 << 2;
36 pub const END_STREAM: u16 = 1 << 3;
38
39 const VALID_V1_MASK: u16 = 0x000F;
41
42 pub fn new() -> Self {
43 Self(0)
44 }
45
46 pub fn with_crc(mut self) -> Self {
47 self.0 |= Self::CRC_PRESENT;
48 self
49 }
50
51 pub fn with_stream(mut self) -> Self {
52 self.0 |= Self::STREAM;
53 self
54 }
55
56 pub fn with_end_stream(mut self) -> Self {
57 self.0 |= Self::END_STREAM;
58 self
59 }
60
61 pub fn has_crc(&self) -> bool {
62 self.0 & Self::CRC_PRESENT != 0
63 }
64
65 pub fn is_compressed(&self) -> bool {
66 self.0 & Self::COMPRESSED != 0
67 }
68
69 pub fn is_stream(&self) -> bool {
70 self.0 & Self::STREAM != 0
71 }
72
73 pub fn is_end_stream(&self) -> bool {
74 self.0 & Self::END_STREAM != 0
75 }
76
77 pub fn bits(&self) -> u16 {
78 self.0
79 }
80
81 pub fn from_bits(bits: u16) -> Result<Self, ProtocolError> {
82 if bits & !Self::VALID_V1_MASK != 0 {
83 return Err(ProtocolError::InvalidFlags(bits));
84 }
85 Ok(Self(bits))
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct Frame {
92 pub version: u16,
94 pub flags: FrameFlags,
96 pub header_extension: Bytes,
98 pub payload: Bytes,
100}
101
102impl Frame {
103 pub fn new(payload: Bytes) -> Self {
105 Self {
106 version: crate::PROTOCOL_VERSION,
107 flags: FrameFlags::new().with_crc(),
108 header_extension: Bytes::new(),
109 payload,
110 }
111 }
112
113 pub fn from_json<T: serde::Serialize>(value: &T) -> Result<Self, ProtocolError> {
115 let payload = serde_json::to_vec(value)?;
116 Ok(Self::new(Bytes::from(payload)))
117 }
118
119 pub fn encode(&self) -> Result<BytesMut, ProtocolError> {
121 let payload_len = self.payload.len() as u32;
122 if payload_len > MAX_PAYLOAD_SIZE {
123 return Err(ProtocolError::FrameTooLarge {
124 size: payload_len,
125 max: MAX_PAYLOAD_SIZE,
126 });
127 }
128
129 let header_len = self.header_extension.len() as u16;
130 let total_size = FRAME_HEADER_SIZE + header_len as usize + self.payload.len();
131 let mut buf = BytesMut::with_capacity(total_size);
132
133 buf.put_slice(&MAGIC);
135
136 buf.put_u16(self.version);
138
139 buf.put_u16(self.flags.bits());
141
142 buf.put_u16(header_len);
144
145 buf.put_u32(payload_len);
147
148 let crc = if self.flags.has_crc() {
150 crc32c::crc32c(&self.payload)
151 } else {
152 0
153 };
154 buf.put_u32(crc);
155
156 if !self.header_extension.is_empty() {
158 buf.put_slice(&self.header_extension);
159 }
160
161 buf.put_slice(&self.payload);
163
164 Ok(buf)
165 }
166
167 pub fn decode(buf: &mut BytesMut) -> Result<Option<Self>, ProtocolError> {
172 if buf.len() < FRAME_HEADER_SIZE {
173 return Ok(None);
174 }
175
176 let magic: [u8; 4] = buf[0..4].try_into().unwrap();
178 if magic != MAGIC {
179 return Err(ProtocolError::InvalidMagic(magic));
180 }
181
182 let version = u16::from_be_bytes([buf[4], buf[5]]);
183 if version != crate::PROTOCOL_VERSION {
184 return Err(ProtocolError::UnsupportedVersion(version));
185 }
186
187 let flags_bits = u16::from_be_bytes([buf[6], buf[7]]);
188 let flags = FrameFlags::from_bits(flags_bits)?;
189
190 let header_len = u16::from_be_bytes([buf[8], buf[9]]) as usize;
191 let payload_len = u32::from_be_bytes([buf[10], buf[11], buf[12], buf[13]]) as usize;
192
193 if payload_len > MAX_PAYLOAD_SIZE as usize {
194 return Err(ProtocolError::FrameTooLarge {
195 size: payload_len as u32,
196 max: MAX_PAYLOAD_SIZE,
197 });
198 }
199
200 let crc_expected = u32::from_be_bytes([buf[14], buf[15], buf[16], buf[17]]);
201
202 let total_len = FRAME_HEADER_SIZE + header_len + payload_len;
203 if buf.len() < total_len {
204 return Ok(None);
205 }
206
207 buf.advance(FRAME_HEADER_SIZE);
209
210 let header_extension = buf.split_to(header_len).freeze();
212
213 let payload = buf.split_to(payload_len).freeze();
215
216 if flags.has_crc() {
218 let crc_actual = crc32c::crc32c(&payload);
219 if crc_actual != crc_expected {
220 return Err(ProtocolError::CrcMismatch {
221 expected: crc_expected,
222 actual: crc_actual,
223 });
224 }
225 }
226
227 Ok(Some(Self {
228 version,
229 flags,
230 header_extension,
231 payload,
232 }))
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn test_frame_roundtrip() {
242 let payload = Bytes::from(r#"{"type":"request","id":"1","op":"PING","params":{}}"#);
243 let frame = Frame::new(payload.clone());
244
245 let encoded = frame.encode().unwrap();
246 let mut buf = encoded;
247 let decoded = Frame::decode(&mut buf).unwrap().unwrap();
248
249 assert_eq!(decoded.version, crate::PROTOCOL_VERSION);
250 assert!(decoded.flags.has_crc());
251 assert_eq!(decoded.payload, payload);
252 }
253
254 #[test]
255 fn test_crc_validation() {
256 let payload = Bytes::from(r#"{"test":"data"}"#);
257 let frame = Frame::new(payload);
258 let mut encoded = frame.encode().unwrap();
259
260 let len = encoded.len();
262 encoded[len - 1] ^= 0xFF;
263
264 let result = Frame::decode(&mut encoded);
265 assert!(matches!(result, Err(ProtocolError::CrcMismatch { .. })));
266 }
267
268 #[test]
269 fn test_invalid_magic() {
270 let mut buf =
272 BytesMut::from(&b"BADX\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"[..]);
273 let result = Frame::decode(&mut buf);
274 assert!(matches!(result, Err(ProtocolError::InvalidMagic(_))));
275 }
276
277 #[test]
278 fn test_incomplete_frame() {
279 let mut buf = BytesMut::from(&b"RCPX\x00\x01\x00\x01"[..]);
281 let result = Frame::decode(&mut buf);
282 assert!(result.unwrap().is_none());
283 }
284
285 #[test]
286 fn test_unsupported_version() {
287 let mut buf =
289 BytesMut::from(&b"RCPX\x00\x63\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"[..]);
290 let result = Frame::decode(&mut buf);
291 assert!(matches!(result, Err(ProtocolError::UnsupportedVersion(99))));
292 }
293
294 #[test]
295 fn test_frame_flags() {
296 let flags = FrameFlags::new().with_crc().with_stream().with_end_stream();
297
298 assert!(flags.has_crc());
299 assert!(flags.is_stream());
300 assert!(flags.is_end_stream());
301 assert!(!flags.is_compressed());
302 }
303
304 #[test]
305 fn test_invalid_flags() {
306 let result = FrameFlags::from_bits(0x0100);
308 assert!(matches!(result, Err(ProtocolError::InvalidFlags(0x0100))));
309 }
310
311 #[test]
312 fn test_frame_too_large() {
313 let huge_payload = vec![0u8; (MAX_PAYLOAD_SIZE + 1) as usize];
314 let frame = Frame::new(Bytes::from(huge_payload));
315 let result = frame.encode();
316 assert!(matches!(result, Err(ProtocolError::FrameTooLarge { .. })));
317 }
318
319 #[test]
320 fn test_empty_payload() {
321 let payload = Bytes::from(r#"{}"#);
322 let frame = Frame::new(payload.clone());
323
324 let encoded = frame.encode().unwrap();
325 let mut buf = encoded;
326 let decoded = Frame::decode(&mut buf).unwrap().unwrap();
327
328 assert_eq!(decoded.payload, payload);
329 }
330
331 #[test]
332 fn test_frame_from_json() {
333 #[derive(serde::Serialize)]
334 struct TestMsg {
335 value: i32,
336 }
337 let frame = Frame::from_json(&TestMsg { value: 42 }).unwrap();
338 let payload_str = std::str::from_utf8(&frame.payload).unwrap();
339 assert!(payload_str.contains("42"));
340 }
341
342 #[test]
343 fn test_frame_with_header_extension() {
344 let mut frame = Frame::new(Bytes::from(r#"{"test":true}"#));
345 frame.header_extension = Bytes::from(&b"ext_data"[..]);
346
347 let encoded = frame.encode().unwrap();
348 let mut buf = encoded;
349 let decoded = Frame::decode(&mut buf).unwrap().unwrap();
350
351 assert_eq!(decoded.header_extension.as_ref(), b"ext_data");
352 }
353
354 #[test]
355 fn test_frame_without_crc() {
356 let mut frame = Frame::new(Bytes::from(r#"{"test":true}"#));
357 frame.flags = FrameFlags::new(); let encoded = frame.encode().unwrap();
360 let mut buf = encoded;
361 let decoded = Frame::decode(&mut buf).unwrap().unwrap();
362
363 assert!(!decoded.flags.has_crc());
364 }
365
366 #[test]
367 fn test_multiple_frames_in_buffer() {
368 let frame1 = Frame::new(Bytes::from(r#"{"id":"1"}"#));
369 let frame2 = Frame::new(Bytes::from(r#"{"id":"2"}"#));
370
371 let mut buf = BytesMut::new();
372 buf.extend_from_slice(&frame1.encode().unwrap());
373 buf.extend_from_slice(&frame2.encode().unwrap());
374
375 let decoded1 = Frame::decode(&mut buf).unwrap().unwrap();
376 assert!(std::str::from_utf8(&decoded1.payload)
377 .unwrap()
378 .contains("\"1\""));
379
380 let decoded2 = Frame::decode(&mut buf).unwrap().unwrap();
381 assert!(std::str::from_utf8(&decoded2.payload)
382 .unwrap()
383 .contains("\"2\""));
384 }
385}