Skip to main content

rstmdb_protocol/
frame.rs

1//! Binary frame format for RCP.
2//!
3//! Frame layout (18 bytes header + optional header extension + payload):
4//!
5//! ```text
6//! +--------+---------+--------+------------+-------------+--------+
7//! | magic  | version | flags  | header_len | payload_len | crc32c |
8//! | 4 bytes| 2 bytes |2 bytes |  2 bytes   |   4 bytes   | 4 bytes|
9//! +--------+---------+--------+------------+-------------+--------+
10//! | [header_ext] | payload                                        |
11//! | header_len   | payload_len bytes                              |
12//! +--------------+------------------------------------------------+
13//! ```
14
15use crate::error::ProtocolError;
16use crate::MAX_PAYLOAD_SIZE;
17use bytes::{Buf, BufMut, Bytes, BytesMut};
18
19/// Magic bytes identifying RCP frames: "RCPX"
20pub const MAGIC: [u8; 4] = *b"RCPX";
21
22/// Size of the fixed frame header in bytes (4+2+2+2+4+4 = 18).
23pub const FRAME_HEADER_SIZE: usize = 18;
24
25/// Frame flags bitfield.
26#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
27pub struct FrameFlags(u16);
28
29impl FrameFlags {
30    /// CRC32C checksum is present and valid.
31    pub const CRC_PRESENT: u16 = 1 << 0;
32    /// Payload is compressed (reserved for future use).
33    pub const COMPRESSED: u16 = 1 << 1;
34    /// This frame is part of a stream.
35    pub const STREAM: u16 = 1 << 2;
36    /// Final frame of a stream.
37    pub const END_STREAM: u16 = 1 << 3;
38
39    /// Valid flags mask for protocol version 1.
40    const VALID_V1_MASK: u16 = 0x000F;
41
42    pub fn new() -> Self {
43        Self(0)
44    }
45
46    pub fn with_crc(mut self) -> Self {
47        self.0 |= Self::CRC_PRESENT;
48        self
49    }
50
51    pub fn with_stream(mut self) -> Self {
52        self.0 |= Self::STREAM;
53        self
54    }
55
56    pub fn with_end_stream(mut self) -> Self {
57        self.0 |= Self::END_STREAM;
58        self
59    }
60
61    pub fn has_crc(&self) -> bool {
62        self.0 & Self::CRC_PRESENT != 0
63    }
64
65    pub fn is_compressed(&self) -> bool {
66        self.0 & Self::COMPRESSED != 0
67    }
68
69    pub fn is_stream(&self) -> bool {
70        self.0 & Self::STREAM != 0
71    }
72
73    pub fn is_end_stream(&self) -> bool {
74        self.0 & Self::END_STREAM != 0
75    }
76
77    pub fn bits(&self) -> u16 {
78        self.0
79    }
80
81    pub fn from_bits(bits: u16) -> Result<Self, ProtocolError> {
82        if bits & !Self::VALID_V1_MASK != 0 {
83            return Err(ProtocolError::InvalidFlags(bits));
84        }
85        Ok(Self(bits))
86    }
87}
88
89/// A parsed RCP frame.
90#[derive(Debug, Clone)]
91pub struct Frame {
92    /// Protocol version.
93    pub version: u16,
94    /// Frame flags.
95    pub flags: FrameFlags,
96    /// Optional header extension (reserved for future use).
97    pub header_extension: Bytes,
98    /// Frame payload (JSON data).
99    pub payload: Bytes,
100}
101
102impl Frame {
103    /// Creates a new frame with the given payload.
104    pub fn new(payload: Bytes) -> Self {
105        Self {
106            version: crate::PROTOCOL_VERSION,
107            flags: FrameFlags::new().with_crc(),
108            header_extension: Bytes::new(),
109            payload,
110        }
111    }
112
113    /// Creates a new frame from a JSON-serializable value.
114    pub fn from_json<T: serde::Serialize>(value: &T) -> Result<Self, ProtocolError> {
115        let payload = serde_json::to_vec(value)?;
116        Ok(Self::new(Bytes::from(payload)))
117    }
118
119    /// Encodes the frame into bytes.
120    pub fn encode(&self) -> Result<BytesMut, ProtocolError> {
121        let payload_len = self.payload.len() as u32;
122        if payload_len > MAX_PAYLOAD_SIZE {
123            return Err(ProtocolError::FrameTooLarge {
124                size: payload_len,
125                max: MAX_PAYLOAD_SIZE,
126            });
127        }
128
129        let header_len = self.header_extension.len() as u16;
130        let total_size = FRAME_HEADER_SIZE + header_len as usize + self.payload.len();
131        let mut buf = BytesMut::with_capacity(total_size);
132
133        // Magic (4 bytes)
134        buf.put_slice(&MAGIC);
135
136        // Version (2 bytes)
137        buf.put_u16(self.version);
138
139        // Flags (2 bytes)
140        buf.put_u16(self.flags.bits());
141
142        // Header extension length (2 bytes)
143        buf.put_u16(header_len);
144
145        // Payload length (4 bytes)
146        buf.put_u32(payload_len);
147
148        // CRC32C of payload (4 bytes)
149        let crc = if self.flags.has_crc() {
150            crc32c::crc32c(&self.payload)
151        } else {
152            0
153        };
154        buf.put_u32(crc);
155
156        // Header extension (if any)
157        if !self.header_extension.is_empty() {
158            buf.put_slice(&self.header_extension);
159        }
160
161        // Payload
162        buf.put_slice(&self.payload);
163
164        Ok(buf)
165    }
166
167    /// Decodes a frame from bytes.
168    ///
169    /// Returns `Ok(Some(frame))` if a complete frame was decoded,
170    /// `Ok(None)` if more data is needed, or `Err` on protocol errors.
171    pub fn decode(buf: &mut BytesMut) -> Result<Option<Self>, ProtocolError> {
172        if buf.len() < FRAME_HEADER_SIZE {
173            return Ok(None);
174        }
175
176        // Peek at header without consuming
177        let magic: [u8; 4] = buf[0..4].try_into().unwrap();
178        if magic != MAGIC {
179            return Err(ProtocolError::InvalidMagic(magic));
180        }
181
182        let version = u16::from_be_bytes([buf[4], buf[5]]);
183        if version != crate::PROTOCOL_VERSION {
184            return Err(ProtocolError::UnsupportedVersion(version));
185        }
186
187        let flags_bits = u16::from_be_bytes([buf[6], buf[7]]);
188        let flags = FrameFlags::from_bits(flags_bits)?;
189
190        let header_len = u16::from_be_bytes([buf[8], buf[9]]) as usize;
191        let payload_len = u32::from_be_bytes([buf[10], buf[11], buf[12], buf[13]]) as usize;
192
193        if payload_len > MAX_PAYLOAD_SIZE as usize {
194            return Err(ProtocolError::FrameTooLarge {
195                size: payload_len as u32,
196                max: MAX_PAYLOAD_SIZE,
197            });
198        }
199
200        let crc_expected = u32::from_be_bytes([buf[14], buf[15], buf[16], buf[17]]);
201
202        let total_len = FRAME_HEADER_SIZE + header_len + payload_len;
203        if buf.len() < total_len {
204            return Ok(None);
205        }
206
207        // Consume header
208        buf.advance(FRAME_HEADER_SIZE);
209
210        // Read header extension
211        let header_extension = buf.split_to(header_len).freeze();
212
213        // Read payload
214        let payload = buf.split_to(payload_len).freeze();
215
216        // Validate CRC if present
217        if flags.has_crc() {
218            let crc_actual = crc32c::crc32c(&payload);
219            if crc_actual != crc_expected {
220                return Err(ProtocolError::CrcMismatch {
221                    expected: crc_expected,
222                    actual: crc_actual,
223                });
224            }
225        }
226
227        Ok(Some(Self {
228            version,
229            flags,
230            header_extension,
231            payload,
232        }))
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn test_frame_roundtrip() {
242        let payload = Bytes::from(r#"{"type":"request","id":"1","op":"PING","params":{}}"#);
243        let frame = Frame::new(payload.clone());
244
245        let encoded = frame.encode().unwrap();
246        let mut buf = encoded;
247        let decoded = Frame::decode(&mut buf).unwrap().unwrap();
248
249        assert_eq!(decoded.version, crate::PROTOCOL_VERSION);
250        assert!(decoded.flags.has_crc());
251        assert_eq!(decoded.payload, payload);
252    }
253
254    #[test]
255    fn test_crc_validation() {
256        let payload = Bytes::from(r#"{"test":"data"}"#);
257        let frame = Frame::new(payload);
258        let mut encoded = frame.encode().unwrap();
259
260        // Corrupt the payload
261        let len = encoded.len();
262        encoded[len - 1] ^= 0xFF;
263
264        let result = Frame::decode(&mut encoded);
265        assert!(matches!(result, Err(ProtocolError::CrcMismatch { .. })));
266    }
267
268    #[test]
269    fn test_invalid_magic() {
270        // 18 bytes: 4 magic + 2 version + 2 flags + 2 header_len + 4 payload_len + 4 crc
271        let mut buf =
272            BytesMut::from(&b"BADX\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"[..]);
273        let result = Frame::decode(&mut buf);
274        assert!(matches!(result, Err(ProtocolError::InvalidMagic(_))));
275    }
276
277    #[test]
278    fn test_incomplete_frame() {
279        // Only 10 bytes, less than header size
280        let mut buf = BytesMut::from(&b"RCPX\x00\x01\x00\x01"[..]);
281        let result = Frame::decode(&mut buf);
282        assert!(result.unwrap().is_none());
283    }
284
285    #[test]
286    fn test_unsupported_version() {
287        // Valid magic but wrong version (99)
288        let mut buf =
289            BytesMut::from(&b"RCPX\x00\x63\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"[..]);
290        let result = Frame::decode(&mut buf);
291        assert!(matches!(result, Err(ProtocolError::UnsupportedVersion(99))));
292    }
293
294    #[test]
295    fn test_frame_flags() {
296        let flags = FrameFlags::new().with_crc().with_stream().with_end_stream();
297
298        assert!(flags.has_crc());
299        assert!(flags.is_stream());
300        assert!(flags.is_end_stream());
301        assert!(!flags.is_compressed());
302    }
303
304    #[test]
305    fn test_invalid_flags() {
306        // Bit outside valid v1 mask
307        let result = FrameFlags::from_bits(0x0100);
308        assert!(matches!(result, Err(ProtocolError::InvalidFlags(0x0100))));
309    }
310
311    #[test]
312    fn test_frame_too_large() {
313        let huge_payload = vec![0u8; (MAX_PAYLOAD_SIZE + 1) as usize];
314        let frame = Frame::new(Bytes::from(huge_payload));
315        let result = frame.encode();
316        assert!(matches!(result, Err(ProtocolError::FrameTooLarge { .. })));
317    }
318
319    #[test]
320    fn test_empty_payload() {
321        let payload = Bytes::from(r#"{}"#);
322        let frame = Frame::new(payload.clone());
323
324        let encoded = frame.encode().unwrap();
325        let mut buf = encoded;
326        let decoded = Frame::decode(&mut buf).unwrap().unwrap();
327
328        assert_eq!(decoded.payload, payload);
329    }
330
331    #[test]
332    fn test_frame_from_json() {
333        #[derive(serde::Serialize)]
334        struct TestMsg {
335            value: i32,
336        }
337        let frame = Frame::from_json(&TestMsg { value: 42 }).unwrap();
338        let payload_str = std::str::from_utf8(&frame.payload).unwrap();
339        assert!(payload_str.contains("42"));
340    }
341
342    #[test]
343    fn test_frame_with_header_extension() {
344        let mut frame = Frame::new(Bytes::from(r#"{"test":true}"#));
345        frame.header_extension = Bytes::from(&b"ext_data"[..]);
346
347        let encoded = frame.encode().unwrap();
348        let mut buf = encoded;
349        let decoded = Frame::decode(&mut buf).unwrap().unwrap();
350
351        assert_eq!(decoded.header_extension.as_ref(), b"ext_data");
352    }
353
354    #[test]
355    fn test_frame_without_crc() {
356        let mut frame = Frame::new(Bytes::from(r#"{"test":true}"#));
357        frame.flags = FrameFlags::new(); // No CRC
358
359        let encoded = frame.encode().unwrap();
360        let mut buf = encoded;
361        let decoded = Frame::decode(&mut buf).unwrap().unwrap();
362
363        assert!(!decoded.flags.has_crc());
364    }
365
366    #[test]
367    fn test_multiple_frames_in_buffer() {
368        let frame1 = Frame::new(Bytes::from(r#"{"id":"1"}"#));
369        let frame2 = Frame::new(Bytes::from(r#"{"id":"2"}"#));
370
371        let mut buf = BytesMut::new();
372        buf.extend_from_slice(&frame1.encode().unwrap());
373        buf.extend_from_slice(&frame2.encode().unwrap());
374
375        let decoded1 = Frame::decode(&mut buf).unwrap().unwrap();
376        assert!(std::str::from_utf8(&decoded1.payload)
377            .unwrap()
378            .contains("\"1\""));
379
380        let decoded2 = Frame::decode(&mut buf).unwrap().unwrap();
381        assert!(std::str::from_utf8(&decoded2.payload)
382            .unwrap()
383            .contains("\"2\""));
384    }
385}