1use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
9
10use crate::{
11 error::{ProtocolError, ProtocolResult},
12 message::{FRAME_HEADER_SIZE, Message},
13};
14
15pub const MAX_FRAME_SIZE: u32 = 4 * 1024 * 1024;
24
25#[derive(Debug, Clone)]
36pub struct RawFrame {
37 pub id: u32,
39
40 pub flags: u8,
42
43 pub body: Vec<u8>,
45}
46
47pub fn encode_raw_to_buf(frame: &RawFrame, buf: &mut Vec<u8>) -> ProtocolResult<()> {
55 let frame_len = u32::try_from(FRAME_HEADER_SIZE + frame.body.len()).map_err(|_| {
56 ProtocolError::FrameTooLarge {
57 size: u32::MAX,
58 max: MAX_FRAME_SIZE,
59 }
60 })?;
61
62 if frame_len > MAX_FRAME_SIZE {
63 return Err(ProtocolError::FrameTooLarge {
64 size: frame_len,
65 max: MAX_FRAME_SIZE,
66 });
67 }
68
69 buf.extend_from_slice(&frame_len.to_be_bytes());
70 buf.extend_from_slice(&frame.id.to_be_bytes());
71 buf.push(frame.flags);
72 buf.extend_from_slice(&frame.body);
73 Ok(())
74}
75
76pub fn try_decode_raw_from_buf(buf: &mut Vec<u8>) -> ProtocolResult<Option<RawFrame>> {
83 if buf.len() < 4 {
84 return Ok(None);
85 }
86
87 let frame_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
88
89 if frame_len > MAX_FRAME_SIZE {
90 return Err(ProtocolError::FrameTooLarge {
91 size: frame_len,
92 max: MAX_FRAME_SIZE,
93 });
94 }
95
96 let frame_len = frame_len as usize;
97 let total = 4 + frame_len;
98
99 if buf.len() < total {
100 return Ok(None);
101 }
102
103 if frame_len < FRAME_HEADER_SIZE {
104 return Err(ProtocolError::FrameTooShort {
105 size: frame_len as u32,
106 min: FRAME_HEADER_SIZE as u32,
107 });
108 }
109
110 let id = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
111 let flags = buf[8];
112 let body = buf[4 + FRAME_HEADER_SIZE..total].to_vec();
113
114 buf.drain(..total);
115 Ok(Some(RawFrame { id, flags, body }))
116}
117
118pub async fn read_raw_frame<R: AsyncRead + Unpin>(reader: &mut R) -> ProtocolResult<RawFrame> {
122 let mut len_buf = [0u8; 4];
123 match reader.read_exact(&mut len_buf).await {
124 Ok(_) => {}
125 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
126 return Err(ProtocolError::UnexpectedEof);
127 }
128 Err(e) => return Err(e.into()),
129 }
130
131 let frame_len = u32::from_be_bytes(len_buf);
132
133 if frame_len > MAX_FRAME_SIZE {
134 return Err(ProtocolError::FrameTooLarge {
135 size: frame_len,
136 max: MAX_FRAME_SIZE,
137 });
138 }
139
140 let frame_len = frame_len as usize;
141
142 if frame_len < FRAME_HEADER_SIZE {
143 return Err(ProtocolError::FrameTooShort {
144 size: frame_len as u32,
145 min: FRAME_HEADER_SIZE as u32,
146 });
147 }
148
149 let mut payload = vec![0u8; frame_len];
150 reader.read_exact(&mut payload).await?;
151
152 let id = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
153 let flags = payload[4];
154 let body = payload[FRAME_HEADER_SIZE..].to_vec();
155
156 Ok(RawFrame { id, flags, body })
157}
158
159pub async fn write_raw_frame<W: AsyncWrite + Unpin>(
163 writer: &mut W,
164 frame: &RawFrame,
165) -> ProtocolResult<()> {
166 let mut buf = Vec::new();
167 encode_raw_to_buf(frame, &mut buf)?;
168 writer.write_all(&buf).await?;
169 writer.flush().await?;
170 Ok(())
171}
172
173pub fn encode_to_buf(msg: &Message, buf: &mut Vec<u8>) -> ProtocolResult<()> {
181 let mut body = Vec::new();
182 ciborium::into_writer(msg, &mut body)?;
183 encode_raw_to_buf(
184 &RawFrame {
185 id: msg.id,
186 flags: msg.flags,
187 body,
188 },
189 buf,
190 )
191}
192
193pub fn try_decode_from_buf(buf: &mut Vec<u8>) -> ProtocolResult<Option<Message>> {
200 match try_decode_raw_from_buf(buf)? {
201 Some(frame) => Ok(Some(raw_frame_to_message(frame)?)),
202 None => Ok(None),
203 }
204}
205
206pub async fn read_message<R: AsyncRead + Unpin>(reader: &mut R) -> ProtocolResult<Message> {
210 let frame = read_raw_frame(reader).await?;
211 raw_frame_to_message(frame)
212}
213
214pub async fn write_message<W: AsyncWrite + Unpin>(
218 writer: &mut W,
219 message: &Message,
220) -> ProtocolResult<()> {
221 let mut buf = Vec::new();
222 encode_to_buf(message, &mut buf)?;
223 writer.write_all(&buf).await?;
224 writer.flush().await?;
225 Ok(())
226}
227
228pub fn raw_frame_to_message(frame: RawFrame) -> ProtocolResult<Message> {
230 let mut msg: Message = ciborium::from_reader(&frame.body[..])?;
231 msg.id = frame.id;
232 msg.flags = frame.flags;
233 Ok(msg)
234}
235
236#[cfg(test)]
241mod tests {
242 use super::*;
243 use crate::message::{FLAG_SESSION_START, FLAG_TERMINAL, MessageType, PROTOCOL_VERSION};
244
245 #[tokio::test]
246 async fn test_codec_roundtrip_empty_payload() {
247 let msg = Message::new(MessageType::Ready, 0, Vec::new());
248
249 let mut buf = Vec::new();
250 write_message(&mut buf, &msg).await.unwrap();
251
252 let mut cursor = &buf[..];
253 let decoded = read_message(&mut cursor).await.unwrap();
254
255 assert_eq!(decoded.v, msg.v);
256 assert_eq!(decoded.t, msg.t);
257 assert_eq!(decoded.id, msg.id);
258 assert_eq!(decoded.flags, 0);
259 }
260
261 #[tokio::test]
262 async fn test_codec_roundtrip_with_payload() {
263 use crate::exec::ExecExited;
264
265 let msg =
266 Message::with_payload(MessageType::ExecExited, 7, &ExecExited { code: 42 }).unwrap();
267
268 let mut buf = Vec::new();
269 write_message(&mut buf, &msg).await.unwrap();
270
271 let mut cursor = &buf[..];
272 let decoded = read_message(&mut cursor).await.unwrap();
273
274 assert_eq!(decoded.v, PROTOCOL_VERSION);
275 assert_eq!(decoded.t, MessageType::ExecExited);
276 assert_eq!(decoded.id, 7);
277 assert_eq!(decoded.flags, FLAG_TERMINAL);
278
279 let payload: ExecExited = decoded.payload().unwrap();
280 assert_eq!(payload.code, 42);
281 }
282
283 #[tokio::test]
284 async fn test_codec_multiple_messages() {
285 let messages = vec![
286 Message::new(MessageType::Ready, 0, Vec::new()),
287 Message::new(MessageType::ExecExited, 1, Vec::new()),
288 Message::new(MessageType::Shutdown, 2, Vec::new()),
289 ];
290
291 let mut buf = Vec::new();
292 for msg in &messages {
293 write_message(&mut buf, msg).await.unwrap();
294 }
295
296 let mut cursor = &buf[..];
297 for expected in &messages {
298 let decoded = read_message(&mut cursor).await.unwrap();
299 assert_eq!(decoded.t, expected.t);
300 assert_eq!(decoded.id, expected.id);
301 assert_eq!(decoded.flags, expected.flags);
302 }
303 }
304
305 #[tokio::test]
306 async fn test_codec_unexpected_eof() {
307 let mut cursor: &[u8] = &[];
308 let result = read_message(&mut cursor).await;
309 assert!(matches!(result, Err(ProtocolError::UnexpectedEof)));
310 }
311
312 #[test]
313 fn test_sync_encode_decode_roundtrip() {
314 use crate::exec::ExecExited;
315
316 let msg =
317 Message::with_payload(MessageType::ExecExited, 5, &ExecExited { code: 0 }).unwrap();
318
319 let mut buf = Vec::new();
320 encode_to_buf(&msg, &mut buf).unwrap();
321
322 let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
323 assert_eq!(decoded.t, MessageType::ExecExited);
324 assert_eq!(decoded.id, 5);
325 assert_eq!(decoded.flags, FLAG_TERMINAL);
326
327 let payload: ExecExited = decoded.payload().unwrap();
328 assert_eq!(payload.code, 0);
329 assert!(buf.is_empty());
330 }
331
332 #[test]
333 fn test_sync_decode_incomplete() {
334 let mut buf = vec![0, 0, 0, 10]; assert!(try_decode_from_buf(&mut buf).unwrap().is_none());
336 }
337
338 #[test]
339 fn test_sync_decode_frame_too_large() {
340 let huge_len: u32 = MAX_FRAME_SIZE + 1;
341 let mut buf = Vec::new();
342 buf.extend_from_slice(&huge_len.to_be_bytes());
343 let result = try_decode_from_buf(&mut buf);
344 assert!(matches!(result, Err(ProtocolError::FrameTooLarge { .. })));
345 }
346
347 #[test]
348 fn test_frame_header_wire_format() {
349 let msg = Message::new(MessageType::ExecRequest, 0x12345678, Vec::new());
350
351 let mut buf = Vec::new();
352 encode_to_buf(&msg, &mut buf).unwrap();
353
354 let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
356 assert_eq!(len as usize + 4, buf.len());
357
358 let id = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
360 assert_eq!(id, 0x12345678);
361
362 assert_eq!(buf[8], FLAG_SESSION_START);
364
365 }
367
368 #[test]
369 fn test_flags_roundtrip_terminal() {
370 let msg = Message::new(MessageType::ExecExited, 99, Vec::new());
371
372 let mut buf = Vec::new();
373 encode_to_buf(&msg, &mut buf).unwrap();
374
375 let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
376 assert_ne!(decoded.flags & FLAG_TERMINAL, 0);
377 assert_eq!(decoded.flags & FLAG_SESSION_START, 0);
378 }
379
380 #[test]
381 fn test_flags_roundtrip_session_start() {
382 let msg = Message::new(MessageType::FsRequest, 42, Vec::new());
383
384 let mut buf = Vec::new();
385 encode_to_buf(&msg, &mut buf).unwrap();
386
387 let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
388 assert_ne!(decoded.flags & FLAG_SESSION_START, 0);
389 assert_eq!(decoded.flags & FLAG_TERMINAL, 0);
390 }
391
392 #[test]
393 fn test_sync_decode_frame_too_short() {
394 let mut buf = Vec::new();
396 buf.extend_from_slice(&3u32.to_be_bytes());
397 buf.extend_from_slice(&[0, 0, 0]); let result = try_decode_from_buf(&mut buf);
400 assert!(matches!(result, Err(ProtocolError::FrameTooShort { .. })));
401 }
402
403 #[tokio::test]
404 async fn test_raw_frame_roundtrip() {
405 let frame = RawFrame {
406 id: 0xDEADBEEF,
407 flags: FLAG_TERMINAL,
408 body: vec![1, 2, 3, 4, 5],
409 };
410
411 let mut buf = Vec::new();
412 write_raw_frame(&mut buf, &frame).await.unwrap();
413
414 let mut cursor = &buf[..];
415 let decoded = read_raw_frame(&mut cursor).await.unwrap();
416
417 assert_eq!(decoded.id, frame.id);
418 assert_eq!(decoded.flags, frame.flags);
419 assert_eq!(decoded.body, frame.body);
420 }
421
422 #[test]
423 fn test_raw_frame_sync_roundtrip() {
424 let frame = RawFrame {
425 id: 42,
426 flags: FLAG_SESSION_START,
427 body: vec![0xAA; 100],
428 };
429
430 let mut buf = Vec::new();
431 encode_raw_to_buf(&frame, &mut buf).unwrap();
432
433 let decoded = try_decode_raw_from_buf(&mut buf).unwrap().unwrap();
434 assert_eq!(decoded.id, frame.id);
435 assert_eq!(decoded.flags, frame.flags);
436 assert_eq!(decoded.body, frame.body);
437 assert!(buf.is_empty());
438 }
439
440 #[test]
441 fn test_raw_frame_to_message() {
442 use crate::exec::ExecExited;
443
444 let msg =
445 Message::with_payload(MessageType::ExecExited, 13, &ExecExited { code: 7 }).unwrap();
446
447 let mut buf = Vec::new();
448 encode_to_buf(&msg, &mut buf).unwrap();
449
450 let frame = try_decode_raw_from_buf(&mut buf).unwrap().unwrap();
451 let decoded = raw_frame_to_message(frame).unwrap();
452
453 assert_eq!(decoded.id, 13);
454 assert_eq!(decoded.t, MessageType::ExecExited);
455 let payload: ExecExited = decoded.payload().unwrap();
456 assert_eq!(payload.code, 7);
457 }
458}