Skip to main content

fastapi_core/
websocket.rs

1//! WebSocket protocol support (RFC 6455).
2//!
3//! This module provides:
4//! - WebSocket handshake helpers (`Sec-WebSocket-Accept`)
5//! - A minimal frame codec (mask/unmask, ping/pong/close, text/binary)
6//!
7//! Design constraints for this project:
8//! - No Tokio
9//! - Minimal dependencies (implement SHA1 + base64 locally)
10//! - Cancel-correct: all I/O is async and can be cancelled via asupersync
11
12use asupersync::io::{AsyncRead, AsyncWrite, ReadBuf};
13use asupersync::net::TcpStream;
14use std::future::poll_fn;
15use std::io;
16use std::pin::Pin;
17use std::task::Poll;
18
19/// The GUID used for computing `Sec-WebSocket-Accept`.
20pub const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
21const MAX_TEXT_MESSAGE_BYTES: usize = 64 * 1024 * 1024;
22const MAX_CONTROL_PAYLOAD_BYTES: usize = 125;
23const MAX_CLOSE_REASON_BYTES: usize = 123;
24const CLOSE_CODE_PROTOCOL_ERROR: u16 = 1002;
25const CLOSE_CODE_UNSUPPORTED_DATA: u16 = 1003;
26const CLOSE_CODE_INVALID_PAYLOAD: u16 = 1007;
27const CLOSE_CODE_MESSAGE_TOO_BIG: u16 = 1009;
28
29/// WebSocket handshake error.
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum WebSocketHandshakeError {
32    /// Missing required header.
33    MissingHeader(&'static str),
34    /// Invalid base64 in `Sec-WebSocket-Key`.
35    InvalidKeyBase64,
36    /// Invalid key length (decoded bytes must be 16).
37    InvalidKeyLength { decoded_len: usize },
38}
39
40impl std::fmt::Display for WebSocketHandshakeError {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        match self {
43            Self::MissingHeader(h) => write!(f, "missing required websocket header: {h}"),
44            Self::InvalidKeyBase64 => write!(f, "invalid Sec-WebSocket-Key (base64 decode failed)"),
45            Self::InvalidKeyLength { decoded_len } => write!(
46                f,
47                "invalid Sec-WebSocket-Key (decoded length {decoded_len}, expected 16)"
48            ),
49        }
50    }
51}
52
53impl std::error::Error for WebSocketHandshakeError {}
54
55/// Compute `Sec-WebSocket-Accept` from `Sec-WebSocket-Key` (RFC 6455).
56///
57/// Validates that the key is base64 and decodes to 16 bytes (as required by RFC 6455).
58pub fn websocket_accept_from_key(key: &str) -> Result<String, WebSocketHandshakeError> {
59    let key = key.trim();
60    if key.is_empty() {
61        return Err(WebSocketHandshakeError::MissingHeader("sec-websocket-key"));
62    }
63
64    let decoded = base64_decode(key).ok_or(WebSocketHandshakeError::InvalidKeyBase64)?;
65    if decoded.len() != 16 {
66        return Err(WebSocketHandshakeError::InvalidKeyLength {
67            decoded_len: decoded.len(),
68        });
69    }
70
71    let mut input = Vec::with_capacity(key.len() + WS_GUID.len());
72    input.extend_from_slice(key.as_bytes());
73    input.extend_from_slice(WS_GUID.as_bytes());
74
75    let digest = sha1(&input);
76    Ok(base64_encode(&digest))
77}
78
79/// WebSocket opcode.
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81#[repr(u8)]
82pub enum OpCode {
83    Continuation = 0x0,
84    Text = 0x1,
85    Binary = 0x2,
86    Close = 0x8,
87    Ping = 0x9,
88    Pong = 0xA,
89}
90
91impl OpCode {
92    fn from_u8(b: u8) -> Option<Self> {
93        match b {
94            0x0 => Some(Self::Continuation),
95            0x1 => Some(Self::Text),
96            0x2 => Some(Self::Binary),
97            0x8 => Some(Self::Close),
98            0x9 => Some(Self::Ping),
99            0xA => Some(Self::Pong),
100            _ => None,
101        }
102    }
103
104    fn is_control(self) -> bool {
105        matches!(self, Self::Close | Self::Ping | Self::Pong)
106    }
107}
108
109/// A single WebSocket frame.
110#[derive(Debug, Clone, PartialEq, Eq)]
111pub struct Frame {
112    pub fin: bool,
113    pub opcode: OpCode,
114    pub payload: Vec<u8>,
115}
116
117/// WebSocket protocol error.
118#[derive(Debug)]
119pub enum WebSocketError {
120    Io(io::Error),
121    Protocol(&'static str),
122    Utf8(std::str::Utf8Error),
123    MessageTooLarge { size: usize, limit: usize },
124}
125
126impl std::fmt::Display for WebSocketError {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        match self {
129            Self::Io(e) => write!(f, "websocket I/O error: {e}"),
130            Self::Protocol(msg) => write!(f, "websocket protocol error: {msg}"),
131            Self::Utf8(e) => write!(f, "invalid utf-8 in websocket text frame: {e}"),
132            Self::MessageTooLarge { size, limit } => {
133                write!(
134                    f,
135                    "websocket message too large: {size} bytes (limit {limit})"
136                )
137            }
138        }
139    }
140}
141
142impl std::error::Error for WebSocketError {
143    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
144        match self {
145            Self::Io(e) => Some(e),
146            Self::Utf8(e) => Some(e),
147            Self::Protocol(_) | Self::MessageTooLarge { .. } => None,
148        }
149    }
150}
151
152impl From<io::Error> for WebSocketError {
153    fn from(e: io::Error) -> Self {
154        Self::Io(e)
155    }
156}
157
158impl From<std::str::Utf8Error> for WebSocketError {
159    fn from(e: std::str::Utf8Error) -> Self {
160        Self::Utf8(e)
161    }
162}
163
164/// A WebSocket connection (server-side).
165///
166/// Notes:
167/// - Server -> client frames are not masked.
168/// - Client -> server frames must be masked (enforced).
169#[derive(Debug)]
170pub struct WebSocket {
171    stream: TcpStream,
172    rx: Vec<u8>,
173}
174
175impl WebSocket {
176    /// Create a websocket from a TCP stream and an optional prefix of already-buffered bytes.
177    #[must_use]
178    pub fn new(stream: TcpStream, buffered: Vec<u8>) -> Self {
179        Self {
180            stream,
181            rx: buffered,
182        }
183    }
184
185    /// Read the next frame.
186    pub async fn read_frame(&mut self) -> Result<Frame, WebSocketError> {
187        let header = self.read_exact_buf(2).await?;
188        let b0 = header[0];
189        let b1 = header[1];
190
191        let fin = (b0 & 0x80) != 0;
192        let rsv = (b0 >> 4) & 0x07;
193        if rsv != 0 {
194            return Err(WebSocketError::Protocol(
195                "reserved bits must be 0 (no extensions negotiated)",
196            ));
197        }
198        let opcode =
199            OpCode::from_u8(b0 & 0x0f).ok_or(WebSocketError::Protocol("invalid opcode"))?;
200        let masked = (b1 & 0x80) != 0;
201        let mut len7 = u64::from(b1 & 0x7f);
202
203        if opcode.is_control() && !fin {
204            return Err(WebSocketError::Protocol(
205                "control frames must not be fragmented",
206            ));
207        }
208
209        if len7 == 126 {
210            let b = self.read_exact_buf(2).await?;
211            len7 = u64::from(u16::from_be_bytes([b[0], b[1]]));
212        } else if len7 == 127 {
213            let b = self.read_exact_buf(8).await?;
214            len7 = u64::from_be_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]);
215            // Most implementations reject lengths with the high bit set (non-minimal encoding).
216            if (len7 >> 63) != 0 {
217                return Err(WebSocketError::Protocol("invalid 64-bit length"));
218            }
219        }
220
221        if !masked {
222            return Err(WebSocketError::Protocol(
223                "client->server frames must be masked",
224            ));
225        }
226        let payload_len = usize::try_from(len7).map_err(|_| WebSocketError::MessageTooLarge {
227            size: usize::MAX,
228            limit: MAX_TEXT_MESSAGE_BYTES,
229        })?;
230
231        if opcode.is_control() && payload_len > 125 {
232            return Err(WebSocketError::Protocol("control frame too large"));
233        }
234        if payload_len > MAX_TEXT_MESSAGE_BYTES {
235            return Err(WebSocketError::MessageTooLarge {
236                size: payload_len,
237                limit: MAX_TEXT_MESSAGE_BYTES,
238            });
239        }
240
241        let mask = self.read_exact_buf(4).await?;
242
243        let mut payload = self.read_exact_buf(payload_len).await?;
244        for (i, b) in payload.iter_mut().enumerate() {
245            *b ^= mask[i & 3];
246        }
247
248        Ok(Frame {
249            fin,
250            opcode,
251            payload,
252        })
253    }
254
255    /// Write a frame to the peer (server-side, unmasked).
256    pub async fn write_frame(&mut self, frame: &Frame) -> Result<(), WebSocketError> {
257        validate_outgoing_frame(frame)?;
258
259        let mut out = Vec::with_capacity(2 + frame.payload.len() + 8);
260        let b0 = (if frame.fin { 0x80 } else { 0 }) | (frame.opcode as u8);
261        out.push(b0);
262
263        let len = u64::try_from(frame.payload.len())
264            .map_err(|_| WebSocketError::Protocol("len too large"))?;
265        if len <= 125 {
266            out.push(len as u8);
267        } else if let Ok(len16) = u16::try_from(len) {
268            out.push(126);
269            out.extend_from_slice(&len16.to_be_bytes());
270        } else {
271            out.push(127);
272            out.extend_from_slice(&len.to_be_bytes());
273        }
274
275        out.extend_from_slice(&frame.payload);
276        write_all(&mut self.stream, &out).await?;
277        flush(&mut self.stream).await?;
278        Ok(())
279    }
280
281    /// Convenience: read a text message.
282    pub async fn read_text(&mut self) -> Result<String, WebSocketError> {
283        self.read_text_or_close()
284            .await?
285            .ok_or(WebSocketError::Protocol("websocket closed"))
286    }
287
288    /// Convenience: read a text message, transparently handling ping/pong/close.
289    ///
290    /// Behavior:
291    /// - `Ping` frames are answered with `Pong` (same payload) and ignored.
292    /// - `Pong` frames are ignored.
293    /// - `Close` frames are replied to with a `Close` echo and return `Ok(None)`.
294    /// - Any non-text data frame returns a protocol error.
295    pub async fn read_text_or_close(&mut self) -> Result<Option<String>, WebSocketError> {
296        let mut text_fragments: Vec<u8> = Vec::new();
297        let mut collecting_text_fragments = false;
298
299        loop {
300            let frame = match self.read_frame().await {
301                Ok(frame) => frame,
302                Err(err @ WebSocketError::MessageTooLarge { .. }) => {
303                    let _ = self.send_close_code(CLOSE_CODE_MESSAGE_TOO_BIG).await;
304                    return Err(err);
305                }
306                Err(err @ WebSocketError::Protocol(_)) => {
307                    // Malformed frames should trigger a protocol close frame.
308                    let _ = self.send_close_code(CLOSE_CODE_PROTOCOL_ERROR).await;
309                    return Err(err);
310                }
311                Err(err) => return Err(err),
312            };
313            match frame.opcode {
314                OpCode::Text => {
315                    if collecting_text_fragments {
316                        let _ = self.send_close_code(CLOSE_CODE_PROTOCOL_ERROR).await;
317                        return Err(WebSocketError::Protocol(
318                            "new text frame before fragmented text completed",
319                        ));
320                    }
321                    if frame.fin {
322                        match std::str::from_utf8(&frame.payload) {
323                            Ok(s) => return Ok(Some(s.to_string())),
324                            Err(err) => {
325                                let _ = self.send_close_code(CLOSE_CODE_INVALID_PAYLOAD).await;
326                                return Err(WebSocketError::Utf8(err));
327                            }
328                        }
329                    }
330
331                    if frame.payload.len() > MAX_TEXT_MESSAGE_BYTES {
332                        let _ = self.send_close_code(CLOSE_CODE_MESSAGE_TOO_BIG).await;
333                        return Err(WebSocketError::Protocol("text message too large"));
334                    }
335                    text_fragments.extend_from_slice(&frame.payload);
336                    collecting_text_fragments = true;
337                }
338                OpCode::Ping => {
339                    self.send_pong(&frame.payload).await?;
340                }
341                OpCode::Pong => {}
342                OpCode::Close => {
343                    if !is_valid_close_payload(&frame.payload) {
344                        // RFC 6455: malformed close payload is a protocol error.
345                        let _ = self.send_close_code(CLOSE_CODE_PROTOCOL_ERROR).await;
346                        return Err(WebSocketError::Protocol("invalid close frame payload"));
347                    }
348                    // Echo the close payload (if any) and let the caller exit cleanly.
349                    let close = Frame {
350                        fin: true,
351                        opcode: OpCode::Close,
352                        payload: frame.payload,
353                    };
354                    let _ = self.write_frame(&close).await;
355                    return Ok(None);
356                }
357                OpCode::Binary => {
358                    let _ = self.send_close_code(CLOSE_CODE_UNSUPPORTED_DATA).await;
359                    return Err(WebSocketError::Protocol(
360                        "expected text frame, got binary frame",
361                    ));
362                }
363                OpCode::Continuation => {
364                    if !collecting_text_fragments {
365                        let _ = self.send_close_code(CLOSE_CODE_PROTOCOL_ERROR).await;
366                        return Err(WebSocketError::Protocol("unexpected continuation frame"));
367                    }
368
369                    let next_size = text_fragments.len().saturating_add(frame.payload.len());
370                    if next_size > MAX_TEXT_MESSAGE_BYTES {
371                        let _ = self.send_close_code(CLOSE_CODE_MESSAGE_TOO_BIG).await;
372                        return Err(WebSocketError::Protocol("text message too large"));
373                    }
374                    text_fragments.extend_from_slice(&frame.payload);
375
376                    if frame.fin {
377                        match std::str::from_utf8(&text_fragments) {
378                            Ok(s) => return Ok(Some(s.to_string())),
379                            Err(err) => {
380                                let _ = self.send_close_code(CLOSE_CODE_INVALID_PAYLOAD).await;
381                                return Err(WebSocketError::Utf8(err));
382                            }
383                        }
384                    }
385                }
386            }
387        }
388    }
389
390    /// Send a `Pong` control frame (server-side, unmasked).
391    pub async fn send_pong(&mut self, payload: &[u8]) -> Result<(), WebSocketError> {
392        if payload.len() > MAX_CONTROL_PAYLOAD_BYTES {
393            return Err(WebSocketError::Protocol("pong payload too large"));
394        }
395        let frame = Frame {
396            fin: true,
397            opcode: OpCode::Pong,
398            payload: payload.to_vec(),
399        };
400        self.write_frame(&frame).await
401    }
402
403    /// Convenience: send a text message.
404    pub async fn send_text(&mut self, text: &str) -> Result<(), WebSocketError> {
405        let frame = Frame {
406            fin: true,
407            opcode: OpCode::Text,
408            payload: text.as_bytes().to_vec(),
409        };
410        self.write_frame(&frame).await
411    }
412
413    /// Convenience: send a binary message.
414    pub async fn send_bytes(&mut self, data: &[u8]) -> Result<(), WebSocketError> {
415        let frame = Frame {
416            fin: true,
417            opcode: OpCode::Binary,
418            payload: data.to_vec(),
419        };
420        self.write_frame(&frame).await
421    }
422
423    /// Convenience: send a `Ping` control frame.
424    pub async fn ping(&mut self, payload: &[u8]) -> Result<(), WebSocketError> {
425        if payload.len() > MAX_CONTROL_PAYLOAD_BYTES {
426            return Err(WebSocketError::Protocol("ping payload too large"));
427        }
428        let frame = Frame {
429            fin: true,
430            opcode: OpCode::Ping,
431            payload: payload.to_vec(),
432        };
433        self.write_frame(&frame).await
434    }
435
436    /// Initiate a close handshake with an explicit close code and optional reason.
437    pub async fn close(
438        &mut self,
439        close_code: u16,
440        reason: Option<&str>,
441    ) -> Result<(), WebSocketError> {
442        let payload = build_close_payload(close_code, reason)?;
443        let frame = Frame {
444            fin: true,
445            opcode: OpCode::Close,
446            payload,
447        };
448        self.write_frame(&frame).await
449    }
450
451    async fn send_close_code(&mut self, close_code: u16) -> Result<(), WebSocketError> {
452        let frame = Frame {
453            fin: true,
454            opcode: OpCode::Close,
455            payload: close_code.to_be_bytes().to_vec(),
456        };
457        self.write_frame(&frame).await
458    }
459
460    async fn read_exact_buf(&mut self, n: usize) -> Result<Vec<u8>, WebSocketError> {
461        while self.rx.len() < n {
462            let mut tmp = vec![0u8; 8192];
463            let read = read_once(&mut self.stream, &mut tmp).await?;
464            if read == 0 {
465                return Err(WebSocketError::Protocol("unexpected EOF"));
466            }
467            self.rx.extend_from_slice(&tmp[..read]);
468        }
469
470        let out = self.rx.drain(..n).collect();
471        Ok(out)
472    }
473}
474
475async fn read_once(stream: &mut TcpStream, buffer: &mut [u8]) -> io::Result<usize> {
476    poll_fn(|cx| {
477        let mut read_buf = ReadBuf::new(buffer);
478        match Pin::new(&mut *stream).poll_read(cx, &mut read_buf) {
479            Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len())),
480            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
481            Poll::Pending => Poll::Pending,
482        }
483    })
484    .await
485}
486
487async fn write_all(stream: &mut TcpStream, mut buf: &[u8]) -> io::Result<()> {
488    while !buf.is_empty() {
489        let n = poll_fn(|cx| Pin::new(&mut *stream).poll_write(cx, buf)).await?;
490        if n == 0 {
491            return Err(io::Error::new(io::ErrorKind::WriteZero, "write zero"));
492        }
493        buf = &buf[n..];
494    }
495    Ok(())
496}
497
498async fn flush(stream: &mut TcpStream) -> io::Result<()> {
499    poll_fn(|cx| Pin::new(&mut *stream).poll_flush(cx)).await
500}
501
502// =============================================================================
503// SHA1 (RFC 3174) - minimal implementation
504// =============================================================================
505
506fn sha1(data: &[u8]) -> [u8; 20] {
507    let mut h0: u32 = 0x67452301;
508    let mut h1: u32 = 0xEFCDAB89;
509    let mut h2: u32 = 0x98BADCFE;
510    let mut h3: u32 = 0x10325476;
511    let mut h4: u32 = 0xC3D2E1F0;
512
513    let bit_len = (data.len() as u64) * 8;
514    let padded_len = (data.len() + 9).div_ceil(64) * 64;
515    let mut msg = Vec::with_capacity(padded_len);
516    msg.extend_from_slice(data);
517    msg.push(0x80);
518    while (msg.len() % 64) != 56 {
519        msg.push(0);
520    }
521    msg.extend_from_slice(&bit_len.to_be_bytes());
522
523    for chunk in msg.chunks_exact(64) {
524        let mut words = [0u32; 80];
525        for (word_index, word) in words.iter_mut().take(16).enumerate() {
526            let byte_index = word_index * 4;
527            *word = u32::from_be_bytes([
528                chunk[byte_index],
529                chunk[byte_index + 1],
530                chunk[byte_index + 2],
531                chunk[byte_index + 3],
532            ]);
533        }
534        for i in 16..80 {
535            words[i] = (words[i - 3] ^ words[i - 8] ^ words[i - 14] ^ words[i - 16]).rotate_left(1);
536        }
537
538        let mut state_a = h0;
539        let mut state_b = h1;
540        let mut state_c = h2;
541        let mut state_d = h3;
542        let mut state_e = h4;
543
544        for (round, &word) in words.iter().enumerate() {
545            let (mix, constant) = match round {
546                0..=19 => ((state_b & state_c) | ((!state_b) & state_d), 0x5A827999),
547                20..=39 => (state_b ^ state_c ^ state_d, 0x6ED9EBA1),
548                40..=59 => (
549                    (state_b & state_c) | (state_b & state_d) | (state_c & state_d),
550                    0x8F1BBCDC,
551                ),
552                _ => (state_b ^ state_c ^ state_d, 0xCA62C1D6),
553            };
554            let temp = state_a
555                .rotate_left(5)
556                .wrapping_add(mix)
557                .wrapping_add(state_e)
558                .wrapping_add(constant)
559                .wrapping_add(word);
560            state_e = state_d;
561            state_d = state_c;
562            state_c = state_b.rotate_left(30);
563            state_b = state_a;
564            state_a = temp;
565        }
566
567        h0 = h0.wrapping_add(state_a);
568        h1 = h1.wrapping_add(state_b);
569        h2 = h2.wrapping_add(state_c);
570        h3 = h3.wrapping_add(state_d);
571        h4 = h4.wrapping_add(state_e);
572    }
573
574    let mut out = [0u8; 20];
575    out[0..4].copy_from_slice(&h0.to_be_bytes());
576    out[4..8].copy_from_slice(&h1.to_be_bytes());
577    out[8..12].copy_from_slice(&h2.to_be_bytes());
578    out[12..16].copy_from_slice(&h3.to_be_bytes());
579    out[16..20].copy_from_slice(&h4.to_be_bytes());
580    out
581}
582
583// =============================================================================
584// Base64 (RFC 4648) - minimal (no alloc-free tricks; small and deterministic)
585// =============================================================================
586
587const B64: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
588
589fn base64_encode(data: &[u8]) -> String {
590    let mut out = String::with_capacity(data.len().div_ceil(3) * 4);
591    let mut idx = 0;
592    while idx + 3 <= data.len() {
593        let b0 = u32::from(data[idx]);
594        let b1 = u32::from(data[idx + 1]);
595        let b2 = u32::from(data[idx + 2]);
596        let word24 = (b0 << 16) | (b1 << 8) | b2;
597
598        out.push(B64[((word24 >> 18) & 0x3f) as usize] as char);
599        out.push(B64[((word24 >> 12) & 0x3f) as usize] as char);
600        out.push(B64[((word24 >> 6) & 0x3f) as usize] as char);
601        out.push(B64[(word24 & 0x3f) as usize] as char);
602        idx += 3;
603    }
604
605    let rem = data.len() - idx;
606    if rem == 1 {
607        let b0 = u32::from(data[idx]);
608        let word24 = b0 << 16;
609        out.push(B64[((word24 >> 18) & 0x3f) as usize] as char);
610        out.push(B64[((word24 >> 12) & 0x3f) as usize] as char);
611        out.push('=');
612        out.push('=');
613    } else if rem == 2 {
614        let b0 = u32::from(data[idx]);
615        let b1 = u32::from(data[idx + 1]);
616        let word24 = (b0 << 16) | (b1 << 8);
617        out.push(B64[((word24 >> 18) & 0x3f) as usize] as char);
618        out.push(B64[((word24 >> 12) & 0x3f) as usize] as char);
619        out.push(B64[((word24 >> 6) & 0x3f) as usize] as char);
620        out.push('=');
621    }
622
623    out
624}
625
626fn base64_decode(input: &str) -> Option<Vec<u8>> {
627    let input = input.trim();
628    if input.len() % 4 != 0 {
629        return None;
630    }
631    let mut out = Vec::with_capacity((input.len() / 4) * 3);
632    let bytes = input.as_bytes();
633    let mut idx = 0;
634    while idx < bytes.len() {
635        let is_last = idx + 4 == bytes.len();
636
637        let v0 = decode_b64(bytes[idx])?;
638        let v1 = decode_b64(bytes[idx + 1])?;
639        let b2 = bytes[idx + 2];
640        let b3 = bytes[idx + 3];
641
642        let v2 = if b2 == b'=' {
643            if !is_last || b3 != b'=' {
644                return None;
645            }
646            64u32
647        } else {
648            u32::from(decode_b64(b2)?)
649        };
650
651        let v3 = if b3 == b'=' {
652            if !is_last {
653                return None;
654            }
655            64u32
656        } else {
657            u32::from(decode_b64(b3)?)
658        };
659
660        let word24 = (u32::from(v0) << 18) | (u32::from(v1) << 12) | (v2 << 6) | v3;
661        out.push(((word24 >> 16) & 0xff) as u8);
662        if b2 != b'=' {
663            out.push(((word24 >> 8) & 0xff) as u8);
664        }
665        if b3 != b'=' {
666            out.push((word24 & 0xff) as u8);
667        }
668
669        idx += 4;
670    }
671    Some(out)
672}
673
674fn decode_b64(b: u8) -> Option<u8> {
675    match b {
676        b'A'..=b'Z' => Some(b - b'A'),
677        b'a'..=b'z' => Some(b - b'a' + 26),
678        b'0'..=b'9' => Some(b - b'0' + 52),
679        b'+' => Some(62),
680        b'/' => Some(63),
681        _ => None,
682    }
683}
684
685fn is_valid_close_payload(payload: &[u8]) -> bool {
686    if payload.is_empty() {
687        return true;
688    }
689    if payload.len() < 2 {
690        return false;
691    }
692
693    let code = u16::from_be_bytes([payload[0], payload[1]]);
694    if !is_valid_close_code(code) {
695        return false;
696    }
697
698    if payload.len() == 2 {
699        return true;
700    }
701
702    std::str::from_utf8(&payload[2..]).is_ok()
703}
704
705fn build_close_payload(close_code: u16, reason: Option<&str>) -> Result<Vec<u8>, WebSocketError> {
706    if !is_valid_close_code(close_code) {
707        return Err(WebSocketError::Protocol("invalid close code"));
708    }
709
710    let mut payload = Vec::with_capacity(2 + reason.map_or(0, str::len));
711    payload.extend_from_slice(&close_code.to_be_bytes());
712    if let Some(reason_str) = reason {
713        let mut end = reason_str.len().min(MAX_CLOSE_REASON_BYTES);
714        while end > 0 && !reason_str.is_char_boundary(end) {
715            end -= 1;
716        }
717        payload.extend_from_slice(&reason_str.as_bytes()[..end]);
718    }
719    Ok(payload)
720}
721
722fn is_valid_close_code(code: u16) -> bool {
723    matches!(
724        code,
725        1000 | 1001 | 1002 | 1003 | 1007 | 1008 | 1009 | 1010 | 1011 | 1012 | 1013 | 1014 | 3000
726            ..=4999
727    )
728}
729
730fn validate_outgoing_frame(frame: &Frame) -> Result<(), WebSocketError> {
731    if frame.opcode.is_control() {
732        if !frame.fin {
733            return Err(WebSocketError::Protocol(
734                "control frames must not be fragmented",
735            ));
736        }
737        if frame.payload.len() > MAX_CONTROL_PAYLOAD_BYTES {
738            return Err(WebSocketError::Protocol("control frame too large"));
739        }
740        if matches!(frame.opcode, OpCode::Close) && !is_valid_close_payload(&frame.payload) {
741            return Err(WebSocketError::Protocol("invalid close frame payload"));
742        }
743    }
744    Ok(())
745}
746
747#[cfg(test)]
748mod tests {
749    use super::*;
750
751    #[test]
752    fn accept_key_known_vector() {
753        // RFC 6455 example
754        let key = "dGhlIHNhbXBsZSBub25jZQ==";
755        let accept = websocket_accept_from_key(key).unwrap();
756        assert_eq!(accept, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
757    }
758
759    #[test]
760    fn base64_roundtrip_small() {
761        let data = b"hello world";
762        let enc = base64_encode(data);
763        let dec = base64_decode(&enc).unwrap();
764        assert_eq!(dec, data);
765    }
766
767    #[test]
768    fn close_payload_validation() {
769        assert!(is_valid_close_payload(&[]));
770        assert!(!is_valid_close_payload(&[0x03]));
771        assert!(!is_valid_close_payload(&[0x03, 0xEE])); // 1006 cannot be sent
772        assert!(is_valid_close_payload(&[0x03, 0xE8])); // 1000
773        assert!(is_valid_close_payload(&[0x03, 0xE8, b'o', b'k']));
774        assert!(!is_valid_close_payload(&[0x03, 0xE8, 0xFF])); // invalid utf-8 reason
775    }
776
777    #[test]
778    fn build_close_payload_rejects_invalid_code() {
779        let err = build_close_payload(1006, None).expect_err("1006 must be rejected");
780        assert!(matches!(err, WebSocketError::Protocol(_)));
781    }
782
783    #[test]
784    fn build_close_payload_truncates_on_utf8_boundary() {
785        let reason = "é".repeat(100); // 200 bytes UTF-8.
786        let payload = build_close_payload(1000, Some(&reason)).expect("payload");
787        assert!(payload.len() <= MAX_CONTROL_PAYLOAD_BYTES);
788        let reason_bytes = &payload[2..];
789        assert!(
790            std::str::from_utf8(reason_bytes).is_ok(),
791            "close reason must remain valid UTF-8"
792        );
793    }
794
795    #[test]
796    fn outgoing_frame_validation_rejects_fragmented_control() {
797        let frame = Frame {
798            fin: false,
799            opcode: OpCode::Ping,
800            payload: vec![],
801        };
802        let err = validate_outgoing_frame(&frame).expect_err("fragmented control frame must fail");
803        assert!(matches!(err, WebSocketError::Protocol(_)));
804    }
805
806    #[test]
807    fn outgoing_frame_validation_rejects_oversized_control() {
808        let frame = Frame {
809            fin: true,
810            opcode: OpCode::Pong,
811            payload: vec![0; MAX_CONTROL_PAYLOAD_BYTES + 1],
812        };
813        let err = validate_outgoing_frame(&frame).expect_err("oversized control frame must fail");
814        assert!(matches!(err, WebSocketError::Protocol(_)));
815    }
816
817    #[test]
818    fn outgoing_frame_validation_rejects_invalid_close_payload() {
819        // 1006 is not a sendable close code.
820        let frame = Frame {
821            fin: true,
822            opcode: OpCode::Close,
823            payload: 1006u16.to_be_bytes().to_vec(),
824        };
825        let err = validate_outgoing_frame(&frame).expect_err("invalid close payload must fail");
826        assert!(matches!(err, WebSocketError::Protocol(_)));
827    }
828
829    #[test]
830    fn outgoing_frame_validation_accepts_data_frames() {
831        let frame = Frame {
832            fin: false,
833            opcode: OpCode::Text,
834            payload: vec![0; MAX_CONTROL_PAYLOAD_BYTES + 10],
835        };
836        assert!(validate_outgoing_frame(&frame).is_ok());
837    }
838}