1use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
9
10use crate::{
11 error::{ProtocolError, ProtocolResult},
12 message::{FRAME_HEADER_SIZE, Message},
13};
14
15pub const MAX_FRAME_SIZE: u32 = 4 * 1024 * 1024;
24
25pub fn encode_to_buf(msg: &Message, buf: &mut Vec<u8>) -> ProtocolResult<()> {
33 let mut cbor = Vec::new();
35 ciborium::into_writer(msg, &mut cbor)?;
36
37 let frame_len = u32::try_from(FRAME_HEADER_SIZE + cbor.len()).map_err(|_| {
39 ProtocolError::FrameTooLarge {
40 size: u32::MAX,
41 max: MAX_FRAME_SIZE,
42 }
43 })?;
44
45 if frame_len > MAX_FRAME_SIZE {
46 return Err(ProtocolError::FrameTooLarge {
47 size: frame_len,
48 max: MAX_FRAME_SIZE,
49 });
50 }
51
52 buf.extend_from_slice(&frame_len.to_be_bytes());
53 buf.extend_from_slice(&msg.id.to_be_bytes());
54 buf.push(msg.flags);
55 buf.extend_from_slice(&cbor);
56 Ok(())
57}
58
59pub fn try_decode_from_buf(buf: &mut Vec<u8>) -> ProtocolResult<Option<Message>> {
66 if buf.len() < 4 {
67 return Ok(None);
68 }
69
70 let frame_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
71
72 if frame_len > MAX_FRAME_SIZE {
73 return Err(ProtocolError::FrameTooLarge {
74 size: frame_len,
75 max: MAX_FRAME_SIZE,
76 });
77 }
78
79 let frame_len = frame_len as usize;
80 let total = 4 + frame_len;
81
82 if buf.len() < total {
83 return Ok(None);
84 }
85
86 if frame_len < FRAME_HEADER_SIZE {
87 return Err(ProtocolError::FrameTooShort {
88 size: frame_len as u32,
89 min: FRAME_HEADER_SIZE as u32,
90 });
91 }
92
93 let id = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
95 let flags = buf[8];
96
97 let cbor = &buf[4 + FRAME_HEADER_SIZE..total];
99 let mut msg: Message = ciborium::from_reader(cbor)?;
100 msg.id = id;
101 msg.flags = flags;
102
103 buf.drain(..total);
104 Ok(Some(msg))
105}
106
107pub async fn read_message<R: AsyncRead + Unpin>(reader: &mut R) -> ProtocolResult<Message> {
111 let mut len_buf = [0u8; 4];
113 match reader.read_exact(&mut len_buf).await {
114 Ok(_) => {}
115 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
116 return Err(ProtocolError::UnexpectedEof);
117 }
118 Err(e) => return Err(e.into()),
119 }
120
121 let frame_len = u32::from_be_bytes(len_buf);
122
123 if frame_len > MAX_FRAME_SIZE {
124 return Err(ProtocolError::FrameTooLarge {
125 size: frame_len,
126 max: MAX_FRAME_SIZE,
127 });
128 }
129
130 let frame_len = frame_len as usize;
131
132 if frame_len < FRAME_HEADER_SIZE {
133 return Err(ProtocolError::FrameTooShort {
134 size: frame_len as u32,
135 min: FRAME_HEADER_SIZE as u32,
136 });
137 }
138
139 let mut payload = vec![0u8; frame_len];
141 reader.read_exact(&mut payload).await?;
142
143 let id = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
145 let flags = payload[4];
146
147 let cbor = &payload[FRAME_HEADER_SIZE..];
149 let mut msg: Message = ciborium::from_reader(cbor)?;
150 msg.id = id;
151 msg.flags = flags;
152
153 Ok(msg)
154}
155
156pub async fn write_message<W: AsyncWrite + Unpin>(
160 writer: &mut W,
161 message: &Message,
162) -> ProtocolResult<()> {
163 let mut buf = Vec::new();
164 encode_to_buf(message, &mut buf)?;
165 writer.write_all(&buf).await?;
166 writer.flush().await?;
167 Ok(())
168}
169
170#[cfg(test)]
175mod tests {
176 use super::*;
177 use crate::message::{FLAG_SESSION_START, FLAG_TERMINAL, MessageType, PROTOCOL_VERSION};
178
179 #[tokio::test]
180 async fn test_codec_roundtrip_empty_payload() {
181 let msg = Message::new(MessageType::Ready, 0, Vec::new());
182
183 let mut buf = Vec::new();
184 write_message(&mut buf, &msg).await.unwrap();
185
186 let mut cursor = &buf[..];
187 let decoded = read_message(&mut cursor).await.unwrap();
188
189 assert_eq!(decoded.v, msg.v);
190 assert_eq!(decoded.t, msg.t);
191 assert_eq!(decoded.id, msg.id);
192 assert_eq!(decoded.flags, 0);
193 }
194
195 #[tokio::test]
196 async fn test_codec_roundtrip_with_payload() {
197 use crate::exec::ExecExited;
198
199 let msg =
200 Message::with_payload(MessageType::ExecExited, 7, &ExecExited { code: 42 }).unwrap();
201
202 let mut buf = Vec::new();
203 write_message(&mut buf, &msg).await.unwrap();
204
205 let mut cursor = &buf[..];
206 let decoded = read_message(&mut cursor).await.unwrap();
207
208 assert_eq!(decoded.v, PROTOCOL_VERSION);
209 assert_eq!(decoded.t, MessageType::ExecExited);
210 assert_eq!(decoded.id, 7);
211 assert_eq!(decoded.flags, FLAG_TERMINAL);
212
213 let payload: ExecExited = decoded.payload().unwrap();
214 assert_eq!(payload.code, 42);
215 }
216
217 #[tokio::test]
218 async fn test_codec_multiple_messages() {
219 let messages = vec![
220 Message::new(MessageType::Ready, 0, Vec::new()),
221 Message::new(MessageType::ExecExited, 1, Vec::new()),
222 Message::new(MessageType::Shutdown, 2, Vec::new()),
223 ];
224
225 let mut buf = Vec::new();
226 for msg in &messages {
227 write_message(&mut buf, msg).await.unwrap();
228 }
229
230 let mut cursor = &buf[..];
231 for expected in &messages {
232 let decoded = read_message(&mut cursor).await.unwrap();
233 assert_eq!(decoded.t, expected.t);
234 assert_eq!(decoded.id, expected.id);
235 assert_eq!(decoded.flags, expected.flags);
236 }
237 }
238
239 #[tokio::test]
240 async fn test_codec_unexpected_eof() {
241 let mut cursor: &[u8] = &[];
242 let result = read_message(&mut cursor).await;
243 assert!(matches!(result, Err(ProtocolError::UnexpectedEof)));
244 }
245
246 #[test]
247 fn test_sync_encode_decode_roundtrip() {
248 use crate::exec::ExecExited;
249
250 let msg =
251 Message::with_payload(MessageType::ExecExited, 5, &ExecExited { code: 0 }).unwrap();
252
253 let mut buf = Vec::new();
254 encode_to_buf(&msg, &mut buf).unwrap();
255
256 let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
257 assert_eq!(decoded.t, MessageType::ExecExited);
258 assert_eq!(decoded.id, 5);
259 assert_eq!(decoded.flags, FLAG_TERMINAL);
260
261 let payload: ExecExited = decoded.payload().unwrap();
262 assert_eq!(payload.code, 0);
263 assert!(buf.is_empty());
264 }
265
266 #[test]
267 fn test_sync_decode_incomplete() {
268 let mut buf = vec![0, 0, 0, 10]; assert!(try_decode_from_buf(&mut buf).unwrap().is_none());
270 }
271
272 #[test]
273 fn test_sync_decode_frame_too_large() {
274 let huge_len: u32 = MAX_FRAME_SIZE + 1;
275 let mut buf = Vec::new();
276 buf.extend_from_slice(&huge_len.to_be_bytes());
277 let result = try_decode_from_buf(&mut buf);
278 assert!(matches!(result, Err(ProtocolError::FrameTooLarge { .. })));
279 }
280
281 #[test]
282 fn test_frame_header_wire_format() {
283 let msg = Message::new(MessageType::ExecRequest, 0x12345678, Vec::new());
284
285 let mut buf = Vec::new();
286 encode_to_buf(&msg, &mut buf).unwrap();
287
288 let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
290 assert_eq!(len as usize + 4, buf.len());
291
292 let id = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
294 assert_eq!(id, 0x12345678);
295
296 assert_eq!(buf[8], FLAG_SESSION_START);
298
299 }
301
302 #[test]
303 fn test_flags_roundtrip_terminal() {
304 let msg = Message::new(MessageType::ExecExited, 99, Vec::new());
305
306 let mut buf = Vec::new();
307 encode_to_buf(&msg, &mut buf).unwrap();
308
309 let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
310 assert_ne!(decoded.flags & FLAG_TERMINAL, 0);
311 assert_eq!(decoded.flags & FLAG_SESSION_START, 0);
312 }
313
314 #[test]
315 fn test_flags_roundtrip_session_start() {
316 let msg = Message::new(MessageType::FsRequest, 42, Vec::new());
317
318 let mut buf = Vec::new();
319 encode_to_buf(&msg, &mut buf).unwrap();
320
321 let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
322 assert_ne!(decoded.flags & FLAG_SESSION_START, 0);
323 assert_eq!(decoded.flags & FLAG_TERMINAL, 0);
324 }
325
326 #[test]
327 fn test_sync_decode_frame_too_short() {
328 let mut buf = Vec::new();
330 buf.extend_from_slice(&3u32.to_be_bytes());
331 buf.extend_from_slice(&[0, 0, 0]); let result = try_decode_from_buf(&mut buf);
334 assert!(matches!(result, Err(ProtocolError::FrameTooShort { .. })));
335 }
336}