mitoxide_proto/
codec.rs

1//! Frame codec for async streams
2
3use crate::{Frame, ProtocolError};
4use bytes::{Buf, BufMut, Bytes, BytesMut};
5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6
7/// Maximum frame size (16MB)
8pub const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
9
10/// Frame codec for encoding/decoding frames over async streams
11pub struct FrameCodec {
12    /// Read buffer for incoming data
13    read_buf: BytesMut,
14    /// Maximum frame size allowed
15    max_frame_size: usize,
16}
17
18impl Default for FrameCodec {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl FrameCodec {
25    /// Create a new frame codec with default settings
26    pub fn new() -> Self {
27        Self {
28            read_buf: BytesMut::with_capacity(8192),
29            max_frame_size: MAX_FRAME_SIZE,
30        }
31    }
32    
33    /// Create a new frame codec with custom max frame size
34    pub fn with_max_frame_size(max_frame_size: usize) -> Self {
35        Self {
36            read_buf: BytesMut::with_capacity(8192),
37            max_frame_size,
38        }
39    }
40    
41    /// Encode a frame to bytes with length prefix
42    pub fn encode_frame(&self, frame: &Frame) -> Result<Bytes, ProtocolError> {
43        // Serialize the frame to MessagePack
44        let frame_bytes = frame.to_msgpack()?;
45        
46        // Check frame size limit
47        if frame_bytes.len() > self.max_frame_size {
48            return Err(ProtocolError::FrameTooLarge {
49                size: frame_bytes.len(),
50                max: self.max_frame_size,
51            });
52        }
53        
54        // Create buffer with length prefix (4 bytes) + frame data
55        let mut buf = BytesMut::with_capacity(4 + frame_bytes.len());
56        buf.put_u32(frame_bytes.len() as u32);
57        buf.put_slice(&frame_bytes);
58        
59        Ok(buf.freeze())
60    }
61    
62    /// Write a frame to an async writer
63    pub async fn write_frame<W>(&self, writer: &mut W, frame: &Frame) -> Result<(), ProtocolError>
64    where
65        W: AsyncWrite + Unpin,
66    {
67        let encoded = self.encode_frame(frame)?;
68        writer.write_all(&encoded).await
69            .map_err(|e| ProtocolError::Serialization(format!("Write error: {}", e)))?;
70        writer.flush().await
71            .map_err(|e| ProtocolError::Serialization(format!("Flush error: {}", e)))?;
72        Ok(())
73    }
74    
75    /// Read a frame from an async reader
76    pub async fn read_frame<R>(&mut self, reader: &mut R) -> Result<Option<Frame>, ProtocolError>
77    where
78        R: AsyncRead + Unpin,
79    {
80        loop {
81            // Try to decode a frame from the buffer
82            if let Some(frame) = self.try_decode_frame()? {
83                return Ok(Some(frame));
84            }
85            
86            // Need more data, read from the stream
87            let mut temp_buf = [0u8; 8192];
88            let n = reader.read(&mut temp_buf).await
89                .map_err(|e| ProtocolError::Serialization(format!("Read error: {}", e)))?;
90            
91            if n == 0 {
92                // EOF reached
93                if self.read_buf.is_empty() {
94                    return Ok(None);
95                } else {
96                    return Err(ProtocolError::InvalidFrame);
97                }
98            }
99            
100            self.read_buf.extend_from_slice(&temp_buf[..n]);
101        }
102    }
103    
104    /// Try to decode a frame from the internal buffer
105    pub fn try_decode_frame(&mut self) -> Result<Option<Frame>, ProtocolError> {
106        if self.read_buf.len() < 4 {
107            // Not enough data for length prefix
108            return Ok(None);
109        }
110        
111        // Read the length prefix without consuming it
112        let frame_len = (&self.read_buf[..4]).get_u32() as usize;
113        
114        // Check frame size limit
115        if frame_len > self.max_frame_size {
116            return Err(ProtocolError::FrameTooLarge {
117                size: frame_len,
118                max: self.max_frame_size,
119            });
120        }
121        
122        // Check if we have the complete frame
123        if self.read_buf.len() < 4 + frame_len {
124            return Ok(None);
125        }
126        
127        // We have a complete frame, consume the length prefix
128        self.read_buf.advance(4);
129        
130        // Extract the frame data
131        let frame_data = self.read_buf.split_to(frame_len);
132        
133        // Deserialize the frame
134        let frame = Frame::from_msgpack(&frame_data)?;
135        Ok(Some(frame))
136    }
137    
138    /// Get the current buffer size
139    pub fn buffer_size(&self) -> usize {
140        self.read_buf.len()
141    }
142    
143    /// Clear the internal buffer
144    pub fn clear_buffer(&mut self) {
145        self.read_buf.clear();
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use std::io::Cursor;
153    use proptest::prelude::*;
154    
155    #[tokio::test]
156    async fn test_frame_encode_decode() {
157        let codec = FrameCodec::new();
158        let frame = Frame::data(1, 42, Bytes::from("test payload"));
159        
160        let encoded = codec.encode_frame(&frame).unwrap();
161        assert!(encoded.len() > 4); // Should have length prefix
162        
163        let mut codec2 = FrameCodec::new();
164        let mut cursor = Cursor::new(encoded);
165        let decoded = codec2.read_frame(&mut cursor).await.unwrap().unwrap();
166        
167        assert_eq!(frame.stream_id, decoded.stream_id);
168        assert_eq!(frame.sequence, decoded.sequence);
169        assert_eq!(frame.flags, decoded.flags);
170        assert_eq!(frame.payload, decoded.payload);
171    }
172    
173    #[tokio::test]
174    async fn test_write_read_frame() {
175        let codec = FrameCodec::new();
176        let frame = Frame::end_stream(123, 456);
177        
178        let mut buffer = Vec::new();
179        codec.write_frame(&mut buffer, &frame).await.unwrap();
180        
181        let mut codec2 = FrameCodec::new();
182        let mut cursor = Cursor::new(buffer);
183        let decoded = codec2.read_frame(&mut cursor).await.unwrap().unwrap();
184        
185        assert_eq!(frame.stream_id, decoded.stream_id);
186        assert_eq!(frame.sequence, decoded.sequence);
187        assert!(decoded.is_end_stream());
188    }
189    
190    #[tokio::test]
191    async fn test_partial_frame_reading() {
192        let codec = FrameCodec::new();
193        let frame = Frame::data(1, 1, Bytes::from("test"));
194        let encoded = codec.encode_frame(&frame).unwrap();
195        
196        // Test the try_decode_frame method directly with partial data
197        let mut codec2 = FrameCodec::new();
198        
199        // Add partial data to the buffer
200        let mid = encoded.len() / 2;
201        codec2.read_buf.extend_from_slice(&encoded[..mid]);
202        
203        // Should return None (incomplete)
204        let result1 = codec2.try_decode_frame().unwrap();
205        assert!(result1.is_none());
206        
207        // Add the rest of the data
208        codec2.read_buf.extend_from_slice(&encoded[mid..]);
209        
210        // Should now return the complete frame
211        let result2 = codec2.try_decode_frame().unwrap().unwrap();
212        
213        assert_eq!(frame.stream_id, result2.stream_id);
214        assert_eq!(frame.payload, result2.payload);
215    }
216    
217    #[tokio::test]
218    async fn test_multiple_frames_in_buffer() {
219        let codec = FrameCodec::new();
220        let frame1 = Frame::data(1, 1, Bytes::from("first"));
221        let frame2 = Frame::data(2, 2, Bytes::from("second"));
222        
223        let encoded1 = codec.encode_frame(&frame1).unwrap();
224        let encoded2 = codec.encode_frame(&frame2).unwrap();
225        
226        // Combine both frames in one buffer
227        let mut combined = BytesMut::new();
228        combined.extend_from_slice(&encoded1);
229        combined.extend_from_slice(&encoded2);
230        
231        let mut codec2 = FrameCodec::new();
232        let mut cursor = Cursor::new(combined.freeze());
233        
234        // Read first frame
235        let decoded1 = codec2.read_frame(&mut cursor).await.unwrap().unwrap();
236        assert_eq!(frame1.stream_id, decoded1.stream_id);
237        assert_eq!(frame1.payload, decoded1.payload);
238        
239        // Read second frame
240        let decoded2 = codec2.read_frame(&mut cursor).await.unwrap().unwrap();
241        assert_eq!(frame2.stream_id, decoded2.stream_id);
242        assert_eq!(frame2.payload, decoded2.payload);
243        
244        // No more frames
245        let result3 = codec2.read_frame(&mut cursor).await.unwrap();
246        assert!(result3.is_none());
247    }
248    
249    #[tokio::test]
250    async fn test_frame_too_large() {
251        let codec = FrameCodec::with_max_frame_size(100);
252        let large_payload = Bytes::from(vec![0u8; 200]);
253        let frame = Frame::data(1, 1, large_payload);
254        
255        let result = codec.encode_frame(&frame);
256        assert!(matches!(result, Err(ProtocolError::FrameTooLarge { .. })));
257    }
258    
259    #[tokio::test]
260    async fn test_invalid_frame_data() {
261        let mut codec = FrameCodec::new();
262        
263        // Create invalid frame data (valid length prefix but invalid MessagePack)
264        let mut invalid_data = BytesMut::new();
265        invalid_data.put_u32(4); // Length prefix
266        invalid_data.put_slice(&[0xFF, 0xFF, 0xFF, 0xFF]); // Invalid MessagePack
267        
268        let mut cursor = Cursor::new(invalid_data.freeze());
269        let result = codec.read_frame(&mut cursor).await;
270        
271        assert!(matches!(result, Err(ProtocolError::Serialization(_))));
272    }
273    
274    #[tokio::test]
275    async fn test_empty_stream() {
276        let mut codec = FrameCodec::new();
277        let mut cursor = Cursor::new(Vec::<u8>::new());
278        
279        let result = codec.read_frame(&mut cursor).await.unwrap();
280        assert!(result.is_none());
281    }
282    
283    proptest! {
284        #[test]
285        fn test_codec_roundtrip_properties(
286            stream_id in any::<u32>(),
287            sequence in any::<u32>(),
288            payload in prop::collection::vec(any::<u8>(), 0..1000)
289        ) {
290            tokio_test::block_on(async {
291                let codec = FrameCodec::new();
292                let frame = Frame::data(stream_id, sequence, Bytes::from(payload));
293                
294                let encoded = codec.encode_frame(&frame)?;
295                
296                let mut codec2 = FrameCodec::new();
297                let mut cursor = Cursor::new(encoded);
298                let decoded = codec2.read_frame(&mut cursor).await?.unwrap();
299                
300                prop_assert_eq!(frame.stream_id, decoded.stream_id);
301                prop_assert_eq!(frame.sequence, decoded.sequence);
302                prop_assert_eq!(frame.flags, decoded.flags);
303                prop_assert_eq!(frame.payload, decoded.payload);
304                
305                Ok(())
306            })?;
307        }
308    }
309}