Skip to main content

oxigdal_websocket/protocol/
framing.rs

1//! Message framing for WebSocket protocol
2
3use crate::error::{Error, Result};
4use bytes::{BufMut, Bytes, BytesMut};
5
6/// Frame type enumeration
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8#[repr(u8)]
9pub enum FrameType {
10    /// Data frame
11    Data = 0,
12    /// Control frame
13    Control = 1,
14    /// Heartbeat frame
15    Heartbeat = 2,
16    /// Fragment start
17    FragmentStart = 3,
18    /// Fragment continuation
19    FragmentContinuation = 4,
20    /// Fragment end
21    FragmentEnd = 5,
22}
23
24impl TryFrom<u8> for FrameType {
25    type Error = Error;
26
27    fn try_from(value: u8) -> Result<Self> {
28        match value {
29            0 => Ok(FrameType::Data),
30            1 => Ok(FrameType::Control),
31            2 => Ok(FrameType::Heartbeat),
32            3 => Ok(FrameType::FragmentStart),
33            4 => Ok(FrameType::FragmentContinuation),
34            5 => Ok(FrameType::FragmentEnd),
35            _ => Err(Error::Protocol(format!("Invalid frame type: {}", value))),
36        }
37    }
38}
39
40/// Frame header structure
41///
42/// Layout (8 bytes):
43/// - Byte 0: Frame type (4 bits) | Protocol version (4 bits)
44/// - Byte 1: Flags (compressed: 1 bit, fragmented: 1 bit, reserved: 6 bits)
45/// - Bytes 2-5: Payload length (u32, big-endian)
46/// - Bytes 6-7: Reserved
47#[derive(Debug, Clone)]
48pub struct FrameHeader {
49    /// Frame type
50    pub frame_type: FrameType,
51    /// Protocol version
52    pub version: u8,
53    /// Compressed flag
54    pub compressed: bool,
55    /// Fragmented flag
56    pub fragmented: bool,
57    /// Payload length
58    pub payload_length: u32,
59}
60
61impl FrameHeader {
62    /// Header size in bytes
63    pub const SIZE: usize = 8;
64
65    /// Create a new frame header
66    pub fn new(frame_type: FrameType, version: u8, compressed: bool, payload_length: u32) -> Self {
67        Self {
68            frame_type,
69            version,
70            compressed,
71            fragmented: false,
72            payload_length,
73        }
74    }
75
76    /// Encode frame header to bytes
77    pub fn encode(&self) -> [u8; Self::SIZE] {
78        let mut buf = [0u8; Self::SIZE];
79
80        // Byte 0: frame type (upper 4 bits) | version (lower 4 bits)
81        buf[0] = ((self.frame_type as u8) << 4) | (self.version & 0x0F);
82
83        // Byte 1: flags
84        let mut flags = 0u8;
85        if self.compressed {
86            flags |= 0x80; // Set bit 7
87        }
88        if self.fragmented {
89            flags |= 0x40; // Set bit 6
90        }
91        buf[1] = flags;
92
93        // Bytes 2-5: payload length (big-endian)
94        buf[2..6].copy_from_slice(&self.payload_length.to_be_bytes());
95
96        // Bytes 6-7: reserved (zeros)
97        buf
98    }
99
100    /// Decode frame header from bytes
101    pub fn decode(data: &[u8]) -> Result<Self> {
102        if data.len() < Self::SIZE {
103            return Err(Error::Protocol(format!(
104                "Insufficient data for frame header: expected {}, got {}",
105                Self::SIZE,
106                data.len()
107            )));
108        }
109
110        // Parse byte 0
111        let frame_type = FrameType::try_from(data[0] >> 4)?;
112        let version = data[0] & 0x0F;
113
114        // Parse byte 1
115        let compressed = (data[1] & 0x80) != 0;
116        let fragmented = (data[1] & 0x40) != 0;
117
118        // Parse bytes 2-5
119        let payload_length = u32::from_be_bytes([data[2], data[3], data[4], data[5]]);
120
121        Ok(Self {
122            frame_type,
123            version,
124            compressed,
125            fragmented,
126            payload_length,
127        })
128    }
129
130    /// Get total frame size (header + payload)
131    pub fn total_size(&self) -> usize {
132        Self::SIZE + self.payload_length as usize
133    }
134}
135
136/// WebSocket frame
137#[derive(Debug, Clone)]
138pub struct Frame {
139    /// Frame header
140    pub header: FrameHeader,
141    /// Frame payload
142    pub payload: Bytes,
143}
144
145impl Frame {
146    /// Create a new frame
147    pub fn new(frame_type: FrameType, version: u8, compressed: bool, payload: Bytes) -> Self {
148        let header = FrameHeader::new(frame_type, version, compressed, payload.len() as u32);
149        Self { header, payload }
150    }
151
152    /// Create a data frame
153    pub fn data(version: u8, compressed: bool, payload: Bytes) -> Self {
154        Self::new(FrameType::Data, version, compressed, payload)
155    }
156
157    /// Create a control frame
158    pub fn control(version: u8, payload: Bytes) -> Self {
159        Self::new(FrameType::Control, version, false, payload)
160    }
161
162    /// Create a heartbeat frame
163    pub fn heartbeat(version: u8) -> Self {
164        Self::new(FrameType::Heartbeat, version, false, Bytes::new())
165    }
166
167    /// Get frame size
168    pub fn size(&self) -> usize {
169        self.header.total_size()
170    }
171}
172
173/// Frame codec for encoding and decoding frames
174pub struct FrameCodec {
175    max_payload_size: u32,
176}
177
178impl FrameCodec {
179    /// Create a new frame codec
180    pub fn new() -> Self {
181        Self {
182            max_payload_size: 16 * 1024 * 1024, // 16MB
183        }
184    }
185
186    /// Create a new frame codec with custom max payload size
187    pub fn with_max_payload_size(max_payload_size: u32) -> Self {
188        Self { max_payload_size }
189    }
190
191    /// Encode a frame to bytes
192    pub fn encode(&self, frame: &Frame) -> Result<Bytes> {
193        if frame.header.payload_length > self.max_payload_size {
194            return Err(Error::Protocol(format!(
195                "Payload size {} exceeds maximum {}",
196                frame.header.payload_length, self.max_payload_size
197            )));
198        }
199
200        let mut buf = BytesMut::with_capacity(frame.size());
201
202        // Write header
203        buf.put_slice(&frame.header.encode());
204
205        // Write payload
206        buf.put_slice(&frame.payload);
207
208        Ok(buf.freeze())
209    }
210
211    /// Decode a frame from bytes
212    pub fn decode(&self, data: &[u8]) -> Result<Frame> {
213        // Parse header
214        let header = FrameHeader::decode(data)?;
215
216        // Validate payload size
217        if header.payload_length > self.max_payload_size {
218            return Err(Error::Protocol(format!(
219                "Payload size {} exceeds maximum {}",
220                header.payload_length, self.max_payload_size
221            )));
222        }
223
224        // Check total data length
225        let total_size = header.total_size();
226        if data.len() < total_size {
227            return Err(Error::Protocol(format!(
228                "Insufficient data for frame: expected {}, got {}",
229                total_size,
230                data.len()
231            )));
232        }
233
234        // Extract payload
235        let payload = Bytes::copy_from_slice(&data[FrameHeader::SIZE..total_size]);
236
237        Ok(Frame { header, payload })
238    }
239
240    /// Decode multiple frames from a buffer
241    pub fn decode_all(&self, data: &[u8]) -> Result<Vec<Frame>> {
242        let mut frames = Vec::new();
243        let mut offset = 0;
244
245        while offset < data.len() {
246            if data.len() - offset < FrameHeader::SIZE {
247                break;
248            }
249
250            let header = FrameHeader::decode(&data[offset..])?;
251            let total_size = header.total_size();
252
253            if data.len() - offset < total_size {
254                break;
255            }
256
257            let frame = self.decode(&data[offset..])?;
258            frames.push(frame);
259
260            offset += total_size;
261        }
262
263        Ok(frames)
264    }
265}
266
267impl Default for FrameCodec {
268    fn default() -> Self {
269        Self::new()
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_frame_header_encode_decode() -> Result<()> {
279        let header = FrameHeader::new(FrameType::Data, 1, true, 1024);
280        let encoded = header.encode();
281        let decoded = FrameHeader::decode(&encoded)?;
282
283        assert_eq!(header.frame_type as u8, decoded.frame_type as u8);
284        assert_eq!(header.version, decoded.version);
285        assert_eq!(header.compressed, decoded.compressed);
286        assert_eq!(header.payload_length, decoded.payload_length);
287        Ok(())
288    }
289
290    #[test]
291    fn test_frame_encode_decode() -> Result<()> {
292        let codec = FrameCodec::new();
293        let payload = Bytes::from(vec![1, 2, 3, 4, 5]);
294        let frame = Frame::data(1, false, payload.clone());
295
296        let encoded = codec.encode(&frame)?;
297        let decoded = codec.decode(&encoded)?;
298
299        assert_eq!(
300            frame.header.frame_type as u8,
301            decoded.header.frame_type as u8
302        );
303        assert_eq!(frame.payload, decoded.payload);
304        Ok(())
305    }
306
307    #[test]
308    fn test_frame_codec_decode_all() -> Result<()> {
309        let codec = FrameCodec::new();
310
311        // Create multiple frames
312        let frame1 = Frame::data(1, false, Bytes::from(vec![1, 2, 3]));
313        let frame2 = Frame::data(1, false, Bytes::from(vec![4, 5, 6]));
314
315        // Encode them
316        let mut buf = BytesMut::new();
317        buf.put_slice(&codec.encode(&frame1)?);
318        buf.put_slice(&codec.encode(&frame2)?);
319
320        // Decode all
321        let frames = codec.decode_all(&buf)?;
322
323        assert_eq!(frames.len(), 2);
324        assert_eq!(frames[0].payload, Bytes::from(vec![1, 2, 3]));
325        assert_eq!(frames[1].payload, Bytes::from(vec![4, 5, 6]));
326        Ok(())
327    }
328
329    #[test]
330    fn test_frame_heartbeat() -> Result<()> {
331        let codec = FrameCodec::new();
332        let frame = Frame::heartbeat(1);
333
334        let encoded = codec.encode(&frame)?;
335        let decoded = codec.decode(&encoded)?;
336
337        assert_eq!(decoded.header.frame_type as u8, FrameType::Heartbeat as u8);
338        assert!(decoded.payload.is_empty());
339        Ok(())
340    }
341
342    #[test]
343    fn test_frame_max_size() {
344        let codec = FrameCodec::with_max_payload_size(100);
345        let payload = Bytes::from(vec![0; 200]);
346        let frame = Frame::data(1, false, payload);
347
348        let result = codec.encode(&frame);
349        assert!(result.is_err());
350    }
351}