microsandbox_protocol/
codec.rs1use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4
5use crate::{
6 error::{ProtocolError, ProtocolResult},
7 message::Message,
8};
9
10pub const MAX_FRAME_SIZE: u32 = 4 * 1024 * 1024;
16
17pub fn encode_to_buf(msg: &Message, buf: &mut Vec<u8>) -> ProtocolResult<()> {
25 let mut payload = Vec::new();
26 ciborium::into_writer(msg, &mut payload)?;
27
28 let len = u32::try_from(payload.len()).map_err(|_| ProtocolError::FrameTooLarge {
29 size: u32::MAX,
30 max: MAX_FRAME_SIZE,
31 })?;
32
33 if len > MAX_FRAME_SIZE {
34 return Err(ProtocolError::FrameTooLarge {
35 size: len,
36 max: MAX_FRAME_SIZE,
37 });
38 }
39
40 buf.extend_from_slice(&len.to_be_bytes());
41 buf.extend_from_slice(&payload);
42 Ok(())
43}
44
45pub fn try_decode_from_buf(buf: &mut Vec<u8>) -> ProtocolResult<Option<Message>> {
52 if buf.len() < 4 {
53 return Ok(None);
54 }
55
56 let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
57
58 if len > MAX_FRAME_SIZE {
59 return Err(ProtocolError::FrameTooLarge {
60 size: len,
61 max: MAX_FRAME_SIZE,
62 });
63 }
64
65 let len = len as usize;
66 if buf.len() < 4 + len {
67 return Ok(None);
68 }
69
70 let payload = &buf[4..4 + len];
71 let msg: Message = ciborium::from_reader(payload)?;
72
73 buf.drain(..4 + len);
74 Ok(Some(msg))
75}
76
77pub async fn read_message<R: AsyncRead + Unpin>(reader: &mut R) -> ProtocolResult<Message> {
81 let mut len_buf = [0u8; 4];
83 match reader.read_exact(&mut len_buf).await {
84 Ok(_) => {}
85 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
86 return Err(ProtocolError::UnexpectedEof);
87 }
88 Err(e) => return Err(e.into()),
89 }
90
91 let len = u32::from_be_bytes(len_buf);
92
93 if len > MAX_FRAME_SIZE {
94 return Err(ProtocolError::FrameTooLarge {
95 size: len,
96 max: MAX_FRAME_SIZE,
97 });
98 }
99
100 let mut payload = vec![0u8; len as usize];
102 reader.read_exact(&mut payload).await?;
103
104 let message: Message = ciborium::from_reader(&payload[..])?;
106 Ok(message)
107}
108
109pub async fn write_message<W: AsyncWrite + Unpin>(
113 writer: &mut W,
114 message: &Message,
115) -> ProtocolResult<()> {
116 let mut buf = Vec::new();
117 encode_to_buf(message, &mut buf)?;
118 writer.write_all(&buf).await?;
119 writer.flush().await?;
120 Ok(())
121}
122
123#[cfg(test)]
128mod tests {
129 use super::*;
130 use crate::message::{MessageType, PROTOCOL_VERSION};
131
132 #[tokio::test]
133 async fn test_codec_roundtrip_empty_payload() {
134 let msg = Message {
135 v: PROTOCOL_VERSION,
136 t: MessageType::Ready,
137 id: 0,
138 p: Vec::new(),
139 };
140
141 let mut buf = Vec::new();
142 write_message(&mut buf, &msg).await.unwrap();
143
144 let mut cursor = &buf[..];
145 let decoded = read_message(&mut cursor).await.unwrap();
146
147 assert_eq!(decoded.v, msg.v);
148 assert_eq!(decoded.t, msg.t);
149 assert_eq!(decoded.id, msg.id);
150 }
151
152 #[tokio::test]
153 async fn test_codec_roundtrip_with_payload() {
154 use crate::exec::ExecExited;
155
156 let msg =
157 Message::with_payload(MessageType::ExecExited, 7, &ExecExited { code: 42 }).unwrap();
158
159 let mut buf = Vec::new();
160 write_message(&mut buf, &msg).await.unwrap();
161
162 let mut cursor = &buf[..];
163 let decoded = read_message(&mut cursor).await.unwrap();
164
165 assert_eq!(decoded.v, PROTOCOL_VERSION);
166 assert_eq!(decoded.t, MessageType::ExecExited);
167 assert_eq!(decoded.id, 7);
168
169 let payload: ExecExited = decoded.payload().unwrap();
170 assert_eq!(payload.code, 42);
171 }
172
173 #[tokio::test]
174 async fn test_codec_multiple_messages() {
175 let messages = vec![
176 Message::new(MessageType::Ready, 0, Vec::new()),
177 Message::new(MessageType::ExecExited, 1, Vec::new()),
178 Message::new(MessageType::Shutdown, 2, Vec::new()),
179 ];
180
181 let mut buf = Vec::new();
182 for msg in &messages {
183 write_message(&mut buf, msg).await.unwrap();
184 }
185
186 let mut cursor = &buf[..];
187 for expected in &messages {
188 let decoded = read_message(&mut cursor).await.unwrap();
189 assert_eq!(decoded.t, expected.t);
190 assert_eq!(decoded.id, expected.id);
191 }
192 }
193
194 #[tokio::test]
195 async fn test_codec_unexpected_eof() {
196 let mut cursor: &[u8] = &[];
197 let result = read_message(&mut cursor).await;
198 assert!(matches!(result, Err(ProtocolError::UnexpectedEof)));
199 }
200
201 #[test]
202 fn test_sync_encode_decode_roundtrip() {
203 use crate::exec::ExecExited;
204
205 let msg =
206 Message::with_payload(MessageType::ExecExited, 5, &ExecExited { code: 0 }).unwrap();
207
208 let mut buf = Vec::new();
209 encode_to_buf(&msg, &mut buf).unwrap();
210
211 let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
212 assert_eq!(decoded.t, MessageType::ExecExited);
213 assert_eq!(decoded.id, 5);
214
215 let payload: ExecExited = decoded.payload().unwrap();
216 assert_eq!(payload.code, 0);
217 assert!(buf.is_empty());
218 }
219
220 #[test]
221 fn test_sync_decode_incomplete() {
222 let mut buf = vec![0, 0, 0, 10]; assert!(try_decode_from_buf(&mut buf).unwrap().is_none());
224 }
225
226 #[test]
227 fn test_sync_decode_frame_too_large() {
228 let huge_len: u32 = MAX_FRAME_SIZE + 1;
229 let mut buf = Vec::new();
230 buf.extend_from_slice(&huge_len.to_be_bytes());
231 let result = try_decode_from_buf(&mut buf);
232 assert!(matches!(result, Err(ProtocolError::FrameTooLarge { .. })));
233 }
234}