1use rkyv::api::high::{HighDeserializer, HighSerializer, HighValidator};
6use rkyv::ser::allocator::ArenaHandle;
7use rkyv::util::AlignedVec;
8use rkyv::{Archive, Deserialize, Serialize, rancor};
9use std::io::{BufReader, BufWriter, Read, Write};
10use thiserror::Error;
11
12pub const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
14
15#[derive(Debug, Error)]
17#[non_exhaustive]
18pub enum FrameError {
19 #[error("I/O error: {0}")]
21 Io(#[from] std::io::Error),
22
23 #[error("Serialization error: {0}")]
25 Serialization(String),
26
27 #[error("Deserialization error: {0}")]
29 Deserialization(String),
30
31 #[error("Frame too large: {size} bytes (max {max} bytes)")]
33 FrameTooLarge {
34 size: usize,
36 max: usize,
38 },
39
40 #[error("Invalid frame: {0}")]
42 InvalidFrame(String),
43
44 #[error("End of stream")]
46 EndOfStream,
47}
48
49pub fn write_frame<W, T>(writer: &mut BufWriter<W>, message: &T) -> Result<(), FrameError>
58where
59 W: Write,
60 T: for<'a> Serialize<HighSerializer<AlignedVec, ArenaHandle<'a>, rancor::Error>>,
61{
62 let bytes = rkyv::to_bytes::<rancor::Error>(message)
64 .map_err(|e| FrameError::Serialization(e.to_string()))?;
65
66 let len = bytes.len();
67 if len > MAX_FRAME_SIZE {
68 return Err(FrameError::FrameTooLarge {
69 size: len,
70 max: MAX_FRAME_SIZE,
71 });
72 }
73
74 writer.write_all(&(len as u32).to_le_bytes())?;
76
77 writer.write_all(&bytes)?;
79
80 writer.flush()?;
82
83 Ok(())
84}
85
86pub fn read_frame<R, T>(reader: &mut BufReader<R>) -> Result<T, FrameError>
88where
89 R: Read,
90 T: Archive,
91 T::Archived: for<'a> rkyv::bytecheck::CheckBytes<HighValidator<'a, rancor::Error>>
92 + Deserialize<T, HighDeserializer<rancor::Error>>,
93{
94 let mut len_buf = [0u8; 4];
96 match reader.read_exact(&mut len_buf) {
97 Ok(()) => {}
98 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
99 return Err(FrameError::EndOfStream);
100 }
101 Err(e) => return Err(FrameError::Io(e)),
102 }
103
104 let len = u32::from_le_bytes(len_buf) as usize;
105
106 if len > MAX_FRAME_SIZE {
108 return Err(FrameError::FrameTooLarge {
109 size: len,
110 max: MAX_FRAME_SIZE,
111 });
112 }
113
114 if len == 0 {
115 return Err(FrameError::InvalidFrame("zero-length frame".to_string()));
116 }
117
118 let mut buf = AlignedVec::<16>::with_capacity(len);
120 buf.resize(len, 0);
121 reader.read_exact(&mut buf)?;
122
123 let value: T = rkyv::from_bytes::<T, rancor::Error>(&buf)
125 .map_err(|e| FrameError::Deserialization(e.to_string()))?;
126
127 Ok(value)
128}
129
130pub struct FrameWriter<W: Write> {
132 writer: BufWriter<W>,
133}
134
135impl<W: Write> FrameWriter<W> {
136 pub fn new(writer: W) -> Self {
138 Self {
139 writer: BufWriter::with_capacity(64 * 1024, writer), }
141 }
142
143 pub fn write<T>(&mut self, message: &T) -> Result<(), FrameError>
145 where
146 T: for<'a> Serialize<HighSerializer<AlignedVec, ArenaHandle<'a>, rancor::Error>>,
147 {
148 write_frame(&mut self.writer, message)
149 }
150
151 pub fn flush(&mut self) -> Result<(), FrameError> {
153 self.writer.flush()?;
154 Ok(())
155 }
156
157 pub fn inner_mut(&mut self) -> &mut BufWriter<W> {
159 &mut self.writer
160 }
161
162 pub fn into_inner(self) -> BufWriter<W> {
164 self.writer
165 }
166}
167
168pub struct FrameReader<R: Read> {
170 reader: BufReader<R>,
171}
172
173impl<R: Read> FrameReader<R> {
174 pub fn new(reader: R) -> Self {
176 Self {
177 reader: BufReader::with_capacity(64 * 1024, reader), }
179 }
180
181 pub fn read<T>(&mut self) -> Result<T, FrameError>
183 where
184 T: Archive,
185 T::Archived: for<'a> rkyv::bytecheck::CheckBytes<HighValidator<'a, rancor::Error>>
186 + Deserialize<T, HighDeserializer<rancor::Error>>,
187 {
188 read_frame(&mut self.reader)
189 }
190
191 pub fn has_buffered_data(&self) -> bool {
193 !self.reader.buffer().is_empty()
194 }
195
196 pub fn inner_mut(&mut self) -> &mut BufReader<R> {
198 &mut self.reader
199 }
200
201 pub fn into_inner(self) -> BufReader<R> {
203 self.reader
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
211 use std::io::Cursor;
212
213 #[derive(Debug, Clone, PartialEq, Archive, RkyvSerialize, RkyvDeserialize)]
214 struct TestMessage {
215 value: u64,
216 text: String,
217 }
218
219 #[test]
220 fn test_roundtrip() {
221 let original = TestMessage {
222 value: 42,
223 text: "hello world".to_string(),
224 };
225
226 let mut buffer = Vec::new();
228 {
229 let mut writer = FrameWriter::new(&mut buffer);
230 writer.write(&original).unwrap();
231 }
232
233 let mut reader = FrameReader::new(Cursor::new(buffer));
235 let decoded: TestMessage = reader.read().unwrap();
236
237 assert_eq!(original, decoded);
238 }
239
240 #[test]
241 fn test_multiple_messages() {
242 let messages = vec![
243 TestMessage {
244 value: 1,
245 text: "first".to_string(),
246 },
247 TestMessage {
248 value: 2,
249 text: "second".to_string(),
250 },
251 TestMessage {
252 value: 3,
253 text: "third".to_string(),
254 },
255 ];
256
257 let mut buffer = Vec::new();
259 {
260 let mut writer = FrameWriter::new(&mut buffer);
261 for msg in &messages {
262 writer.write(msg).unwrap();
263 }
264 }
265
266 let mut reader = FrameReader::new(Cursor::new(buffer));
268 for expected in &messages {
269 let decoded: TestMessage = reader.read().unwrap();
270 assert_eq!(expected, &decoded);
271 }
272 }
273
274 #[test]
275 fn test_end_of_stream() {
276 let buffer: Vec<u8> = Vec::new();
277 let mut reader = FrameReader::new(Cursor::new(buffer));
278 let result: Result<TestMessage, _> = reader.read();
279 assert!(matches!(result, Err(FrameError::EndOfStream)));
280 }
281}