1use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
9
10use crate::{
11 error::{ProtocolError, ProtocolResult},
12 message::{FRAME_HEADER_SIZE, Message},
13};
14
15pub const MAX_FRAME_SIZE: u32 = 4 * 1024 * 1024;
24
25#[derive(Debug, Clone)]
36pub struct RawFrame {
37 pub id: u32,
39
40 pub flags: u8,
42
43 pub body: Vec<u8>,
45}
46
47pub fn encode_raw_to_buf(frame: &RawFrame, buf: &mut Vec<u8>) -> ProtocolResult<()> {
55 let frame_len = u32::try_from(FRAME_HEADER_SIZE + frame.body.len()).map_err(|_| {
56 ProtocolError::FrameTooLarge {
57 size: u32::MAX,
58 max: MAX_FRAME_SIZE,
59 }
60 })?;
61
62 if frame_len > MAX_FRAME_SIZE {
63 return Err(ProtocolError::FrameTooLarge {
64 size: frame_len,
65 max: MAX_FRAME_SIZE,
66 });
67 }
68
69 buf.extend_from_slice(&frame_len.to_be_bytes());
70 buf.extend_from_slice(&frame.id.to_be_bytes());
71 buf.push(frame.flags);
72 buf.extend_from_slice(&frame.body);
73 Ok(())
74}
75
76pub fn try_decode_raw_from_buf(buf: &mut Vec<u8>) -> ProtocolResult<Option<RawFrame>> {
83 if buf.len() < 4 {
84 return Ok(None);
85 }
86
87 let frame_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
88
89 if frame_len > MAX_FRAME_SIZE {
90 return Err(ProtocolError::FrameTooLarge {
91 size: frame_len,
92 max: MAX_FRAME_SIZE,
93 });
94 }
95
96 let frame_len = frame_len as usize;
97 let total = 4 + frame_len;
98
99 if buf.len() < total {
100 return Ok(None);
101 }
102
103 if frame_len < FRAME_HEADER_SIZE {
104 return Err(ProtocolError::FrameTooShort {
105 size: frame_len as u32,
106 min: FRAME_HEADER_SIZE as u32,
107 });
108 }
109
110 let id = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
111 let flags = buf[8];
112 let body = buf[4 + FRAME_HEADER_SIZE..total].to_vec();
113
114 buf.drain(..total);
115 Ok(Some(RawFrame { id, flags, body }))
116}
117
118pub async fn read_raw_frame<R: AsyncRead + Unpin>(reader: &mut R) -> ProtocolResult<RawFrame> {
122 let mut len_buf = [0u8; 4];
123 match reader.read_exact(&mut len_buf).await {
124 Ok(_) => {}
125 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
126 return Err(ProtocolError::UnexpectedEof);
127 }
128 Err(e) => return Err(e.into()),
129 }
130
131 let frame_len = u32::from_be_bytes(len_buf);
132
133 if frame_len > MAX_FRAME_SIZE {
134 return Err(ProtocolError::FrameTooLarge {
135 size: frame_len,
136 max: MAX_FRAME_SIZE,
137 });
138 }
139
140 let frame_len = frame_len as usize;
141
142 if frame_len < FRAME_HEADER_SIZE {
143 return Err(ProtocolError::FrameTooShort {
144 size: frame_len as u32,
145 min: FRAME_HEADER_SIZE as u32,
146 });
147 }
148
149 let mut payload = vec![0u8; frame_len];
150 reader.read_exact(&mut payload).await?;
151
152 let id = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
153 let flags = payload[4];
154 let body = payload[FRAME_HEADER_SIZE..].to_vec();
155
156 Ok(RawFrame { id, flags, body })
157}
158
159pub async fn write_raw_frame<W: AsyncWrite + Unpin>(
163 writer: &mut W,
164 frame: &RawFrame,
165) -> ProtocolResult<()> {
166 let mut buf = Vec::new();
167 encode_raw_to_buf(frame, &mut buf)?;
168 writer.write_all(&buf).await?;
169 writer.flush().await?;
170 Ok(())
171}
172
173pub fn encode_to_buf(msg: &Message, buf: &mut Vec<u8>) -> ProtocolResult<()> {
181 let mut body = Vec::new();
182 ciborium::into_writer(msg, &mut body)?;
183 encode_raw_to_buf(
184 &RawFrame {
185 id: msg.id,
186 flags: msg.flags,
187 body,
188 },
189 buf,
190 )
191}
192
193pub fn try_decode_from_buf(buf: &mut Vec<u8>) -> ProtocolResult<Option<Message>> {
200 if buf.len() < 4 {
201 return Ok(None);
202 }
203
204 let frame_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
205
206 if frame_len > MAX_FRAME_SIZE {
207 return Err(ProtocolError::FrameTooLarge {
208 size: frame_len,
209 max: MAX_FRAME_SIZE,
210 });
211 }
212
213 let frame_len = frame_len as usize;
214 let total = 4 + frame_len;
215
216 if buf.len() < total {
217 return Ok(None);
218 }
219
220 let msg = decode_message_frame(&buf[..total])?;
221 buf.drain(..total);
222 Ok(Some(msg))
223}
224
225pub async fn read_message<R: AsyncRead + Unpin>(reader: &mut R) -> ProtocolResult<Message> {
229 let frame = read_raw_frame(reader).await?;
230 raw_frame_to_message(frame)
231}
232
233pub async fn write_message<W: AsyncWrite + Unpin>(
237 writer: &mut W,
238 message: &Message,
239) -> ProtocolResult<()> {
240 let mut buf = Vec::new();
241 encode_to_buf(message, &mut buf)?;
242 writer.write_all(&buf).await?;
243 writer.flush().await?;
244 Ok(())
245}
246
247pub fn raw_frame_to_message(frame: RawFrame) -> ProtocolResult<Message> {
249 let mut msg: Message = ciborium::from_reader(&frame.body[..])?;
250 msg.id = frame.id;
251 msg.flags = frame.flags;
252 Ok(msg)
253}
254
255pub fn decode_message_frame(frame: &[u8]) -> ProtocolResult<Message> {
260 if frame.len() < 4 {
261 return Err(ProtocolError::UnexpectedEof);
262 }
263
264 let frame_len = u32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]);
265 if frame_len > MAX_FRAME_SIZE {
266 return Err(ProtocolError::FrameTooLarge {
267 size: frame_len,
268 max: MAX_FRAME_SIZE,
269 });
270 }
271
272 let frame_len = frame_len as usize;
273 let total = 4 + frame_len;
274 if frame.len() < total {
275 return Err(ProtocolError::UnexpectedEof);
276 }
277
278 if frame_len < FRAME_HEADER_SIZE {
279 return Err(ProtocolError::FrameTooShort {
280 size: frame_len as u32,
281 min: FRAME_HEADER_SIZE as u32,
282 });
283 }
284
285 let mut msg: Message = ciborium::from_reader(&frame[4 + FRAME_HEADER_SIZE..total])?;
286 msg.id = u32::from_be_bytes([frame[4], frame[5], frame[6], frame[7]]);
287 msg.flags = frame[8];
288 Ok(msg)
289}
290
291#[cfg(test)]
296mod tests {
297 use super::*;
298 use crate::message::{FLAG_SESSION_START, FLAG_TERMINAL, MessageType, PROTOCOL_VERSION};
299
300 #[tokio::test]
301 async fn test_codec_roundtrip_empty_payload() {
302 let msg = Message::new(MessageType::Ready, 0, Vec::new());
303
304 let mut buf = Vec::new();
305 write_message(&mut buf, &msg).await.unwrap();
306
307 let mut cursor = &buf[..];
308 let decoded = read_message(&mut cursor).await.unwrap();
309
310 assert_eq!(decoded.v, msg.v);
311 assert_eq!(decoded.t, msg.t);
312 assert_eq!(decoded.id, msg.id);
313 assert_eq!(decoded.flags, 0);
314 }
315
316 #[tokio::test]
317 async fn test_codec_roundtrip_with_payload() {
318 use crate::exec::ExecExited;
319
320 let msg =
321 Message::with_payload(MessageType::ExecExited, 7, &ExecExited { code: 42 }).unwrap();
322
323 let mut buf = Vec::new();
324 write_message(&mut buf, &msg).await.unwrap();
325
326 let mut cursor = &buf[..];
327 let decoded = read_message(&mut cursor).await.unwrap();
328
329 assert_eq!(decoded.v, PROTOCOL_VERSION);
330 assert_eq!(decoded.t, MessageType::ExecExited);
331 assert_eq!(decoded.id, 7);
332 assert_eq!(decoded.flags, FLAG_TERMINAL);
333
334 let payload: ExecExited = decoded.payload().unwrap();
335 assert_eq!(payload.code, 42);
336 }
337
338 #[tokio::test]
339 async fn test_codec_multiple_messages() {
340 let messages = vec![
341 Message::new(MessageType::Ready, 0, Vec::new()),
342 Message::new(MessageType::ExecExited, 1, Vec::new()),
343 Message::new(MessageType::Shutdown, 2, Vec::new()),
344 ];
345
346 let mut buf = Vec::new();
347 for msg in &messages {
348 write_message(&mut buf, msg).await.unwrap();
349 }
350
351 let mut cursor = &buf[..];
352 for expected in &messages {
353 let decoded = read_message(&mut cursor).await.unwrap();
354 assert_eq!(decoded.t, expected.t);
355 assert_eq!(decoded.id, expected.id);
356 assert_eq!(decoded.flags, expected.flags);
357 }
358 }
359
360 #[tokio::test]
361 async fn test_codec_unexpected_eof() {
362 let mut cursor: &[u8] = &[];
363 let result = read_message(&mut cursor).await;
364 assert!(matches!(result, Err(ProtocolError::UnexpectedEof)));
365 }
366
367 #[test]
368 fn test_sync_encode_decode_roundtrip() {
369 use crate::exec::ExecExited;
370
371 let msg =
372 Message::with_payload(MessageType::ExecExited, 5, &ExecExited { code: 0 }).unwrap();
373
374 let mut buf = Vec::new();
375 encode_to_buf(&msg, &mut buf).unwrap();
376
377 let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
378 assert_eq!(decoded.t, MessageType::ExecExited);
379 assert_eq!(decoded.id, 5);
380 assert_eq!(decoded.flags, FLAG_TERMINAL);
381
382 let payload: ExecExited = decoded.payload().unwrap();
383 assert_eq!(payload.code, 0);
384 assert!(buf.is_empty());
385 }
386
387 #[test]
388 fn test_borrowed_decode_message_frame_roundtrip() {
389 use crate::exec::ExecExited;
390
391 let msg =
392 Message::with_payload(MessageType::ExecExited, 5, &ExecExited { code: 0 }).unwrap();
393
394 let mut buf = Vec::new();
395 encode_to_buf(&msg, &mut buf).unwrap();
396
397 let decoded = decode_message_frame(&buf).unwrap();
398 assert_eq!(decoded.t, MessageType::ExecExited);
399 assert_eq!(decoded.id, 5);
400 assert_eq!(decoded.flags, FLAG_TERMINAL);
401
402 let payload: ExecExited = decoded.payload().unwrap();
403 assert_eq!(payload.code, 0);
404 assert!(!buf.is_empty(), "borrowed decode must not consume input");
405 }
406
407 #[test]
408 fn test_borrowed_decode_message_frame_rejects_incomplete() {
409 let buf = vec![0, 0, 0, 10];
410 assert!(matches!(
411 decode_message_frame(&buf),
412 Err(ProtocolError::UnexpectedEof)
413 ));
414 }
415
416 #[test]
417 fn test_sync_decode_incomplete() {
418 let mut buf = vec![0, 0, 0, 10]; assert!(try_decode_from_buf(&mut buf).unwrap().is_none());
420 }
421
422 #[test]
423 fn test_sync_decode_frame_too_large() {
424 let huge_len: u32 = MAX_FRAME_SIZE + 1;
425 let mut buf = Vec::new();
426 buf.extend_from_slice(&huge_len.to_be_bytes());
427 let result = try_decode_from_buf(&mut buf);
428 assert!(matches!(result, Err(ProtocolError::FrameTooLarge { .. })));
429 }
430
431 #[test]
432 fn test_frame_header_wire_format() {
433 let msg = Message::new(MessageType::ExecRequest, 0x12345678, Vec::new());
434
435 let mut buf = Vec::new();
436 encode_to_buf(&msg, &mut buf).unwrap();
437
438 let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
440 assert_eq!(len as usize + 4, buf.len());
441
442 let id = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
444 assert_eq!(id, 0x12345678);
445
446 assert_eq!(buf[8], FLAG_SESSION_START);
448
449 }
451
452 #[test]
453 fn test_flags_roundtrip_terminal() {
454 let msg = Message::new(MessageType::ExecExited, 99, Vec::new());
455
456 let mut buf = Vec::new();
457 encode_to_buf(&msg, &mut buf).unwrap();
458
459 let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
460 assert_ne!(decoded.flags & FLAG_TERMINAL, 0);
461 assert_eq!(decoded.flags & FLAG_SESSION_START, 0);
462 }
463
464 #[test]
465 fn test_flags_roundtrip_session_start() {
466 let msg = Message::new(MessageType::FsRequest, 42, Vec::new());
467
468 let mut buf = Vec::new();
469 encode_to_buf(&msg, &mut buf).unwrap();
470
471 let decoded = try_decode_from_buf(&mut buf).unwrap().unwrap();
472 assert_ne!(decoded.flags & FLAG_SESSION_START, 0);
473 assert_eq!(decoded.flags & FLAG_TERMINAL, 0);
474 }
475
476 #[test]
477 fn test_sync_decode_frame_too_short() {
478 let mut buf = Vec::new();
480 buf.extend_from_slice(&3u32.to_be_bytes());
481 buf.extend_from_slice(&[0, 0, 0]); let result = try_decode_from_buf(&mut buf);
484 assert!(matches!(result, Err(ProtocolError::FrameTooShort { .. })));
485 }
486
487 #[tokio::test]
488 async fn test_raw_frame_roundtrip() {
489 let frame = RawFrame {
490 id: 0xDEADBEEF,
491 flags: FLAG_TERMINAL,
492 body: vec![1, 2, 3, 4, 5],
493 };
494
495 let mut buf = Vec::new();
496 write_raw_frame(&mut buf, &frame).await.unwrap();
497
498 let mut cursor = &buf[..];
499 let decoded = read_raw_frame(&mut cursor).await.unwrap();
500
501 assert_eq!(decoded.id, frame.id);
502 assert_eq!(decoded.flags, frame.flags);
503 assert_eq!(decoded.body, frame.body);
504 }
505
506 #[test]
507 fn test_raw_frame_sync_roundtrip() {
508 let frame = RawFrame {
509 id: 42,
510 flags: FLAG_SESSION_START,
511 body: vec![0xAA; 100],
512 };
513
514 let mut buf = Vec::new();
515 encode_raw_to_buf(&frame, &mut buf).unwrap();
516
517 let decoded = try_decode_raw_from_buf(&mut buf).unwrap().unwrap();
518 assert_eq!(decoded.id, frame.id);
519 assert_eq!(decoded.flags, frame.flags);
520 assert_eq!(decoded.body, frame.body);
521 assert!(buf.is_empty());
522 }
523
524 #[test]
525 fn test_raw_frame_to_message() {
526 use crate::exec::ExecExited;
527
528 let msg =
529 Message::with_payload(MessageType::ExecExited, 13, &ExecExited { code: 7 }).unwrap();
530
531 let mut buf = Vec::new();
532 encode_to_buf(&msg, &mut buf).unwrap();
533
534 let frame = try_decode_raw_from_buf(&mut buf).unwrap().unwrap();
535 let decoded = raw_frame_to_message(frame).unwrap();
536
537 assert_eq!(decoded.id, 13);
538 assert_eq!(decoded.t, MessageType::ExecExited);
539 let payload: ExecExited = decoded.payload().unwrap();
540 assert_eq!(payload.code, 7);
541 }
542}