Skip to main content

microsandbox_protocol/
codec.rs

1//! Length-prefixed frame codec for reading and writing protocol messages.
2//!
3//! Wire format: `[len: u32 BE][id: u32 BE][flags: u8][CBOR(v, t, p)]`
4//!
5//! The correlation ID and flags sit in a fixed-position binary header so that
6//! relay intermediaries can route frames without CBOR parsing.
7
8use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
9
10use crate::{
11    error::{ProtocolError, ProtocolResult},
12    message::{FRAME_HEADER_SIZE, Message},
13};
14
15//--------------------------------------------------------------------------------------------------
16// Constants
17//--------------------------------------------------------------------------------------------------
18
19/// Maximum allowed frame size (4 MiB).
20///
21/// This covers everything after the 4-byte length prefix:
22/// `id (4) + flags (1) + CBOR payload`.
23pub const MAX_FRAME_SIZE: u32 = 4 * 1024 * 1024;
24
25//--------------------------------------------------------------------------------------------------
26// Functions
27//--------------------------------------------------------------------------------------------------
28
29/// Encodes a message to a byte buffer using the length-prefixed frame format.
30///
31/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][CBOR(v, t, p)]`
32pub fn encode_to_buf(msg: &Message, buf: &mut Vec<u8>) -> ProtocolResult<()> {
33    // Serialize the CBOR body (v, t, p — id and flags are excluded via serde(skip)).
34    let mut cbor = Vec::new();
35    ciborium::into_writer(msg, &mut cbor)?;
36
37    // Total frame payload = id (4) + flags (1) + CBOR body.
38    let frame_len = u32::try_from(FRAME_HEADER_SIZE + cbor.len()).map_err(|_| {
39        ProtocolError::FrameTooLarge {
40            size: u32::MAX,
41            max: MAX_FRAME_SIZE,
42        }
43    })?;
44
45    if frame_len > MAX_FRAME_SIZE {
46        return Err(ProtocolError::FrameTooLarge {
47            size: frame_len,
48            max: MAX_FRAME_SIZE,
49        });
50    }
51
52    buf.extend_from_slice(&frame_len.to_be_bytes());
53    buf.extend_from_slice(&msg.id.to_be_bytes());
54    buf.push(msg.flags);
55    buf.extend_from_slice(&cbor);
56    Ok(())
57}
58
59/// Tries to decode a complete message from a byte buffer.
60///
61/// Returns `Some(Message)` if a complete frame is available, consuming
62/// the bytes. Returns `None` if more data is needed.
63///
64/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][CBOR(v, t, p)]`
65pub fn try_decode_from_buf(buf: &mut Vec<u8>) -> ProtocolResult<Option<Message>> {
66    if buf.len() < 4 {
67        return Ok(None);
68    }
69
70    let frame_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
71
72    if frame_len > MAX_FRAME_SIZE {
73        return Err(ProtocolError::FrameTooLarge {
74            size: frame_len,
75            max: MAX_FRAME_SIZE,
76        });
77    }
78
79    let frame_len = frame_len as usize;
80    let total = 4 + frame_len;
81
82    if buf.len() < total {
83        return Ok(None);
84    }
85
86    if frame_len < FRAME_HEADER_SIZE {
87        return Err(ProtocolError::FrameTooShort {
88            size: frame_len as u32,
89            min: FRAME_HEADER_SIZE as u32,
90        });
91    }
92
93    // Extract header fields.
94    let id = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
95    let flags = buf[8];
96
97    // Deserialize the CBOR body.
98    let cbor = &buf[4 + FRAME_HEADER_SIZE..total];
99    let mut msg: Message = ciborium::from_reader(cbor)?;
100    msg.id = id;
101    msg.flags = flags;
102
103    buf.drain(..total);
104    Ok(Some(msg))
105}
106
107/// Reads a length-prefixed message from the given reader.
108///
109/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][CBOR(v, t, p)]`
110pub async fn read_message<R: AsyncRead + Unpin>(reader: &mut R) -> ProtocolResult<Message> {
111    // Read the 4-byte length prefix.
112    let mut len_buf = [0u8; 4];
113    match reader.read_exact(&mut len_buf).await {
114        Ok(_) => {}
115        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
116            return Err(ProtocolError::UnexpectedEof);
117        }
118        Err(e) => return Err(e.into()),
119    }
120
121    let frame_len = u32::from_be_bytes(len_buf);
122
123    if frame_len > MAX_FRAME_SIZE {
124        return Err(ProtocolError::FrameTooLarge {
125            size: frame_len,
126            max: MAX_FRAME_SIZE,
127        });
128    }
129
130    let frame_len = frame_len as usize;
131
132    if frame_len < FRAME_HEADER_SIZE {
133        return Err(ProtocolError::FrameTooShort {
134            size: frame_len as u32,
135            min: FRAME_HEADER_SIZE as u32,
136        });
137    }
138
139    // Read the full frame payload.
140    let mut payload = vec![0u8; frame_len];
141    reader.read_exact(&mut payload).await?;
142
143    // Extract header fields.
144    let id = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
145    let flags = payload[4];
146
147    // Deserialize the CBOR body.
148    let cbor = &payload[FRAME_HEADER_SIZE..];
149    let mut msg: Message = ciborium::from_reader(cbor)?;
150    msg.id = id;
151    msg.flags = flags;
152
153    Ok(msg)
154}
155
156/// Writes a length-prefixed message to the given writer.
157///
158/// Frame format: `[len: u32 BE][id: u32 BE][flags: u8][CBOR(v, t, p)]`
159pub async fn write_message<W: AsyncWrite + Unpin>(
160    writer: &mut W,
161    message: &Message,
162) -> ProtocolResult<()> {
163    let mut buf = Vec::new();
164    encode_to_buf(message, &mut buf)?;
165    writer.write_all(&buf).await?;
166    writer.flush().await?;
167    Ok(())
168}
169
170//--------------------------------------------------------------------------------------------------
171// Tests
172//--------------------------------------------------------------------------------------------------
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use crate::message::{FLAG_SESSION_START, FLAG_TERMINAL, MessageType, PROTOCOL_VERSION};
178
179    #[tokio::test]
180    async fn test_codec_roundtrip_empty_payload() {
181        let msg = Message::new(MessageType::Ready, 0, Vec::new());
182
183        let mut buf = Vec::new();
184        write_message(&mut buf, &msg).await.unwrap();
185
186        let mut cursor = &buf[..];
187        let decoded = read_message(&mut cursor).await.unwrap();
188
189        assert_eq!(decoded.v, msg.v);
190        assert_eq!(decoded.t, msg.t);
191        assert_eq!(decoded.id, msg.id);
192        assert_eq!(decoded.flags, 0);
193    }
194
195    #[tokio::test]
196    async fn test_codec_roundtrip_with_payload() {
197        use crate::exec::ExecExited;
198
199        let msg =
200            Message::with_payload(MessageType::ExecExited, 7, &ExecExited { code: 42 }).unwrap();
201
202        let mut buf = Vec::new();
203        write_message(&mut buf, &msg).await.unwrap();
204
205        let mut cursor = &buf[..];
206        let decoded = read_message(&mut cursor).await.unwrap();
207
208        assert_eq!(decoded.v, PROTOCOL_VERSION);
209        assert_eq!(decoded.t, MessageType::ExecExited);
210        assert_eq!(decoded.id, 7);
211        assert_eq!(decoded.flags, FLAG_TERMINAL);
212
213        let payload: ExecExited = decoded.payload().unwrap();
214        assert_eq!(payload.code, 42);
215    }
216
217    #[tokio::test]
218    async fn test_codec_multiple_messages() {
219        let messages = vec![
220            Message::new(MessageType::Ready, 0, Vec::new()),
221            Message::new(MessageType::ExecExited, 1, Vec::new()),
222            Message::new(MessageType::Shutdown, 2, Vec::new()),
223        ];
224
225        let mut buf = Vec::new();
226        for msg in &messages {
227            write_message(&mut buf, msg).await.unwrap();
228        }
229
230        let mut cursor = &buf[..];
231        for expected in &messages {
232            let decoded = read_message(&mut cursor).await.unwrap();
233            assert_eq!(decoded.t, expected.t);
234            assert_eq!(decoded.id, expected.id);
235            assert_eq!(decoded.flags, expected.flags);
236        }
237    }
238
239    #[tokio::test]
240    async fn test_codec_unexpected_eof() {
241        let mut cursor: &[u8] = &[];
242        let result = read_message(&mut cursor).await;
243        assert!(matches!(result, Err(ProtocolError::UnexpectedEof)));
244    }
245
246    #[test]
247    fn test_sync_encode_decode_roundtrip() {
248        use crate::exec::ExecExited;
249
250        let msg =
251            Message::with_payload(MessageType::ExecExited, 5, &ExecExited { code: 0 }).unwrap();
252
253        let mut buf = Vec::new();
254        encode_to_buf(&msg, &mut buf).unwrap();
255
256        let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
257        assert_eq!(decoded.t, MessageType::ExecExited);
258        assert_eq!(decoded.id, 5);
259        assert_eq!(decoded.flags, FLAG_TERMINAL);
260
261        let payload: ExecExited = decoded.payload().unwrap();
262        assert_eq!(payload.code, 0);
263        assert!(buf.is_empty());
264    }
265
266    #[test]
267    fn test_sync_decode_incomplete() {
268        let mut buf = vec![0, 0, 0, 10]; // Length 10 but no payload bytes.
269        assert!(try_decode_from_buf(&mut buf).unwrap().is_none());
270    }
271
272    #[test]
273    fn test_sync_decode_frame_too_large() {
274        let huge_len: u32 = MAX_FRAME_SIZE + 1;
275        let mut buf = Vec::new();
276        buf.extend_from_slice(&huge_len.to_be_bytes());
277        let result = try_decode_from_buf(&mut buf);
278        assert!(matches!(result, Err(ProtocolError::FrameTooLarge { .. })));
279    }
280
281    #[test]
282    fn test_frame_header_wire_format() {
283        let msg = Message::new(MessageType::ExecRequest, 0x12345678, Vec::new());
284
285        let mut buf = Vec::new();
286        encode_to_buf(&msg, &mut buf).unwrap();
287
288        // Bytes 0–3: length prefix (u32 BE).
289        let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
290        assert_eq!(len as usize + 4, buf.len());
291
292        // Bytes 4–7: correlation ID (u32 BE).
293        let id = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
294        assert_eq!(id, 0x12345678);
295
296        // Byte 8: flags.
297        assert_eq!(buf[8], FLAG_SESSION_START);
298
299        // Bytes 9..: CBOR body (v, t, p — no id or flags).
300    }
301
302    #[test]
303    fn test_flags_roundtrip_terminal() {
304        let msg = Message::new(MessageType::ExecExited, 99, Vec::new());
305
306        let mut buf = Vec::new();
307        encode_to_buf(&msg, &mut buf).unwrap();
308
309        let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
310        assert_ne!(decoded.flags & FLAG_TERMINAL, 0);
311        assert_eq!(decoded.flags & FLAG_SESSION_START, 0);
312    }
313
314    #[test]
315    fn test_flags_roundtrip_session_start() {
316        let msg = Message::new(MessageType::FsRequest, 42, Vec::new());
317
318        let mut buf = Vec::new();
319        encode_to_buf(&msg, &mut buf).unwrap();
320
321        let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
322        assert_ne!(decoded.flags & FLAG_SESSION_START, 0);
323        assert_eq!(decoded.flags & FLAG_TERMINAL, 0);
324    }
325
326    #[test]
327    fn test_sync_decode_frame_too_short() {
328        // Frame with len=3 (too short for id+flags header).
329        let mut buf = Vec::new();
330        buf.extend_from_slice(&3u32.to_be_bytes());
331        buf.extend_from_slice(&[0, 0, 0]); // 3 bytes of payload.
332
333        let result = try_decode_from_buf(&mut buf);
334        assert!(matches!(result, Err(ProtocolError::FrameTooShort { .. })));
335    }
336}