Skip to main content

microsandbox_protocol/
codec.rs

1//! Length-prefixed frame codec for reading and writing protocol messages.
2//!
3//! Wire format: `[len: u32 BE][id: u32 BE][flags: u8][CBOR(v, t, p)]`
4//!
5//! The correlation ID and flags sit in a fixed-position binary header so that
6//! relay intermediaries can route frames without CBOR parsing.
7
8use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
9
10use crate::{
11    error::{ProtocolError, ProtocolResult},
12    message::{FRAME_HEADER_SIZE, Message},
13};
14
15//--------------------------------------------------------------------------------------------------
16// Constants
17//--------------------------------------------------------------------------------------------------
18
19/// Maximum allowed frame size (4 MiB).
20///
21/// This covers everything after the 4-byte length prefix:
22/// `id (4) + flags (1) + CBOR payload`.
23pub const MAX_FRAME_SIZE: u32 = 4 * 1024 * 1024;
24
25//--------------------------------------------------------------------------------------------------
26// Types
27//--------------------------------------------------------------------------------------------------
28
29/// A frame with the binary header parsed but the CBOR body left untouched.
30///
31/// Used by routers, relays, and FFI consumers that want to handle framing
32/// without paying for CBOR (de)serialization. The [`body`](Self::body) field
33/// contains the exact CBOR-encoded `Message` body bytes — `v`, `t`, `p` —
34/// the same bytes that follow the binary header on the wire.
35#[derive(Debug, Clone)]
36pub struct RawFrame {
37    /// Correlation ID. Same as [`Message::id`].
38    pub id: u32,
39
40    /// Frame flags. Same as [`Message::flags`].
41    pub flags: u8,
42
43    /// Raw CBOR bytes of the message body (`v`, `t`, `p`). Not decoded.
44    pub body: Vec<u8>,
45}
46
47//--------------------------------------------------------------------------------------------------
48// Functions: Raw frame codec (CBOR-blind)
49//--------------------------------------------------------------------------------------------------
50
51/// Encodes a raw frame to a byte buffer using the length-prefixed format.
52///
53/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][body...]`
54pub fn encode_raw_to_buf(frame: &RawFrame, buf: &mut Vec<u8>) -> ProtocolResult<()> {
55    let frame_len = u32::try_from(FRAME_HEADER_SIZE + frame.body.len()).map_err(|_| {
56        ProtocolError::FrameTooLarge {
57            size: u32::MAX,
58            max: MAX_FRAME_SIZE,
59        }
60    })?;
61
62    if frame_len > MAX_FRAME_SIZE {
63        return Err(ProtocolError::FrameTooLarge {
64            size: frame_len,
65            max: MAX_FRAME_SIZE,
66        });
67    }
68
69    buf.extend_from_slice(&frame_len.to_be_bytes());
70    buf.extend_from_slice(&frame.id.to_be_bytes());
71    buf.push(frame.flags);
72    buf.extend_from_slice(&frame.body);
73    Ok(())
74}
75
76/// Tries to decode a complete raw frame from a byte buffer.
77///
78/// Returns `Some(RawFrame)` if a complete frame is available, consuming
79/// the bytes. Returns `None` if more data is needed.
80///
81/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][body...]`
82pub fn try_decode_raw_from_buf(buf: &mut Vec<u8>) -> ProtocolResult<Option<RawFrame>> {
83    if buf.len() < 4 {
84        return Ok(None);
85    }
86
87    let frame_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
88
89    if frame_len > MAX_FRAME_SIZE {
90        return Err(ProtocolError::FrameTooLarge {
91            size: frame_len,
92            max: MAX_FRAME_SIZE,
93        });
94    }
95
96    let frame_len = frame_len as usize;
97    let total = 4 + frame_len;
98
99    if buf.len() < total {
100        return Ok(None);
101    }
102
103    if frame_len < FRAME_HEADER_SIZE {
104        return Err(ProtocolError::FrameTooShort {
105            size: frame_len as u32,
106            min: FRAME_HEADER_SIZE as u32,
107        });
108    }
109
110    let id = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
111    let flags = buf[8];
112    let body = buf[4 + FRAME_HEADER_SIZE..total].to_vec();
113
114    buf.drain(..total);
115    Ok(Some(RawFrame { id, flags, body }))
116}
117
118/// Reads a length-prefixed raw frame from the given reader.
119///
120/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][body...]`
121pub async fn read_raw_frame<R: AsyncRead + Unpin>(reader: &mut R) -> ProtocolResult<RawFrame> {
122    let mut len_buf = [0u8; 4];
123    match reader.read_exact(&mut len_buf).await {
124        Ok(_) => {}
125        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
126            return Err(ProtocolError::UnexpectedEof);
127        }
128        Err(e) => return Err(e.into()),
129    }
130
131    let frame_len = u32::from_be_bytes(len_buf);
132
133    if frame_len > MAX_FRAME_SIZE {
134        return Err(ProtocolError::FrameTooLarge {
135            size: frame_len,
136            max: MAX_FRAME_SIZE,
137        });
138    }
139
140    let frame_len = frame_len as usize;
141
142    if frame_len < FRAME_HEADER_SIZE {
143        return Err(ProtocolError::FrameTooShort {
144            size: frame_len as u32,
145            min: FRAME_HEADER_SIZE as u32,
146        });
147    }
148
149    let mut payload = vec![0u8; frame_len];
150    reader.read_exact(&mut payload).await?;
151
152    let id = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
153    let flags = payload[4];
154    let body = payload[FRAME_HEADER_SIZE..].to_vec();
155
156    Ok(RawFrame { id, flags, body })
157}
158
159/// Writes a length-prefixed raw frame to the given writer.
160///
161/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][body...]`
162pub async fn write_raw_frame<W: AsyncWrite + Unpin>(
163    writer: &mut W,
164    frame: &RawFrame,
165) -> ProtocolResult<()> {
166    let mut buf = Vec::new();
167    encode_raw_to_buf(frame, &mut buf)?;
168    writer.write_all(&buf).await?;
169    writer.flush().await?;
170    Ok(())
171}
172
173//--------------------------------------------------------------------------------------------------
174// Functions: Typed message codec (CBOR-aware)
175//--------------------------------------------------------------------------------------------------
176
177/// Encodes a message to a byte buffer using the length-prefixed frame format.
178///
179/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][CBOR(v, t, p)]`
180pub fn encode_to_buf(msg: &Message, buf: &mut Vec<u8>) -> ProtocolResult<()> {
181    let mut body = Vec::new();
182    ciborium::into_writer(msg, &mut body)?;
183    encode_raw_to_buf(
184        &RawFrame {
185            id: msg.id,
186            flags: msg.flags,
187            body,
188        },
189        buf,
190    )
191}
192
193/// Tries to decode a complete message from a byte buffer.
194///
195/// Returns `Some(Message)` if a complete frame is available, consuming
196/// the bytes. Returns `None` if more data is needed.
197///
198/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][CBOR(v, t, p)]`
199pub fn try_decode_from_buf(buf: &mut Vec<u8>) -> ProtocolResult<Option<Message>> {
200    if buf.len() < 4 {
201        return Ok(None);
202    }
203
204    let frame_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
205
206    if frame_len > MAX_FRAME_SIZE {
207        return Err(ProtocolError::FrameTooLarge {
208            size: frame_len,
209            max: MAX_FRAME_SIZE,
210        });
211    }
212
213    let frame_len = frame_len as usize;
214    let total = 4 + frame_len;
215
216    if buf.len() < total {
217        return Ok(None);
218    }
219
220    let msg = decode_message_frame(&buf[..total])?;
221    buf.drain(..total);
222    Ok(Some(msg))
223}
224
225/// Reads a length-prefixed message from the given reader.
226///
227/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][CBOR(v, t, p)]`
228pub async fn read_message<R: AsyncRead + Unpin>(reader: &mut R) -> ProtocolResult<Message> {
229    let frame = read_raw_frame(reader).await?;
230    raw_frame_to_message(frame)
231}
232
233/// Writes a length-prefixed message to the given writer.
234///
235/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][CBOR(v, t, p)]`
236pub async fn write_message<W: AsyncWrite + Unpin>(
237    writer: &mut W,
238    message: &Message,
239) -> ProtocolResult<()> {
240    let mut buf = Vec::new();
241    encode_to_buf(message, &mut buf)?;
242    writer.write_all(&buf).await?;
243    writer.flush().await?;
244    Ok(())
245}
246
247/// Decodes a [`RawFrame`] into a typed [`Message`] by CBOR-deserializing the body.
248pub fn raw_frame_to_message(frame: RawFrame) -> ProtocolResult<Message> {
249    let mut msg: Message = ciborium::from_reader(&frame.body[..])?;
250    msg.id = frame.id;
251    msg.flags = frame.flags;
252    Ok(msg)
253}
254
255/// Decodes one complete length-prefixed frame from a borrowed byte slice.
256///
257/// The input must include the 4-byte length prefix, frame header, and CBOR body.
258/// The slice is not consumed or copied.
259pub fn decode_message_frame(frame: &[u8]) -> ProtocolResult<Message> {
260    if frame.len() < 4 {
261        return Err(ProtocolError::UnexpectedEof);
262    }
263
264    let frame_len = u32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]);
265    if frame_len > MAX_FRAME_SIZE {
266        return Err(ProtocolError::FrameTooLarge {
267            size: frame_len,
268            max: MAX_FRAME_SIZE,
269        });
270    }
271
272    let frame_len = frame_len as usize;
273    let total = 4 + frame_len;
274    if frame.len() < total {
275        return Err(ProtocolError::UnexpectedEof);
276    }
277
278    if frame_len < FRAME_HEADER_SIZE {
279        return Err(ProtocolError::FrameTooShort {
280            size: frame_len as u32,
281            min: FRAME_HEADER_SIZE as u32,
282        });
283    }
284
285    let mut msg: Message = ciborium::from_reader(&frame[4 + FRAME_HEADER_SIZE..total])?;
286    msg.id = u32::from_be_bytes([frame[4], frame[5], frame[6], frame[7]]);
287    msg.flags = frame[8];
288    Ok(msg)
289}
290
291//--------------------------------------------------------------------------------------------------
292// Tests
293//--------------------------------------------------------------------------------------------------
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use crate::message::{FLAG_SESSION_START, FLAG_TERMINAL, MessageType, PROTOCOL_VERSION};
299
300    #[tokio::test]
301    async fn test_codec_roundtrip_empty_payload() {
302        let msg = Message::new(MessageType::Ready, 0, Vec::new());
303
304        let mut buf = Vec::new();
305        write_message(&mut buf, &msg).await.unwrap();
306
307        let mut cursor = &buf[..];
308        let decoded = read_message(&mut cursor).await.unwrap();
309
310        assert_eq!(decoded.v, msg.v);
311        assert_eq!(decoded.t, msg.t);
312        assert_eq!(decoded.id, msg.id);
313        assert_eq!(decoded.flags, 0);
314    }
315
316    #[tokio::test]
317    async fn test_codec_roundtrip_with_payload() {
318        use crate::exec::ExecExited;
319
320        let msg =
321            Message::with_payload(MessageType::ExecExited, 7, &ExecExited { code: 42 }).unwrap();
322
323        let mut buf = Vec::new();
324        write_message(&mut buf, &msg).await.unwrap();
325
326        let mut cursor = &buf[..];
327        let decoded = read_message(&mut cursor).await.unwrap();
328
329        assert_eq!(decoded.v, PROTOCOL_VERSION);
330        assert_eq!(decoded.t, MessageType::ExecExited);
331        assert_eq!(decoded.id, 7);
332        assert_eq!(decoded.flags, FLAG_TERMINAL);
333
334        let payload: ExecExited = decoded.payload().unwrap();
335        assert_eq!(payload.code, 42);
336    }
337
338    #[tokio::test]
339    async fn test_codec_multiple_messages() {
340        let messages = vec![
341            Message::new(MessageType::Ready, 0, Vec::new()),
342            Message::new(MessageType::ExecExited, 1, Vec::new()),
343            Message::new(MessageType::Shutdown, 2, Vec::new()),
344        ];
345
346        let mut buf = Vec::new();
347        for msg in &messages {
348            write_message(&mut buf, msg).await.unwrap();
349        }
350
351        let mut cursor = &buf[..];
352        for expected in &messages {
353            let decoded = read_message(&mut cursor).await.unwrap();
354            assert_eq!(decoded.t, expected.t);
355            assert_eq!(decoded.id, expected.id);
356            assert_eq!(decoded.flags, expected.flags);
357        }
358    }
359
360    #[tokio::test]
361    async fn test_codec_unexpected_eof() {
362        let mut cursor: &[u8] = &[];
363        let result = read_message(&mut cursor).await;
364        assert!(matches!(result, Err(ProtocolError::UnexpectedEof)));
365    }
366
367    #[test]
368    fn test_sync_encode_decode_roundtrip() {
369        use crate::exec::ExecExited;
370
371        let msg =
372            Message::with_payload(MessageType::ExecExited, 5, &ExecExited { code: 0 }).unwrap();
373
374        let mut buf = Vec::new();
375        encode_to_buf(&msg, &mut buf).unwrap();
376
377        let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
378        assert_eq!(decoded.t, MessageType::ExecExited);
379        assert_eq!(decoded.id, 5);
380        assert_eq!(decoded.flags, FLAG_TERMINAL);
381
382        let payload: ExecExited = decoded.payload().unwrap();
383        assert_eq!(payload.code, 0);
384        assert!(buf.is_empty());
385    }
386
387    #[test]
388    fn test_borrowed_decode_message_frame_roundtrip() {
389        use crate::exec::ExecExited;
390
391        let msg =
392            Message::with_payload(MessageType::ExecExited, 5, &ExecExited { code: 0 }).unwrap();
393
394        let mut buf = Vec::new();
395        encode_to_buf(&msg, &mut buf).unwrap();
396
397        let decoded = decode_message_frame(&buf).unwrap();
398        assert_eq!(decoded.t, MessageType::ExecExited);
399        assert_eq!(decoded.id, 5);
400        assert_eq!(decoded.flags, FLAG_TERMINAL);
401
402        let payload: ExecExited = decoded.payload().unwrap();
403        assert_eq!(payload.code, 0);
404        assert!(!buf.is_empty(), "borrowed decode must not consume input");
405    }
406
407    #[test]
408    fn test_borrowed_decode_message_frame_rejects_incomplete() {
409        let buf = vec![0, 0, 0, 10];
410        assert!(matches!(
411            decode_message_frame(&buf),
412            Err(ProtocolError::UnexpectedEof)
413        ));
414    }
415
416    #[test]
417    fn test_sync_decode_incomplete() {
418        let mut buf = vec![0, 0, 0, 10]; // Length 10 but no payload bytes.
419        assert!(try_decode_from_buf(&mut buf).unwrap().is_none());
420    }
421
422    #[test]
423    fn test_sync_decode_frame_too_large() {
424        let huge_len: u32 = MAX_FRAME_SIZE + 1;
425        let mut buf = Vec::new();
426        buf.extend_from_slice(&huge_len.to_be_bytes());
427        let result = try_decode_from_buf(&mut buf);
428        assert!(matches!(result, Err(ProtocolError::FrameTooLarge { .. })));
429    }
430
431    #[test]
432    fn test_frame_header_wire_format() {
433        let msg = Message::new(MessageType::ExecRequest, 0x12345678, Vec::new());
434
435        let mut buf = Vec::new();
436        encode_to_buf(&msg, &mut buf).unwrap();
437
438        // Bytes 0–3: length prefix (u32 BE).
439        let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
440        assert_eq!(len as usize + 4, buf.len());
441
442        // Bytes 4–7: correlation ID (u32 BE).
443        let id = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
444        assert_eq!(id, 0x12345678);
445
446        // Byte 8: flags.
447        assert_eq!(buf[8], FLAG_SESSION_START);
448
449        // Bytes 9..: CBOR body (v, t, p — no id or flags).
450    }
451
452    #[test]
453    fn test_flags_roundtrip_terminal() {
454        let msg = Message::new(MessageType::ExecExited, 99, Vec::new());
455
456        let mut buf = Vec::new();
457        encode_to_buf(&msg, &mut buf).unwrap();
458
459        let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
460        assert_ne!(decoded.flags & FLAG_TERMINAL, 0);
461        assert_eq!(decoded.flags & FLAG_SESSION_START, 0);
462    }
463
464    #[test]
465    fn test_flags_roundtrip_session_start() {
466        let msg = Message::new(MessageType::FsRequest, 42, Vec::new());
467
468        let mut buf = Vec::new();
469        encode_to_buf(&msg, &mut buf).unwrap();
470
471        let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
472        assert_ne!(decoded.flags & FLAG_SESSION_START, 0);
473        assert_eq!(decoded.flags & FLAG_TERMINAL, 0);
474    }
475
476    #[test]
477    fn test_sync_decode_frame_too_short() {
478        // Frame with len=3 (too short for id+flags header).
479        let mut buf = Vec::new();
480        buf.extend_from_slice(&3u32.to_be_bytes());
481        buf.extend_from_slice(&[0, 0, 0]); // 3 bytes of payload.
482
483        let result = try_decode_from_buf(&mut buf);
484        assert!(matches!(result, Err(ProtocolError::FrameTooShort { .. })));
485    }
486
487    #[tokio::test]
488    async fn test_raw_frame_roundtrip() {
489        let frame = RawFrame {
490            id: 0xDEADBEEF,
491            flags: FLAG_TERMINAL,
492            body: vec![1, 2, 3, 4, 5],
493        };
494
495        let mut buf = Vec::new();
496        write_raw_frame(&mut buf, &frame).await.unwrap();
497
498        let mut cursor = &buf[..];
499        let decoded = read_raw_frame(&mut cursor).await.unwrap();
500
501        assert_eq!(decoded.id, frame.id);
502        assert_eq!(decoded.flags, frame.flags);
503        assert_eq!(decoded.body, frame.body);
504    }
505
506    #[test]
507    fn test_raw_frame_sync_roundtrip() {
508        let frame = RawFrame {
509            id: 42,
510            flags: FLAG_SESSION_START,
511            body: vec![0xAA; 100],
512        };
513
514        let mut buf = Vec::new();
515        encode_raw_to_buf(&frame, &mut buf).unwrap();
516
517        let decoded = try_decode_raw_from_buf(&mut buf).unwrap().unwrap();
518        assert_eq!(decoded.id, frame.id);
519        assert_eq!(decoded.flags, frame.flags);
520        assert_eq!(decoded.body, frame.body);
521        assert!(buf.is_empty());
522    }
523
524    #[test]
525    fn test_raw_frame_to_message() {
526        use crate::exec::ExecExited;
527
528        let msg =
529            Message::with_payload(MessageType::ExecExited, 13, &ExecExited { code: 7 }).unwrap();
530
531        let mut buf = Vec::new();
532        encode_to_buf(&msg, &mut buf).unwrap();
533
534        let frame = try_decode_raw_from_buf(&mut buf).unwrap().unwrap();
535        let decoded = raw_frame_to_message(frame).unwrap();
536
537        assert_eq!(decoded.id, 13);
538        assert_eq!(decoded.t, MessageType::ExecExited);
539        let payload: ExecExited = decoded.payload().unwrap();
540        assert_eq!(payload.code, 7);
541    }
542}