aerosocket_core/
frame.rs

1//! WebSocket frame parsing and serialization
2//!
3//! This module provides efficient zero-copy frame parsing and serialization
4//! following the RFC 6455 WebSocket protocol specification.
5
6use crate::{
7    error::{Error, FrameError, Result},
8    protocol::{frame::*, Opcode},
9};
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11
12/// Represents a WebSocket frame according to RFC 6455
13#[derive(Debug, Clone)]
14pub struct Frame {
15    /// Indicates if this is the final frame in a message
16    pub fin: bool,
17    /// Reserved bits (RSV1, RSV2, RSV3)
18    pub rsv: [bool; 3],
19    /// Frame opcode
20    pub opcode: Opcode,
21    /// Indicates if the payload is masked
22    pub masked: bool,
23    /// Masking key (if present)
24    pub mask: Option<[u8; 4]>,
25    /// Payload data
26    pub payload: Bytes,
27}
28
29impl Frame {
30    /// Create a new frame with the given opcode and payload
31    pub fn new(opcode: Opcode, payload: impl Into<Bytes>) -> Self {
32        Self {
33            fin: true,
34            rsv: [false; 3],
35            opcode,
36            masked: false,
37            mask: None,
38            payload: payload.into(),
39        }
40    }
41
42    /// Create a continuation frame
43    pub fn continuation(payload: impl Into<Bytes>) -> Self {
44        Self::new(Opcode::Continuation, payload)
45    }
46
47    /// Create a text frame
48    pub fn text(payload: impl Into<Bytes>) -> Self {
49        Self::new(Opcode::Text, payload)
50    }
51
52    /// Create a binary frame
53    pub fn binary(payload: impl Into<Bytes>) -> Self {
54        Self::new(Opcode::Binary, payload)
55    }
56
57    /// Create a close frame with optional code and reason
58    pub fn close(code: Option<u16>, reason: Option<&str>) -> Self {
59        let mut payload = BytesMut::new();
60
61        if let Some(code) = code {
62            payload.put_u16(code);
63        }
64
65        if let Some(reason) = reason {
66            payload.put_slice(reason.as_bytes());
67        }
68
69        Self::new(Opcode::Close, payload.freeze())
70    }
71
72    /// Create a ping frame
73    pub fn ping(payload: impl Into<Bytes>) -> Self {
74        Self::new(Opcode::Ping, payload)
75    }
76
77    /// Create a pong frame
78    pub fn pong(payload: impl Into<Bytes>) -> Self {
79        Self::new(Opcode::Pong, payload)
80    }
81
82    /// Set the FIN bit
83    pub fn fin(mut self, fin: bool) -> Self {
84        self.fin = fin;
85        self
86    }
87
88    /// Set reserved bits
89    pub fn rsv(mut self, rsv1: bool, rsv2: bool, rsv3: bool) -> Self {
90        self.rsv = [rsv1, rsv2, rsv3];
91        self
92    }
93
94    /// Apply masking to the frame (for client frames)
95    pub fn mask(mut self, enabled: bool) -> Self {
96        if enabled && !self.masked {
97            let mask = rand::random::<[u8; 4]>();
98            self.payload = mask_bytes(&self.payload, &mask);
99            self.masked = true;
100            self.mask = Some(mask);
101        } else if !enabled && self.masked {
102            // Unmask if possible (server frames should never be masked)
103            if let Some(mask) = self.mask {
104                self.payload = mask_bytes(&self.payload, &mask);
105            }
106            self.masked = false;
107            self.mask = None;
108        }
109        self
110    }
111
112    /// Serialize the frame to bytes
113    pub fn to_bytes(&self) -> Bytes {
114        let mut buf = BytesMut::new();
115        self.write_to(&mut buf);
116        buf.freeze()
117    }
118
119    /// Write the frame to a buffer
120    pub fn write_to(&self, buf: &mut BytesMut) {
121        // Write first byte
122        let first_byte = ((self.fin as u8) << 7)
123            | ((self.rsv[0] as u8) << 6)
124            | ((self.rsv[1] as u8) << 5)
125            | ((self.rsv[2] as u8) << 4)
126            | self.opcode.value();
127        buf.put_u8(first_byte);
128
129        // Write payload length and mask bit
130        let payload_len = self.payload.len();
131        let mask_bit = (self.masked as u8) << 7;
132
133        if payload_len < 126 {
134            buf.put_u8(mask_bit | payload_len as u8);
135        } else if payload_len <= u16::MAX as usize {
136            buf.put_u8(mask_bit | PAYLOAD_LEN_16);
137            buf.put_u16(payload_len as u16);
138        } else {
139            buf.put_u8(mask_bit | PAYLOAD_LEN_64);
140            buf.put_u64(payload_len as u64);
141        }
142
143        // Write masking key if present
144        if let Some(mask) = self.mask {
145            buf.put_slice(&mask);
146        }
147
148        // Write payload
149        buf.put_slice(&self.payload);
150    }
151
152    /// Parse a frame from bytes
153    pub fn parse(buf: &mut BytesMut) -> Result<Self> {
154        if buf.len() < 2 {
155            return Err(FrameError::InsufficientData {
156                needed: 2,
157                have: buf.len(),
158            }
159            .into());
160        }
161
162        let mut cursor = std::io::Cursor::new(&buf[..]);
163
164        // Read first byte
165        let first_byte = cursor.get_u8();
166        let fin = (first_byte & FIN_BIT) != 0;
167        let rsv1 = (first_byte & RSV1_BIT) != 0;
168        let rsv2 = (first_byte & RSV2_BIT) != 0;
169        let rsv3 = (first_byte & RSV3_BIT) != 0;
170        let opcode = Opcode::from(first_byte & OPCODE_MASK)
171            .ok_or(FrameError::InvalidOpcode(first_byte & OPCODE_MASK))?;
172
173        // Read second byte
174        let second_byte = cursor.get_u8();
175        let masked = (second_byte & MASK_BIT) != 0;
176        let mut payload_len = (second_byte & PAYLOAD_LEN_MASK) as usize;
177
178        // Read extended payload length if needed
179        if payload_len == 126 {
180            if buf.len() < 4 {
181                return Err(FrameError::InsufficientData {
182                    needed: 4,
183                    have: buf.len(),
184                }
185                .into());
186            }
187            payload_len = cursor.get_u16() as usize;
188        } else if payload_len == 127 {
189            if buf.len() < 10 {
190                return Err(FrameError::InsufficientData {
191                    needed: 10,
192                    have: buf.len(),
193                }
194                .into());
195            }
196            payload_len = cursor.get_u64() as usize;
197        }
198
199        // Read masking key if present
200        let mask = if masked {
201            if buf.len() < cursor.position() as usize + 4 + payload_len {
202                return Err(FrameError::InsufficientData {
203                    needed: cursor.position() as usize + 4 + payload_len,
204                    have: buf.len(),
205                }
206                .into());
207            }
208            let mut mask = [0u8; 4];
209            cursor.copy_to_slice(&mut mask);
210            Some(mask)
211        } else {
212            None
213        };
214
215        // Read payload
216        if buf.len() < cursor.position() as usize + payload_len {
217            return Err(FrameError::InsufficientData {
218                needed: cursor.position() as usize + payload_len,
219                have: buf.len(),
220            }
221            .into());
222        }
223
224        let mut payload = Bytes::copy_from_slice(
225            &buf[cursor.position() as usize..cursor.position() as usize + payload_len],
226        );
227
228        // Unmask payload if needed
229        if let Some(mask) = mask {
230            payload = mask_bytes(&payload, &mask);
231        }
232
233        // Advance the buffer
234        let frame_len = cursor.position() as usize + payload_len;
235        buf.advance(frame_len);
236
237        // Validate frame
238        if opcode.is_control() && !fin {
239            return Err(FrameError::FragmentedControlFrame.into());
240        }
241
242        if rsv1 || rsv2 || rsv3 {
243            return Err(FrameError::ReservedBitsSet.into());
244        }
245
246        Ok(Frame {
247            fin,
248            rsv: [rsv1, rsv2, rsv3],
249            opcode,
250            masked,
251            mask,
252            payload,
253        })
254    }
255
256    /// Get the frame kind
257    pub fn kind(&self) -> FrameKind {
258        match self.opcode {
259            Opcode::Text => FrameKind::Text,
260            Opcode::Binary => FrameKind::Binary,
261            Opcode::Close => FrameKind::Close,
262            Opcode::Ping => FrameKind::Ping,
263            Opcode::Pong => FrameKind::Pong,
264            Opcode::Continuation => FrameKind::Continuation,
265            _ => FrameKind::Reserved,
266        }
267    }
268
269    /// Get the payload length
270    pub fn payload_len(&self) -> usize {
271        self.payload.len()
272    }
273
274    /// Check if this is a control frame
275    pub fn is_control(&self) -> bool {
276        self.opcode.is_control()
277    }
278
279    /// Check if this is a data frame
280    pub fn is_data(&self) -> bool {
281        self.opcode.is_data()
282    }
283
284    /// Check if this is the final frame
285    pub fn is_final(&self) -> bool {
286        self.fin
287    }
288}
289
290/// Frame kind for easier matching
291#[derive(Debug, Clone, Copy, PartialEq, Eq)]
292pub enum FrameKind {
293    /// Text frame
294    Text,
295    /// Binary frame
296    Binary,
297    /// Close frame
298    Close,
299    /// Ping frame
300    Ping,
301    /// Pong frame
302    Pong,
303    /// Continuation frame
304    Continuation,
305    /// Reserved frame
306    Reserved,
307}
308
309/// Apply masking to bytes
310fn mask_bytes(data: &[u8], mask: &[u8; 4]) -> Bytes {
311    let mut masked = BytesMut::with_capacity(data.len());
312    for (i, &byte) in data.iter().enumerate() {
313        masked.put_u8(byte ^ mask[i % 4]);
314    }
315    masked.freeze()
316}
317
318/// Frame parser for incremental parsing
319#[derive(Debug, Default)]
320pub struct FrameParser {
321    /// Buffer for partial frame data
322    buffer: BytesMut,
323    /// Expected frame size (if known)
324    expected_size: Option<usize>,
325}
326
327impl FrameParser {
328    /// Create a new frame parser
329    pub fn new() -> Self {
330        Self::default()
331    }
332
333    /// Feed data to the parser and try to extract frames
334    pub fn feed(&mut self, data: &[u8]) -> Vec<Result<Frame>> {
335        self.buffer.extend_from_slice(data);
336        self.extract_frames()
337    }
338
339    /// Extract complete frames from the buffer
340    fn extract_frames(&mut self) -> Vec<Result<Frame>> {
341        let mut frames = Vec::new();
342
343        while let Some(frame) = self.try_parse_frame() {
344            match frame {
345                Ok(f) => frames.push(Ok(f)),
346                Err(e) => {
347                    frames.push(Err(e));
348                    break;
349                }
350            }
351        }
352
353        frames
354    }
355
356    /// Try to parse a single frame from the buffer
357    fn try_parse_frame(&mut self) -> Option<Result<Frame>> {
358        let mut buf = self.buffer.clone();
359
360        match Frame::parse(&mut buf) {
361            Ok(frame) => {
362                // Remove the parsed data from the buffer
363                let parsed_len = self.buffer.len() - buf.len();
364                self.buffer.advance(parsed_len);
365                Some(Ok(frame))
366            }
367            Err(Error::Frame(FrameError::InsufficientData { .. })) => {
368                // Not enough data, wait for more
369                None
370            }
371            Err(e) => {
372                // Parse error, consume the problematic data
373                self.buffer.clear();
374                Some(Err(e))
375            }
376        }
377    }
378
379    /// Get the number of bytes currently buffered
380    pub fn buffered_bytes(&self) -> usize {
381        self.buffer.len()
382    }
383
384    /// Clear the parser buffer
385    pub fn clear(&mut self) {
386        self.buffer.clear();
387        self.expected_size = None;
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn test_text_frame_serialization() {
397        let frame = Frame::text("hello");
398        let bytes = frame.to_bytes();
399
400        assert_eq!(bytes[0], 0x81); // FIN=1, RSV=000, Opcode=0001
401        assert_eq!(bytes[1], 0x05); // MASK=0, Length=5
402        assert_eq!(&bytes[2..], b"hello");
403    }
404
405    #[test]
406    fn test_masked_frame() {
407        let frame = Frame::text("hello").mask(true);
408        let bytes = frame.to_bytes();
409
410        assert_eq!(bytes[1] & 0x80, 0x80); // MASK bit set
411        assert_eq!(bytes.len(), 2 + 4 + 5); // header + mask + payload
412    }
413
414    #[test]
415    fn test_frame_parsing() {
416        let original = Frame::text("hello");
417        let bytes = original.to_bytes();
418        let mut buf = BytesMut::from(&bytes[..]);
419
420        let parsed = Frame::parse(&mut buf).unwrap();
421        assert_eq!(parsed.kind(), FrameKind::Text);
422        assert_eq!(parsed.payload, "hello");
423        assert!(buf.is_empty());
424    }
425
426    #[test]
427    fn test_large_frame() {
428        let payload = vec![0u8; 65536]; // 64KB
429        let frame = Frame::binary(payload.clone());
430        let bytes = frame.to_bytes();
431
432        assert_eq!(bytes[1], 127); // Extended 64-bit length
433        assert_eq!(bytes[2..10], (65536u64).to_be_bytes());
434    }
435
436    #[test]
437    fn test_close_frame() {
438        let frame = Frame::close(Some(1000), Some("Goodbye"));
439        let bytes = frame.to_bytes();
440
441        assert_eq!(bytes[0], 0x88); // FIN=1, Opcode=8
442        assert_eq!(bytes[1], 0x09); // Length=9 (2 bytes code + 6 bytes reason + 1 byte for length prefix?)
443        assert_eq!(&bytes[2..4], 1000u16.to_be_bytes());
444        assert_eq!(&bytes[4..], b"Goodbye");
445        assert_eq!(bytes.len(), 11); // Total frame length
446    }
447
448    #[test]
449    fn test_frame_parser() {
450        let mut parser = FrameParser::new();
451
452        let frame1 = Frame::text("frame1");
453        let frame2 = Frame::ping("ping");
454
455        let bytes1 = frame1.to_bytes();
456        let bytes2 = frame2.to_bytes();
457
458        // Feed partial data
459        let frames = parser.feed(&bytes1[..5]);
460        assert_eq!(frames.len(), 0); // Not enough data
461
462        // Feed remaining data
463        let frames = parser.feed(&bytes1[5..]);
464        assert_eq!(frames.len(), 1);
465        assert!(frames[0].as_ref().unwrap().is_data());
466
467        // Feed second frame
468        let frames = parser.feed(&bytes2);
469        assert_eq!(frames.len(), 1);
470        assert!(frames[0].as_ref().unwrap().is_control());
471    }
472}