Skip to main content

microsandbox_protocol/
codec.rs

1//! Length-prefixed frame codec for reading and writing protocol messages.
2//!
3//! Wire format: `[len: u32 BE][id: u32 BE][flags: u8][CBOR(v, t, p)]`
4//!
5//! The correlation ID and flags sit in a fixed-position binary header so that
6//! relay intermediaries can route frames without CBOR parsing.
7
8use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
9
10use crate::{
11    error::{ProtocolError, ProtocolResult},
12    message::{FRAME_HEADER_SIZE, Message},
13};
14
15//--------------------------------------------------------------------------------------------------
16// Constants
17//--------------------------------------------------------------------------------------------------
18
19/// Maximum allowed frame size (4 MiB).
20///
21/// This covers everything after the 4-byte length prefix:
22/// `id (4) + flags (1) + CBOR payload`.
23pub const MAX_FRAME_SIZE: u32 = 4 * 1024 * 1024;
24
25//--------------------------------------------------------------------------------------------------
26// Types
27//--------------------------------------------------------------------------------------------------
28
29/// A frame with the binary header parsed but the CBOR body left untouched.
30///
31/// Used by routers, relays, and FFI consumers that want to handle framing
32/// without paying for CBOR (de)serialization. The [`body`](Self::body) field
33/// contains the exact CBOR-encoded `Message` body bytes — `v`, `t`, `p` —
34/// the same bytes that follow the binary header on the wire.
35#[derive(Debug, Clone)]
36pub struct RawFrame {
37    /// Correlation ID. Same as [`Message::id`].
38    pub id: u32,
39
40    /// Frame flags. Same as [`Message::flags`].
41    pub flags: u8,
42
43    /// Raw CBOR bytes of the message body (`v`, `t`, `p`). Not decoded.
44    pub body: Vec<u8>,
45}
46
47//--------------------------------------------------------------------------------------------------
48// Functions: Raw frame codec (CBOR-blind)
49//--------------------------------------------------------------------------------------------------
50
51/// Encodes a raw frame to a byte buffer using the length-prefixed format.
52///
53/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][body...]`
54pub fn encode_raw_to_buf(frame: &RawFrame, buf: &mut Vec<u8>) -> ProtocolResult<()> {
55    let frame_len = u32::try_from(FRAME_HEADER_SIZE + frame.body.len()).map_err(|_| {
56        ProtocolError::FrameTooLarge {
57            size: u32::MAX,
58            max: MAX_FRAME_SIZE,
59        }
60    })?;
61
62    if frame_len > MAX_FRAME_SIZE {
63        return Err(ProtocolError::FrameTooLarge {
64            size: frame_len,
65            max: MAX_FRAME_SIZE,
66        });
67    }
68
69    buf.extend_from_slice(&frame_len.to_be_bytes());
70    buf.extend_from_slice(&frame.id.to_be_bytes());
71    buf.push(frame.flags);
72    buf.extend_from_slice(&frame.body);
73    Ok(())
74}
75
76/// Tries to decode a complete raw frame from a byte buffer.
77///
78/// Returns `Some(RawFrame)` if a complete frame is available, consuming
79/// the bytes. Returns `None` if more data is needed.
80///
81/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][body...]`
82pub fn try_decode_raw_from_buf(buf: &mut Vec<u8>) -> ProtocolResult<Option<RawFrame>> {
83    if buf.len() < 4 {
84        return Ok(None);
85    }
86
87    let frame_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
88
89    if frame_len > MAX_FRAME_SIZE {
90        return Err(ProtocolError::FrameTooLarge {
91            size: frame_len,
92            max: MAX_FRAME_SIZE,
93        });
94    }
95
96    let frame_len = frame_len as usize;
97    let total = 4 + frame_len;
98
99    if buf.len() < total {
100        return Ok(None);
101    }
102
103    if frame_len < FRAME_HEADER_SIZE {
104        return Err(ProtocolError::FrameTooShort {
105            size: frame_len as u32,
106            min: FRAME_HEADER_SIZE as u32,
107        });
108    }
109
110    let id = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
111    let flags = buf[8];
112    let body = buf[4 + FRAME_HEADER_SIZE..total].to_vec();
113
114    buf.drain(..total);
115    Ok(Some(RawFrame { id, flags, body }))
116}
117
118/// Reads a length-prefixed raw frame from the given reader.
119///
120/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][body...]`
121pub async fn read_raw_frame<R: AsyncRead + Unpin>(reader: &mut R) -> ProtocolResult<RawFrame> {
122    let mut len_buf = [0u8; 4];
123    match reader.read_exact(&mut len_buf).await {
124        Ok(_) => {}
125        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
126            return Err(ProtocolError::UnexpectedEof);
127        }
128        Err(e) => return Err(e.into()),
129    }
130
131    let frame_len = u32::from_be_bytes(len_buf);
132
133    if frame_len > MAX_FRAME_SIZE {
134        return Err(ProtocolError::FrameTooLarge {
135            size: frame_len,
136            max: MAX_FRAME_SIZE,
137        });
138    }
139
140    let frame_len = frame_len as usize;
141
142    if frame_len < FRAME_HEADER_SIZE {
143        return Err(ProtocolError::FrameTooShort {
144            size: frame_len as u32,
145            min: FRAME_HEADER_SIZE as u32,
146        });
147    }
148
149    let mut payload = vec![0u8; frame_len];
150    reader.read_exact(&mut payload).await?;
151
152    let id = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
153    let flags = payload[4];
154    let body = payload[FRAME_HEADER_SIZE..].to_vec();
155
156    Ok(RawFrame { id, flags, body })
157}
158
159/// Writes a length-prefixed raw frame to the given writer.
160///
161/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][body...]`
162pub async fn write_raw_frame<W: AsyncWrite + Unpin>(
163    writer: &mut W,
164    frame: &RawFrame,
165) -> ProtocolResult<()> {
166    let mut buf = Vec::new();
167    encode_raw_to_buf(frame, &mut buf)?;
168    writer.write_all(&buf).await?;
169    writer.flush().await?;
170    Ok(())
171}
172
173//--------------------------------------------------------------------------------------------------
174// Functions: Typed message codec (CBOR-aware)
175//--------------------------------------------------------------------------------------------------
176
177/// Encodes a message to a byte buffer using the length-prefixed frame format.
178///
179/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][CBOR(v, t, p)]`
180pub fn encode_to_buf(msg: &Message, buf: &mut Vec<u8>) -> ProtocolResult<()> {
181    let mut body = Vec::new();
182    ciborium::into_writer(msg, &mut body)?;
183    encode_raw_to_buf(
184        &RawFrame {
185            id: msg.id,
186            flags: msg.flags,
187            body,
188        },
189        buf,
190    )
191}
192
193/// Tries to decode a complete message from a byte buffer.
194///
195/// Returns `Some(Message)` if a complete frame is available, consuming
196/// the bytes. Returns `None` if more data is needed.
197///
198/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][CBOR(v, t, p)]`
199pub fn try_decode_from_buf(buf: &mut Vec<u8>) -> ProtocolResult<Option<Message>> {
200    match try_decode_raw_from_buf(buf)? {
201        Some(frame) => Ok(Some(raw_frame_to_message(frame)?)),
202        None => Ok(None),
203    }
204}
205
206/// Reads a length-prefixed message from the given reader.
207///
208/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][CBOR(v, t, p)]`
209pub async fn read_message<R: AsyncRead + Unpin>(reader: &mut R) -> ProtocolResult<Message> {
210    let frame = read_raw_frame(reader).await?;
211    raw_frame_to_message(frame)
212}
213
214/// Writes a length-prefixed message to the given writer.
215///
216/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][CBOR(v, t, p)]`
217pub async fn write_message<W: AsyncWrite + Unpin>(
218    writer: &mut W,
219    message: &Message,
220) -> ProtocolResult<()> {
221    let mut buf = Vec::new();
222    encode_to_buf(message, &mut buf)?;
223    writer.write_all(&buf).await?;
224    writer.flush().await?;
225    Ok(())
226}
227
228/// Decodes a [`RawFrame`] into a typed [`Message`] by CBOR-deserializing the body.
229pub fn raw_frame_to_message(frame: RawFrame) -> ProtocolResult<Message> {
230    let mut msg: Message = ciborium::from_reader(&frame.body[..])?;
231    msg.id = frame.id;
232    msg.flags = frame.flags;
233    Ok(msg)
234}
235
236//--------------------------------------------------------------------------------------------------
237// Tests
238//--------------------------------------------------------------------------------------------------
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use crate::message::{FLAG_SESSION_START, FLAG_TERMINAL, MessageType, PROTOCOL_VERSION};
244
245    #[tokio::test]
246    async fn test_codec_roundtrip_empty_payload() {
247        let msg = Message::new(MessageType::Ready, 0, Vec::new());
248
249        let mut buf = Vec::new();
250        write_message(&mut buf, &msg).await.unwrap();
251
252        let mut cursor = &buf[..];
253        let decoded = read_message(&mut cursor).await.unwrap();
254
255        assert_eq!(decoded.v, msg.v);
256        assert_eq!(decoded.t, msg.t);
257        assert_eq!(decoded.id, msg.id);
258        assert_eq!(decoded.flags, 0);
259    }
260
261    #[tokio::test]
262    async fn test_codec_roundtrip_with_payload() {
263        use crate::exec::ExecExited;
264
265        let msg =
266            Message::with_payload(MessageType::ExecExited, 7, &ExecExited { code: 42 }).unwrap();
267
268        let mut buf = Vec::new();
269        write_message(&mut buf, &msg).await.unwrap();
270
271        let mut cursor = &buf[..];
272        let decoded = read_message(&mut cursor).await.unwrap();
273
274        assert_eq!(decoded.v, PROTOCOL_VERSION);
275        assert_eq!(decoded.t, MessageType::ExecExited);
276        assert_eq!(decoded.id, 7);
277        assert_eq!(decoded.flags, FLAG_TERMINAL);
278
279        let payload: ExecExited = decoded.payload().unwrap();
280        assert_eq!(payload.code, 42);
281    }
282
283    #[tokio::test]
284    async fn test_codec_multiple_messages() {
285        let messages = vec![
286            Message::new(MessageType::Ready, 0, Vec::new()),
287            Message::new(MessageType::ExecExited, 1, Vec::new()),
288            Message::new(MessageType::Shutdown, 2, Vec::new()),
289        ];
290
291        let mut buf = Vec::new();
292        for msg in &messages {
293            write_message(&mut buf, msg).await.unwrap();
294        }
295
296        let mut cursor = &buf[..];
297        for expected in &messages {
298            let decoded = read_message(&mut cursor).await.unwrap();
299            assert_eq!(decoded.t, expected.t);
300            assert_eq!(decoded.id, expected.id);
301            assert_eq!(decoded.flags, expected.flags);
302        }
303    }
304
305    #[tokio::test]
306    async fn test_codec_unexpected_eof() {
307        let mut cursor: &[u8] = &[];
308        let result = read_message(&mut cursor).await;
309        assert!(matches!(result, Err(ProtocolError::UnexpectedEof)));
310    }
311
312    #[test]
313    fn test_sync_encode_decode_roundtrip() {
314        use crate::exec::ExecExited;
315
316        let msg =
317            Message::with_payload(MessageType::ExecExited, 5, &ExecExited { code: 0 }).unwrap();
318
319        let mut buf = Vec::new();
320        encode_to_buf(&msg, &mut buf).unwrap();
321
322        let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
323        assert_eq!(decoded.t, MessageType::ExecExited);
324        assert_eq!(decoded.id, 5);
325        assert_eq!(decoded.flags, FLAG_TERMINAL);
326
327        let payload: ExecExited = decoded.payload().unwrap();
328        assert_eq!(payload.code, 0);
329        assert!(buf.is_empty());
330    }
331
332    #[test]
333    fn test_sync_decode_incomplete() {
334        let mut buf = vec![0, 0, 0, 10]; // Length 10 but no payload bytes.
335        assert!(try_decode_from_buf(&mut buf).unwrap().is_none());
336    }
337
338    #[test]
339    fn test_sync_decode_frame_too_large() {
340        let huge_len: u32 = MAX_FRAME_SIZE + 1;
341        let mut buf = Vec::new();
342        buf.extend_from_slice(&huge_len.to_be_bytes());
343        let result = try_decode_from_buf(&mut buf);
344        assert!(matches!(result, Err(ProtocolError::FrameTooLarge { .. })));
345    }
346
347    #[test]
348    fn test_frame_header_wire_format() {
349        let msg = Message::new(MessageType::ExecRequest, 0x12345678, Vec::new());
350
351        let mut buf = Vec::new();
352        encode_to_buf(&msg, &mut buf).unwrap();
353
354        // Bytes 0–3: length prefix (u32 BE).
355        let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
356        assert_eq!(len as usize + 4, buf.len());
357
358        // Bytes 4–7: correlation ID (u32 BE).
359        let id = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
360        assert_eq!(id, 0x12345678);
361
362        // Byte 8: flags.
363        assert_eq!(buf[8], FLAG_SESSION_START);
364
365        // Bytes 9..: CBOR body (v, t, p — no id or flags).
366    }
367
368    #[test]
369    fn test_flags_roundtrip_terminal() {
370        let msg = Message::new(MessageType::ExecExited, 99, Vec::new());
371
372        let mut buf = Vec::new();
373        encode_to_buf(&msg, &mut buf).unwrap();
374
375        let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
376        assert_ne!(decoded.flags & FLAG_TERMINAL, 0);
377        assert_eq!(decoded.flags & FLAG_SESSION_START, 0);
378    }
379
380    #[test]
381    fn test_flags_roundtrip_session_start() {
382        let msg = Message::new(MessageType::FsRequest, 42, Vec::new());
383
384        let mut buf = Vec::new();
385        encode_to_buf(&msg, &mut buf).unwrap();
386
387        let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
388        assert_ne!(decoded.flags & FLAG_SESSION_START, 0);
389        assert_eq!(decoded.flags & FLAG_TERMINAL, 0);
390    }
391
392    #[test]
393    fn test_sync_decode_frame_too_short() {
394        // Frame with len=3 (too short for id+flags header).
395        let mut buf = Vec::new();
396        buf.extend_from_slice(&3u32.to_be_bytes());
397        buf.extend_from_slice(&[0, 0, 0]); // 3 bytes of payload.
398
399        let result = try_decode_from_buf(&mut buf);
400        assert!(matches!(result, Err(ProtocolError::FrameTooShort { .. })));
401    }
402
403    #[tokio::test]
404    async fn test_raw_frame_roundtrip() {
405        let frame = RawFrame {
406            id: 0xDEADBEEF,
407            flags: FLAG_TERMINAL,
408            body: vec![1, 2, 3, 4, 5],
409        };
410
411        let mut buf = Vec::new();
412        write_raw_frame(&mut buf, &frame).await.unwrap();
413
414        let mut cursor = &buf[..];
415        let decoded = read_raw_frame(&mut cursor).await.unwrap();
416
417        assert_eq!(decoded.id, frame.id);
418        assert_eq!(decoded.flags, frame.flags);
419        assert_eq!(decoded.body, frame.body);
420    }
421
422    #[test]
423    fn test_raw_frame_sync_roundtrip() {
424        let frame = RawFrame {
425            id: 42,
426            flags: FLAG_SESSION_START,
427            body: vec![0xAA; 100],
428        };
429
430        let mut buf = Vec::new();
431        encode_raw_to_buf(&frame, &mut buf).unwrap();
432
433        let decoded = try_decode_raw_from_buf(&mut buf).unwrap().unwrap();
434        assert_eq!(decoded.id, frame.id);
435        assert_eq!(decoded.flags, frame.flags);
436        assert_eq!(decoded.body, frame.body);
437        assert!(buf.is_empty());
438    }
439
440    #[test]
441    fn test_raw_frame_to_message() {
442        use crate::exec::ExecExited;
443
444        let msg =
445            Message::with_payload(MessageType::ExecExited, 13, &ExecExited { code: 7 }).unwrap();
446
447        let mut buf = Vec::new();
448        encode_to_buf(&msg, &mut buf).unwrap();
449
450        let frame = try_decode_raw_from_buf(&mut buf).unwrap().unwrap();
451        let decoded = raw_frame_to_message(frame).unwrap();
452
453        assert_eq!(decoded.id, 13);
454        assert_eq!(decoded.t, MessageType::ExecExited);
455        let payload: ExecExited = decoded.payload().unwrap();
456        assert_eq!(payload.code, 7);
457    }
458}