1use serde::{Serialize, de::DeserializeOwned};
2use std::time::Duration;
3use tokio::io::{AsyncReadExt, AsyncWriteExt};
4
5const MAX_FRAME_BYTES: usize = 10 * 1024 * 1024;
6
7#[derive(thiserror::Error, Debug)]
14pub enum WireError {
15 #[error("io: {0}")]
16 Io(#[from] std::io::Error),
17 #[error("decode: {0}")]
18 Decode(#[from] serde_json::Error),
19 #[error("frame length overflow: {0} bytes (max {MAX_FRAME_BYTES})")]
20 LengthOverflow(u32),
21 #[error("timeout after {0:?}")]
22 Timeout(Duration),
23}
24
25impl From<WireError> for std::io::Error {
26 fn from(e: WireError) -> Self {
29 match e {
30 WireError::Io(io) => io,
31 other => std::io::Error::new(std::io::ErrorKind::InvalidData, other.to_string()),
32 }
33 }
34}
35
36pub async fn write_frame_with_deadline<W, T>(
42 writer: &mut W,
43 msg: &T,
44 deadline: Option<Duration>,
45) -> Result<(), WireError>
46where
47 W: AsyncWriteExt + Unpin,
48 T: Serialize,
49{
50 let fut = async {
51 let body = serde_json::to_vec(msg)?;
52 let len = u32::try_from(body.len()).map_err(|_| WireError::LengthOverflow(u32::MAX))?;
53 if body.len() > MAX_FRAME_BYTES {
54 return Err(WireError::LengthOverflow(len));
55 }
56 writer.write_all(&len.to_be_bytes()).await?;
57 writer.write_all(&body).await?;
58 writer.flush().await?;
59 Ok::<(), WireError>(())
60 };
61 match deadline {
62 None => fut.await,
63 Some(d) => tokio::time::timeout(d, fut)
64 .await
65 .map_err(|_| WireError::Timeout(d))?,
66 }
67}
68
69pub async fn read_frame_with_deadline<R, T>(
71 reader: &mut R,
72 deadline: Option<Duration>,
73) -> Result<T, WireError>
74where
75 R: AsyncReadExt + Unpin,
76 T: DeserializeOwned,
77{
78 let fut = async {
79 let mut len_buf = [0u8; 4];
80 reader.read_exact(&mut len_buf).await?;
81 let len = u32::from_be_bytes(len_buf);
82 if len as usize > MAX_FRAME_BYTES {
83 return Err(WireError::LengthOverflow(len));
84 }
85 let mut body = vec![0u8; len as usize];
86 reader.read_exact(&mut body).await?;
87 Ok::<T, WireError>(serde_json::from_slice(&body)?)
88 };
89 match deadline {
90 None => fut.await,
91 Some(d) => tokio::time::timeout(d, fut)
92 .await
93 .map_err(|_| WireError::Timeout(d))?,
94 }
95}
96
97pub async fn write_frame<W, T>(writer: &mut W, msg: &T) -> std::io::Result<()>
99where
100 W: AsyncWriteExt + Unpin,
101 T: Serialize,
102{
103 write_frame_with_deadline(writer, msg, None)
104 .await
105 .map_err(Into::into)
106}
107
108pub async fn read_frame<R, T>(reader: &mut R) -> std::io::Result<T>
110where
111 R: AsyncReadExt + Unpin,
112 T: DeserializeOwned,
113{
114 read_frame_with_deadline(reader, None)
115 .await
116 .map_err(Into::into)
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use serde::Deserialize;
123
124 #[derive(Serialize, Deserialize, PartialEq, Debug)]
125 struct M {
126 kind: String,
127 n: u32,
128 }
129
130 #[tokio::test]
131 async fn write_then_read_roundtrip() {
132 let msg = M {
133 kind: "ping".into(),
134 n: 7,
135 };
136 let mut buf: Vec<u8> = Vec::new();
137 write_frame(&mut buf, &msg).await.unwrap();
138
139 let mut cursor = std::io::Cursor::new(&buf);
140 let back: M = read_frame(&mut cursor).await.unwrap();
141 assert_eq!(back, msg);
142 }
143
144 #[tokio::test]
145 async fn frame_uses_big_endian_u32_prefix() {
146 let msg = M {
147 kind: "x".into(),
148 n: 1,
149 };
150 let mut buf: Vec<u8> = Vec::new();
151 write_frame(&mut buf, &msg).await.unwrap();
152 let body_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
153 assert_eq!(body_len, buf.len() - 4);
154 }
155
156 #[tokio::test]
157 async fn oversized_frame_errors() {
158 let mut header = Vec::new();
160 let bogus_len: u32 = 20 * 1024 * 1024;
161 header.extend_from_slice(&bogus_len.to_be_bytes());
162 let mut cursor = std::io::Cursor::new(header);
163 let err = read_frame::<_, M>(&mut cursor).await.unwrap_err();
164 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
165 }
166
167 #[tokio::test]
170 async fn read_frame_with_deadline_returns_timeout_on_no_data() {
171 let (_producer, mut consumer) = tokio::io::duplex(64);
174 let err = read_frame_with_deadline::<_, M>(&mut consumer, Some(Duration::from_millis(50)))
175 .await
176 .unwrap_err();
177 assert!(matches!(err, WireError::Timeout(_)), "got {err:?}");
178 }
179
180 #[tokio::test]
181 async fn read_frame_with_deadline_succeeds_within_deadline() {
182 let msg = M {
183 kind: "ok".into(),
184 n: 42,
185 };
186 let mut buf: Vec<u8> = Vec::new();
187 write_frame(&mut buf, &msg).await.unwrap();
188 let mut cursor = std::io::Cursor::new(&buf);
189 let back: M = read_frame_with_deadline(&mut cursor, Some(Duration::from_millis(500)))
190 .await
191 .unwrap();
192 assert_eq!(back, msg);
193 }
194
195 #[tokio::test]
196 async fn read_frame_with_deadline_none_means_unbounded() {
197 let msg = M {
199 kind: "k".into(),
200 n: 1,
201 };
202 let mut buf: Vec<u8> = Vec::new();
203 write_frame(&mut buf, &msg).await.unwrap();
204 let mut cursor = std::io::Cursor::new(&buf);
205 let back: M = read_frame_with_deadline(&mut cursor, None).await.unwrap();
206 assert_eq!(back, msg);
207 }
208
209 #[tokio::test]
210 async fn write_frame_with_deadline_succeeds_to_in_memory_buf() {
211 let msg = M {
212 kind: "k".into(),
213 n: 1,
214 };
215 let mut buf: Vec<u8> = Vec::new();
216 write_frame_with_deadline(&mut buf, &msg, Some(Duration::from_millis(500)))
217 .await
218 .unwrap();
219 assert!(buf.len() > 4);
220 }
221
222 #[tokio::test]
223 async fn write_frame_with_deadline_returns_timeout_on_blocked_writer() {
224 let (mut producer, _consumer) = tokio::io::duplex(4);
228 let big_msg = M {
229 kind: "x".repeat(50),
230 n: 1,
231 };
232 let err =
233 write_frame_with_deadline(&mut producer, &big_msg, Some(Duration::from_millis(50)))
234 .await
235 .unwrap_err();
236 assert!(matches!(err, WireError::Timeout(_)), "got {err:?}");
237 }
238
239 #[tokio::test]
240 async fn wire_error_into_io_error_preserves_io_underlying() {
241 let io_err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "nope");
242 let wire = WireError::Io(io_err);
243 let back: std::io::Error = wire.into();
244 assert_eq!(back.kind(), std::io::ErrorKind::PermissionDenied);
245 }
246
247 #[tokio::test]
248 async fn wire_error_into_io_error_wraps_non_io_as_invalid_data() {
249 let wire = WireError::Timeout(Duration::from_millis(100));
250 let back: std::io::Error = wire.into();
251 assert_eq!(back.kind(), std::io::ErrorKind::InvalidData);
252 assert!(back.to_string().contains("timeout"));
253 }
254
255 #[tokio::test]
256 async fn length_overflow_distinct_from_io_error() {
257 let mut header = Vec::new();
258 let bogus_len: u32 = 20 * 1024 * 1024;
259 header.extend_from_slice(&bogus_len.to_be_bytes());
260 let mut cursor = std::io::Cursor::new(header);
261 let err = read_frame_with_deadline::<_, M>(&mut cursor, None)
262 .await
263 .unwrap_err();
264 assert!(
265 matches!(err, WireError::LengthOverflow(n) if n == bogus_len),
266 "got {err:?}"
267 );
268 }
269}