1use asupersync::io::{AsyncRead, AsyncWrite, ReadBuf};
13use asupersync::net::TcpStream;
14use std::future::poll_fn;
15use std::io;
16use std::pin::Pin;
17use std::task::Poll;
18
19pub const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
21const MAX_TEXT_MESSAGE_BYTES: usize = 64 * 1024 * 1024;
22const MAX_CONTROL_PAYLOAD_BYTES: usize = 125;
23const MAX_CLOSE_REASON_BYTES: usize = 123;
24const CLOSE_CODE_PROTOCOL_ERROR: u16 = 1002;
25const CLOSE_CODE_UNSUPPORTED_DATA: u16 = 1003;
26const CLOSE_CODE_INVALID_PAYLOAD: u16 = 1007;
27const CLOSE_CODE_MESSAGE_TOO_BIG: u16 = 1009;
28
29#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum WebSocketHandshakeError {
32 MissingHeader(&'static str),
34 InvalidKeyBase64,
36 InvalidKeyLength { decoded_len: usize },
38}
39
40impl std::fmt::Display for WebSocketHandshakeError {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 match self {
43 Self::MissingHeader(h) => write!(f, "missing required websocket header: {h}"),
44 Self::InvalidKeyBase64 => write!(f, "invalid Sec-WebSocket-Key (base64 decode failed)"),
45 Self::InvalidKeyLength { decoded_len } => write!(
46 f,
47 "invalid Sec-WebSocket-Key (decoded length {decoded_len}, expected 16)"
48 ),
49 }
50 }
51}
52
53impl std::error::Error for WebSocketHandshakeError {}
54
55pub fn websocket_accept_from_key(key: &str) -> Result<String, WebSocketHandshakeError> {
59 let key = key.trim();
60 if key.is_empty() {
61 return Err(WebSocketHandshakeError::MissingHeader("sec-websocket-key"));
62 }
63
64 let decoded = base64_decode(key).ok_or(WebSocketHandshakeError::InvalidKeyBase64)?;
65 if decoded.len() != 16 {
66 return Err(WebSocketHandshakeError::InvalidKeyLength {
67 decoded_len: decoded.len(),
68 });
69 }
70
71 let mut input = Vec::with_capacity(key.len() + WS_GUID.len());
72 input.extend_from_slice(key.as_bytes());
73 input.extend_from_slice(WS_GUID.as_bytes());
74
75 let digest = sha1(&input);
76 Ok(base64_encode(&digest))
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81#[repr(u8)]
82pub enum OpCode {
83 Continuation = 0x0,
84 Text = 0x1,
85 Binary = 0x2,
86 Close = 0x8,
87 Ping = 0x9,
88 Pong = 0xA,
89}
90
91impl OpCode {
92 fn from_u8(b: u8) -> Option<Self> {
93 match b {
94 0x0 => Some(Self::Continuation),
95 0x1 => Some(Self::Text),
96 0x2 => Some(Self::Binary),
97 0x8 => Some(Self::Close),
98 0x9 => Some(Self::Ping),
99 0xA => Some(Self::Pong),
100 _ => None,
101 }
102 }
103
104 fn is_control(self) -> bool {
105 matches!(self, Self::Close | Self::Ping | Self::Pong)
106 }
107}
108
109#[derive(Debug, Clone, PartialEq, Eq)]
111pub struct Frame {
112 pub fin: bool,
113 pub opcode: OpCode,
114 pub payload: Vec<u8>,
115}
116
117#[derive(Debug)]
119pub enum WebSocketError {
120 Io(io::Error),
121 Protocol(&'static str),
122 Utf8(std::str::Utf8Error),
123 MessageTooLarge { size: usize, limit: usize },
124}
125
126impl std::fmt::Display for WebSocketError {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 match self {
129 Self::Io(e) => write!(f, "websocket I/O error: {e}"),
130 Self::Protocol(msg) => write!(f, "websocket protocol error: {msg}"),
131 Self::Utf8(e) => write!(f, "invalid utf-8 in websocket text frame: {e}"),
132 Self::MessageTooLarge { size, limit } => {
133 write!(
134 f,
135 "websocket message too large: {size} bytes (limit {limit})"
136 )
137 }
138 }
139 }
140}
141
142impl std::error::Error for WebSocketError {
143 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
144 match self {
145 Self::Io(e) => Some(e),
146 Self::Utf8(e) => Some(e),
147 Self::Protocol(_) | Self::MessageTooLarge { .. } => None,
148 }
149 }
150}
151
152impl From<io::Error> for WebSocketError {
153 fn from(e: io::Error) -> Self {
154 Self::Io(e)
155 }
156}
157
158impl From<std::str::Utf8Error> for WebSocketError {
159 fn from(e: std::str::Utf8Error) -> Self {
160 Self::Utf8(e)
161 }
162}
163
164#[derive(Debug)]
170pub struct WebSocket {
171 stream: TcpStream,
172 rx: Vec<u8>,
173}
174
175impl WebSocket {
176 #[must_use]
178 pub fn new(stream: TcpStream, buffered: Vec<u8>) -> Self {
179 Self {
180 stream,
181 rx: buffered,
182 }
183 }
184
185 pub async fn read_frame(&mut self) -> Result<Frame, WebSocketError> {
187 let header = self.read_exact_buf(2).await?;
188 let b0 = header[0];
189 let b1 = header[1];
190
191 let fin = (b0 & 0x80) != 0;
192 let rsv = (b0 >> 4) & 0x07;
193 if rsv != 0 {
194 return Err(WebSocketError::Protocol(
195 "reserved bits must be 0 (no extensions negotiated)",
196 ));
197 }
198 let opcode =
199 OpCode::from_u8(b0 & 0x0f).ok_or(WebSocketError::Protocol("invalid opcode"))?;
200 let masked = (b1 & 0x80) != 0;
201 let mut len7 = u64::from(b1 & 0x7f);
202
203 if opcode.is_control() && !fin {
204 return Err(WebSocketError::Protocol(
205 "control frames must not be fragmented",
206 ));
207 }
208
209 if len7 == 126 {
210 let b = self.read_exact_buf(2).await?;
211 len7 = u64::from(u16::from_be_bytes([b[0], b[1]]));
212 } else if len7 == 127 {
213 let b = self.read_exact_buf(8).await?;
214 len7 = u64::from_be_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]);
215 if (len7 >> 63) != 0 {
217 return Err(WebSocketError::Protocol("invalid 64-bit length"));
218 }
219 }
220
221 if !masked {
222 return Err(WebSocketError::Protocol(
223 "client->server frames must be masked",
224 ));
225 }
226 let payload_len = usize::try_from(len7).map_err(|_| WebSocketError::MessageTooLarge {
227 size: usize::MAX,
228 limit: MAX_TEXT_MESSAGE_BYTES,
229 })?;
230
231 if opcode.is_control() && payload_len > 125 {
232 return Err(WebSocketError::Protocol("control frame too large"));
233 }
234 if payload_len > MAX_TEXT_MESSAGE_BYTES {
235 return Err(WebSocketError::MessageTooLarge {
236 size: payload_len,
237 limit: MAX_TEXT_MESSAGE_BYTES,
238 });
239 }
240
241 let mask = self.read_exact_buf(4).await?;
242
243 let mut payload = self.read_exact_buf(payload_len).await?;
244 for (i, b) in payload.iter_mut().enumerate() {
245 *b ^= mask[i & 3];
246 }
247
248 Ok(Frame {
249 fin,
250 opcode,
251 payload,
252 })
253 }
254
255 pub async fn write_frame(&mut self, frame: &Frame) -> Result<(), WebSocketError> {
257 validate_outgoing_frame(frame)?;
258
259 let mut out = Vec::with_capacity(2 + frame.payload.len() + 8);
260 let b0 = (if frame.fin { 0x80 } else { 0 }) | (frame.opcode as u8);
261 out.push(b0);
262
263 let len = u64::try_from(frame.payload.len())
264 .map_err(|_| WebSocketError::Protocol("len too large"))?;
265 if len <= 125 {
266 out.push(len as u8);
267 } else if let Ok(len16) = u16::try_from(len) {
268 out.push(126);
269 out.extend_from_slice(&len16.to_be_bytes());
270 } else {
271 out.push(127);
272 out.extend_from_slice(&len.to_be_bytes());
273 }
274
275 out.extend_from_slice(&frame.payload);
276 write_all(&mut self.stream, &out).await?;
277 flush(&mut self.stream).await?;
278 Ok(())
279 }
280
281 pub async fn read_text(&mut self) -> Result<String, WebSocketError> {
283 self.read_text_or_close()
284 .await?
285 .ok_or(WebSocketError::Protocol("websocket closed"))
286 }
287
288 pub async fn read_text_or_close(&mut self) -> Result<Option<String>, WebSocketError> {
296 let mut text_fragments: Vec<u8> = Vec::new();
297 let mut collecting_text_fragments = false;
298
299 loop {
300 let frame = match self.read_frame().await {
301 Ok(frame) => frame,
302 Err(err @ WebSocketError::MessageTooLarge { .. }) => {
303 let _ = self.send_close_code(CLOSE_CODE_MESSAGE_TOO_BIG).await;
304 return Err(err);
305 }
306 Err(err @ WebSocketError::Protocol(_)) => {
307 let _ = self.send_close_code(CLOSE_CODE_PROTOCOL_ERROR).await;
309 return Err(err);
310 }
311 Err(err) => return Err(err),
312 };
313 match frame.opcode {
314 OpCode::Text => {
315 if collecting_text_fragments {
316 let _ = self.send_close_code(CLOSE_CODE_PROTOCOL_ERROR).await;
317 return Err(WebSocketError::Protocol(
318 "new text frame before fragmented text completed",
319 ));
320 }
321 if frame.fin {
322 match std::str::from_utf8(&frame.payload) {
323 Ok(s) => return Ok(Some(s.to_string())),
324 Err(err) => {
325 let _ = self.send_close_code(CLOSE_CODE_INVALID_PAYLOAD).await;
326 return Err(WebSocketError::Utf8(err));
327 }
328 }
329 }
330
331 if frame.payload.len() > MAX_TEXT_MESSAGE_BYTES {
332 let _ = self.send_close_code(CLOSE_CODE_MESSAGE_TOO_BIG).await;
333 return Err(WebSocketError::Protocol("text message too large"));
334 }
335 text_fragments.extend_from_slice(&frame.payload);
336 collecting_text_fragments = true;
337 }
338 OpCode::Ping => {
339 self.send_pong(&frame.payload).await?;
340 }
341 OpCode::Pong => {}
342 OpCode::Close => {
343 if !is_valid_close_payload(&frame.payload) {
344 let _ = self.send_close_code(CLOSE_CODE_PROTOCOL_ERROR).await;
346 return Err(WebSocketError::Protocol("invalid close frame payload"));
347 }
348 let close = Frame {
350 fin: true,
351 opcode: OpCode::Close,
352 payload: frame.payload,
353 };
354 let _ = self.write_frame(&close).await;
355 return Ok(None);
356 }
357 OpCode::Binary => {
358 let _ = self.send_close_code(CLOSE_CODE_UNSUPPORTED_DATA).await;
359 return Err(WebSocketError::Protocol(
360 "expected text frame, got binary frame",
361 ));
362 }
363 OpCode::Continuation => {
364 if !collecting_text_fragments {
365 let _ = self.send_close_code(CLOSE_CODE_PROTOCOL_ERROR).await;
366 return Err(WebSocketError::Protocol("unexpected continuation frame"));
367 }
368
369 let next_size = text_fragments.len().saturating_add(frame.payload.len());
370 if next_size > MAX_TEXT_MESSAGE_BYTES {
371 let _ = self.send_close_code(CLOSE_CODE_MESSAGE_TOO_BIG).await;
372 return Err(WebSocketError::Protocol("text message too large"));
373 }
374 text_fragments.extend_from_slice(&frame.payload);
375
376 if frame.fin {
377 match std::str::from_utf8(&text_fragments) {
378 Ok(s) => return Ok(Some(s.to_string())),
379 Err(err) => {
380 let _ = self.send_close_code(CLOSE_CODE_INVALID_PAYLOAD).await;
381 return Err(WebSocketError::Utf8(err));
382 }
383 }
384 }
385 }
386 }
387 }
388 }
389
390 pub async fn send_pong(&mut self, payload: &[u8]) -> Result<(), WebSocketError> {
392 if payload.len() > MAX_CONTROL_PAYLOAD_BYTES {
393 return Err(WebSocketError::Protocol("pong payload too large"));
394 }
395 let frame = Frame {
396 fin: true,
397 opcode: OpCode::Pong,
398 payload: payload.to_vec(),
399 };
400 self.write_frame(&frame).await
401 }
402
403 pub async fn send_text(&mut self, text: &str) -> Result<(), WebSocketError> {
405 let frame = Frame {
406 fin: true,
407 opcode: OpCode::Text,
408 payload: text.as_bytes().to_vec(),
409 };
410 self.write_frame(&frame).await
411 }
412
413 pub async fn send_bytes(&mut self, data: &[u8]) -> Result<(), WebSocketError> {
415 let frame = Frame {
416 fin: true,
417 opcode: OpCode::Binary,
418 payload: data.to_vec(),
419 };
420 self.write_frame(&frame).await
421 }
422
423 pub async fn ping(&mut self, payload: &[u8]) -> Result<(), WebSocketError> {
425 if payload.len() > MAX_CONTROL_PAYLOAD_BYTES {
426 return Err(WebSocketError::Protocol("ping payload too large"));
427 }
428 let frame = Frame {
429 fin: true,
430 opcode: OpCode::Ping,
431 payload: payload.to_vec(),
432 };
433 self.write_frame(&frame).await
434 }
435
436 pub async fn close(
438 &mut self,
439 close_code: u16,
440 reason: Option<&str>,
441 ) -> Result<(), WebSocketError> {
442 let payload = build_close_payload(close_code, reason)?;
443 let frame = Frame {
444 fin: true,
445 opcode: OpCode::Close,
446 payload,
447 };
448 self.write_frame(&frame).await
449 }
450
451 async fn send_close_code(&mut self, close_code: u16) -> Result<(), WebSocketError> {
452 let frame = Frame {
453 fin: true,
454 opcode: OpCode::Close,
455 payload: close_code.to_be_bytes().to_vec(),
456 };
457 self.write_frame(&frame).await
458 }
459
460 async fn read_exact_buf(&mut self, n: usize) -> Result<Vec<u8>, WebSocketError> {
461 while self.rx.len() < n {
462 let mut tmp = vec![0u8; 8192];
463 let read = read_once(&mut self.stream, &mut tmp).await?;
464 if read == 0 {
465 return Err(WebSocketError::Protocol("unexpected EOF"));
466 }
467 self.rx.extend_from_slice(&tmp[..read]);
468 }
469
470 let out = self.rx.drain(..n).collect();
471 Ok(out)
472 }
473}
474
475async fn read_once(stream: &mut TcpStream, buffer: &mut [u8]) -> io::Result<usize> {
476 poll_fn(|cx| {
477 let mut read_buf = ReadBuf::new(buffer);
478 match Pin::new(&mut *stream).poll_read(cx, &mut read_buf) {
479 Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len())),
480 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
481 Poll::Pending => Poll::Pending,
482 }
483 })
484 .await
485}
486
487async fn write_all(stream: &mut TcpStream, mut buf: &[u8]) -> io::Result<()> {
488 while !buf.is_empty() {
489 let n = poll_fn(|cx| Pin::new(&mut *stream).poll_write(cx, buf)).await?;
490 if n == 0 {
491 return Err(io::Error::new(io::ErrorKind::WriteZero, "write zero"));
492 }
493 buf = &buf[n..];
494 }
495 Ok(())
496}
497
498async fn flush(stream: &mut TcpStream) -> io::Result<()> {
499 poll_fn(|cx| Pin::new(&mut *stream).poll_flush(cx)).await
500}
501
502fn sha1(data: &[u8]) -> [u8; 20] {
507 let mut h0: u32 = 0x67452301;
508 let mut h1: u32 = 0xEFCDAB89;
509 let mut h2: u32 = 0x98BADCFE;
510 let mut h3: u32 = 0x10325476;
511 let mut h4: u32 = 0xC3D2E1F0;
512
513 let bit_len = (data.len() as u64) * 8;
514 let padded_len = (data.len() + 9).div_ceil(64) * 64;
515 let mut msg = Vec::with_capacity(padded_len);
516 msg.extend_from_slice(data);
517 msg.push(0x80);
518 while (msg.len() % 64) != 56 {
519 msg.push(0);
520 }
521 msg.extend_from_slice(&bit_len.to_be_bytes());
522
523 for chunk in msg.chunks_exact(64) {
524 let mut words = [0u32; 80];
525 for (word_index, word) in words.iter_mut().take(16).enumerate() {
526 let byte_index = word_index * 4;
527 *word = u32::from_be_bytes([
528 chunk[byte_index],
529 chunk[byte_index + 1],
530 chunk[byte_index + 2],
531 chunk[byte_index + 3],
532 ]);
533 }
534 for i in 16..80 {
535 words[i] = (words[i - 3] ^ words[i - 8] ^ words[i - 14] ^ words[i - 16]).rotate_left(1);
536 }
537
538 let mut state_a = h0;
539 let mut state_b = h1;
540 let mut state_c = h2;
541 let mut state_d = h3;
542 let mut state_e = h4;
543
544 for (round, &word) in words.iter().enumerate() {
545 let (mix, constant) = match round {
546 0..=19 => ((state_b & state_c) | ((!state_b) & state_d), 0x5A827999),
547 20..=39 => (state_b ^ state_c ^ state_d, 0x6ED9EBA1),
548 40..=59 => (
549 (state_b & state_c) | (state_b & state_d) | (state_c & state_d),
550 0x8F1BBCDC,
551 ),
552 _ => (state_b ^ state_c ^ state_d, 0xCA62C1D6),
553 };
554 let temp = state_a
555 .rotate_left(5)
556 .wrapping_add(mix)
557 .wrapping_add(state_e)
558 .wrapping_add(constant)
559 .wrapping_add(word);
560 state_e = state_d;
561 state_d = state_c;
562 state_c = state_b.rotate_left(30);
563 state_b = state_a;
564 state_a = temp;
565 }
566
567 h0 = h0.wrapping_add(state_a);
568 h1 = h1.wrapping_add(state_b);
569 h2 = h2.wrapping_add(state_c);
570 h3 = h3.wrapping_add(state_d);
571 h4 = h4.wrapping_add(state_e);
572 }
573
574 let mut out = [0u8; 20];
575 out[0..4].copy_from_slice(&h0.to_be_bytes());
576 out[4..8].copy_from_slice(&h1.to_be_bytes());
577 out[8..12].copy_from_slice(&h2.to_be_bytes());
578 out[12..16].copy_from_slice(&h3.to_be_bytes());
579 out[16..20].copy_from_slice(&h4.to_be_bytes());
580 out
581}
582
583const B64: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
588
589fn base64_encode(data: &[u8]) -> String {
590 let mut out = String::with_capacity(data.len().div_ceil(3) * 4);
591 let mut idx = 0;
592 while idx + 3 <= data.len() {
593 let b0 = u32::from(data[idx]);
594 let b1 = u32::from(data[idx + 1]);
595 let b2 = u32::from(data[idx + 2]);
596 let word24 = (b0 << 16) | (b1 << 8) | b2;
597
598 out.push(B64[((word24 >> 18) & 0x3f) as usize] as char);
599 out.push(B64[((word24 >> 12) & 0x3f) as usize] as char);
600 out.push(B64[((word24 >> 6) & 0x3f) as usize] as char);
601 out.push(B64[(word24 & 0x3f) as usize] as char);
602 idx += 3;
603 }
604
605 let rem = data.len() - idx;
606 if rem == 1 {
607 let b0 = u32::from(data[idx]);
608 let word24 = b0 << 16;
609 out.push(B64[((word24 >> 18) & 0x3f) as usize] as char);
610 out.push(B64[((word24 >> 12) & 0x3f) as usize] as char);
611 out.push('=');
612 out.push('=');
613 } else if rem == 2 {
614 let b0 = u32::from(data[idx]);
615 let b1 = u32::from(data[idx + 1]);
616 let word24 = (b0 << 16) | (b1 << 8);
617 out.push(B64[((word24 >> 18) & 0x3f) as usize] as char);
618 out.push(B64[((word24 >> 12) & 0x3f) as usize] as char);
619 out.push(B64[((word24 >> 6) & 0x3f) as usize] as char);
620 out.push('=');
621 }
622
623 out
624}
625
626fn base64_decode(input: &str) -> Option<Vec<u8>> {
627 let input = input.trim();
628 if input.len() % 4 != 0 {
629 return None;
630 }
631 let mut out = Vec::with_capacity((input.len() / 4) * 3);
632 let bytes = input.as_bytes();
633 let mut idx = 0;
634 while idx < bytes.len() {
635 let is_last = idx + 4 == bytes.len();
636
637 let v0 = decode_b64(bytes[idx])?;
638 let v1 = decode_b64(bytes[idx + 1])?;
639 let b2 = bytes[idx + 2];
640 let b3 = bytes[idx + 3];
641
642 let v2 = if b2 == b'=' {
643 if !is_last || b3 != b'=' {
644 return None;
645 }
646 64u32
647 } else {
648 u32::from(decode_b64(b2)?)
649 };
650
651 let v3 = if b3 == b'=' {
652 if !is_last {
653 return None;
654 }
655 64u32
656 } else {
657 u32::from(decode_b64(b3)?)
658 };
659
660 let word24 = (u32::from(v0) << 18) | (u32::from(v1) << 12) | (v2 << 6) | v3;
661 out.push(((word24 >> 16) & 0xff) as u8);
662 if b2 != b'=' {
663 out.push(((word24 >> 8) & 0xff) as u8);
664 }
665 if b3 != b'=' {
666 out.push((word24 & 0xff) as u8);
667 }
668
669 idx += 4;
670 }
671 Some(out)
672}
673
674fn decode_b64(b: u8) -> Option<u8> {
675 match b {
676 b'A'..=b'Z' => Some(b - b'A'),
677 b'a'..=b'z' => Some(b - b'a' + 26),
678 b'0'..=b'9' => Some(b - b'0' + 52),
679 b'+' => Some(62),
680 b'/' => Some(63),
681 _ => None,
682 }
683}
684
685fn is_valid_close_payload(payload: &[u8]) -> bool {
686 if payload.is_empty() {
687 return true;
688 }
689 if payload.len() < 2 {
690 return false;
691 }
692
693 let code = u16::from_be_bytes([payload[0], payload[1]]);
694 if !is_valid_close_code(code) {
695 return false;
696 }
697
698 if payload.len() == 2 {
699 return true;
700 }
701
702 std::str::from_utf8(&payload[2..]).is_ok()
703}
704
705fn build_close_payload(close_code: u16, reason: Option<&str>) -> Result<Vec<u8>, WebSocketError> {
706 if !is_valid_close_code(close_code) {
707 return Err(WebSocketError::Protocol("invalid close code"));
708 }
709
710 let mut payload = Vec::with_capacity(2 + reason.map_or(0, str::len));
711 payload.extend_from_slice(&close_code.to_be_bytes());
712 if let Some(reason_str) = reason {
713 let mut end = reason_str.len().min(MAX_CLOSE_REASON_BYTES);
714 while end > 0 && !reason_str.is_char_boundary(end) {
715 end -= 1;
716 }
717 payload.extend_from_slice(&reason_str.as_bytes()[..end]);
718 }
719 Ok(payload)
720}
721
722fn is_valid_close_code(code: u16) -> bool {
723 matches!(
724 code,
725 1000 | 1001 | 1002 | 1003 | 1007 | 1008 | 1009 | 1010 | 1011 | 1012 | 1013 | 1014 | 3000
726 ..=4999
727 )
728}
729
730fn validate_outgoing_frame(frame: &Frame) -> Result<(), WebSocketError> {
731 if frame.opcode.is_control() {
732 if !frame.fin {
733 return Err(WebSocketError::Protocol(
734 "control frames must not be fragmented",
735 ));
736 }
737 if frame.payload.len() > MAX_CONTROL_PAYLOAD_BYTES {
738 return Err(WebSocketError::Protocol("control frame too large"));
739 }
740 if matches!(frame.opcode, OpCode::Close) && !is_valid_close_payload(&frame.payload) {
741 return Err(WebSocketError::Protocol("invalid close frame payload"));
742 }
743 }
744 Ok(())
745}
746
747#[cfg(test)]
748mod tests {
749 use super::*;
750
751 #[test]
752 fn accept_key_known_vector() {
753 let key = "dGhlIHNhbXBsZSBub25jZQ==";
755 let accept = websocket_accept_from_key(key).unwrap();
756 assert_eq!(accept, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
757 }
758
759 #[test]
760 fn base64_roundtrip_small() {
761 let data = b"hello world";
762 let enc = base64_encode(data);
763 let dec = base64_decode(&enc).unwrap();
764 assert_eq!(dec, data);
765 }
766
767 #[test]
768 fn close_payload_validation() {
769 assert!(is_valid_close_payload(&[]));
770 assert!(!is_valid_close_payload(&[0x03]));
771 assert!(!is_valid_close_payload(&[0x03, 0xEE])); assert!(is_valid_close_payload(&[0x03, 0xE8])); assert!(is_valid_close_payload(&[0x03, 0xE8, b'o', b'k']));
774 assert!(!is_valid_close_payload(&[0x03, 0xE8, 0xFF])); }
776
777 #[test]
778 fn build_close_payload_rejects_invalid_code() {
779 let err = build_close_payload(1006, None).expect_err("1006 must be rejected");
780 assert!(matches!(err, WebSocketError::Protocol(_)));
781 }
782
783 #[test]
784 fn build_close_payload_truncates_on_utf8_boundary() {
785 let reason = "é".repeat(100); let payload = build_close_payload(1000, Some(&reason)).expect("payload");
787 assert!(payload.len() <= MAX_CONTROL_PAYLOAD_BYTES);
788 let reason_bytes = &payload[2..];
789 assert!(
790 std::str::from_utf8(reason_bytes).is_ok(),
791 "close reason must remain valid UTF-8"
792 );
793 }
794
795 #[test]
796 fn outgoing_frame_validation_rejects_fragmented_control() {
797 let frame = Frame {
798 fin: false,
799 opcode: OpCode::Ping,
800 payload: vec![],
801 };
802 let err = validate_outgoing_frame(&frame).expect_err("fragmented control frame must fail");
803 assert!(matches!(err, WebSocketError::Protocol(_)));
804 }
805
806 #[test]
807 fn outgoing_frame_validation_rejects_oversized_control() {
808 let frame = Frame {
809 fin: true,
810 opcode: OpCode::Pong,
811 payload: vec![0; MAX_CONTROL_PAYLOAD_BYTES + 1],
812 };
813 let err = validate_outgoing_frame(&frame).expect_err("oversized control frame must fail");
814 assert!(matches!(err, WebSocketError::Protocol(_)));
815 }
816
817 #[test]
818 fn outgoing_frame_validation_rejects_invalid_close_payload() {
819 let frame = Frame {
821 fin: true,
822 opcode: OpCode::Close,
823 payload: 1006u16.to_be_bytes().to_vec(),
824 };
825 let err = validate_outgoing_frame(&frame).expect_err("invalid close payload must fail");
826 assert!(matches!(err, WebSocketError::Protocol(_)));
827 }
828
829 #[test]
830 fn outgoing_frame_validation_accepts_data_frames() {
831 let frame = Frame {
832 fin: false,
833 opcode: OpCode::Text,
834 payload: vec![0; MAX_CONTROL_PAYLOAD_BYTES + 10],
835 };
836 assert!(validate_outgoing_frame(&frame).is_ok());
837 }
838}