1use crate::{Frame, ProtocolError};
4use bytes::{Buf, BufMut, Bytes, BytesMut};
5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6
7pub const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
9
10pub struct FrameCodec {
12 read_buf: BytesMut,
14 max_frame_size: usize,
16}
17
18impl Default for FrameCodec {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl FrameCodec {
25 pub fn new() -> Self {
27 Self {
28 read_buf: BytesMut::with_capacity(8192),
29 max_frame_size: MAX_FRAME_SIZE,
30 }
31 }
32
33 pub fn with_max_frame_size(max_frame_size: usize) -> Self {
35 Self {
36 read_buf: BytesMut::with_capacity(8192),
37 max_frame_size,
38 }
39 }
40
41 pub fn encode_frame(&self, frame: &Frame) -> Result<Bytes, ProtocolError> {
43 let frame_bytes = frame.to_msgpack()?;
45
46 if frame_bytes.len() > self.max_frame_size {
48 return Err(ProtocolError::FrameTooLarge {
49 size: frame_bytes.len(),
50 max: self.max_frame_size,
51 });
52 }
53
54 let mut buf = BytesMut::with_capacity(4 + frame_bytes.len());
56 buf.put_u32(frame_bytes.len() as u32);
57 buf.put_slice(&frame_bytes);
58
59 Ok(buf.freeze())
60 }
61
62 pub async fn write_frame<W>(&self, writer: &mut W, frame: &Frame) -> Result<(), ProtocolError>
64 where
65 W: AsyncWrite + Unpin,
66 {
67 let encoded = self.encode_frame(frame)?;
68 writer.write_all(&encoded).await
69 .map_err(|e| ProtocolError::Serialization(format!("Write error: {}", e)))?;
70 writer.flush().await
71 .map_err(|e| ProtocolError::Serialization(format!("Flush error: {}", e)))?;
72 Ok(())
73 }
74
75 pub async fn read_frame<R>(&mut self, reader: &mut R) -> Result<Option<Frame>, ProtocolError>
77 where
78 R: AsyncRead + Unpin,
79 {
80 loop {
81 if let Some(frame) = self.try_decode_frame()? {
83 return Ok(Some(frame));
84 }
85
86 let mut temp_buf = [0u8; 8192];
88 let n = reader.read(&mut temp_buf).await
89 .map_err(|e| ProtocolError::Serialization(format!("Read error: {}", e)))?;
90
91 if n == 0 {
92 if self.read_buf.is_empty() {
94 return Ok(None);
95 } else {
96 return Err(ProtocolError::InvalidFrame);
97 }
98 }
99
100 self.read_buf.extend_from_slice(&temp_buf[..n]);
101 }
102 }
103
104 pub fn try_decode_frame(&mut self) -> Result<Option<Frame>, ProtocolError> {
106 if self.read_buf.len() < 4 {
107 return Ok(None);
109 }
110
111 let frame_len = (&self.read_buf[..4]).get_u32() as usize;
113
114 if frame_len > self.max_frame_size {
116 return Err(ProtocolError::FrameTooLarge {
117 size: frame_len,
118 max: self.max_frame_size,
119 });
120 }
121
122 if self.read_buf.len() < 4 + frame_len {
124 return Ok(None);
125 }
126
127 self.read_buf.advance(4);
129
130 let frame_data = self.read_buf.split_to(frame_len);
132
133 let frame = Frame::from_msgpack(&frame_data)?;
135 Ok(Some(frame))
136 }
137
138 pub fn buffer_size(&self) -> usize {
140 self.read_buf.len()
141 }
142
143 pub fn clear_buffer(&mut self) {
145 self.read_buf.clear();
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use std::io::Cursor;
153 use proptest::prelude::*;
154
155 #[tokio::test]
156 async fn test_frame_encode_decode() {
157 let codec = FrameCodec::new();
158 let frame = Frame::data(1, 42, Bytes::from("test payload"));
159
160 let encoded = codec.encode_frame(&frame).unwrap();
161 assert!(encoded.len() > 4); let mut codec2 = FrameCodec::new();
164 let mut cursor = Cursor::new(encoded);
165 let decoded = codec2.read_frame(&mut cursor).await.unwrap().unwrap();
166
167 assert_eq!(frame.stream_id, decoded.stream_id);
168 assert_eq!(frame.sequence, decoded.sequence);
169 assert_eq!(frame.flags, decoded.flags);
170 assert_eq!(frame.payload, decoded.payload);
171 }
172
173 #[tokio::test]
174 async fn test_write_read_frame() {
175 let codec = FrameCodec::new();
176 let frame = Frame::end_stream(123, 456);
177
178 let mut buffer = Vec::new();
179 codec.write_frame(&mut buffer, &frame).await.unwrap();
180
181 let mut codec2 = FrameCodec::new();
182 let mut cursor = Cursor::new(buffer);
183 let decoded = codec2.read_frame(&mut cursor).await.unwrap().unwrap();
184
185 assert_eq!(frame.stream_id, decoded.stream_id);
186 assert_eq!(frame.sequence, decoded.sequence);
187 assert!(decoded.is_end_stream());
188 }
189
190 #[tokio::test]
191 async fn test_partial_frame_reading() {
192 let codec = FrameCodec::new();
193 let frame = Frame::data(1, 1, Bytes::from("test"));
194 let encoded = codec.encode_frame(&frame).unwrap();
195
196 let mut codec2 = FrameCodec::new();
198
199 let mid = encoded.len() / 2;
201 codec2.read_buf.extend_from_slice(&encoded[..mid]);
202
203 let result1 = codec2.try_decode_frame().unwrap();
205 assert!(result1.is_none());
206
207 codec2.read_buf.extend_from_slice(&encoded[mid..]);
209
210 let result2 = codec2.try_decode_frame().unwrap().unwrap();
212
213 assert_eq!(frame.stream_id, result2.stream_id);
214 assert_eq!(frame.payload, result2.payload);
215 }
216
217 #[tokio::test]
218 async fn test_multiple_frames_in_buffer() {
219 let codec = FrameCodec::new();
220 let frame1 = Frame::data(1, 1, Bytes::from("first"));
221 let frame2 = Frame::data(2, 2, Bytes::from("second"));
222
223 let encoded1 = codec.encode_frame(&frame1).unwrap();
224 let encoded2 = codec.encode_frame(&frame2).unwrap();
225
226 let mut combined = BytesMut::new();
228 combined.extend_from_slice(&encoded1);
229 combined.extend_from_slice(&encoded2);
230
231 let mut codec2 = FrameCodec::new();
232 let mut cursor = Cursor::new(combined.freeze());
233
234 let decoded1 = codec2.read_frame(&mut cursor).await.unwrap().unwrap();
236 assert_eq!(frame1.stream_id, decoded1.stream_id);
237 assert_eq!(frame1.payload, decoded1.payload);
238
239 let decoded2 = codec2.read_frame(&mut cursor).await.unwrap().unwrap();
241 assert_eq!(frame2.stream_id, decoded2.stream_id);
242 assert_eq!(frame2.payload, decoded2.payload);
243
244 let result3 = codec2.read_frame(&mut cursor).await.unwrap();
246 assert!(result3.is_none());
247 }
248
249 #[tokio::test]
250 async fn test_frame_too_large() {
251 let codec = FrameCodec::with_max_frame_size(100);
252 let large_payload = Bytes::from(vec![0u8; 200]);
253 let frame = Frame::data(1, 1, large_payload);
254
255 let result = codec.encode_frame(&frame);
256 assert!(matches!(result, Err(ProtocolError::FrameTooLarge { .. })));
257 }
258
259 #[tokio::test]
260 async fn test_invalid_frame_data() {
261 let mut codec = FrameCodec::new();
262
263 let mut invalid_data = BytesMut::new();
265 invalid_data.put_u32(4); invalid_data.put_slice(&[0xFF, 0xFF, 0xFF, 0xFF]); let mut cursor = Cursor::new(invalid_data.freeze());
269 let result = codec.read_frame(&mut cursor).await;
270
271 assert!(matches!(result, Err(ProtocolError::Serialization(_))));
272 }
273
274 #[tokio::test]
275 async fn test_empty_stream() {
276 let mut codec = FrameCodec::new();
277 let mut cursor = Cursor::new(Vec::<u8>::new());
278
279 let result = codec.read_frame(&mut cursor).await.unwrap();
280 assert!(result.is_none());
281 }
282
283 proptest! {
284 #[test]
285 fn test_codec_roundtrip_properties(
286 stream_id in any::<u32>(),
287 sequence in any::<u32>(),
288 payload in prop::collection::vec(any::<u8>(), 0..1000)
289 ) {
290 tokio_test::block_on(async {
291 let codec = FrameCodec::new();
292 let frame = Frame::data(stream_id, sequence, Bytes::from(payload));
293
294 let encoded = codec.encode_frame(&frame)?;
295
296 let mut codec2 = FrameCodec::new();
297 let mut cursor = Cursor::new(encoded);
298 let decoded = codec2.read_frame(&mut cursor).await?.unwrap();
299
300 prop_assert_eq!(frame.stream_id, decoded.stream_id);
301 prop_assert_eq!(frame.sequence, decoded.sequence);
302 prop_assert_eq!(frame.flags, decoded.flags);
303 prop_assert_eq!(frame.payload, decoded.payload);
304
305 Ok(())
306 })?;
307 }
308 }
309}