Skip to main content

microsandbox_protocol/
codec.rs

1//! Length-prefixed CBOR frame codec for reading and writing protocol messages.
2
3use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4
5use crate::{
6    error::{ProtocolError, ProtocolResult},
7    message::Message,
8};
9
10//--------------------------------------------------------------------------------------------------
11// Constants
12//--------------------------------------------------------------------------------------------------
13
14/// Maximum allowed frame size (4 MiB).
15pub const MAX_FRAME_SIZE: u32 = 4 * 1024 * 1024;
16
17//--------------------------------------------------------------------------------------------------
18// Functions
19//--------------------------------------------------------------------------------------------------
20
21/// Encodes a message to a byte buffer using the length-prefixed frame format.
22///
23/// Frame format: `[len: u32 BE][CBOR payload]`
24pub fn encode_to_buf(msg: &Message, buf: &mut Vec<u8>) -> ProtocolResult<()> {
25    let mut payload = Vec::new();
26    ciborium::into_writer(msg, &mut payload)?;
27
28    let len = u32::try_from(payload.len()).map_err(|_| ProtocolError::FrameTooLarge {
29        size: u32::MAX,
30        max: MAX_FRAME_SIZE,
31    })?;
32
33    if len > MAX_FRAME_SIZE {
34        return Err(ProtocolError::FrameTooLarge {
35            size: len,
36            max: MAX_FRAME_SIZE,
37        });
38    }
39
40    buf.extend_from_slice(&len.to_be_bytes());
41    buf.extend_from_slice(&payload);
42    Ok(())
43}
44
45/// Tries to decode a complete message from a byte buffer.
46///
47/// Returns `Some(Message)` if a complete frame is available, consuming
48/// the bytes. Returns `None` if more data is needed.
49///
50/// Frame format: `[len: u32 BE][CBOR payload]`
51pub fn try_decode_from_buf(buf: &mut Vec<u8>) -> ProtocolResult<Option<Message>> {
52    if buf.len() < 4 {
53        return Ok(None);
54    }
55
56    let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
57
58    if len > MAX_FRAME_SIZE {
59        return Err(ProtocolError::FrameTooLarge {
60            size: len,
61            max: MAX_FRAME_SIZE,
62        });
63    }
64
65    let len = len as usize;
66    if buf.len() < 4 + len {
67        return Ok(None);
68    }
69
70    let payload = &buf[4..4 + len];
71    let msg: Message = ciborium::from_reader(payload)?;
72
73    buf.drain(..4 + len);
74    Ok(Some(msg))
75}
76
77/// Reads a length-prefixed CBOR message from the given reader.
78///
79/// Frame format: `[len: u32 BE][CBOR payload]`
80pub async fn read_message<R: AsyncRead + Unpin>(reader: &mut R) -> ProtocolResult<Message> {
81    // Read the 4-byte length prefix.
82    let mut len_buf = [0u8; 4];
83    match reader.read_exact(&mut len_buf).await {
84        Ok(_) => {}
85        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
86            return Err(ProtocolError::UnexpectedEof);
87        }
88        Err(e) => return Err(e.into()),
89    }
90
91    let len = u32::from_be_bytes(len_buf);
92
93    if len > MAX_FRAME_SIZE {
94        return Err(ProtocolError::FrameTooLarge {
95            size: len,
96            max: MAX_FRAME_SIZE,
97        });
98    }
99
100    // Read the CBOR payload.
101    let mut payload = vec![0u8; len as usize];
102    reader.read_exact(&mut payload).await?;
103
104    // Deserialize the message.
105    let message: Message = ciborium::from_reader(&payload[..])?;
106    Ok(message)
107}
108
109/// Writes a length-prefixed CBOR message to the given writer.
110///
111/// Frame format: `[len: u32 BE][CBOR payload]`
112pub async fn write_message<W: AsyncWrite + Unpin>(
113    writer: &mut W,
114    message: &Message,
115) -> ProtocolResult<()> {
116    let mut buf = Vec::new();
117    encode_to_buf(message, &mut buf)?;
118    writer.write_all(&buf).await?;
119    writer.flush().await?;
120    Ok(())
121}
122
123//--------------------------------------------------------------------------------------------------
124// Tests
125//--------------------------------------------------------------------------------------------------
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use crate::message::{MessageType, PROTOCOL_VERSION};
131
132    #[tokio::test]
133    async fn test_codec_roundtrip_empty_payload() {
134        let msg = Message {
135            v: PROTOCOL_VERSION,
136            t: MessageType::Ready,
137            id: 0,
138            p: Vec::new(),
139        };
140
141        let mut buf = Vec::new();
142        write_message(&mut buf, &msg).await.unwrap();
143
144        let mut cursor = &buf[..];
145        let decoded = read_message(&mut cursor).await.unwrap();
146
147        assert_eq!(decoded.v, msg.v);
148        assert_eq!(decoded.t, msg.t);
149        assert_eq!(decoded.id, msg.id);
150    }
151
152    #[tokio::test]
153    async fn test_codec_roundtrip_with_payload() {
154        use crate::exec::ExecExited;
155
156        let msg =
157            Message::with_payload(MessageType::ExecExited, 7, &ExecExited { code: 42 }).unwrap();
158
159        let mut buf = Vec::new();
160        write_message(&mut buf, &msg).await.unwrap();
161
162        let mut cursor = &buf[..];
163        let decoded = read_message(&mut cursor).await.unwrap();
164
165        assert_eq!(decoded.v, PROTOCOL_VERSION);
166        assert_eq!(decoded.t, MessageType::ExecExited);
167        assert_eq!(decoded.id, 7);
168
169        let payload: ExecExited = decoded.payload().unwrap();
170        assert_eq!(payload.code, 42);
171    }
172
173    #[tokio::test]
174    async fn test_codec_multiple_messages() {
175        let messages = vec![
176            Message::new(MessageType::Ready, 0, Vec::new()),
177            Message::new(MessageType::ExecExited, 1, Vec::new()),
178            Message::new(MessageType::Shutdown, 2, Vec::new()),
179        ];
180
181        let mut buf = Vec::new();
182        for msg in &messages {
183            write_message(&mut buf, msg).await.unwrap();
184        }
185
186        let mut cursor = &buf[..];
187        for expected in &messages {
188            let decoded = read_message(&mut cursor).await.unwrap();
189            assert_eq!(decoded.t, expected.t);
190            assert_eq!(decoded.id, expected.id);
191        }
192    }
193
194    #[tokio::test]
195    async fn test_codec_unexpected_eof() {
196        let mut cursor: &[u8] = &[];
197        let result = read_message(&mut cursor).await;
198        assert!(matches!(result, Err(ProtocolError::UnexpectedEof)));
199    }
200
201    #[test]
202    fn test_sync_encode_decode_roundtrip() {
203        use crate::exec::ExecExited;
204
205        let msg =
206            Message::with_payload(MessageType::ExecExited, 5, &ExecExited { code: 0 }).unwrap();
207
208        let mut buf = Vec::new();
209        encode_to_buf(&msg, &mut buf).unwrap();
210
211        let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
212        assert_eq!(decoded.t, MessageType::ExecExited);
213        assert_eq!(decoded.id, 5);
214
215        let payload: ExecExited = decoded.payload().unwrap();
216        assert_eq!(payload.code, 0);
217        assert!(buf.is_empty());
218    }
219
220    #[test]
221    fn test_sync_decode_incomplete() {
222        let mut buf = vec![0, 0, 0, 10]; // Length 10 but no payload bytes.
223        assert!(try_decode_from_buf(&mut buf).unwrap().is_none());
224    }
225
226    #[test]
227    fn test_sync_decode_frame_too_large() {
228        let huge_len: u32 = MAX_FRAME_SIZE + 1;
229        let mut buf = Vec::new();
230        buf.extend_from_slice(&huge_len.to_be_bytes());
231        let result = try_decode_from_buf(&mut buf);
232        assert!(matches!(result, Err(ProtocolError::FrameTooLarge { .. })));
233    }
234}