1use std::io;
8
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10
11use super::{
12 decode_frame_parts, encode_frame, frame_len_from_header, Frame, FrameError, FRAME_HEADER_SIZE,
13};
14
15#[derive(Debug)]
16pub enum RedWireIoError {
17 Io(io::Error),
18 Frame(FrameError),
19}
20
21impl std::fmt::Display for RedWireIoError {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 match self {
24 Self::Io(err) => write!(f, "{err}"),
25 Self::Frame(err) => write!(f, "{err}"),
26 }
27 }
28}
29
30impl std::error::Error for RedWireIoError {
31 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
32 match self {
33 Self::Io(err) => Some(err),
34 Self::Frame(err) => Some(err),
35 }
36 }
37}
38
39impl From<io::Error> for RedWireIoError {
40 fn from(err: io::Error) -> Self {
41 Self::Io(err)
42 }
43}
44
45impl From<FrameError> for RedWireIoError {
46 fn from(err: FrameError) -> Self {
47 Self::Frame(err)
48 }
49}
50
51pub async fn read_frame_async<S>(stream: &mut S) -> Result<Frame, RedWireIoError>
52where
53 S: AsyncRead + Unpin + Send,
54{
55 let mut header = [0u8; FRAME_HEADER_SIZE];
56 stream.read_exact(&mut header).await?;
57 let length = frame_len_from_header(&header)?;
58
59 let payload_len = length - FRAME_HEADER_SIZE;
60 let mut payload = vec![0u8; payload_len];
61 if length > FRAME_HEADER_SIZE {
62 stream.read_exact(&mut payload).await?;
63 }
64 Ok(decode_frame_parts(&header, &payload)?)
65}
66
67pub async fn write_frame_async<S>(stream: &mut S, frame: &Frame) -> Result<(), RedWireIoError>
68where
69 S: AsyncWrite + Unpin + Send,
70{
71 stream.write_all(&encode_frame(frame)).await?;
72 Ok(())
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use crate::redwire::MessageKind;
79
80 #[tokio::test]
81 async fn async_frame_io_round_trips_over_duplex() {
82 let (mut left, mut right) = tokio::io::duplex(1024);
83 let frame = Frame::new(MessageKind::Ping, 42, b"hello".to_vec());
84
85 write_frame_async(&mut left, &frame).await.unwrap();
86 let decoded = read_frame_async(&mut right).await.unwrap();
87
88 assert_eq!(decoded, frame);
89 }
90}