1use std::io::{self, Read, Write};
6
7const MAX_MESSAGE_SIZE: u32 = 16 * 1024 * 1024;
9
10pub fn write_framed(stream: &mut impl Write, data: &[u8]) -> io::Result<()> {
12 let len: u32 = data.len().try_into().map_err(|_| {
13 io::Error::new(
14 io::ErrorKind::InvalidInput,
15 format!(
16 "Message too large: {} bytes (max {})",
17 data.len(),
18 MAX_MESSAGE_SIZE
19 ),
20 )
21 })?;
22 if len > MAX_MESSAGE_SIZE {
23 return Err(io::Error::new(
24 io::ErrorKind::InvalidInput,
25 format!(
26 "Message too large: {} bytes (max {})",
27 len, MAX_MESSAGE_SIZE
28 ),
29 ));
30 }
31 stream.write_all(&len.to_be_bytes())?;
32 stream.write_all(data)?;
33 stream.flush()
34}
35
36pub fn read_framed(stream: &mut impl Read) -> io::Result<Vec<u8>> {
38 let mut len_buf = [0u8; 4];
39 stream.read_exact(&mut len_buf)?;
40 let len = u32::from_be_bytes(len_buf);
41
42 if len > MAX_MESSAGE_SIZE {
43 return Err(io::Error::new(
44 io::ErrorKind::InvalidData,
45 format!(
46 "Message too large: {} bytes (max {})",
47 len, MAX_MESSAGE_SIZE
48 ),
49 ));
50 }
51
52 let mut buf = vec![0u8; len as usize];
53 stream.read_exact(&mut buf)?;
54 Ok(buf)
55}
56
57pub async fn write_framed_async(
59 stream: &mut (impl tokio::io::AsyncWriteExt + Unpin),
60 data: &[u8],
61) -> io::Result<()> {
62 let len: u32 = data.len().try_into().map_err(|_| {
63 io::Error::new(
64 io::ErrorKind::InvalidInput,
65 format!(
66 "Message too large: {} bytes (max {})",
67 data.len(),
68 MAX_MESSAGE_SIZE
69 ),
70 )
71 })?;
72 if len > MAX_MESSAGE_SIZE {
73 return Err(io::Error::new(
74 io::ErrorKind::InvalidInput,
75 format!(
76 "Message too large: {} bytes (max {})",
77 len, MAX_MESSAGE_SIZE
78 ),
79 ));
80 }
81 stream.write_all(&len.to_be_bytes()).await?;
82 stream.write_all(data).await?;
83 stream.flush().await
84}
85
86pub async fn read_framed_async(
88 stream: &mut (impl tokio::io::AsyncReadExt + Unpin),
89) -> io::Result<Vec<u8>> {
90 let mut len_buf = [0u8; 4];
91 stream.read_exact(&mut len_buf).await?;
92 let len = u32::from_be_bytes(len_buf);
93
94 if len > MAX_MESSAGE_SIZE {
95 return Err(io::Error::new(
96 io::ErrorKind::InvalidData,
97 format!(
98 "Message too large: {} bytes (max {})",
99 len, MAX_MESSAGE_SIZE
100 ),
101 ));
102 }
103
104 let mut buf = vec![0u8; len as usize];
105 stream.read_exact(&mut buf).await?;
106 Ok(buf)
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use std::io::Cursor;
113
114 #[test]
115 fn test_sync_roundtrip() {
116 let data = b"hello world";
117 let mut buf = Vec::new();
118 write_framed(&mut buf, data).unwrap();
119
120 let mut cursor = Cursor::new(buf);
121 let result = read_framed(&mut cursor).unwrap();
122 assert_eq!(result, data);
123 }
124
125 #[test]
126 fn test_empty_message() {
127 let data = b"";
128 let mut buf = Vec::new();
129 write_framed(&mut buf, data).unwrap();
130
131 let mut cursor = Cursor::new(buf);
132 let result = read_framed(&mut cursor).unwrap();
133 assert_eq!(result, data.to_vec());
134 }
135
136 #[test]
137 fn test_message_too_large() {
138 let len = (MAX_MESSAGE_SIZE + 1).to_be_bytes();
139 let mut cursor = Cursor::new(len.to_vec());
140 let result = read_framed(&mut cursor);
141 assert!(result.is_err());
142 assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
143 }
144
145 #[tokio::test]
146 async fn test_async_roundtrip() {
147 let data = b"hello async world";
148 let mut buf = Vec::new();
149 write_framed_async(&mut buf, data).await.unwrap();
150
151 let mut cursor = io::Cursor::new(buf);
152 let result = read_framed_async(&mut cursor).await.unwrap();
153 assert_eq!(result, data);
154 }
155}