Skip to main content

fastapi_http/
websocket.rs

1//! WebSocket protocol implementation (RFC 6455).
2//!
3//! This module provides a complete WebSocket implementation built on asupersync's
4//! I/O primitives, with no external dependencies for SHA-1 or base64.
5//!
6//! # Architecture
7//!
8//! The implementation is layered:
9//!
10//! 1. **Handshake** — HTTP upgrade negotiation (101 Switching Protocols)
11//! 2. **Frame codec** — Binary frame parsing/encoding per RFC 6455 §5
12//! 3. **WebSocket** — High-level API matching FastAPI/Starlette semantics
13//!
14//! # Example
15//!
16//! ```ignore
17//! use fastapi_http::websocket::{WebSocket, Message};
18//!
19//! async fn handler(mut ws: WebSocket) {
20//!     ws.accept(None).await.unwrap();
21//!     loop {
22//!         match ws.receive().await {
23//!             Ok(Message::Text(text)) => {
24//!                 ws.send_text(&text).await.unwrap();
25//!             }
26//!             Ok(Message::Close(_, _)) | Err(_) => break,
27//!             Ok(Message::Binary(_)) => {}
28//!             Ok(Message::Ping(_) | Message::Pong(_)) => unreachable!(),
29//!         }
30//!     }
31//! }
32//! ```
33
34use asupersync::io::{AsyncRead, AsyncWrite, ReadBuf};
35use asupersync::net::TcpStream;
36use std::future::poll_fn;
37use std::io;
38use std::pin::Pin;
39use std::task::Poll;
40
41// ============================================================================
42// SHA-1 (RFC 3174) — minimal, safe implementation for WebSocket handshake
43// ============================================================================
44
45/// SHA-1 digest (20 bytes / 160 bits).
46///
47/// Standard SHA-1 variable names (a-e, h0-h4, w[]) follow RFC 3174 exactly.
48#[allow(clippy::many_single_char_names)]
49fn sha1(data: &[u8]) -> [u8; 20] {
50    let mut h0: u32 = 0x6745_2301;
51    let mut h1: u32 = 0xEFCD_AB89;
52    let mut h2: u32 = 0x98BA_DCFE;
53    let mut h3: u32 = 0x1032_5476;
54    let mut h4: u32 = 0xC3D2_E1F0;
55
56    // Pre-processing: pad message
57    let bit_len = (data.len() as u64) * 8;
58    let mut msg = data.to_vec();
59    msg.push(0x80);
60    while (msg.len() % 64) != 56 {
61        msg.push(0);
62    }
63    msg.extend_from_slice(&bit_len.to_be_bytes());
64
65    // Process each 512-bit (64-byte) block
66    for block in msg.chunks_exact(64) {
67        let mut w = [0u32; 80];
68        for (idx, word) in w.iter_mut().take(16).enumerate() {
69            *word = u32::from_be_bytes([
70                block[idx * 4],
71                block[idx * 4 + 1],
72                block[idx * 4 + 2],
73                block[idx * 4 + 3],
74            ]);
75        }
76        for idx in 16..80 {
77            w[idx] = (w[idx - 3] ^ w[idx - 8] ^ w[idx - 14] ^ w[idx - 16]).rotate_left(1);
78        }
79
80        let (mut a, mut b, mut c, mut d, mut e) = (h0, h1, h2, h3, h4);
81
82        #[allow(clippy::needless_range_loop)]
83        for idx in 0..80 {
84            let (f, k) = match idx {
85                0..=19 => ((b & c) | ((!b) & d), 0x5A82_7999_u32),
86                20..=39 => (b ^ c ^ d, 0x6ED9_EBA1_u32),
87                40..=59 => ((b & c) | (b & d) | (c & d), 0x8F1B_BCDC_u32),
88                _ => (b ^ c ^ d, 0xCA62_C1D6_u32),
89            };
90
91            let temp = a
92                .rotate_left(5)
93                .wrapping_add(f)
94                .wrapping_add(e)
95                .wrapping_add(k)
96                .wrapping_add(w[idx]);
97            e = d;
98            d = c;
99            c = b.rotate_left(30);
100            b = a;
101            a = temp;
102        }
103
104        h0 = h0.wrapping_add(a);
105        h1 = h1.wrapping_add(b);
106        h2 = h2.wrapping_add(c);
107        h3 = h3.wrapping_add(d);
108        h4 = h4.wrapping_add(e);
109    }
110
111    let mut result = [0u8; 20];
112    result[0..4].copy_from_slice(&h0.to_be_bytes());
113    result[4..8].copy_from_slice(&h1.to_be_bytes());
114    result[8..12].copy_from_slice(&h2.to_be_bytes());
115    result[12..16].copy_from_slice(&h3.to_be_bytes());
116    result[16..20].copy_from_slice(&h4.to_be_bytes());
117    result
118}
119
120// ============================================================================
121// Base64 encoding — minimal implementation for handshake accept key
122// ============================================================================
123
124const BASE64_CHARS: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
125
126/// Base64-encode bytes to a string.
127fn base64_encode(data: &[u8]) -> String {
128    let mut result = String::with_capacity(data.len().div_ceil(3) * 4);
129    for chunk in data.chunks(3) {
130        let b0 = u32::from(chunk[0]);
131        let b1 = if chunk.len() > 1 {
132            u32::from(chunk[1])
133        } else {
134            0
135        };
136        let b2 = if chunk.len() > 2 {
137            u32::from(chunk[2])
138        } else {
139            0
140        };
141        let triple = (b0 << 16) | (b1 << 8) | b2;
142
143        result.push(BASE64_CHARS[((triple >> 18) & 0x3F) as usize] as char);
144        result.push(BASE64_CHARS[((triple >> 12) & 0x3F) as usize] as char);
145
146        if chunk.len() > 1 {
147            result.push(BASE64_CHARS[((triple >> 6) & 0x3F) as usize] as char);
148        } else {
149            result.push('=');
150        }
151
152        if chunk.len() > 2 {
153            result.push(BASE64_CHARS[(triple & 0x3F) as usize] as char);
154        } else {
155            result.push('=');
156        }
157    }
158    result
159}
160
161// ============================================================================
162// WebSocket constants (RFC 6455)
163// ============================================================================
164
165/// The WebSocket GUID used in the handshake (RFC 6455 §4.2.2).
166const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
167
168/// Default maximum frame payload size (16 MiB).
169pub const DEFAULT_MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
170
171/// Default maximum message size (64 MiB, for multi-frame messages).
172pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024;
173
174// ============================================================================
175// Types
176// ============================================================================
177
178/// WebSocket frame opcode (RFC 6455 §5.2).
179#[derive(Debug, Clone, Copy, PartialEq, Eq)]
180pub enum Opcode {
181    /// Continuation frame (0x0).
182    Continuation,
183    /// Text frame (0x1) — payload must be valid UTF-8.
184    Text,
185    /// Binary frame (0x2).
186    Binary,
187    /// Connection close (0x8).
188    Close,
189    /// Ping (0x9).
190    Ping,
191    /// Pong (0xA).
192    Pong,
193}
194
195impl Opcode {
196    /// Parse an opcode from the low 4 bits of the first frame byte.
197    fn from_u8(value: u8) -> Result<Self, WebSocketError> {
198        match value & 0x0F {
199            0x0 => Ok(Self::Continuation),
200            0x1 => Ok(Self::Text),
201            0x2 => Ok(Self::Binary),
202            0x8 => Ok(Self::Close),
203            0x9 => Ok(Self::Ping),
204            0xA => Ok(Self::Pong),
205            other => Err(WebSocketError::Protocol(format!(
206                "unknown opcode: 0x{other:X}"
207            ))),
208        }
209    }
210
211    fn to_u8(self) -> u8 {
212        match self {
213            Self::Continuation => 0x0,
214            Self::Text => 0x1,
215            Self::Binary => 0x2,
216            Self::Close => 0x8,
217            Self::Ping => 0x9,
218            Self::Pong => 0xA,
219        }
220    }
221
222    /// Returns true for control frames (close, ping, pong).
223    fn is_control(self) -> bool {
224        matches!(self, Self::Close | Self::Ping | Self::Pong)
225    }
226}
227
228/// WebSocket close status code (RFC 6455 §7.4.1).
229#[derive(Debug, Clone, Copy, PartialEq, Eq)]
230pub enum CloseCode {
231    /// Normal closure (1000).
232    Normal,
233    /// Endpoint going away (1001).
234    GoingAway,
235    /// Protocol error (1002).
236    ProtocolError,
237    /// Unsupported data type (1003).
238    UnsupportedData,
239    /// No status code present (1005) — must not be sent in a frame.
240    NoStatusReceived,
241    /// Abnormal closure (1006) — must not be sent in a frame.
242    AbnormalClosure,
243    /// Invalid payload data (1007).
244    InvalidPayload,
245    /// Policy violation (1008).
246    PolicyViolation,
247    /// Message too big (1009).
248    MessageTooBig,
249    /// Missing expected extension (1010).
250    MandatoryExtension,
251    /// Internal server error (1011).
252    InternalError,
253    /// Service restart (1012).
254    ServiceRestart,
255    /// Try again later (1013).
256    TryAgainLater,
257    /// Bad gateway (1014).
258    BadGateway,
259    /// Application-defined or registered code in the 3000-4999 range.
260    Application(u16),
261}
262
263impl CloseCode {
264    /// Convert to the 2-byte wire representation.
265    pub fn to_u16(self) -> u16 {
266        match self {
267            Self::Normal => 1000,
268            Self::GoingAway => 1001,
269            Self::ProtocolError => 1002,
270            Self::UnsupportedData => 1003,
271            Self::NoStatusReceived => 1005,
272            Self::AbnormalClosure => 1006,
273            Self::InvalidPayload => 1007,
274            Self::PolicyViolation => 1008,
275            Self::MessageTooBig => 1009,
276            Self::MandatoryExtension => 1010,
277            Self::InternalError => 1011,
278            Self::ServiceRestart => 1012,
279            Self::TryAgainLater => 1013,
280            Self::BadGateway => 1014,
281            Self::Application(code) => code,
282        }
283    }
284
285    /// Parse from a 2-byte wire value.
286    pub fn from_u16(code: u16) -> Self {
287        match code {
288            1000 => Self::Normal,
289            1001 => Self::GoingAway,
290            1002 => Self::ProtocolError,
291            1003 => Self::UnsupportedData,
292            1005 => Self::NoStatusReceived,
293            1006 => Self::AbnormalClosure,
294            1007 => Self::InvalidPayload,
295            1008 => Self::PolicyViolation,
296            1009 => Self::MessageTooBig,
297            1010 => Self::MandatoryExtension,
298            1011 => Self::InternalError,
299            1012 => Self::ServiceRestart,
300            1013 => Self::TryAgainLater,
301            1014 => Self::BadGateway,
302            3000..=4999 => Self::Application(code),
303            _ => Self::ProtocolError,
304        }
305    }
306}
307
308impl std::fmt::Display for CloseCode {
309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        write!(f, "{}", self.to_u16())
311    }
312}
313
314/// A WebSocket message (assembled from one or more frames).
315#[derive(Debug, Clone, PartialEq, Eq)]
316pub enum Message {
317    /// UTF-8 text message.
318    Text(String),
319    /// Binary message.
320    Binary(Vec<u8>),
321    /// Ping with optional payload (max 125 bytes).
322    Ping(Vec<u8>),
323    /// Pong with optional payload (max 125 bytes).
324    Pong(Vec<u8>),
325    /// Close with optional code and reason.
326    Close(Option<CloseCode>, Option<String>),
327}
328
329/// A raw WebSocket frame.
330#[derive(Debug, Clone)]
331struct Frame {
332    fin: bool,
333    opcode: Opcode,
334    payload: Vec<u8>,
335}
336
337/// WebSocket error type.
338#[derive(Debug)]
339pub enum WebSocketError {
340    /// I/O error on the underlying stream.
341    Io(io::Error),
342    /// Protocol violation.
343    Protocol(String),
344    /// Connection closed.
345    ConnectionClosed,
346    /// Frame or message exceeds configured size limit.
347    MessageTooLarge { size: usize, limit: usize },
348    /// Invalid UTF-8 in a text message.
349    InvalidUtf8,
350    /// Handshake failed.
351    HandshakeFailed(String),
352}
353
354impl std::fmt::Display for WebSocketError {
355    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356        match self {
357            Self::Io(e) => write!(f, "WebSocket I/O error: {e}"),
358            Self::Protocol(msg) => write!(f, "WebSocket protocol error: {msg}"),
359            Self::ConnectionClosed => write!(f, "WebSocket connection closed"),
360            Self::MessageTooLarge { size, limit } => {
361                write!(
362                    f,
363                    "WebSocket message too large: {size} bytes (limit: {limit})"
364                )
365            }
366            Self::InvalidUtf8 => write!(f, "WebSocket: invalid UTF-8 in text message"),
367            Self::HandshakeFailed(msg) => write!(f, "WebSocket handshake failed: {msg}"),
368        }
369    }
370}
371
372impl std::error::Error for WebSocketError {
373    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
374        match self {
375            Self::Io(e) => Some(e),
376            _ => None,
377        }
378    }
379}
380
381impl From<io::Error> for WebSocketError {
382    fn from(e: io::Error) -> Self {
383        Self::Io(e)
384    }
385}
386
387/// Configuration for WebSocket connections.
388#[derive(Debug, Clone)]
389pub struct WebSocketConfig {
390    /// Maximum frame payload size in bytes.
391    pub max_frame_size: usize,
392    /// Maximum message size in bytes (for multi-frame messages).
393    pub max_message_size: usize,
394}
395
396impl Default for WebSocketConfig {
397    fn default() -> Self {
398        Self {
399            max_frame_size: DEFAULT_MAX_FRAME_SIZE,
400            max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
401        }
402    }
403}
404
405// ============================================================================
406// Handshake
407// ============================================================================
408
409/// Compute the `Sec-WebSocket-Accept` value from the client's
410/// `Sec-WebSocket-Key` header (RFC 6455 §4.2.2 step 4).
411pub fn accept_key(client_key: &str) -> String {
412    let mut input = String::with_capacity(client_key.len() + WS_GUID.len());
413    input.push_str(client_key.trim());
414    input.push_str(WS_GUID);
415    base64_encode(&sha1(input.as_bytes()))
416}
417
418/// Validate that an HTTP request is a valid WebSocket upgrade request.
419///
420/// Checks (RFC 6455 §4.2.1):
421/// - Method is GET
422/// - `Upgrade: websocket` header present (case-insensitive)
423/// - `Connection: upgrade` header present (case-insensitive)
424/// - `Sec-WebSocket-Key` header present and non-empty
425/// - `Sec-WebSocket-Version: 13` header present
426///
427/// Returns the `Sec-WebSocket-Key` value on success.
428pub fn validate_upgrade_request(
429    method: &str,
430    headers: &[(String, Vec<u8>)],
431) -> Result<String, WebSocketError> {
432    // Must be GET
433    if !method.eq_ignore_ascii_case("GET") {
434        return Err(WebSocketError::HandshakeFailed(
435            "WebSocket upgrade requires GET method".into(),
436        ));
437    }
438
439    let find_header = |name: &str| -> Option<String> {
440        headers
441            .iter()
442            .find(|(k, _)| k.eq_ignore_ascii_case(name))
443            .and_then(|(_, v)| String::from_utf8(v.clone()).ok())
444    };
445
446    // Check Upgrade header
447    let upgrade = find_header("upgrade")
448        .ok_or_else(|| WebSocketError::HandshakeFailed("missing Upgrade header".into()))?;
449    if !upgrade
450        .split(',')
451        .any(|v| v.trim().eq_ignore_ascii_case("websocket"))
452    {
453        return Err(WebSocketError::HandshakeFailed(
454            "Upgrade header must contain 'websocket'".into(),
455        ));
456    }
457
458    // Check Connection header
459    let connection = find_header("connection")
460        .ok_or_else(|| WebSocketError::HandshakeFailed("missing Connection header".into()))?;
461    if !connection
462        .split(',')
463        .any(|v| v.trim().eq_ignore_ascii_case("upgrade"))
464    {
465        return Err(WebSocketError::HandshakeFailed(
466            "Connection header must contain 'upgrade'".into(),
467        ));
468    }
469
470    // Check Sec-WebSocket-Key
471    let key = find_header("sec-websocket-key").ok_or_else(|| {
472        WebSocketError::HandshakeFailed("missing Sec-WebSocket-Key header".into())
473    })?;
474    let key = key.trim();
475    if key.is_empty() {
476        return Err(WebSocketError::HandshakeFailed(
477            "Sec-WebSocket-Key must not be empty".into(),
478        ));
479    }
480    if fastapi_core::websocket_accept_from_key(key).is_err() {
481        return Err(WebSocketError::HandshakeFailed(
482            "invalid Sec-WebSocket-Key (must be valid base64 with 16 decoded bytes)".into(),
483        ));
484    }
485
486    // Check Sec-WebSocket-Version
487    let version = find_header("sec-websocket-version").ok_or_else(|| {
488        WebSocketError::HandshakeFailed("missing Sec-WebSocket-Version header".into())
489    })?;
490    if version.trim() != "13" {
491        return Err(WebSocketError::HandshakeFailed(format!(
492            "unsupported WebSocket version: {version} (expected 13)"
493        )));
494    }
495
496    Ok(key.to_string())
497}
498
499fn is_valid_subprotocol_token(value: &str) -> bool {
500    // RFC 6455 references HTTP token syntax: 1*<any CHAR except CTLs or separators>.
501    // Restrict to visible ASCII token chars and reject separators/whitespace.
502    const SEPARATORS: &str = "()<>@,;:\\\"/[]?={} \t";
503    !value.is_empty()
504        && value
505            .bytes()
506            .all(|b| b.is_ascii() && (0x21..=0x7E).contains(&b) && !SEPARATORS.contains(b as char))
507}
508
509/// Build the HTTP 101 Switching Protocols response bytes for a WebSocket upgrade.
510///
511/// If `subprotocol` is provided, includes `Sec-WebSocket-Protocol` in the response.
512pub fn build_accept_response(
513    client_key: &str,
514    subprotocol: Option<&str>,
515) -> Result<Vec<u8>, WebSocketError> {
516    let accept = accept_key(client_key);
517    let mut response = format!(
518        "HTTP/1.1 101 Switching Protocols\r\n\
519         Upgrade: websocket\r\n\
520         Connection: Upgrade\r\n\
521         Sec-WebSocket-Accept: {accept}\r\n"
522    );
523    if let Some(proto) = subprotocol {
524        if !is_valid_subprotocol_token(proto) {
525            return Err(WebSocketError::HandshakeFailed(
526                "invalid Sec-WebSocket-Protocol token".into(),
527            ));
528        }
529        response.push_str(&format!("Sec-WebSocket-Protocol: {proto}\r\n"));
530    }
531    response.push_str("\r\n");
532    Ok(response.into_bytes())
533}
534
535// ============================================================================
536// Frame codec
537// ============================================================================
538
539/// Read a single WebSocket frame from the stream.
540///
541/// Handles variable-length payload encoding, masking (client-to-server),
542/// and validates control frame constraints (max 125 bytes, must be FIN).
543async fn read_frame(
544    stream: &mut TcpStream,
545    config: &WebSocketConfig,
546) -> Result<Frame, WebSocketError> {
547    // Read first 2 bytes: FIN/opcode + mask/payload-len
548    let mut header = [0u8; 2];
549    read_exact(stream, &mut header).await?;
550
551    let fin = (header[0] & 0x80) != 0;
552    let rsv = (header[0] >> 4) & 0x07;
553    if rsv != 0 {
554        return Err(WebSocketError::Protocol(
555            "reserved bits must be 0 (no extensions negotiated)".into(),
556        ));
557    }
558
559    let opcode = Opcode::from_u8(header[0])?;
560    let masked = (header[1] & 0x80) != 0;
561    let payload_len_byte = header[1] & 0x7F;
562
563    if !masked {
564        return Err(WebSocketError::Protocol(
565            "client-to-server frames must be masked".into(),
566        ));
567    }
568
569    // Determine actual payload length
570    let payload_len: usize = match payload_len_byte {
571        0..=125 => payload_len_byte as usize,
572        126 => {
573            let mut len_bytes = [0u8; 2];
574            read_exact(stream, &mut len_bytes).await?;
575            u16::from_be_bytes(len_bytes) as usize
576        }
577        _ => {
578            // 127 — 8-byte length
579            let mut len_bytes = [0u8; 8];
580            read_exact(stream, &mut len_bytes).await?;
581            let len = u64::from_be_bytes(len_bytes);
582            // RFC 6455 §5.2: most significant bit MUST be 0
583            if (len >> 63) != 0 {
584                return Err(WebSocketError::Protocol(
585                    "64-bit frame length has most significant bit set".into(),
586                ));
587            }
588            // Check for overflow and excessive size
589            if len > usize::MAX as u64 {
590                return Err(WebSocketError::MessageTooLarge {
591                    size: usize::MAX,
592                    limit: config.max_frame_size,
593                });
594            }
595            len as usize
596        }
597    };
598
599    // Validate control frame constraints (RFC 6455 §5.5)
600    if opcode.is_control() {
601        if !fin {
602            return Err(WebSocketError::Protocol(
603                "control frames must not be fragmented".into(),
604            ));
605        }
606        if payload_len > 125 {
607            return Err(WebSocketError::Protocol(
608                "control frame payload must not exceed 125 bytes".into(),
609            ));
610        }
611    }
612
613    // Check frame size limit
614    if payload_len > config.max_frame_size {
615        return Err(WebSocketError::MessageTooLarge {
616            size: payload_len,
617            limit: config.max_frame_size,
618        });
619    }
620
621    // Read masking key (if present)
622    let mask_key = if masked {
623        let mut key = [0u8; 4];
624        read_exact(stream, &mut key).await?;
625        Some(key)
626    } else {
627        None
628    };
629
630    // Read payload
631    let mut payload = vec![0u8; payload_len];
632    if payload_len > 0 {
633        read_exact(stream, &mut payload).await?;
634    }
635
636    // Unmask payload (XOR with mask key)
637    if let Some(key) = mask_key {
638        for (i, byte) in payload.iter_mut().enumerate() {
639            *byte ^= key[i % 4];
640        }
641    }
642
643    Ok(Frame {
644        fin,
645        opcode,
646        payload,
647    })
648}
649
650/// Write a single WebSocket frame to the stream.
651///
652/// Server-to-client frames are NOT masked per RFC 6455 §5.1.
653async fn write_frame(
654    stream: &mut TcpStream,
655    fin: bool,
656    opcode: Opcode,
657    payload: &[u8],
658) -> Result<(), WebSocketError> {
659    let mut header = Vec::with_capacity(10);
660
661    // First byte: FIN + opcode
662    let first_byte = if fin { 0x80 } else { 0x00 } | opcode.to_u8();
663    header.push(first_byte);
664
665    // Second byte: mask=0 + payload length
666    let len = payload.len();
667    if len < 126 {
668        header.push(len as u8);
669    } else if len <= 0xFFFF {
670        header.push(126);
671        header.extend_from_slice(&(len as u16).to_be_bytes());
672    } else {
673        header.push(127);
674        header.extend_from_slice(&(len as u64).to_be_bytes());
675    }
676
677    // Write header + payload
678    ws_write_all(stream, &header).await?;
679    if !payload.is_empty() {
680        ws_write_all(stream, payload).await?;
681    }
682    ws_flush(stream).await?;
683
684    Ok(())
685}
686
687/// Assemble a complete message from potentially fragmented frames.
688///
689/// Handles continuation frames and interleaved control frames.
690async fn read_message(
691    stream: &mut TcpStream,
692    config: &WebSocketConfig,
693) -> Result<Message, WebSocketError> {
694    let mut message_opcode: Option<Opcode> = None;
695    let mut message_data: Vec<u8> = Vec::new();
696
697    loop {
698        let frame = read_frame(stream, config).await?;
699
700        // Handle control frames (can be interleaved with data frames)
701        if frame.opcode.is_control() {
702            match frame.opcode {
703                Opcode::Close => {
704                    let (code, reason) = parse_close_payload(&frame.payload)?;
705                    return Ok(Message::Close(code, reason));
706                }
707                Opcode::Ping => {
708                    write_frame(stream, true, Opcode::Pong, &frame.payload).await?;
709                    continue;
710                }
711                Opcode::Pong => continue,
712                _ => unreachable!(),
713            }
714        }
715
716        // Data frame handling
717        match frame.opcode {
718            Opcode::Continuation => {
719                if message_opcode.is_none() {
720                    return Err(WebSocketError::Protocol(
721                        "continuation frame without initial frame".into(),
722                    ));
723                }
724            }
725            Opcode::Text | Opcode::Binary => {
726                if message_opcode.is_some() {
727                    return Err(WebSocketError::Protocol(
728                        "new data frame while previous message is incomplete".into(),
729                    ));
730                }
731                message_opcode = Some(frame.opcode);
732            }
733            _ => {}
734        }
735
736        // Check message size limit
737        let new_size = message_data.len() + frame.payload.len();
738        if new_size > config.max_message_size {
739            return Err(WebSocketError::MessageTooLarge {
740                size: new_size,
741                limit: config.max_message_size,
742            });
743        }
744
745        message_data.extend_from_slice(&frame.payload);
746
747        if frame.fin {
748            break;
749        }
750    }
751
752    let opcode = message_opcode
753        .ok_or_else(|| WebSocketError::Protocol("empty message (no data frames)".into()))?;
754
755    match opcode {
756        Opcode::Text => {
757            let text = String::from_utf8(message_data).map_err(|_| WebSocketError::InvalidUtf8)?;
758            Ok(Message::Text(text))
759        }
760        Opcode::Binary => Ok(Message::Binary(message_data)),
761        _ => unreachable!(),
762    }
763}
764
765/// Parse a close frame payload into (code, reason).
766fn parse_close_payload(
767    payload: &[u8],
768) -> Result<(Option<CloseCode>, Option<String>), WebSocketError> {
769    if payload.len() < 2 {
770        if payload.is_empty() {
771            return Ok((None, None));
772        }
773        return Err(WebSocketError::Protocol(
774            "close frame payload must be empty or at least 2 bytes".into(),
775        ));
776    }
777    let code_raw = u16::from_be_bytes([payload[0], payload[1]]);
778    if !is_valid_close_code(code_raw) {
779        return Err(WebSocketError::Protocol(format!(
780            "invalid close code in close frame: {code_raw}"
781        )));
782    }
783    let code = CloseCode::from_u16(code_raw);
784    let reason = if payload.len() > 2 {
785        Some(
786            std::str::from_utf8(&payload[2..])
787                .map_err(|_| WebSocketError::Protocol("close reason must be valid UTF-8".into()))?
788                .to_string(),
789        )
790    } else {
791        None
792    };
793    Ok((Some(code), reason))
794}
795
796/// Build a close frame payload from code and reason.
797fn build_close_payload(code: CloseCode, reason: Option<&str>) -> Result<Vec<u8>, WebSocketError> {
798    if !is_valid_close_code(code.to_u16()) {
799        return Err(WebSocketError::Protocol(format!(
800            "invalid close code for close frame: {}",
801            code.to_u16()
802        )));
803    }
804    let mut payload = Vec::with_capacity(2 + reason.map_or(0, str::len));
805    payload.extend_from_slice(&code.to_u16().to_be_bytes());
806    if let Some(reason_str) = reason {
807        // Truncate reason to fit in 125 bytes total
808        let max_reason = 123; // 125 - 2 bytes for code
809        let mut end = reason_str.len().min(max_reason);
810        while end > 0 && !reason_str.is_char_boundary(end) {
811            end -= 1;
812        }
813        payload.extend_from_slice(&reason_str.as_bytes()[..end]);
814    }
815    Ok(payload)
816}
817
818fn is_valid_close_code(code: u16) -> bool {
819    matches!(
820        code,
821        1000 | 1001 | 1002 | 1003 | 1007 | 1008 | 1009 | 1010 | 1011 | 1012 | 1013 | 1014 | 3000
822            ..=4999
823    )
824}
825
826// ============================================================================
827// WebSocket — high-level API
828// ============================================================================
829
830/// Connection state machine.
831#[derive(Debug, Clone, Copy, PartialEq, Eq)]
832enum WsState {
833    /// Upgrade request received but not yet accepted.
834    Pending,
835    /// Handshake complete, frames can be sent/received.
836    Open,
837    /// Close frame sent, waiting for peer's close response.
838    CloseSent,
839    /// Connection fully closed.
840    Closed,
841}
842
843/// A WebSocket connection.
844///
845/// Wraps a `TcpStream` that has been upgraded from HTTP. Provides a
846/// high-level API matching FastAPI/Starlette semantics.
847///
848/// # Lifecycle
849///
850/// 1. Created by the server after detecting an upgrade request
851/// 2. Handler calls [`accept()`](WebSocket::accept) to complete the handshake
852/// 3. Handler sends/receives messages
853/// 4. Handler calls [`close()`](WebSocket::close) or the peer closes
854///
855/// # Example
856///
857/// ```ignore
858/// async fn chat(mut ws: WebSocket) {
859///     ws.accept(None).await.unwrap();
860///     while let Ok(msg) = ws.receive().await {
861///         match msg {
862///             Message::Text(text) => {
863///                 ws.send_text(&format!("echo: {text}")).await.unwrap();
864///             }
865///             Message::Binary(_data) => {}
866///             // receive() auto-replies to ping and does not surface pong.
867///             Message::Ping(_) | Message::Pong(_) => unreachable!(),
868///             Message::Close(_, _) => break,
869///         }
870///     }
871///     ws.close(CloseCode::Normal, None).await.ok();
872/// }
873/// ```
874pub struct WebSocket {
875    stream: TcpStream,
876    state: WsState,
877    client_key: String,
878    config: WebSocketConfig,
879}
880
881impl WebSocket {
882    /// Create a new WebSocket from an upgraded TCP stream.
883    ///
884    /// The `client_key` is the `Sec-WebSocket-Key` header value from the
885    /// upgrade request.
886    pub fn new(stream: TcpStream, client_key: String) -> Self {
887        Self {
888            stream,
889            state: WsState::Pending,
890            client_key,
891            config: WebSocketConfig::default(),
892        }
893    }
894
895    /// Create a new WebSocket with custom configuration.
896    pub fn with_config(stream: TcpStream, client_key: String, config: WebSocketConfig) -> Self {
897        Self {
898            stream,
899            state: WsState::Pending,
900            client_key,
901            config,
902        }
903    }
904
905    /// Complete the WebSocket handshake by sending the 101 response.
906    ///
907    /// Optionally specify a subprotocol to include in the response.
908    ///
909    /// # Errors
910    ///
911    /// Returns an error if the handshake has already been completed or
912    /// if writing the response fails.
913    pub async fn accept(&mut self, subprotocol: Option<&str>) -> Result<(), WebSocketError> {
914        if self.state != WsState::Pending {
915            return Err(WebSocketError::Protocol(
916                "accept() called on non-pending WebSocket".into(),
917            ));
918        }
919
920        let response_bytes = build_accept_response(&self.client_key, subprotocol)?;
921        ws_write_all(&mut self.stream, &response_bytes).await?;
922        ws_flush(&mut self.stream).await?;
923        self.state = WsState::Open;
924        Ok(())
925    }
926
927    /// Receive the next message from the client.
928    ///
929    /// Automatically responds to ping frames with pong. Returns text,
930    /// binary, and close messages to the caller.
931    ///
932    /// # Errors
933    ///
934    /// Returns an error if the connection is closed or a protocol
935    /// violation occurs.
936    pub async fn receive(&mut self) -> Result<Message, WebSocketError> {
937        self.ensure_can_receive()?;
938        let msg = read_message(&mut self.stream, &self.config).await?;
939        match msg {
940            Message::Close(code, reason) => {
941                // If we haven't sent close yet, echo it back
942                if self.state == WsState::Open {
943                    let payload = match code {
944                        Some(close_code) => build_close_payload(close_code, reason.as_deref())?,
945                        None => Vec::new(),
946                    };
947                    write_frame(&mut self.stream, true, Opcode::Close, &payload)
948                        .await
949                        .ok(); // Best-effort
950                }
951                self.state = WsState::Closed;
952                Ok(Message::Close(code, reason))
953            }
954            _ => Ok(msg),
955        }
956    }
957
958    /// Send a text message.
959    pub async fn send_text(&mut self, text: &str) -> Result<(), WebSocketError> {
960        self.ensure_open()?;
961        write_frame(&mut self.stream, true, Opcode::Text, text.as_bytes()).await
962    }
963
964    /// Send a binary message.
965    pub async fn send_bytes(&mut self, data: &[u8]) -> Result<(), WebSocketError> {
966        self.ensure_open()?;
967        write_frame(&mut self.stream, true, Opcode::Binary, data).await
968    }
969
970    /// Receive a text message.
971    ///
972    /// Skips pong messages, auto-responds to pings. Returns an error
973    /// if a binary or close message is received.
974    pub async fn receive_text(&mut self) -> Result<String, WebSocketError> {
975        match self.receive().await? {
976            Message::Text(text) => Ok(text),
977            Message::Close(code, reason) => Err(WebSocketError::Protocol(format!(
978                "expected text, got close (code={code:?}, reason={reason:?})"
979            ))),
980            other => Err(WebSocketError::Protocol(format!(
981                "expected text message, got {other:?}"
982            ))),
983        }
984    }
985
986    /// Receive a binary message.
987    ///
988    /// Skips pong messages, auto-responds to pings. Returns an error
989    /// if a text or close message is received.
990    pub async fn receive_bytes(&mut self) -> Result<Vec<u8>, WebSocketError> {
991        match self.receive().await? {
992            Message::Binary(data) => Ok(data),
993            Message::Close(code, reason) => Err(WebSocketError::Protocol(format!(
994                "expected binary, got close (code={code:?}, reason={reason:?})"
995            ))),
996            other => Err(WebSocketError::Protocol(format!(
997                "expected binary message, got {other:?}"
998            ))),
999        }
1000    }
1001
1002    /// Send a ping frame with optional payload.
1003    pub async fn ping(&mut self, data: &[u8]) -> Result<(), WebSocketError> {
1004        self.ensure_open()?;
1005        if data.len() > 125 {
1006            return Err(WebSocketError::Protocol(
1007                "ping payload must not exceed 125 bytes".into(),
1008            ));
1009        }
1010        write_frame(&mut self.stream, true, Opcode::Ping, data).await
1011    }
1012
1013    /// Send a pong frame with optional payload.
1014    pub async fn pong(&mut self, data: &[u8]) -> Result<(), WebSocketError> {
1015        self.ensure_open()?;
1016        if data.len() > 125 {
1017            return Err(WebSocketError::Protocol(
1018                "pong payload must not exceed 125 bytes".into(),
1019            ));
1020        }
1021        write_frame(&mut self.stream, true, Opcode::Pong, data).await
1022    }
1023
1024    /// Initiate a close handshake.
1025    ///
1026    /// Sends a close frame and transitions to `CloseSent`. The peer should
1027    /// respond with its own close frame.
1028    pub async fn close(
1029        &mut self,
1030        code: CloseCode,
1031        reason: Option<&str>,
1032    ) -> Result<(), WebSocketError> {
1033        if self.state == WsState::Closed || self.state == WsState::CloseSent {
1034            return Ok(());
1035        }
1036        if self.state == WsState::Pending {
1037            self.state = WsState::Closed;
1038            return Ok(());
1039        }
1040
1041        let payload = build_close_payload(code, reason)?;
1042        write_frame(&mut self.stream, true, Opcode::Close, &payload).await?;
1043        self.state = WsState::CloseSent;
1044        Ok(())
1045    }
1046
1047    /// Returns `true` if the connection is open and can send/receive messages.
1048    pub fn is_open(&self) -> bool {
1049        self.state == WsState::Open
1050    }
1051
1052    /// Returns the current connection state.
1053    pub fn state(&self) -> &'static str {
1054        match self.state {
1055            WsState::Pending => "pending",
1056            WsState::Open => "open",
1057            WsState::CloseSent => "close_sent",
1058            WsState::Closed => "closed",
1059        }
1060    }
1061
1062    fn ensure_open(&self) -> Result<(), WebSocketError> {
1063        match self.state {
1064            WsState::Open => Ok(()),
1065            WsState::Pending => Err(WebSocketError::Protocol(
1066                "must call accept() before sending/receiving".into(),
1067            )),
1068            WsState::CloseSent | WsState::Closed => Err(WebSocketError::ConnectionClosed),
1069        }
1070    }
1071
1072    fn ensure_can_receive(&self) -> Result<(), WebSocketError> {
1073        match self.state {
1074            WsState::Open | WsState::CloseSent => Ok(()),
1075            WsState::Pending => Err(WebSocketError::Protocol(
1076                "must call accept() before sending/receiving".into(),
1077            )),
1078            WsState::Closed => Err(WebSocketError::ConnectionClosed),
1079        }
1080    }
1081}
1082
1083// ============================================================================
1084// I/O helpers (using asupersync primitives)
1085// ============================================================================
1086
1087/// Read exactly `buf.len()` bytes from the stream.
1088async fn read_exact(stream: &mut TcpStream, buf: &mut [u8]) -> Result<(), WebSocketError> {
1089    let mut offset = 0;
1090    while offset < buf.len() {
1091        let n = ws_read(stream, &mut buf[offset..]).await?;
1092        if n == 0 {
1093            return Err(WebSocketError::ConnectionClosed);
1094        }
1095        offset += n;
1096    }
1097    Ok(())
1098}
1099
1100/// Read some bytes from the stream.
1101async fn ws_read(stream: &mut TcpStream, buf: &mut [u8]) -> Result<usize, WebSocketError> {
1102    poll_fn(|cx| {
1103        let mut read_buf = ReadBuf::new(buf);
1104        match Pin::new(&mut *stream).poll_read(cx, &mut read_buf) {
1105            Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len())),
1106            Poll::Ready(Err(e)) => Poll::Ready(Err(WebSocketError::Io(e))),
1107            Poll::Pending => Poll::Pending,
1108        }
1109    })
1110    .await
1111}
1112
1113/// Write all bytes to the stream.
1114async fn ws_write_all(stream: &mut TcpStream, mut buf: &[u8]) -> Result<(), WebSocketError> {
1115    while !buf.is_empty() {
1116        let n = poll_fn(|cx| Pin::new(&mut *stream).poll_write(cx, buf))
1117            .await
1118            .map_err(WebSocketError::Io)?;
1119        if n == 0 {
1120            return Err(WebSocketError::Io(io::Error::new(
1121                io::ErrorKind::WriteZero,
1122                "failed to write to WebSocket stream",
1123            )));
1124        }
1125        buf = &buf[n..];
1126    }
1127    Ok(())
1128}
1129
1130/// Flush the stream.
1131async fn ws_flush(stream: &mut TcpStream) -> Result<(), WebSocketError> {
1132    poll_fn(|cx| Pin::new(&mut *stream).poll_flush(cx))
1133        .await
1134        .map_err(WebSocketError::Io)
1135}
1136
1137// ============================================================================
1138// Tests
1139// ============================================================================
1140
1141#[cfg(test)]
1142mod tests {
1143    use super::*;
1144
1145    #[test]
1146    fn test_sha1_empty() {
1147        let result = sha1(b"");
1148        let expected: [u8; 20] = [
1149            0xda, 0x39, 0xa3, 0xee, 0x5e, 0x6b, 0x4b, 0x0d, 0x32, 0x55, 0xbf, 0xef, 0x95, 0x60,
1150            0x18, 0x90, 0xaf, 0xd8, 0x07, 0x09,
1151        ];
1152        assert_eq!(result, expected);
1153    }
1154
1155    #[test]
1156    fn test_sha1_abc() {
1157        let result = sha1(b"abc");
1158        let expected: [u8; 20] = [
1159            0xa9, 0x99, 0x3e, 0x36, 0x47, 0x06, 0x81, 0x6a, 0xba, 0x3e, 0x25, 0x71, 0x78, 0x50,
1160            0xc2, 0x6c, 0x9c, 0xd0, 0xd8, 0x9d,
1161        ];
1162        assert_eq!(result, expected);
1163    }
1164
1165    #[test]
1166    fn test_base64_encode() {
1167        assert_eq!(base64_encode(b""), "");
1168        assert_eq!(base64_encode(b"f"), "Zg==");
1169        assert_eq!(base64_encode(b"fo"), "Zm8=");
1170        assert_eq!(base64_encode(b"foo"), "Zm9v");
1171        assert_eq!(base64_encode(b"foobar"), "Zm9vYmFy");
1172    }
1173
1174    #[test]
1175    fn test_accept_key() {
1176        // RFC 6455 §4.2.2 example
1177        let key = accept_key("dGhlIHNhbXBsZSBub25jZQ==");
1178        assert_eq!(key, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
1179    }
1180
1181    #[test]
1182    fn test_close_code_roundtrip() {
1183        let codes = [
1184            CloseCode::Normal,
1185            CloseCode::GoingAway,
1186            CloseCode::ProtocolError,
1187            CloseCode::UnsupportedData,
1188            CloseCode::InvalidPayload,
1189            CloseCode::PolicyViolation,
1190            CloseCode::MessageTooBig,
1191            CloseCode::MandatoryExtension,
1192            CloseCode::InternalError,
1193            CloseCode::ServiceRestart,
1194            CloseCode::TryAgainLater,
1195            CloseCode::BadGateway,
1196            CloseCode::Application(3000),
1197            CloseCode::Application(4000),
1198            CloseCode::Application(4999),
1199        ];
1200        for code in codes {
1201            assert_eq!(CloseCode::from_u16(code.to_u16()), code);
1202        }
1203    }
1204
1205    #[test]
1206    fn test_opcode_roundtrip() {
1207        let opcodes = [
1208            Opcode::Continuation,
1209            Opcode::Text,
1210            Opcode::Binary,
1211            Opcode::Close,
1212            Opcode::Ping,
1213            Opcode::Pong,
1214        ];
1215        for op in opcodes {
1216            assert_eq!(Opcode::from_u8(op.to_u8()).unwrap(), op);
1217        }
1218    }
1219
1220    #[test]
1221    fn test_opcode_unknown() {
1222        assert!(Opcode::from_u8(0x03).is_err());
1223        assert!(Opcode::from_u8(0x07).is_err());
1224    }
1225
1226    #[test]
1227    fn test_opcode_is_control() {
1228        assert!(!Opcode::Continuation.is_control());
1229        assert!(!Opcode::Text.is_control());
1230        assert!(!Opcode::Binary.is_control());
1231        assert!(Opcode::Close.is_control());
1232        assert!(Opcode::Ping.is_control());
1233        assert!(Opcode::Pong.is_control());
1234    }
1235
1236    #[test]
1237    fn test_build_accept_response_basic() {
1238        let resp = build_accept_response("dGhlIHNhbXBsZSBub25jZQ==", None)
1239            .expect("response build should succeed");
1240        let resp_str = String::from_utf8(resp).unwrap();
1241        assert!(resp_str.starts_with("HTTP/1.1 101 Switching Protocols\r\n"));
1242        assert!(resp_str.contains("Upgrade: websocket\r\n"));
1243        assert!(resp_str.contains("Connection: Upgrade\r\n"));
1244        assert!(resp_str.contains("Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"));
1245        assert!(resp_str.ends_with("\r\n\r\n"));
1246    }
1247
1248    #[test]
1249    fn test_build_accept_response_with_subprotocol() {
1250        let resp = build_accept_response("dGhlIHNhbXBsZSBub25jZQ==", Some("graphql-ws"))
1251            .expect("response build should succeed");
1252        let resp_str = String::from_utf8(resp).unwrap();
1253        assert!(resp_str.contains("Sec-WebSocket-Protocol: graphql-ws\r\n"));
1254    }
1255
1256    #[test]
1257    fn test_build_accept_response_rejects_invalid_subprotocol_token() {
1258        let err =
1259            build_accept_response("dGhlIHNhbXBsZSBub25jZQ==", Some("graphql-ws\r\nX-Evil: 1"))
1260                .expect_err("invalid subprotocol token must fail");
1261        assert!(matches!(err, WebSocketError::HandshakeFailed(_)));
1262        assert!(
1263            err.to_string()
1264                .contains("invalid Sec-WebSocket-Protocol token")
1265        );
1266    }
1267
1268    #[test]
1269    fn test_validate_upgrade_request_valid() {
1270        let headers = vec![
1271            ("Upgrade".into(), b"websocket".to_vec()),
1272            ("Connection".into(), b"upgrade".to_vec()),
1273            (
1274                "Sec-WebSocket-Key".into(),
1275                b"dGhlIHNhbXBsZSBub25jZQ==".to_vec(),
1276            ),
1277            ("Sec-WebSocket-Version".into(), b"13".to_vec()),
1278        ];
1279        let result = validate_upgrade_request("GET", &headers);
1280        assert!(result.is_ok());
1281        assert_eq!(result.unwrap(), "dGhlIHNhbXBsZSBub25jZQ==");
1282    }
1283
1284    #[test]
1285    fn test_validate_upgrade_request_wrong_method() {
1286        let headers = vec![
1287            ("Upgrade".into(), b"websocket".to_vec()),
1288            ("Connection".into(), b"upgrade".to_vec()),
1289            (
1290                "Sec-WebSocket-Key".into(),
1291                b"dGhlIHNhbXBsZSBub25jZQ==".to_vec(),
1292            ),
1293            ("Sec-WebSocket-Version".into(), b"13".to_vec()),
1294        ];
1295        assert!(validate_upgrade_request("POST", &headers).is_err());
1296    }
1297
1298    #[test]
1299    fn test_validate_upgrade_request_missing_upgrade() {
1300        let headers = vec![
1301            ("Connection".into(), b"upgrade".to_vec()),
1302            (
1303                "Sec-WebSocket-Key".into(),
1304                b"dGhlIHNhbXBsZSBub25jZQ==".to_vec(),
1305            ),
1306            ("Sec-WebSocket-Version".into(), b"13".to_vec()),
1307        ];
1308        assert!(validate_upgrade_request("GET", &headers).is_err());
1309    }
1310
1311    #[test]
1312    fn test_validate_upgrade_request_wrong_version() {
1313        let headers = vec![
1314            ("Upgrade".into(), b"websocket".to_vec()),
1315            ("Connection".into(), b"upgrade".to_vec()),
1316            (
1317                "Sec-WebSocket-Key".into(),
1318                b"dGhlIHNhbXBsZSBub25jZQ==".to_vec(),
1319            ),
1320            ("Sec-WebSocket-Version".into(), b"8".to_vec()),
1321        ];
1322        assert!(validate_upgrade_request("GET", &headers).is_err());
1323    }
1324
1325    #[test]
1326    fn test_validate_upgrade_request_invalid_key_base64() {
1327        let headers = vec![
1328            ("Upgrade".into(), b"websocket".to_vec()),
1329            ("Connection".into(), b"upgrade".to_vec()),
1330            ("Sec-WebSocket-Key".into(), b"not-base64".to_vec()),
1331            ("Sec-WebSocket-Version".into(), b"13".to_vec()),
1332        ];
1333        assert!(validate_upgrade_request("GET", &headers).is_err());
1334    }
1335
1336    #[test]
1337    fn test_validate_upgrade_request_invalid_key_length() {
1338        let headers = vec![
1339            ("Upgrade".into(), b"websocket".to_vec()),
1340            ("Connection".into(), b"upgrade".to_vec()),
1341            ("Sec-WebSocket-Key".into(), b"Zm9v".to_vec()),
1342            ("Sec-WebSocket-Version".into(), b"13".to_vec()),
1343        ];
1344        assert!(validate_upgrade_request("GET", &headers).is_err());
1345    }
1346
1347    #[test]
1348    fn test_close_payload_roundtrip() {
1349        let payload = build_close_payload(CloseCode::Normal, Some("goodbye")).unwrap();
1350        let (code, reason) = parse_close_payload(&payload).unwrap();
1351        assert_eq!(code, Some(CloseCode::Normal));
1352        assert_eq!(reason, Some("goodbye".into()));
1353    }
1354
1355    #[test]
1356    fn test_close_payload_no_reason() {
1357        let payload = build_close_payload(CloseCode::GoingAway, None).unwrap();
1358        let (code, reason) = parse_close_payload(&payload).unwrap();
1359        assert_eq!(code, Some(CloseCode::GoingAway));
1360        assert_eq!(reason, None);
1361    }
1362
1363    #[test]
1364    fn test_close_payload_empty() {
1365        let (code, reason) = parse_close_payload(&[]).unwrap();
1366        assert_eq!(code, None);
1367        assert_eq!(reason, None);
1368    }
1369
1370    #[test]
1371    fn test_close_payload_len_one_is_invalid() {
1372        let err = parse_close_payload(&[0x03]).expect_err("len=1 close payload must fail");
1373        assert!(matches!(err, WebSocketError::Protocol(_)));
1374    }
1375
1376    #[test]
1377    fn test_close_payload_invalid_code_is_rejected() {
1378        let err = parse_close_payload(&[0x03, 0xEE]).expect_err("1006 must be rejected");
1379        assert!(matches!(err, WebSocketError::Protocol(_)));
1380    }
1381
1382    #[test]
1383    fn test_build_close_payload_rejects_unsendable_code() {
1384        let err = build_close_payload(CloseCode::NoStatusReceived, None)
1385            .expect_err("1005 must not be sent");
1386        assert!(matches!(err, WebSocketError::Protocol(_)));
1387    }
1388
1389    #[test]
1390    fn test_build_close_payload_truncates_on_utf8_boundary() {
1391        let reason = "é".repeat(100); // 200 bytes UTF-8.
1392        let payload = build_close_payload(CloseCode::Normal, Some(&reason)).unwrap();
1393        assert!(payload.len() <= 125);
1394        let parsed =
1395            std::str::from_utf8(&payload[2..]).expect("reason bytes must stay valid UTF-8");
1396        assert!(!parsed.is_empty());
1397    }
1398
1399    #[test]
1400    fn test_message_equality() {
1401        assert_eq!(Message::Text("hello".into()), Message::Text("hello".into()));
1402        assert_eq!(
1403            Message::Binary(vec![1, 2, 3]),
1404            Message::Binary(vec![1, 2, 3])
1405        );
1406        assert_ne!(
1407            Message::Text("hello".into()),
1408            Message::Binary(b"hello".to_vec())
1409        );
1410    }
1411
1412    #[test]
1413    fn test_websocket_config_default() {
1414        let config = WebSocketConfig::default();
1415        assert_eq!(config.max_frame_size, DEFAULT_MAX_FRAME_SIZE);
1416        assert_eq!(config.max_message_size, DEFAULT_MAX_MESSAGE_SIZE);
1417    }
1418}