1use std::io;
8
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10
11use super::{
12 decode_frame_parts, encode_frame, frame_len_from_header, Frame, FrameError, FRAME_HEADER_SIZE,
13};
14
15#[derive(Debug)]
16pub enum RedWireIoError {
17 Io(io::Error),
18 Frame(FrameError),
19}
20
21impl std::fmt::Display for RedWireIoError {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 match self {
24 Self::Io(err) => write!(f, "{err}"),
25 Self::Frame(err) => write!(f, "{err}"),
26 }
27 }
28}
29
30impl std::error::Error for RedWireIoError {
31 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
32 match self {
33 Self::Io(err) => Some(err),
34 Self::Frame(err) => Some(err),
35 }
36 }
37}
38
39impl From<io::Error> for RedWireIoError {
40 fn from(err: io::Error) -> Self {
41 Self::Io(err)
42 }
43}
44
45impl From<FrameError> for RedWireIoError {
46 fn from(err: FrameError) -> Self {
47 Self::Frame(err)
48 }
49}
50
51pub async fn read_frame_async<S>(stream: &mut S) -> Result<Frame, RedWireIoError>
52where
53 S: AsyncRead + Unpin + Send,
54{
55 let mut header = [0u8; FRAME_HEADER_SIZE];
56 stream.read_exact(&mut header).await?;
57 let length = frame_len_from_header(&header)?;
58
59 let payload_len = length - FRAME_HEADER_SIZE;
60 let mut payload = vec![0u8; payload_len];
61 if length > FRAME_HEADER_SIZE {
62 stream.read_exact(&mut payload).await?;
63 }
64 Ok(decode_frame_parts(&header, &payload)?)
65}
66
67pub async fn write_frame_async<S>(stream: &mut S, frame: &Frame) -> Result<(), RedWireIoError>
68where
69 S: AsyncWrite + Unpin + Send,
70{
71 stream.write_all(&encode_frame(frame)).await?;
72 Ok(())
73}
74
75pub fn frame_to_bytes(frame: &Frame) -> Vec<u8> {
78 encode_frame(frame)
79}
80
81pub fn drain_next_frame(buffer: &mut Vec<u8>) -> Result<Option<Frame>, FrameError> {
87 if buffer.len() < FRAME_HEADER_SIZE {
88 return Ok(None);
89 }
90
91 let mut header = [0u8; FRAME_HEADER_SIZE];
92 header.copy_from_slice(&buffer[..FRAME_HEADER_SIZE]);
93 let length = frame_len_from_header(&header)?;
94 if buffer.len() < length {
95 return Ok(None);
96 }
97
98 let payload = &buffer[FRAME_HEADER_SIZE..length];
99 let frame = decode_frame_parts(&header, payload)?;
100 buffer.drain(..length);
101 Ok(Some(frame))
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use crate::redwire::MessageKind;
108
109 #[tokio::test]
110 async fn async_frame_io_round_trips_over_duplex() {
111 let (mut left, mut right) = tokio::io::duplex(1024);
112 let frame = Frame::new(MessageKind::Ping, 42, b"hello".to_vec());
113
114 write_frame_async(&mut left, &frame).await.unwrap();
115 let decoded = read_frame_async(&mut right).await.unwrap();
116
117 assert_eq!(decoded, frame);
118 }
119
120 #[test]
121 fn chunked_frame_io_drains_complete_frames_and_keeps_leftover() {
122 let first = Frame::new(MessageKind::Ping, 1, b"one".to_vec());
123 let second = Frame::new(MessageKind::Pong, 2, b"two".to_vec());
124 let mut bytes = frame_to_bytes(&first);
125 bytes.extend_from_slice(&frame_to_bytes(&second));
126 bytes.extend_from_slice(b"partial");
127
128 assert_eq!(drain_next_frame(&mut bytes).unwrap(), Some(first));
129 assert_eq!(drain_next_frame(&mut bytes).unwrap(), Some(second));
130 assert_eq!(drain_next_frame(&mut bytes).unwrap(), None);
131 assert_eq!(bytes, b"partial");
132 }
133
134 #[test]
135 fn chunked_frame_io_waits_for_complete_payload() {
136 let frame = Frame::new(MessageKind::Ping, 1, b"hello".to_vec());
137 let mut bytes = frame_to_bytes(&frame);
138 bytes.pop();
139
140 assert_eq!(drain_next_frame(&mut bytes).unwrap(), None);
141 }
142}