sockudo_ws/
frame.rs

1//! WebSocket frame parsing and serialization
2//!
3//! This module implements RFC 6455 WebSocket frame handling with:
4//! - Zero-copy parsing using buffer views
5//! - Fast-path for small messages (< 126 bytes)
6//! - SIMD-accelerated masking
7//! - Minimal allocations in the hot path
8
9use bytes::{Buf, BufMut, Bytes, BytesMut};
10
11use crate::error::{CloseReason, Error, Result};
12use crate::simd::apply_mask;
13use crate::utf8::validate_utf8;
14use crate::{MEDIUM_MESSAGE_THRESHOLD, SMALL_MESSAGE_THRESHOLD};
15
16/// WebSocket opcode
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[repr(u8)]
19pub enum OpCode {
20    /// Continuation frame
21    Continuation = 0x0,
22    /// Text frame
23    Text = 0x1,
24    /// Binary frame
25    Binary = 0x2,
26    /// Connection close
27    Close = 0x8,
28    /// Ping
29    Ping = 0x9,
30    /// Pong
31    Pong = 0xA,
32}
33
34impl OpCode {
35    /// Parse opcode from byte
36    #[inline]
37    pub fn from_u8(byte: u8) -> Option<Self> {
38        match byte {
39            0x0 => Some(OpCode::Continuation),
40            0x1 => Some(OpCode::Text),
41            0x2 => Some(OpCode::Binary),
42            0x8 => Some(OpCode::Close),
43            0x9 => Some(OpCode::Ping),
44            0xA => Some(OpCode::Pong),
45            _ => None,
46        }
47    }
48
49    /// Check if this is a control frame
50    #[inline]
51    pub fn is_control(&self) -> bool {
52        (*self as u8) >= 0x8
53    }
54
55    /// Check if this is a data frame
56    #[inline]
57    pub fn is_data(&self) -> bool {
58        (*self as u8) <= 0x2
59    }
60}
61
62/// A parsed WebSocket frame header
63#[derive(Debug, Clone)]
64pub struct FrameHeader {
65    /// Final fragment flag
66    pub fin: bool,
67    /// RSV1 (used for compression)
68    pub rsv1: bool,
69    /// RSV2 (reserved)
70    pub rsv2: bool,
71    /// RSV3 (reserved)
72    pub rsv3: bool,
73    /// Frame opcode
74    pub opcode: OpCode,
75    /// Mask flag (must be true for client->server)
76    pub masked: bool,
77    /// Payload length
78    pub payload_len: u64,
79    /// Masking key (if masked)
80    pub mask: Option<[u8; 4]>,
81}
82
83impl FrameHeader {
84    /// Get the total header size in bytes
85    #[inline]
86    pub fn header_size(&self) -> usize {
87        let mut size = 2; // Base header
88
89        // Extended payload length
90        if self.payload_len > MEDIUM_MESSAGE_THRESHOLD as u64 {
91            size += 8;
92        } else if self.payload_len > SMALL_MESSAGE_THRESHOLD as u64 {
93            size += 2;
94        }
95
96        // Mask
97        if self.masked {
98            size += 4;
99        }
100
101        size
102    }
103
104    /// Encode the frame header into a buffer
105    #[inline]
106    pub fn encode(&self, buf: &mut BytesMut) {
107        // First byte: FIN, RSV1-3, opcode
108        let mut b0 = self.opcode as u8;
109        if self.fin {
110            b0 |= 0x80;
111        }
112        if self.rsv1 {
113            b0 |= 0x40;
114        }
115        if self.rsv2 {
116            b0 |= 0x20;
117        }
118        if self.rsv3 {
119            b0 |= 0x10;
120        }
121        buf.put_u8(b0);
122
123        // Second byte: mask flag, payload length
124        let mask_bit = if self.masked { 0x80 } else { 0x00 };
125
126        if self.payload_len <= SMALL_MESSAGE_THRESHOLD as u64 {
127            buf.put_u8(mask_bit | self.payload_len as u8);
128        } else if self.payload_len <= MEDIUM_MESSAGE_THRESHOLD as u64 {
129            buf.put_u8(mask_bit | 126);
130            buf.put_u16(self.payload_len as u16);
131        } else {
132            buf.put_u8(mask_bit | 127);
133            buf.put_u64(self.payload_len);
134        }
135
136        // Masking key
137        if let Some(mask) = self.mask {
138            buf.put_slice(&mask);
139        }
140    }
141}
142
143/// A complete WebSocket frame
144#[derive(Debug, Clone)]
145pub struct Frame {
146    /// Frame header
147    pub header: FrameHeader,
148    /// Frame payload (already unmasked)
149    pub payload: Bytes,
150}
151
152impl Frame {
153    /// Create a new frame
154    pub fn new(opcode: OpCode, payload: Bytes, fin: bool) -> Self {
155        Self {
156            header: FrameHeader {
157                fin,
158                rsv1: false,
159                rsv2: false,
160                rsv3: false,
161                opcode,
162                masked: false,
163                payload_len: payload.len() as u64,
164                mask: None,
165            },
166            payload,
167        }
168    }
169
170    /// Create a text frame
171    #[inline]
172    pub fn text(data: impl Into<Bytes>) -> Self {
173        Self::new(OpCode::Text, data.into(), true)
174    }
175
176    /// Create a binary frame
177    #[inline]
178    pub fn binary(data: impl Into<Bytes>) -> Self {
179        Self::new(OpCode::Binary, data.into(), true)
180    }
181
182    /// Create a ping frame
183    #[inline]
184    pub fn ping(data: impl Into<Bytes>) -> Self {
185        Self::new(OpCode::Ping, data.into(), true)
186    }
187
188    /// Create a pong frame
189    #[inline]
190    pub fn pong(data: impl Into<Bytes>) -> Self {
191        Self::new(OpCode::Pong, data.into(), true)
192    }
193
194    /// Create a close frame
195    #[inline]
196    pub fn close(code: u16, reason: &str) -> Self {
197        let mut payload = BytesMut::with_capacity(2 + reason.len());
198        payload.put_u16(code);
199        payload.put_slice(reason.as_bytes());
200        Self::new(OpCode::Close, payload.freeze(), true)
201    }
202
203    /// Create an empty close frame
204    #[inline]
205    pub fn close_empty() -> Self {
206        Self::new(OpCode::Close, Bytes::new(), true)
207    }
208
209    /// Check if this is a control frame
210    #[inline]
211    pub fn is_control(&self) -> bool {
212        self.header.opcode.is_control()
213    }
214
215    /// Check if this is the final fragment
216    #[inline]
217    pub fn is_final(&self) -> bool {
218        self.header.fin
219    }
220
221    /// Get the payload as a string (for text frames)
222    pub fn as_text(&self) -> Result<&str> {
223        if !validate_utf8(&self.payload) {
224            return Err(Error::InvalidUtf8);
225        }
226        // SAFETY: We just validated UTF-8
227        Ok(unsafe { std::str::from_utf8_unchecked(&self.payload) })
228    }
229
230    /// Parse close frame payload
231    pub fn parse_close(&self) -> Option<CloseReason> {
232        if self.payload.len() < 2 {
233            return None;
234        }
235        let code = u16::from_be_bytes([self.payload[0], self.payload[1]]);
236        let reason = if self.payload.len() > 2 {
237            String::from_utf8_lossy(&self.payload[2..]).into_owned()
238        } else {
239            String::new()
240        };
241        Some(CloseReason::new(code, reason))
242    }
243}
244
245/// Frame parser state machine
246#[derive(Debug, Clone, Copy, PartialEq, Eq)]
247enum ParseState {
248    /// Waiting for header bytes
249    Header,
250    /// Waiting for extended payload length (2 bytes)
251    ExtendedLen16,
252    /// Waiting for extended payload length (8 bytes)
253    ExtendedLen64,
254    /// Waiting for mask (4 bytes)
255    Mask,
256    /// Waiting for payload
257    Payload,
258}
259
260/// High-performance frame parser
261///
262/// Designed for zero-copy parsing with minimal allocations.
263/// Uses a state machine to handle partial reads efficiently.
264pub struct FrameParser {
265    state: ParseState,
266    /// Partial header data
267    header_buf: [u8; 14],
268    header_len: usize,
269    /// Parsed header (once complete)
270    header: Option<FrameHeader>,
271    /// Maximum frame size
272    max_frame_size: usize,
273    /// Whether to expect masked frames (server mode)
274    expect_masked: bool,
275    /// Whether RSV1 is allowed (compression enabled)
276    allow_rsv1: bool,
277}
278
279impl FrameParser {
280    /// Create a new frame parser
281    pub fn new(max_frame_size: usize, expect_masked: bool) -> Self {
282        Self {
283            state: ParseState::Header,
284            header_buf: [0; 14],
285            header_len: 0,
286            header: None,
287            max_frame_size,
288            expect_masked,
289            allow_rsv1: false,
290        }
291    }
292
293    /// Create a new frame parser with compression support
294    pub fn with_compression(max_frame_size: usize, expect_masked: bool) -> Self {
295        Self {
296            state: ParseState::Header,
297            header_buf: [0; 14],
298            header_len: 0,
299            header: None,
300            max_frame_size,
301            expect_masked,
302            allow_rsv1: true,
303        }
304    }
305
306    /// Enable or disable RSV1 (compression) support
307    pub fn set_compression(&mut self, enabled: bool) {
308        self.allow_rsv1 = enabled;
309    }
310
311    /// Reset parser state for next frame
312    #[inline]
313    fn reset(&mut self) {
314        self.state = ParseState::Header;
315        self.header_len = 0;
316        self.header = None;
317    }
318
319    /// Parse a frame from the buffer
320    ///
321    /// Returns:
322    /// - Ok(Some(frame)) if a complete frame was parsed
323    /// - Ok(None) if more data is needed
324    /// - Err(e) if parsing failed
325    #[inline]
326    pub fn parse(&mut self, buf: &mut BytesMut) -> Result<Option<Frame>> {
327        // Ultra-fast path for small unmasked frames (server->client)
328        // This handles the common case without any state machine overhead
329        if self.state == ParseState::Header && !self.expect_masked && buf.len() >= 2 {
330            let b0 = buf[0];
331            let b1 = buf[1];
332            let len_byte = b1 & 0x7F;
333
334            // Check: small frame, not masked, no extended length
335            if len_byte <= 125 && (b1 & 0x80) == 0 {
336                let payload_len = len_byte as usize;
337                let total_len = 2 + payload_len;
338
339                if buf.len() >= total_len {
340                    // We have a complete small frame - parse it inline
341                    let fin = b0 & 0x80 != 0;
342                    let rsv1 = b0 & 0x40 != 0;
343                    let rsv2 = b0 & 0x20 != 0;
344                    let rsv3 = b0 & 0x10 != 0;
345
346                    // Quick RSV validation
347                    if (rsv1 && !self.allow_rsv1) || rsv2 || rsv3 {
348                        return self.parse_slow(buf);
349                    }
350
351                    if let Some(opcode) = OpCode::from_u8(b0 & 0x0F) {
352                        // Control frame fragmentation check
353                        if opcode.is_control() && !fin {
354                            return Err(Error::Protocol("control frame must not be fragmented"));
355                        }
356
357                        // Extract payload
358                        buf.advance(2);
359                        let payload = buf.split_to(payload_len).freeze();
360
361                        return Ok(Some(Frame {
362                            header: FrameHeader {
363                                fin,
364                                rsv1,
365                                rsv2,
366                                rsv3,
367                                opcode,
368                                masked: false,
369                                payload_len: payload_len as u64,
370                                mask: None,
371                            },
372                            payload,
373                        }));
374                    }
375                }
376            }
377        }
378
379        // Ultra-fast path for small MASKED frames (client->server)
380        // This handles the common case of small client messages efficiently
381        if self.state == ParseState::Header && self.expect_masked && buf.len() >= 6 {
382            let b0 = buf[0];
383            let b1 = buf[1];
384            let len_byte = b1 & 0x7F;
385
386            // Check: small frame, masked, no extended length
387            if len_byte <= 125 && (b1 & 0x80) != 0 {
388                let payload_len = len_byte as usize;
389                let total_len = 2 + 4 + payload_len; // header + mask + payload
390
391                if buf.len() >= total_len {
392                    // We have a complete small masked frame - parse it inline
393                    let fin = b0 & 0x80 != 0;
394                    let rsv1 = b0 & 0x40 != 0;
395                    let rsv2 = b0 & 0x20 != 0;
396                    let rsv3 = b0 & 0x10 != 0;
397
398                    // Quick RSV validation
399                    if (rsv1 && !self.allow_rsv1) || rsv2 || rsv3 {
400                        return self.parse_slow(buf);
401                    }
402
403                    if let Some(opcode) = OpCode::from_u8(b0 & 0x0F) {
404                        // Control frame fragmentation check
405                        if opcode.is_control() && !fin {
406                            return Err(Error::Protocol("control frame must not be fragmented"));
407                        }
408
409                        // Extract mask
410                        let mask = [buf[2], buf[3], buf[4], buf[5]];
411
412                        // Extract and unmask payload
413                        buf.advance(6);
414                        let mut payload = buf.split_to(payload_len);
415                        apply_mask(&mut payload, mask);
416
417                        return Ok(Some(Frame {
418                            header: FrameHeader {
419                                fin,
420                                rsv1,
421                                rsv2,
422                                rsv3,
423                                opcode,
424                                masked: true,
425                                payload_len: payload_len as u64,
426                                mask: Some(mask),
427                            },
428                            payload: payload.freeze(),
429                        }));
430                    }
431                }
432            }
433        }
434
435        self.parse_slow(buf)
436    }
437
438    /// Slow path for frame parsing - handles all edge cases
439    fn parse_slow(&mut self, buf: &mut BytesMut) -> Result<Option<Frame>> {
440        const DEBUG: bool = false;
441        loop {
442            if DEBUG && !buf.is_empty() {
443                eprintln!(
444                    "[PARSER] State: {:?}, buf_len: {}, header_len: {}",
445                    self.state,
446                    buf.len(),
447                    self.header_len
448                );
449            }
450            match self.state {
451                ParseState::Header => {
452                    if buf.len() < 2 {
453                        return Ok(None);
454                    }
455
456                    // Fast path: try to parse header in one go
457                    let b0 = buf[0];
458                    let b1 = buf[1];
459
460                    // Parse first byte
461                    let fin = b0 & 0x80 != 0;
462                    let rsv1 = b0 & 0x40 != 0;
463                    let rsv2 = b0 & 0x20 != 0;
464                    let rsv3 = b0 & 0x10 != 0;
465
466                    // Check RSV bits (must be 0 unless extension negotiated)
467                    // RSV1 is allowed when compression is enabled
468                    if rsv1 && !self.allow_rsv1 {
469                        return Err(Error::Protocol(
470                            "RSV1 must be 0 (compression not negotiated)",
471                        ));
472                    }
473                    if rsv2 || rsv3 {
474                        return Err(Error::Protocol("RSV2 and RSV3 must be 0"));
475                    }
476
477                    let opcode =
478                        OpCode::from_u8(b0 & 0x0F).ok_or(Error::InvalidFrame("invalid opcode"))?;
479
480                    // Control frames must not be fragmented
481                    if opcode.is_control() && !fin {
482                        return Err(Error::Protocol("control frame must not be fragmented"));
483                    }
484
485                    // Parse second byte
486                    let masked = b1 & 0x80 != 0;
487                    let len_byte = b1 & 0x7F;
488
489                    // Validate masking
490                    if self.expect_masked && !masked {
491                        return Err(Error::Protocol("client frames must be masked"));
492                    }
493                    if !self.expect_masked && masked {
494                        return Err(Error::Protocol("server frames must not be masked"));
495                    }
496
497                    // Determine payload length
498                    let (payload_len, header_size) = if len_byte <= 125 {
499                        (len_byte as u64, 2)
500                    } else if len_byte == 126 {
501                        if buf.len() < 4 {
502                            // Need more data for extended length
503                            self.header_buf[0] = b0;
504                            self.header_buf[1] = b1;
505                            self.header_len = 2;
506                            buf.advance(2); // Consume the 2 bytes we saved
507                            self.state = ParseState::ExtendedLen16;
508                            return Ok(None);
509                        }
510                        let len = u16::from_be_bytes([buf[2], buf[3]]) as u64;
511                        // Validate minimum length for 16-bit encoding
512                        if len < 126 {
513                            return Err(Error::Protocol("payload length not minimal"));
514                        }
515                        (len, 4)
516                    } else {
517                        // len_byte == 127
518                        if buf.len() < 10 {
519                            self.header_buf[0] = b0;
520                            self.header_buf[1] = b1;
521                            self.header_len = 2;
522                            buf.advance(2); // Consume the 2 bytes we saved
523                            self.state = ParseState::ExtendedLen64;
524                            return Ok(None);
525                        }
526                        let len = u64::from_be_bytes([
527                            buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], buf[8], buf[9],
528                        ]);
529                        // Validate minimum length for 64-bit encoding
530                        if len <= 0xFFFF {
531                            return Err(Error::Protocol("payload length not minimal"));
532                        }
533                        // Validate MSB is 0
534                        if len >> 63 != 0 {
535                            return Err(Error::Protocol("payload length MSB must be 0"));
536                        }
537                        (len, 10)
538                    };
539
540                    // Control frame length check
541                    if opcode.is_control() && payload_len > 125 {
542                        return Err(Error::Protocol("control frame too large"));
543                    }
544
545                    // Frame size check
546                    if payload_len > self.max_frame_size as u64 {
547                        return Err(Error::FrameTooLarge);
548                    }
549
550                    // Get mask if needed
551                    let total_header = header_size + if masked { 4 } else { 0 };
552                    if buf.len() < total_header {
553                        // Save partial header and advance buffer
554                        let to_copy = buf.len().min(14);
555                        self.header_buf[..to_copy].copy_from_slice(&buf[..to_copy]);
556                        self.header_len = to_copy;
557                        buf.advance(to_copy);
558
559                        // Create header without mask before transitioning to Mask state
560                        self.header = Some(FrameHeader {
561                            fin,
562                            rsv1,
563                            rsv2,
564                            rsv3,
565                            opcode,
566                            masked,
567                            payload_len,
568                            mask: None,
569                        });
570
571                        self.state = ParseState::Mask;
572                        return Ok(None);
573                    }
574
575                    let mask = if masked {
576                        Some([
577                            buf[header_size],
578                            buf[header_size + 1],
579                            buf[header_size + 2],
580                            buf[header_size + 3],
581                        ])
582                    } else {
583                        None
584                    };
585
586                    // We have the complete header
587                    buf.advance(total_header);
588
589                    self.header = Some(FrameHeader {
590                        fin,
591                        rsv1,
592                        rsv2,
593                        rsv3,
594                        opcode,
595                        masked,
596                        payload_len,
597                        mask,
598                    });
599                    self.state = ParseState::Payload;
600                }
601
602                ParseState::ExtendedLen16 => {
603                    // Need bytes 2 and 3 for 16-bit length (total 4 bytes for header so far)
604                    let target_len = 4;
605                    let needed = target_len - self.header_len;
606                    if buf.len() < needed {
607                        // Copy what we have and advance buffer
608                        let to_copy = buf.len();
609                        self.header_buf[self.header_len..self.header_len + to_copy]
610                            .copy_from_slice(&buf[..to_copy]);
611                        self.header_len += to_copy;
612                        buf.advance(to_copy);
613                        return Ok(None);
614                    }
615
616                    // Copy remaining length bytes
617                    self.header_buf[self.header_len..target_len].copy_from_slice(&buf[..needed]);
618                    buf.advance(needed);
619                    self.header_len = target_len;
620
621                    let payload_len =
622                        u16::from_be_bytes([self.header_buf[2], self.header_buf[3]]) as u64;
623
624                    if payload_len < 126 {
625                        return Err(Error::Protocol("payload length not minimal"));
626                    }
627
628                    // Continue to mask parsing
629                    self.parse_header_with_len(payload_len)?;
630
631                    if self.header.as_ref().unwrap().masked {
632                        self.state = ParseState::Mask;
633                    } else {
634                        self.state = ParseState::Payload;
635                    }
636                }
637
638                ParseState::ExtendedLen64 => {
639                    // Need bytes 2-9 for 64-bit length (total 10 bytes for header so far)
640                    let target_len = 10;
641                    let needed = target_len - self.header_len;
642                    if buf.len() < needed {
643                        // Copy what we have and advance buffer
644                        let to_copy = buf.len();
645                        self.header_buf[self.header_len..self.header_len + to_copy]
646                            .copy_from_slice(&buf[..to_copy]);
647                        self.header_len += to_copy;
648                        buf.advance(to_copy);
649                        return Ok(None);
650                    }
651
652                    self.header_buf[self.header_len..target_len].copy_from_slice(&buf[..needed]);
653                    buf.advance(needed);
654                    self.header_len = target_len;
655
656                    let payload_len = u64::from_be_bytes([
657                        self.header_buf[2],
658                        self.header_buf[3],
659                        self.header_buf[4],
660                        self.header_buf[5],
661                        self.header_buf[6],
662                        self.header_buf[7],
663                        self.header_buf[8],
664                        self.header_buf[9],
665                    ]);
666
667                    if payload_len <= 0xFFFF {
668                        return Err(Error::Protocol("payload length not minimal"));
669                    }
670                    if payload_len >> 63 != 0 {
671                        return Err(Error::Protocol("payload length MSB must be 0"));
672                    }
673
674                    self.parse_header_with_len(payload_len)?;
675
676                    if self.header.as_ref().unwrap().masked {
677                        self.state = ParseState::Mask;
678                    } else {
679                        self.state = ParseState::Payload;
680                    }
681                }
682
683                ParseState::Mask => {
684                    let header = self.header.as_mut().unwrap();
685                    // Calculate base header size (before mask)
686                    let header_base = if header.payload_len > MEDIUM_MESSAGE_THRESHOLD as u64 {
687                        10 // 2 + 8 bytes for 64-bit length
688                    } else if header.payload_len > SMALL_MESSAGE_THRESHOLD as u64 {
689                        4 // 2 + 2 bytes for 16-bit length
690                    } else {
691                        2 // just the 2 base bytes
692                    };
693                    // Total header with mask is header_base + 4
694                    let target_len = header_base + 4;
695                    let needed = target_len - self.header_len;
696
697                    if DEBUG {
698                        eprintln!(
699                            "[PARSER] Mask state: header_base={}, target_len={}, header_len={}, needed={}, payload_len={}",
700                            header_base, target_len, self.header_len, needed, header.payload_len
701                        );
702                        eprintln!(
703                            "[PARSER] header_buf so far: {:?}",
704                            &self.header_buf[..self.header_len]
705                        );
706                    }
707
708                    if buf.len() < needed {
709                        // Copy what we have and advance buffer
710                        let to_copy = buf.len();
711                        self.header_buf[self.header_len..self.header_len + to_copy]
712                            .copy_from_slice(&buf[..to_copy]);
713                        self.header_len += to_copy;
714                        buf.advance(to_copy);
715                        return Ok(None);
716                    }
717
718                    // Copy remaining bytes needed to complete the mask
719                    self.header_buf[self.header_len..target_len].copy_from_slice(&buf[..needed]);
720                    buf.advance(needed);
721                    self.header_len = target_len;
722
723                    // Extract mask from header_buf at the correct position
724                    header.mask = Some([
725                        self.header_buf[header_base],
726                        self.header_buf[header_base + 1],
727                        self.header_buf[header_base + 2],
728                        self.header_buf[header_base + 3],
729                    ]);
730
731                    self.state = ParseState::Payload;
732                }
733
734                ParseState::Payload => {
735                    let header = self.header.as_ref().unwrap();
736                    let payload_len = header.payload_len as usize;
737
738                    if DEBUG {
739                        eprintln!(
740                            "[PARSER] Payload state: need {} bytes, have {} bytes (opcode: {:?}, fin: {}, rsv1: {})",
741                            payload_len,
742                            buf.len(),
743                            header.opcode,
744                            header.fin,
745                            header.rsv1
746                        );
747                        if !buf.is_empty() {
748                            eprintln!(
749                                "[PARSER] First 16 bytes of buffer: {:?}",
750                                &buf[..buf.len().min(16)]
751                            );
752                        }
753                    }
754
755                    if buf.len() < payload_len {
756                        if DEBUG {
757                            eprintln!("[PARSER] Not enough payload data, waiting...");
758                        }
759                        return Ok(None);
760                    }
761
762                    if DEBUG {
763                        eprintln!(
764                            "[PARSER] Extracting payload of {} bytes, buf will have {} bytes remaining",
765                            payload_len,
766                            buf.len() - payload_len
767                        );
768                    }
769
770                    // Extract and unmask payload
771                    let mut payload = buf.split_to(payload_len);
772
773                    if let Some(mask) = header.mask {
774                        apply_mask(&mut payload, mask);
775                    }
776
777                    let frame = Frame {
778                        header: self.header.take().unwrap(),
779                        payload: payload.freeze(),
780                    };
781
782                    if DEBUG {
783                        eprintln!(
784                            "[PARSER] Frame complete! Resetting parser state. Buffer now has {} bytes",
785                            buf.len()
786                        );
787                    }
788
789                    self.reset();
790                    return Ok(Some(frame));
791                }
792            }
793        }
794    }
795
796    /// Parse header from saved buffer with known length
797    fn parse_header_with_len(&mut self, payload_len: u64) -> Result<()> {
798        let b0 = self.header_buf[0];
799        let b1 = self.header_buf[1];
800
801        let fin = b0 & 0x80 != 0;
802        let rsv1 = b0 & 0x40 != 0;
803        let rsv2 = b0 & 0x20 != 0;
804        let rsv3 = b0 & 0x10 != 0;
805        let opcode = OpCode::from_u8(b0 & 0x0F).ok_or(Error::InvalidFrame("invalid opcode"))?;
806        let masked = b1 & 0x80 != 0;
807
808        if opcode.is_control() && payload_len > 125 {
809            return Err(Error::Protocol("control frame too large"));
810        }
811
812        if payload_len > self.max_frame_size as u64 {
813            return Err(Error::FrameTooLarge);
814        }
815
816        self.header = Some(FrameHeader {
817            fin,
818            rsv1,
819            rsv2,
820            rsv3,
821            opcode,
822            masked,
823            payload_len,
824            mask: None,
825        });
826
827        Ok(())
828    }
829}
830
831/// Encode a frame into a buffer
832///
833/// This is the fast path for frame encoding. For masked frames (client mode),
834/// the payload will be copied and masked.
835#[inline]
836pub fn encode_frame(
837    buf: &mut BytesMut,
838    opcode: OpCode,
839    payload: &[u8],
840    fin: bool,
841    mask: Option<[u8; 4]>,
842) {
843    encode_frame_with_rsv(buf, opcode, payload, fin, mask, false)
844}
845
846/// Encode a frame with RSV1 bit control (for compression)
847///
848/// When `rsv1` is true, sets the RSV1 bit indicating compressed data.
849#[inline]
850pub fn encode_frame_with_rsv(
851    buf: &mut BytesMut,
852    opcode: OpCode,
853    payload: &[u8],
854    fin: bool,
855    mask: Option<[u8; 4]>,
856    rsv1: bool,
857) {
858    let payload_len = payload.len();
859
860    // Calculate header size
861    let ext_len_size = if payload_len > MEDIUM_MESSAGE_THRESHOLD {
862        8
863    } else if payload_len > SMALL_MESSAGE_THRESHOLD {
864        2
865    } else {
866        0
867    };
868    let mask_size = if mask.is_some() { 4 } else { 0 };
869    let header_size = 2 + ext_len_size + mask_size;
870    let total_size = header_size + payload_len;
871
872    // Reserve space for header + payload
873    buf.reserve(total_size);
874
875    // SAFETY: We just reserved enough space, and we'll set the length after writing
876    unsafe {
877        let base = buf.as_mut_ptr().add(buf.len());
878        let mut offset = 0;
879
880        // First byte: FIN + RSV1 + opcode
881        let mut b0 = opcode as u8;
882        if fin {
883            b0 |= 0x80;
884        }
885        if rsv1 {
886            b0 |= 0x40;
887        }
888        base.add(offset).write(b0);
889        offset += 1;
890
891        // Second byte: mask flag + length
892        let mask_bit = if mask.is_some() { 0x80u8 } else { 0x00u8 };
893
894        if payload_len <= SMALL_MESSAGE_THRESHOLD {
895            base.add(offset).write(mask_bit | payload_len as u8);
896            offset += 1;
897        } else if payload_len <= MEDIUM_MESSAGE_THRESHOLD {
898            base.add(offset).write(mask_bit | 126);
899            offset += 1;
900            let len_bytes = (payload_len as u16).to_be_bytes();
901            std::ptr::copy_nonoverlapping(len_bytes.as_ptr(), base.add(offset), 2);
902            offset += 2;
903        } else {
904            base.add(offset).write(mask_bit | 127);
905            offset += 1;
906            let len_bytes = (payload_len as u64).to_be_bytes();
907            std::ptr::copy_nonoverlapping(len_bytes.as_ptr(), base.add(offset), 8);
908            offset += 8;
909        }
910
911        // Mask and payload
912        if let Some(m) = mask {
913            // Write mask
914            std::ptr::copy_nonoverlapping(m.as_ptr(), base.add(offset), 4);
915            offset += 4;
916
917            // Copy and mask payload in a single pass
918            let payload_dst = base.add(offset);
919            encode_payload_masked_inline(payload_dst, payload.as_ptr(), payload_len, m);
920        } else {
921            // Fast path: just copy payload
922            std::ptr::copy_nonoverlapping(payload.as_ptr(), base.add(offset), payload_len);
923        }
924
925        // Update buffer length
926        buf.set_len(buf.len() + total_size);
927    }
928}
929
930/// Inline masking during copy - single pass for masked frames
931///
932/// SAFETY: Caller must ensure dst has at least `len` bytes available
933#[inline]
934unsafe fn encode_payload_masked_inline(dst: *mut u8, src: *const u8, len: usize, mask: [u8; 4]) {
935    unsafe {
936        let mask_u32 = u32::from_ne_bytes(mask);
937
938        // Process 8 bytes at a time for better throughput
939        let mut i = 0;
940
941        // Process 8-byte chunks
942        while i + 8 <= len {
943            let mask_u64 = ((mask_u32 as u64) << 32) | (mask_u32 as u64);
944            let src_val = std::ptr::read_unaligned(src.add(i) as *const u64);
945            let masked = src_val ^ mask_u64;
946            std::ptr::write_unaligned(dst.add(i) as *mut u64, masked);
947            i += 8;
948        }
949
950        // Process 4-byte chunk if remaining
951        if i + 4 <= len {
952            let src_val = std::ptr::read_unaligned(src.add(i) as *const u32);
953            let masked = src_val ^ mask_u32;
954            std::ptr::write_unaligned(dst.add(i) as *mut u32, masked);
955            i += 4;
956        }
957
958        // Process remaining bytes
959        while i < len {
960            dst.add(i).write(src.add(i).read() ^ mask[i & 3]);
961            i += 1;
962        }
963    }
964}
965
966#[cfg(test)]
967mod tests {
968    use super::*;
969
970    #[test]
971    fn test_opcode() {
972        assert!(OpCode::Ping.is_control());
973        assert!(OpCode::Pong.is_control());
974        assert!(OpCode::Close.is_control());
975        assert!(!OpCode::Text.is_control());
976        assert!(!OpCode::Binary.is_control());
977        assert!(OpCode::Text.is_data());
978        assert!(OpCode::Binary.is_data());
979        assert!(OpCode::Continuation.is_data());
980    }
981
982    #[test]
983    fn test_parse_small_unmasked() {
984        let mut parser = FrameParser::new(1024 * 1024, false);
985        let mut buf = BytesMut::from(&[0x81, 0x05, b'h', b'e', b'l', b'l', b'o'][..]);
986
987        let frame = parser.parse(&mut buf).unwrap().unwrap();
988        assert!(frame.header.fin);
989        assert_eq!(frame.header.opcode, OpCode::Text);
990        assert_eq!(frame.payload.as_ref(), b"hello");
991    }
992
993    #[test]
994    fn test_parse_small_masked() {
995        let mut parser = FrameParser::new(1024 * 1024, true);
996        let mask = [0x37, 0xfa, 0x21, 0x3d];
997
998        // "Hello" masked with the above key
999        let mut payload = *b"Hello";
1000        apply_mask(&mut payload, mask);
1001
1002        let mut buf = BytesMut::new();
1003        buf.put_u8(0x81); // FIN + Text
1004        buf.put_u8(0x85); // Masked + length 5
1005        buf.put_slice(&mask);
1006        buf.put_slice(&payload);
1007
1008        let frame = parser.parse(&mut buf).unwrap().unwrap();
1009        assert_eq!(frame.payload.as_ref(), b"Hello");
1010    }
1011
1012    #[test]
1013    fn test_parse_medium_length() {
1014        let mut parser = FrameParser::new(1024 * 1024, false);
1015        let payload = vec![0x42u8; 200];
1016
1017        let mut buf = BytesMut::new();
1018        buf.put_u8(0x82); // FIN + Binary
1019        buf.put_u8(126); // Extended length marker
1020        buf.put_u16(200); // Actual length
1021        buf.put_slice(&payload);
1022
1023        let frame = parser.parse(&mut buf).unwrap().unwrap();
1024        assert_eq!(frame.header.opcode, OpCode::Binary);
1025        assert_eq!(frame.payload.len(), 200);
1026    }
1027
1028    #[test]
1029    fn test_encode_frame() {
1030        let mut buf = BytesMut::new();
1031        encode_frame(&mut buf, OpCode::Text, b"hello", true, None);
1032
1033        assert_eq!(buf[0], 0x81); // FIN + Text
1034        assert_eq!(buf[1], 0x05); // Length 5
1035        assert_eq!(&buf[2..], b"hello");
1036    }
1037
1038    #[test]
1039    fn test_encode_frame_masked() {
1040        let mask = [0x01, 0x02, 0x03, 0x04];
1041        let mut buf = BytesMut::new();
1042        encode_frame(&mut buf, OpCode::Text, b"test", true, Some(mask));
1043
1044        assert_eq!(buf[0], 0x81); // FIN + Text
1045        assert_eq!(buf[1], 0x84); // Masked + Length 4
1046        assert_eq!(&buf[2..6], &mask);
1047
1048        // Unmask and verify
1049        let mut payload = buf[6..].to_vec();
1050        apply_mask(&mut payload, mask);
1051        assert_eq!(&payload, b"test");
1052    }
1053
1054    #[test]
1055    fn test_control_frame_fragmentation() {
1056        let mut parser = FrameParser::new(1024, false);
1057        let mut buf = BytesMut::from(&[0x09, 0x00][..]); // Ping, FIN=0 (invalid)
1058        buf[0] = 0x09; // Ping without FIN
1059
1060        let result = parser.parse(&mut buf);
1061        assert!(result.is_err());
1062    }
1063
1064    #[test]
1065    fn test_close_frame() {
1066        let frame = Frame::close(1000, "goodbye");
1067        assert_eq!(frame.header.opcode, OpCode::Close);
1068
1069        let close = frame.parse_close().unwrap();
1070        assert_eq!(close.code, 1000);
1071        assert_eq!(close.reason, "goodbye");
1072    }
1073}