Skip to main content

aerosocket_core/
frame.rs

1//! WebSocket frame parsing and serialization
2//!
3//! This module provides efficient zero-copy frame parsing and serialization
4//! following the RFC 6455 WebSocket protocol specification.
5
6use crate::{
7    error::{Error, FrameError, Result},
8    protocol::{frame::*, Opcode},
9};
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11
12/// Represents a WebSocket frame according to RFC 6455
13#[derive(Debug, Clone)]
14pub struct Frame {
15    /// Indicates if this is the final frame in a message
16    pub fin: bool,
17    /// Reserved bits (RSV1, RSV2, RSV3)
18    pub rsv: [bool; 3],
19    /// Frame opcode
20    pub opcode: Opcode,
21    /// Indicates if the payload is masked
22    pub masked: bool,
23    /// Masking key (if present)
24    pub mask: Option<[u8; 4]>,
25    /// Payload data
26    pub payload: Bytes,
27}
28
29impl Frame {
30    /// Create a new frame with the given opcode and payload
31    pub fn new(opcode: Opcode, payload: impl Into<Bytes>) -> Self {
32        Self {
33            fin: true,
34            rsv: [false; 3],
35            opcode,
36            masked: false,
37            mask: None,
38            payload: payload.into(),
39        }
40    }
41
42    /// Create a continuation frame
43    pub fn continuation(payload: impl Into<Bytes>) -> Self {
44        Self::new(Opcode::Continuation, payload)
45    }
46
47    /// Create a text frame
48    pub fn text(payload: impl Into<Bytes>) -> Self {
49        Self::new(Opcode::Text, payload)
50    }
51
52    /// Create a binary frame
53    pub fn binary(payload: impl Into<Bytes>) -> Self {
54        Self::new(Opcode::Binary, payload)
55    }
56
57    /// Create a close frame with optional code and reason
58    pub fn close(code: Option<u16>, reason: Option<&str>) -> Self {
59        let mut payload = BytesMut::new();
60
61        if let Some(code) = code {
62            payload.put_u16(code);
63        }
64
65        if let Some(reason) = reason {
66            payload.put_slice(reason.as_bytes());
67        }
68
69        Self::new(Opcode::Close, payload.freeze())
70    }
71
72    /// Create a ping frame
73    pub fn ping(payload: impl Into<Bytes>) -> Self {
74        Self::new(Opcode::Ping, payload)
75    }
76
77    /// Create a pong frame
78    pub fn pong(payload: impl Into<Bytes>) -> Self {
79        Self::new(Opcode::Pong, payload)
80    }
81
82    /// Set the FIN bit
83    pub fn fin(mut self, fin: bool) -> Self {
84        self.fin = fin;
85        self
86    }
87
88    /// Set reserved bits
89    pub fn rsv(mut self, rsv1: bool, rsv2: bool, rsv3: bool) -> Self {
90        self.rsv = [rsv1, rsv2, rsv3];
91        self
92    }
93
94    /// Apply masking to the frame (for client frames)
95    pub fn mask(mut self, enabled: bool) -> Self {
96        if enabled && !self.masked {
97            let mask = rand::random::<[u8; 4]>();
98            self.payload = mask_bytes(&self.payload, &mask);
99            self.masked = true;
100            self.mask = Some(mask);
101        } else if !enabled && self.masked {
102            // Unmask if possible (server frames should never be masked)
103            if let Some(mask) = self.mask {
104                self.payload = mask_bytes(&self.payload, &mask);
105            }
106            self.masked = false;
107            self.mask = None;
108        }
109        self
110    }
111
112    /// Apply compression to the frame (for data frames)
113    #[cfg(feature = "compression")]
114    pub fn compress(mut self, enabled: bool) -> Self {
115        if enabled && self.opcode.is_data() && !self.rsv[0] {
116            use flate2::write::DeflateEncoder;
117            use flate2::Compression;
118            use std::io::Write;
119
120            let mut encoder = DeflateEncoder::new(Vec::new(), Compression::new(6));
121            if encoder.write_all(&self.payload).is_ok() && encoder.flush().is_ok() {
122                if let Ok(compressed) = encoder.finish() {
123                    self.payload = Bytes::from(compressed);
124                    self.rsv[0] = true;
125                }
126            }
127        }
128        self
129    }
130
131    /// Serialize the frame to bytes
132    pub fn to_bytes(&self) -> Bytes {
133        let mut buf = BytesMut::new();
134        self.write_to(&mut buf);
135        buf.freeze()
136    }
137
138    /// Write the frame to a buffer
139    pub fn write_to(&self, buf: &mut BytesMut) {
140        // Write first byte
141        let first_byte = ((self.fin as u8) << 7)
142            | ((self.rsv[0] as u8) << 6)
143            | ((self.rsv[1] as u8) << 5)
144            | ((self.rsv[2] as u8) << 4)
145            | self.opcode.value();
146        buf.put_u8(first_byte);
147
148        // Write payload length and mask bit
149        let payload_len = self.payload.len();
150        let mask_bit = (self.masked as u8) << 7;
151
152        if payload_len < 126 {
153            buf.put_u8(mask_bit | payload_len as u8);
154        } else if payload_len <= u16::MAX as usize {
155            buf.put_u8(mask_bit | PAYLOAD_LEN_16);
156            buf.put_u16(payload_len as u16);
157        } else {
158            buf.put_u8(mask_bit | PAYLOAD_LEN_64);
159            buf.put_u64(payload_len as u64);
160        }
161
162        // Write masking key if present
163        if let Some(mask) = self.mask {
164            buf.put_slice(&mask);
165        }
166
167        // Write payload
168        buf.put_slice(&self.payload);
169    }
170
171    /// Parse a frame from bytes
172    pub fn parse(buf: &mut BytesMut, compression_enabled: bool) -> Result<Self> {
173        if buf.len() < 2 {
174            return Err(FrameError::InsufficientData {
175                needed: 2,
176                have: buf.len(),
177            }
178            .into());
179        }
180
181        let mut cursor = std::io::Cursor::new(&buf[..]);
182
183        // Read first byte
184        let first_byte = cursor.get_u8();
185        let fin = (first_byte & FIN_BIT) != 0;
186        let rsv1 = (first_byte & RSV1_BIT) != 0;
187        let rsv2 = (first_byte & RSV2_BIT) != 0;
188        let rsv3 = (first_byte & RSV3_BIT) != 0;
189        let opcode = Opcode::from(first_byte & OPCODE_MASK)
190            .ok_or(FrameError::InvalidOpcode(first_byte & OPCODE_MASK))?;
191
192        // Read second byte
193        let second_byte = cursor.get_u8();
194        let masked = (second_byte & MASK_BIT) != 0;
195        let mut payload_len = (second_byte & PAYLOAD_LEN_MASK) as usize;
196
197        // Read extended payload length if needed
198        if payload_len == 126 {
199            if buf.len() < 4 {
200                return Err(FrameError::InsufficientData {
201                    needed: 4,
202                    have: buf.len(),
203                }
204                .into());
205            }
206            payload_len = cursor.get_u16() as usize;
207        } else if payload_len == 127 {
208            if buf.len() < 10 {
209                return Err(FrameError::InsufficientData {
210                    needed: 10,
211                    have: buf.len(),
212                }
213                .into());
214            }
215            payload_len = cursor.get_u64() as usize;
216        }
217
218        // Read masking key if present
219        let mask = if masked {
220            if buf.len() < cursor.position() as usize + 4 + payload_len {
221                return Err(FrameError::InsufficientData {
222                    needed: cursor.position() as usize + 4 + payload_len,
223                    have: buf.len(),
224                }
225                .into());
226            }
227            let mut mask = [0u8; 4];
228            cursor.copy_to_slice(&mut mask);
229            Some(mask)
230        } else {
231            None
232        };
233
234        // Read payload
235        if buf.len() < cursor.position() as usize + payload_len {
236            return Err(FrameError::InsufficientData {
237                needed: cursor.position() as usize + payload_len,
238                have: buf.len(),
239            }
240            .into());
241        }
242
243        let mut payload = Bytes::copy_from_slice(
244            &buf[cursor.position() as usize..cursor.position() as usize + payload_len],
245        );
246
247        // Unmask payload if needed
248        if let Some(mask) = mask {
249            payload = mask_bytes(&payload, &mask);
250        }
251
252        // Decompress payload if needed
253        #[cfg(feature = "compression")]
254        if rsv1 && compression_enabled {
255            use flate2::read::DeflateDecoder;
256            use std::io::Read;
257
258            let mut decoder = DeflateDecoder::new(&payload[..]);
259            let mut decompressed = Vec::new();
260            if decoder.read_to_end(&mut decompressed).is_err() {
261                return Err(FrameError::DecompressionFailed.into());
262            }
263            payload = Bytes::from(decompressed);
264        }
265
266        // Advance the buffer
267        let frame_len = cursor.position() as usize + payload_len;
268        buf.advance(frame_len);
269
270        // Validate frame
271        if opcode.is_control() && !fin {
272            return Err(FrameError::FragmentedControlFrame.into());
273        }
274
275        if (rsv1 && !(compression_enabled && opcode.is_data())) || rsv2 || rsv3 {
276            return Err(FrameError::ReservedBitsSet.into());
277        }
278
279        Ok(Frame {
280            fin,
281            rsv: [rsv1, rsv2, rsv3],
282            opcode,
283            masked,
284            mask,
285            payload,
286        })
287    }
288
289    /// Get the frame kind
290    pub fn kind(&self) -> FrameKind {
291        match self.opcode {
292            Opcode::Text => FrameKind::Text,
293            Opcode::Binary => FrameKind::Binary,
294            Opcode::Close => FrameKind::Close,
295            Opcode::Ping => FrameKind::Ping,
296            Opcode::Pong => FrameKind::Pong,
297            Opcode::Continuation => FrameKind::Continuation,
298            _ => FrameKind::Reserved,
299        }
300    }
301
302    /// Get the payload length
303    pub fn payload_len(&self) -> usize {
304        self.payload.len()
305    }
306
307    /// Check if this is a control frame
308    pub fn is_control(&self) -> bool {
309        self.opcode.is_control()
310    }
311
312    /// Check if this is a data frame
313    pub fn is_data(&self) -> bool {
314        self.opcode.is_data()
315    }
316
317    /// Check if this is the final frame
318    pub fn is_final(&self) -> bool {
319        self.fin
320    }
321}
322
323/// Frame kind for easier matching
324#[derive(Debug, Clone, Copy, PartialEq, Eq)]
325pub enum FrameKind {
326    /// Text frame
327    Text,
328    /// Binary frame
329    Binary,
330    /// Close frame
331    Close,
332    /// Ping frame
333    Ping,
334    /// Pong frame
335    Pong,
336    /// Continuation frame
337    Continuation,
338    /// Reserved frame
339    Reserved,
340}
341
342/// Apply masking to bytes
343fn mask_bytes(data: &[u8], mask: &[u8; 4]) -> Bytes {
344    let mut masked = BytesMut::with_capacity(data.len());
345    for (i, &byte) in data.iter().enumerate() {
346        masked.put_u8(byte ^ mask[i % 4]);
347    }
348    masked.freeze()
349}
350
351/// Frame parser for incremental parsing
352#[derive(Debug)]
353pub struct FrameParser {
354    /// Buffer for partial frame data
355    buffer: BytesMut,
356    /// Expected frame size (if known)
357    expected_size: Option<usize>,
358    /// Whether compression is enabled for this connection
359    compression_enabled: bool,
360}
361
362impl Default for FrameParser {
363    fn default() -> Self {
364        Self {
365            buffer: BytesMut::new(),
366            expected_size: None,
367            compression_enabled: false,
368        }
369    }
370}
371
372impl FrameParser {
373    /// Create a new frame parser
374    pub fn new() -> Self {
375        Self::default()
376    }
377
378    /// Create a new frame parser with compression enabled
379    pub fn with_compression(compression_enabled: bool) -> Self {
380        Self {
381            buffer: BytesMut::new(),
382            expected_size: None,
383            compression_enabled,
384        }
385    }
386
387    /// Feed data to the parser and try to extract frames
388    pub fn feed(&mut self, data: &[u8]) -> Vec<Result<Frame>> {
389        self.buffer.extend_from_slice(data);
390        self.extract_frames()
391    }
392
393    /// Extract complete frames from the buffer
394    fn extract_frames(&mut self) -> Vec<Result<Frame>> {
395        let mut frames = Vec::new();
396
397        while let Some(frame) = self.try_parse_frame() {
398            match frame {
399                Ok(f) => frames.push(Ok(f)),
400                Err(e) => {
401                    frames.push(Err(e));
402                    break;
403                }
404            }
405        }
406
407        frames
408    }
409
410    /// Try to parse a single frame from the buffer
411    fn try_parse_frame(&mut self) -> Option<Result<Frame>> {
412        let mut buf = self.buffer.clone();
413
414        match Frame::parse(&mut buf, self.compression_enabled) {
415            Ok(frame) => {
416                // Remove the parsed data from the buffer
417                let parsed_len = self.buffer.len() - buf.len();
418                self.buffer.advance(parsed_len);
419                Some(Ok(frame))
420            }
421            Err(Error::Frame(FrameError::InsufficientData { .. })) => {
422                // Not enough data, wait for more
423                None
424            }
425            Err(e) => {
426                // Parse error, consume the problematic data
427                self.buffer.clear();
428                Some(Err(e))
429            }
430        }
431    }
432
433    /// Get the number of bytes currently buffered
434    pub fn buffered_bytes(&self) -> usize {
435        self.buffer.len()
436    }
437
438    /// Clear the parser buffer
439    pub fn clear(&mut self) {
440        self.buffer.clear();
441        self.expected_size = None;
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    #[test]
450    fn test_text_frame_serialization() {
451        let frame = Frame::text("hello");
452        let bytes = frame.to_bytes();
453
454        assert_eq!(bytes[0], 0x81); // FIN=1, RSV=000, Opcode=0001
455        assert_eq!(bytes[1], 0x05); // MASK=0, Length=5
456        assert_eq!(&bytes[2..], b"hello");
457    }
458
459    #[test]
460    fn test_masked_frame() {
461        let frame = Frame::text("hello").mask(true);
462        let bytes = frame.to_bytes();
463
464        assert_eq!(bytes[1] & 0x80, 0x80); // MASK bit set
465        assert_eq!(bytes.len(), 2 + 4 + 5); // header + mask + payload
466    }
467
468    #[test]
469    fn test_frame_parsing() {
470        let original = Frame::text("hello");
471        let bytes = original.to_bytes();
472        let mut buf = BytesMut::from(&bytes[..]);
473
474        let parsed = Frame::parse(&mut buf, false).unwrap();
475        assert_eq!(parsed.kind(), FrameKind::Text);
476        assert_eq!(parsed.payload, "hello");
477        assert!(buf.is_empty());
478    }
479
480    #[test]
481    fn test_large_frame() {
482        let payload = vec![0u8; 65536]; // 64KB
483        let frame = Frame::binary(payload.clone());
484        let bytes = frame.to_bytes();
485
486        assert_eq!(bytes[1], 127); // Extended 64-bit length
487        assert_eq!(bytes[2..10], (65536u64).to_be_bytes());
488    }
489
490    #[test]
491    fn test_close_frame() {
492        let frame = Frame::close(Some(1000), Some("Goodbye"));
493        let bytes = frame.to_bytes();
494
495        assert_eq!(bytes[0], 0x88); // FIN=1, Opcode=8
496        assert_eq!(bytes[1], 0x09); // Length=9 (2 bytes code + 6 bytes reason + 1 byte for length prefix?)
497        assert_eq!(&bytes[2..4], 1000u16.to_be_bytes());
498        assert_eq!(&bytes[4..], b"Goodbye");
499        assert_eq!(bytes.len(), 11); // Total frame length
500    }
501
502    #[test]
503    fn test_frame_parser() {
504        let mut parser = FrameParser::new();
505
506        let frame1 = Frame::text("frame1");
507        let frame2 = Frame::ping("ping");
508
509        let bytes1 = frame1.to_bytes();
510        let bytes2 = frame2.to_bytes();
511
512        // Feed partial data
513        let frames = parser.feed(&bytes1[..5]);
514        assert_eq!(frames.len(), 0); // Not enough data
515
516        // Feed remaining data
517        let frames = parser.feed(&bytes1[5..]);
518        assert_eq!(frames.len(), 1);
519        assert!(frames[0].as_ref().unwrap().is_data());
520
521        // Feed second frame
522        let frames = parser.feed(&bytes2);
523        assert_eq!(frames.len(), 1);
524        assert!(frames[0].as_ref().unwrap().is_control());
525    }
526}