1use asupersync::io::{AsyncRead, AsyncWrite, ReadBuf};
35use asupersync::net::TcpStream;
36use std::future::poll_fn;
37use std::io;
38use std::pin::Pin;
39use std::task::Poll;
40
41#[allow(clippy::many_single_char_names)]
49fn sha1(data: &[u8]) -> [u8; 20] {
50 let mut h0: u32 = 0x6745_2301;
51 let mut h1: u32 = 0xEFCD_AB89;
52 let mut h2: u32 = 0x98BA_DCFE;
53 let mut h3: u32 = 0x1032_5476;
54 let mut h4: u32 = 0xC3D2_E1F0;
55
56 let bit_len = (data.len() as u64) * 8;
58 let mut msg = data.to_vec();
59 msg.push(0x80);
60 while (msg.len() % 64) != 56 {
61 msg.push(0);
62 }
63 msg.extend_from_slice(&bit_len.to_be_bytes());
64
65 for block in msg.chunks_exact(64) {
67 let mut w = [0u32; 80];
68 for (idx, word) in w.iter_mut().take(16).enumerate() {
69 *word = u32::from_be_bytes([
70 block[idx * 4],
71 block[idx * 4 + 1],
72 block[idx * 4 + 2],
73 block[idx * 4 + 3],
74 ]);
75 }
76 for idx in 16..80 {
77 w[idx] = (w[idx - 3] ^ w[idx - 8] ^ w[idx - 14] ^ w[idx - 16]).rotate_left(1);
78 }
79
80 let (mut a, mut b, mut c, mut d, mut e) = (h0, h1, h2, h3, h4);
81
82 #[allow(clippy::needless_range_loop)]
83 for idx in 0..80 {
84 let (f, k) = match idx {
85 0..=19 => ((b & c) | ((!b) & d), 0x5A82_7999_u32),
86 20..=39 => (b ^ c ^ d, 0x6ED9_EBA1_u32),
87 40..=59 => ((b & c) | (b & d) | (c & d), 0x8F1B_BCDC_u32),
88 _ => (b ^ c ^ d, 0xCA62_C1D6_u32),
89 };
90
91 let temp = a
92 .rotate_left(5)
93 .wrapping_add(f)
94 .wrapping_add(e)
95 .wrapping_add(k)
96 .wrapping_add(w[idx]);
97 e = d;
98 d = c;
99 c = b.rotate_left(30);
100 b = a;
101 a = temp;
102 }
103
104 h0 = h0.wrapping_add(a);
105 h1 = h1.wrapping_add(b);
106 h2 = h2.wrapping_add(c);
107 h3 = h3.wrapping_add(d);
108 h4 = h4.wrapping_add(e);
109 }
110
111 let mut result = [0u8; 20];
112 result[0..4].copy_from_slice(&h0.to_be_bytes());
113 result[4..8].copy_from_slice(&h1.to_be_bytes());
114 result[8..12].copy_from_slice(&h2.to_be_bytes());
115 result[12..16].copy_from_slice(&h3.to_be_bytes());
116 result[16..20].copy_from_slice(&h4.to_be_bytes());
117 result
118}
119
120const BASE64_CHARS: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
125
126fn base64_encode(data: &[u8]) -> String {
128 let mut result = String::with_capacity(data.len().div_ceil(3) * 4);
129 for chunk in data.chunks(3) {
130 let b0 = u32::from(chunk[0]);
131 let b1 = if chunk.len() > 1 {
132 u32::from(chunk[1])
133 } else {
134 0
135 };
136 let b2 = if chunk.len() > 2 {
137 u32::from(chunk[2])
138 } else {
139 0
140 };
141 let triple = (b0 << 16) | (b1 << 8) | b2;
142
143 result.push(BASE64_CHARS[((triple >> 18) & 0x3F) as usize] as char);
144 result.push(BASE64_CHARS[((triple >> 12) & 0x3F) as usize] as char);
145
146 if chunk.len() > 1 {
147 result.push(BASE64_CHARS[((triple >> 6) & 0x3F) as usize] as char);
148 } else {
149 result.push('=');
150 }
151
152 if chunk.len() > 2 {
153 result.push(BASE64_CHARS[(triple & 0x3F) as usize] as char);
154 } else {
155 result.push('=');
156 }
157 }
158 result
159}
160
161const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
167
168pub const DEFAULT_MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
170
171pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024;
173
174#[derive(Debug, Clone, Copy, PartialEq, Eq)]
180pub enum Opcode {
181 Continuation,
183 Text,
185 Binary,
187 Close,
189 Ping,
191 Pong,
193}
194
195impl Opcode {
196 fn from_u8(value: u8) -> Result<Self, WebSocketError> {
198 match value & 0x0F {
199 0x0 => Ok(Self::Continuation),
200 0x1 => Ok(Self::Text),
201 0x2 => Ok(Self::Binary),
202 0x8 => Ok(Self::Close),
203 0x9 => Ok(Self::Ping),
204 0xA => Ok(Self::Pong),
205 other => Err(WebSocketError::Protocol(format!(
206 "unknown opcode: 0x{other:X}"
207 ))),
208 }
209 }
210
211 fn to_u8(self) -> u8 {
212 match self {
213 Self::Continuation => 0x0,
214 Self::Text => 0x1,
215 Self::Binary => 0x2,
216 Self::Close => 0x8,
217 Self::Ping => 0x9,
218 Self::Pong => 0xA,
219 }
220 }
221
222 fn is_control(self) -> bool {
224 matches!(self, Self::Close | Self::Ping | Self::Pong)
225 }
226}
227
228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
230pub enum CloseCode {
231 Normal,
233 GoingAway,
235 ProtocolError,
237 UnsupportedData,
239 NoStatusReceived,
241 AbnormalClosure,
243 InvalidPayload,
245 PolicyViolation,
247 MessageTooBig,
249 MandatoryExtension,
251 InternalError,
253 ServiceRestart,
255 TryAgainLater,
257 BadGateway,
259 Application(u16),
261}
262
263impl CloseCode {
264 pub fn to_u16(self) -> u16 {
266 match self {
267 Self::Normal => 1000,
268 Self::GoingAway => 1001,
269 Self::ProtocolError => 1002,
270 Self::UnsupportedData => 1003,
271 Self::NoStatusReceived => 1005,
272 Self::AbnormalClosure => 1006,
273 Self::InvalidPayload => 1007,
274 Self::PolicyViolation => 1008,
275 Self::MessageTooBig => 1009,
276 Self::MandatoryExtension => 1010,
277 Self::InternalError => 1011,
278 Self::ServiceRestart => 1012,
279 Self::TryAgainLater => 1013,
280 Self::BadGateway => 1014,
281 Self::Application(code) => code,
282 }
283 }
284
285 pub fn from_u16(code: u16) -> Self {
287 match code {
288 1000 => Self::Normal,
289 1001 => Self::GoingAway,
290 1002 => Self::ProtocolError,
291 1003 => Self::UnsupportedData,
292 1005 => Self::NoStatusReceived,
293 1006 => Self::AbnormalClosure,
294 1007 => Self::InvalidPayload,
295 1008 => Self::PolicyViolation,
296 1009 => Self::MessageTooBig,
297 1010 => Self::MandatoryExtension,
298 1011 => Self::InternalError,
299 1012 => Self::ServiceRestart,
300 1013 => Self::TryAgainLater,
301 1014 => Self::BadGateway,
302 3000..=4999 => Self::Application(code),
303 _ => Self::ProtocolError,
304 }
305 }
306}
307
308impl std::fmt::Display for CloseCode {
309 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310 write!(f, "{}", self.to_u16())
311 }
312}
313
314#[derive(Debug, Clone, PartialEq, Eq)]
316pub enum Message {
317 Text(String),
319 Binary(Vec<u8>),
321 Ping(Vec<u8>),
323 Pong(Vec<u8>),
325 Close(Option<CloseCode>, Option<String>),
327}
328
329#[derive(Debug, Clone)]
331struct Frame {
332 fin: bool,
333 opcode: Opcode,
334 payload: Vec<u8>,
335}
336
337#[derive(Debug)]
339pub enum WebSocketError {
340 Io(io::Error),
342 Protocol(String),
344 ConnectionClosed,
346 MessageTooLarge { size: usize, limit: usize },
348 InvalidUtf8,
350 HandshakeFailed(String),
352}
353
354impl std::fmt::Display for WebSocketError {
355 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356 match self {
357 Self::Io(e) => write!(f, "WebSocket I/O error: {e}"),
358 Self::Protocol(msg) => write!(f, "WebSocket protocol error: {msg}"),
359 Self::ConnectionClosed => write!(f, "WebSocket connection closed"),
360 Self::MessageTooLarge { size, limit } => {
361 write!(
362 f,
363 "WebSocket message too large: {size} bytes (limit: {limit})"
364 )
365 }
366 Self::InvalidUtf8 => write!(f, "WebSocket: invalid UTF-8 in text message"),
367 Self::HandshakeFailed(msg) => write!(f, "WebSocket handshake failed: {msg}"),
368 }
369 }
370}
371
372impl std::error::Error for WebSocketError {
373 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
374 match self {
375 Self::Io(e) => Some(e),
376 _ => None,
377 }
378 }
379}
380
381impl From<io::Error> for WebSocketError {
382 fn from(e: io::Error) -> Self {
383 Self::Io(e)
384 }
385}
386
387#[derive(Debug, Clone)]
389pub struct WebSocketConfig {
390 pub max_frame_size: usize,
392 pub max_message_size: usize,
394}
395
396impl Default for WebSocketConfig {
397 fn default() -> Self {
398 Self {
399 max_frame_size: DEFAULT_MAX_FRAME_SIZE,
400 max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
401 }
402 }
403}
404
405pub fn accept_key(client_key: &str) -> String {
412 let mut input = String::with_capacity(client_key.len() + WS_GUID.len());
413 input.push_str(client_key.trim());
414 input.push_str(WS_GUID);
415 base64_encode(&sha1(input.as_bytes()))
416}
417
418pub fn validate_upgrade_request(
429 method: &str,
430 headers: &[(String, Vec<u8>)],
431) -> Result<String, WebSocketError> {
432 if !method.eq_ignore_ascii_case("GET") {
434 return Err(WebSocketError::HandshakeFailed(
435 "WebSocket upgrade requires GET method".into(),
436 ));
437 }
438
439 let find_header = |name: &str| -> Option<String> {
440 headers
441 .iter()
442 .find(|(k, _)| k.eq_ignore_ascii_case(name))
443 .and_then(|(_, v)| String::from_utf8(v.clone()).ok())
444 };
445
446 let upgrade = find_header("upgrade")
448 .ok_or_else(|| WebSocketError::HandshakeFailed("missing Upgrade header".into()))?;
449 if !upgrade
450 .split(',')
451 .any(|v| v.trim().eq_ignore_ascii_case("websocket"))
452 {
453 return Err(WebSocketError::HandshakeFailed(
454 "Upgrade header must contain 'websocket'".into(),
455 ));
456 }
457
458 let connection = find_header("connection")
460 .ok_or_else(|| WebSocketError::HandshakeFailed("missing Connection header".into()))?;
461 if !connection
462 .split(',')
463 .any(|v| v.trim().eq_ignore_ascii_case("upgrade"))
464 {
465 return Err(WebSocketError::HandshakeFailed(
466 "Connection header must contain 'upgrade'".into(),
467 ));
468 }
469
470 let key = find_header("sec-websocket-key").ok_or_else(|| {
472 WebSocketError::HandshakeFailed("missing Sec-WebSocket-Key header".into())
473 })?;
474 let key = key.trim();
475 if key.is_empty() {
476 return Err(WebSocketError::HandshakeFailed(
477 "Sec-WebSocket-Key must not be empty".into(),
478 ));
479 }
480 if fastapi_core::websocket_accept_from_key(key).is_err() {
481 return Err(WebSocketError::HandshakeFailed(
482 "invalid Sec-WebSocket-Key (must be valid base64 with 16 decoded bytes)".into(),
483 ));
484 }
485
486 let version = find_header("sec-websocket-version").ok_or_else(|| {
488 WebSocketError::HandshakeFailed("missing Sec-WebSocket-Version header".into())
489 })?;
490 if version.trim() != "13" {
491 return Err(WebSocketError::HandshakeFailed(format!(
492 "unsupported WebSocket version: {version} (expected 13)"
493 )));
494 }
495
496 Ok(key.to_string())
497}
498
499fn is_valid_subprotocol_token(value: &str) -> bool {
500 const SEPARATORS: &str = "()<>@,;:\\\"/[]?={} \t";
503 !value.is_empty()
504 && value
505 .bytes()
506 .all(|b| b.is_ascii() && (0x21..=0x7E).contains(&b) && !SEPARATORS.contains(b as char))
507}
508
509pub fn build_accept_response(
513 client_key: &str,
514 subprotocol: Option<&str>,
515) -> Result<Vec<u8>, WebSocketError> {
516 let accept = accept_key(client_key);
517 let mut response = format!(
518 "HTTP/1.1 101 Switching Protocols\r\n\
519 Upgrade: websocket\r\n\
520 Connection: Upgrade\r\n\
521 Sec-WebSocket-Accept: {accept}\r\n"
522 );
523 if let Some(proto) = subprotocol {
524 if !is_valid_subprotocol_token(proto) {
525 return Err(WebSocketError::HandshakeFailed(
526 "invalid Sec-WebSocket-Protocol token".into(),
527 ));
528 }
529 response.push_str(&format!("Sec-WebSocket-Protocol: {proto}\r\n"));
530 }
531 response.push_str("\r\n");
532 Ok(response.into_bytes())
533}
534
535async fn read_frame(
544 stream: &mut TcpStream,
545 config: &WebSocketConfig,
546) -> Result<Frame, WebSocketError> {
547 let mut header = [0u8; 2];
549 read_exact(stream, &mut header).await?;
550
551 let fin = (header[0] & 0x80) != 0;
552 let rsv = (header[0] >> 4) & 0x07;
553 if rsv != 0 {
554 return Err(WebSocketError::Protocol(
555 "reserved bits must be 0 (no extensions negotiated)".into(),
556 ));
557 }
558
559 let opcode = Opcode::from_u8(header[0])?;
560 let masked = (header[1] & 0x80) != 0;
561 let payload_len_byte = header[1] & 0x7F;
562
563 if !masked {
564 return Err(WebSocketError::Protocol(
565 "client-to-server frames must be masked".into(),
566 ));
567 }
568
569 let payload_len: usize = match payload_len_byte {
571 0..=125 => payload_len_byte as usize,
572 126 => {
573 let mut len_bytes = [0u8; 2];
574 read_exact(stream, &mut len_bytes).await?;
575 u16::from_be_bytes(len_bytes) as usize
576 }
577 _ => {
578 let mut len_bytes = [0u8; 8];
580 read_exact(stream, &mut len_bytes).await?;
581 let len = u64::from_be_bytes(len_bytes);
582 if (len >> 63) != 0 {
584 return Err(WebSocketError::Protocol(
585 "64-bit frame length has most significant bit set".into(),
586 ));
587 }
588 if len > usize::MAX as u64 {
590 return Err(WebSocketError::MessageTooLarge {
591 size: usize::MAX,
592 limit: config.max_frame_size,
593 });
594 }
595 len as usize
596 }
597 };
598
599 if opcode.is_control() {
601 if !fin {
602 return Err(WebSocketError::Protocol(
603 "control frames must not be fragmented".into(),
604 ));
605 }
606 if payload_len > 125 {
607 return Err(WebSocketError::Protocol(
608 "control frame payload must not exceed 125 bytes".into(),
609 ));
610 }
611 }
612
613 if payload_len > config.max_frame_size {
615 return Err(WebSocketError::MessageTooLarge {
616 size: payload_len,
617 limit: config.max_frame_size,
618 });
619 }
620
621 let mask_key = if masked {
623 let mut key = [0u8; 4];
624 read_exact(stream, &mut key).await?;
625 Some(key)
626 } else {
627 None
628 };
629
630 let mut payload = vec![0u8; payload_len];
632 if payload_len > 0 {
633 read_exact(stream, &mut payload).await?;
634 }
635
636 if let Some(key) = mask_key {
638 for (i, byte) in payload.iter_mut().enumerate() {
639 *byte ^= key[i % 4];
640 }
641 }
642
643 Ok(Frame {
644 fin,
645 opcode,
646 payload,
647 })
648}
649
650async fn write_frame(
654 stream: &mut TcpStream,
655 fin: bool,
656 opcode: Opcode,
657 payload: &[u8],
658) -> Result<(), WebSocketError> {
659 let mut header = Vec::with_capacity(10);
660
661 let first_byte = if fin { 0x80 } else { 0x00 } | opcode.to_u8();
663 header.push(first_byte);
664
665 let len = payload.len();
667 if len < 126 {
668 header.push(len as u8);
669 } else if len <= 0xFFFF {
670 header.push(126);
671 header.extend_from_slice(&(len as u16).to_be_bytes());
672 } else {
673 header.push(127);
674 header.extend_from_slice(&(len as u64).to_be_bytes());
675 }
676
677 ws_write_all(stream, &header).await?;
679 if !payload.is_empty() {
680 ws_write_all(stream, payload).await?;
681 }
682 ws_flush(stream).await?;
683
684 Ok(())
685}
686
687async fn read_message(
691 stream: &mut TcpStream,
692 config: &WebSocketConfig,
693) -> Result<Message, WebSocketError> {
694 let mut message_opcode: Option<Opcode> = None;
695 let mut message_data: Vec<u8> = Vec::new();
696
697 loop {
698 let frame = read_frame(stream, config).await?;
699
700 if frame.opcode.is_control() {
702 match frame.opcode {
703 Opcode::Close => {
704 let (code, reason) = parse_close_payload(&frame.payload)?;
705 return Ok(Message::Close(code, reason));
706 }
707 Opcode::Ping => {
708 write_frame(stream, true, Opcode::Pong, &frame.payload).await?;
709 continue;
710 }
711 Opcode::Pong => continue,
712 _ => unreachable!(),
713 }
714 }
715
716 match frame.opcode {
718 Opcode::Continuation => {
719 if message_opcode.is_none() {
720 return Err(WebSocketError::Protocol(
721 "continuation frame without initial frame".into(),
722 ));
723 }
724 }
725 Opcode::Text | Opcode::Binary => {
726 if message_opcode.is_some() {
727 return Err(WebSocketError::Protocol(
728 "new data frame while previous message is incomplete".into(),
729 ));
730 }
731 message_opcode = Some(frame.opcode);
732 }
733 _ => {}
734 }
735
736 let new_size = message_data.len() + frame.payload.len();
738 if new_size > config.max_message_size {
739 return Err(WebSocketError::MessageTooLarge {
740 size: new_size,
741 limit: config.max_message_size,
742 });
743 }
744
745 message_data.extend_from_slice(&frame.payload);
746
747 if frame.fin {
748 break;
749 }
750 }
751
752 let opcode = message_opcode
753 .ok_or_else(|| WebSocketError::Protocol("empty message (no data frames)".into()))?;
754
755 match opcode {
756 Opcode::Text => {
757 let text = String::from_utf8(message_data).map_err(|_| WebSocketError::InvalidUtf8)?;
758 Ok(Message::Text(text))
759 }
760 Opcode::Binary => Ok(Message::Binary(message_data)),
761 _ => unreachable!(),
762 }
763}
764
765fn parse_close_payload(
767 payload: &[u8],
768) -> Result<(Option<CloseCode>, Option<String>), WebSocketError> {
769 if payload.len() < 2 {
770 if payload.is_empty() {
771 return Ok((None, None));
772 }
773 return Err(WebSocketError::Protocol(
774 "close frame payload must be empty or at least 2 bytes".into(),
775 ));
776 }
777 let code_raw = u16::from_be_bytes([payload[0], payload[1]]);
778 if !is_valid_close_code(code_raw) {
779 return Err(WebSocketError::Protocol(format!(
780 "invalid close code in close frame: {code_raw}"
781 )));
782 }
783 let code = CloseCode::from_u16(code_raw);
784 let reason = if payload.len() > 2 {
785 Some(
786 std::str::from_utf8(&payload[2..])
787 .map_err(|_| WebSocketError::Protocol("close reason must be valid UTF-8".into()))?
788 .to_string(),
789 )
790 } else {
791 None
792 };
793 Ok((Some(code), reason))
794}
795
796fn build_close_payload(code: CloseCode, reason: Option<&str>) -> Result<Vec<u8>, WebSocketError> {
798 if !is_valid_close_code(code.to_u16()) {
799 return Err(WebSocketError::Protocol(format!(
800 "invalid close code for close frame: {}",
801 code.to_u16()
802 )));
803 }
804 let mut payload = Vec::with_capacity(2 + reason.map_or(0, str::len));
805 payload.extend_from_slice(&code.to_u16().to_be_bytes());
806 if let Some(reason_str) = reason {
807 let max_reason = 123; let mut end = reason_str.len().min(max_reason);
810 while end > 0 && !reason_str.is_char_boundary(end) {
811 end -= 1;
812 }
813 payload.extend_from_slice(&reason_str.as_bytes()[..end]);
814 }
815 Ok(payload)
816}
817
818fn is_valid_close_code(code: u16) -> bool {
819 matches!(
820 code,
821 1000 | 1001 | 1002 | 1003 | 1007 | 1008 | 1009 | 1010 | 1011 | 1012 | 1013 | 1014 | 3000
822 ..=4999
823 )
824}
825
826#[derive(Debug, Clone, Copy, PartialEq, Eq)]
832enum WsState {
833 Pending,
835 Open,
837 CloseSent,
839 Closed,
841}
842
843pub struct WebSocket {
875 stream: TcpStream,
876 state: WsState,
877 client_key: String,
878 config: WebSocketConfig,
879}
880
881impl WebSocket {
882 pub fn new(stream: TcpStream, client_key: String) -> Self {
887 Self {
888 stream,
889 state: WsState::Pending,
890 client_key,
891 config: WebSocketConfig::default(),
892 }
893 }
894
895 pub fn with_config(stream: TcpStream, client_key: String, config: WebSocketConfig) -> Self {
897 Self {
898 stream,
899 state: WsState::Pending,
900 client_key,
901 config,
902 }
903 }
904
905 pub async fn accept(&mut self, subprotocol: Option<&str>) -> Result<(), WebSocketError> {
914 if self.state != WsState::Pending {
915 return Err(WebSocketError::Protocol(
916 "accept() called on non-pending WebSocket".into(),
917 ));
918 }
919
920 let response_bytes = build_accept_response(&self.client_key, subprotocol)?;
921 ws_write_all(&mut self.stream, &response_bytes).await?;
922 ws_flush(&mut self.stream).await?;
923 self.state = WsState::Open;
924 Ok(())
925 }
926
927 pub async fn receive(&mut self) -> Result<Message, WebSocketError> {
937 self.ensure_can_receive()?;
938 let msg = read_message(&mut self.stream, &self.config).await?;
939 match msg {
940 Message::Close(code, reason) => {
941 if self.state == WsState::Open {
943 let payload = match code {
944 Some(close_code) => build_close_payload(close_code, reason.as_deref())?,
945 None => Vec::new(),
946 };
947 write_frame(&mut self.stream, true, Opcode::Close, &payload)
948 .await
949 .ok(); }
951 self.state = WsState::Closed;
952 Ok(Message::Close(code, reason))
953 }
954 _ => Ok(msg),
955 }
956 }
957
958 pub async fn send_text(&mut self, text: &str) -> Result<(), WebSocketError> {
960 self.ensure_open()?;
961 write_frame(&mut self.stream, true, Opcode::Text, text.as_bytes()).await
962 }
963
964 pub async fn send_bytes(&mut self, data: &[u8]) -> Result<(), WebSocketError> {
966 self.ensure_open()?;
967 write_frame(&mut self.stream, true, Opcode::Binary, data).await
968 }
969
970 pub async fn receive_text(&mut self) -> Result<String, WebSocketError> {
975 match self.receive().await? {
976 Message::Text(text) => Ok(text),
977 Message::Close(code, reason) => Err(WebSocketError::Protocol(format!(
978 "expected text, got close (code={code:?}, reason={reason:?})"
979 ))),
980 other => Err(WebSocketError::Protocol(format!(
981 "expected text message, got {other:?}"
982 ))),
983 }
984 }
985
986 pub async fn receive_bytes(&mut self) -> Result<Vec<u8>, WebSocketError> {
991 match self.receive().await? {
992 Message::Binary(data) => Ok(data),
993 Message::Close(code, reason) => Err(WebSocketError::Protocol(format!(
994 "expected binary, got close (code={code:?}, reason={reason:?})"
995 ))),
996 other => Err(WebSocketError::Protocol(format!(
997 "expected binary message, got {other:?}"
998 ))),
999 }
1000 }
1001
1002 pub async fn ping(&mut self, data: &[u8]) -> Result<(), WebSocketError> {
1004 self.ensure_open()?;
1005 if data.len() > 125 {
1006 return Err(WebSocketError::Protocol(
1007 "ping payload must not exceed 125 bytes".into(),
1008 ));
1009 }
1010 write_frame(&mut self.stream, true, Opcode::Ping, data).await
1011 }
1012
1013 pub async fn pong(&mut self, data: &[u8]) -> Result<(), WebSocketError> {
1015 self.ensure_open()?;
1016 if data.len() > 125 {
1017 return Err(WebSocketError::Protocol(
1018 "pong payload must not exceed 125 bytes".into(),
1019 ));
1020 }
1021 write_frame(&mut self.stream, true, Opcode::Pong, data).await
1022 }
1023
1024 pub async fn close(
1029 &mut self,
1030 code: CloseCode,
1031 reason: Option<&str>,
1032 ) -> Result<(), WebSocketError> {
1033 if self.state == WsState::Closed || self.state == WsState::CloseSent {
1034 return Ok(());
1035 }
1036 if self.state == WsState::Pending {
1037 self.state = WsState::Closed;
1038 return Ok(());
1039 }
1040
1041 let payload = build_close_payload(code, reason)?;
1042 write_frame(&mut self.stream, true, Opcode::Close, &payload).await?;
1043 self.state = WsState::CloseSent;
1044 Ok(())
1045 }
1046
1047 pub fn is_open(&self) -> bool {
1049 self.state == WsState::Open
1050 }
1051
1052 pub fn state(&self) -> &'static str {
1054 match self.state {
1055 WsState::Pending => "pending",
1056 WsState::Open => "open",
1057 WsState::CloseSent => "close_sent",
1058 WsState::Closed => "closed",
1059 }
1060 }
1061
1062 fn ensure_open(&self) -> Result<(), WebSocketError> {
1063 match self.state {
1064 WsState::Open => Ok(()),
1065 WsState::Pending => Err(WebSocketError::Protocol(
1066 "must call accept() before sending/receiving".into(),
1067 )),
1068 WsState::CloseSent | WsState::Closed => Err(WebSocketError::ConnectionClosed),
1069 }
1070 }
1071
1072 fn ensure_can_receive(&self) -> Result<(), WebSocketError> {
1073 match self.state {
1074 WsState::Open | WsState::CloseSent => Ok(()),
1075 WsState::Pending => Err(WebSocketError::Protocol(
1076 "must call accept() before sending/receiving".into(),
1077 )),
1078 WsState::Closed => Err(WebSocketError::ConnectionClosed),
1079 }
1080 }
1081}
1082
1083async fn read_exact(stream: &mut TcpStream, buf: &mut [u8]) -> Result<(), WebSocketError> {
1089 let mut offset = 0;
1090 while offset < buf.len() {
1091 let n = ws_read(stream, &mut buf[offset..]).await?;
1092 if n == 0 {
1093 return Err(WebSocketError::ConnectionClosed);
1094 }
1095 offset += n;
1096 }
1097 Ok(())
1098}
1099
1100async fn ws_read(stream: &mut TcpStream, buf: &mut [u8]) -> Result<usize, WebSocketError> {
1102 poll_fn(|cx| {
1103 let mut read_buf = ReadBuf::new(buf);
1104 match Pin::new(&mut *stream).poll_read(cx, &mut read_buf) {
1105 Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len())),
1106 Poll::Ready(Err(e)) => Poll::Ready(Err(WebSocketError::Io(e))),
1107 Poll::Pending => Poll::Pending,
1108 }
1109 })
1110 .await
1111}
1112
1113async fn ws_write_all(stream: &mut TcpStream, mut buf: &[u8]) -> Result<(), WebSocketError> {
1115 while !buf.is_empty() {
1116 let n = poll_fn(|cx| Pin::new(&mut *stream).poll_write(cx, buf))
1117 .await
1118 .map_err(WebSocketError::Io)?;
1119 if n == 0 {
1120 return Err(WebSocketError::Io(io::Error::new(
1121 io::ErrorKind::WriteZero,
1122 "failed to write to WebSocket stream",
1123 )));
1124 }
1125 buf = &buf[n..];
1126 }
1127 Ok(())
1128}
1129
1130async fn ws_flush(stream: &mut TcpStream) -> Result<(), WebSocketError> {
1132 poll_fn(|cx| Pin::new(&mut *stream).poll_flush(cx))
1133 .await
1134 .map_err(WebSocketError::Io)
1135}
1136
1137#[cfg(test)]
1142mod tests {
1143 use super::*;
1144
1145 #[test]
1146 fn test_sha1_empty() {
1147 let result = sha1(b"");
1148 let expected: [u8; 20] = [
1149 0xda, 0x39, 0xa3, 0xee, 0x5e, 0x6b, 0x4b, 0x0d, 0x32, 0x55, 0xbf, 0xef, 0x95, 0x60,
1150 0x18, 0x90, 0xaf, 0xd8, 0x07, 0x09,
1151 ];
1152 assert_eq!(result, expected);
1153 }
1154
1155 #[test]
1156 fn test_sha1_abc() {
1157 let result = sha1(b"abc");
1158 let expected: [u8; 20] = [
1159 0xa9, 0x99, 0x3e, 0x36, 0x47, 0x06, 0x81, 0x6a, 0xba, 0x3e, 0x25, 0x71, 0x78, 0x50,
1160 0xc2, 0x6c, 0x9c, 0xd0, 0xd8, 0x9d,
1161 ];
1162 assert_eq!(result, expected);
1163 }
1164
1165 #[test]
1166 fn test_base64_encode() {
1167 assert_eq!(base64_encode(b""), "");
1168 assert_eq!(base64_encode(b"f"), "Zg==");
1169 assert_eq!(base64_encode(b"fo"), "Zm8=");
1170 assert_eq!(base64_encode(b"foo"), "Zm9v");
1171 assert_eq!(base64_encode(b"foobar"), "Zm9vYmFy");
1172 }
1173
1174 #[test]
1175 fn test_accept_key() {
1176 let key = accept_key("dGhlIHNhbXBsZSBub25jZQ==");
1178 assert_eq!(key, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
1179 }
1180
1181 #[test]
1182 fn test_close_code_roundtrip() {
1183 let codes = [
1184 CloseCode::Normal,
1185 CloseCode::GoingAway,
1186 CloseCode::ProtocolError,
1187 CloseCode::UnsupportedData,
1188 CloseCode::InvalidPayload,
1189 CloseCode::PolicyViolation,
1190 CloseCode::MessageTooBig,
1191 CloseCode::MandatoryExtension,
1192 CloseCode::InternalError,
1193 CloseCode::ServiceRestart,
1194 CloseCode::TryAgainLater,
1195 CloseCode::BadGateway,
1196 CloseCode::Application(3000),
1197 CloseCode::Application(4000),
1198 CloseCode::Application(4999),
1199 ];
1200 for code in codes {
1201 assert_eq!(CloseCode::from_u16(code.to_u16()), code);
1202 }
1203 }
1204
1205 #[test]
1206 fn test_opcode_roundtrip() {
1207 let opcodes = [
1208 Opcode::Continuation,
1209 Opcode::Text,
1210 Opcode::Binary,
1211 Opcode::Close,
1212 Opcode::Ping,
1213 Opcode::Pong,
1214 ];
1215 for op in opcodes {
1216 assert_eq!(Opcode::from_u8(op.to_u8()).unwrap(), op);
1217 }
1218 }
1219
1220 #[test]
1221 fn test_opcode_unknown() {
1222 assert!(Opcode::from_u8(0x03).is_err());
1223 assert!(Opcode::from_u8(0x07).is_err());
1224 }
1225
1226 #[test]
1227 fn test_opcode_is_control() {
1228 assert!(!Opcode::Continuation.is_control());
1229 assert!(!Opcode::Text.is_control());
1230 assert!(!Opcode::Binary.is_control());
1231 assert!(Opcode::Close.is_control());
1232 assert!(Opcode::Ping.is_control());
1233 assert!(Opcode::Pong.is_control());
1234 }
1235
1236 #[test]
1237 fn test_build_accept_response_basic() {
1238 let resp = build_accept_response("dGhlIHNhbXBsZSBub25jZQ==", None)
1239 .expect("response build should succeed");
1240 let resp_str = String::from_utf8(resp).unwrap();
1241 assert!(resp_str.starts_with("HTTP/1.1 101 Switching Protocols\r\n"));
1242 assert!(resp_str.contains("Upgrade: websocket\r\n"));
1243 assert!(resp_str.contains("Connection: Upgrade\r\n"));
1244 assert!(resp_str.contains("Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"));
1245 assert!(resp_str.ends_with("\r\n\r\n"));
1246 }
1247
1248 #[test]
1249 fn test_build_accept_response_with_subprotocol() {
1250 let resp = build_accept_response("dGhlIHNhbXBsZSBub25jZQ==", Some("graphql-ws"))
1251 .expect("response build should succeed");
1252 let resp_str = String::from_utf8(resp).unwrap();
1253 assert!(resp_str.contains("Sec-WebSocket-Protocol: graphql-ws\r\n"));
1254 }
1255
1256 #[test]
1257 fn test_build_accept_response_rejects_invalid_subprotocol_token() {
1258 let err =
1259 build_accept_response("dGhlIHNhbXBsZSBub25jZQ==", Some("graphql-ws\r\nX-Evil: 1"))
1260 .expect_err("invalid subprotocol token must fail");
1261 assert!(matches!(err, WebSocketError::HandshakeFailed(_)));
1262 assert!(
1263 err.to_string()
1264 .contains("invalid Sec-WebSocket-Protocol token")
1265 );
1266 }
1267
1268 #[test]
1269 fn test_validate_upgrade_request_valid() {
1270 let headers = vec![
1271 ("Upgrade".into(), b"websocket".to_vec()),
1272 ("Connection".into(), b"upgrade".to_vec()),
1273 (
1274 "Sec-WebSocket-Key".into(),
1275 b"dGhlIHNhbXBsZSBub25jZQ==".to_vec(),
1276 ),
1277 ("Sec-WebSocket-Version".into(), b"13".to_vec()),
1278 ];
1279 let result = validate_upgrade_request("GET", &headers);
1280 assert!(result.is_ok());
1281 assert_eq!(result.unwrap(), "dGhlIHNhbXBsZSBub25jZQ==");
1282 }
1283
1284 #[test]
1285 fn test_validate_upgrade_request_wrong_method() {
1286 let headers = vec![
1287 ("Upgrade".into(), b"websocket".to_vec()),
1288 ("Connection".into(), b"upgrade".to_vec()),
1289 (
1290 "Sec-WebSocket-Key".into(),
1291 b"dGhlIHNhbXBsZSBub25jZQ==".to_vec(),
1292 ),
1293 ("Sec-WebSocket-Version".into(), b"13".to_vec()),
1294 ];
1295 assert!(validate_upgrade_request("POST", &headers).is_err());
1296 }
1297
1298 #[test]
1299 fn test_validate_upgrade_request_missing_upgrade() {
1300 let headers = vec![
1301 ("Connection".into(), b"upgrade".to_vec()),
1302 (
1303 "Sec-WebSocket-Key".into(),
1304 b"dGhlIHNhbXBsZSBub25jZQ==".to_vec(),
1305 ),
1306 ("Sec-WebSocket-Version".into(), b"13".to_vec()),
1307 ];
1308 assert!(validate_upgrade_request("GET", &headers).is_err());
1309 }
1310
1311 #[test]
1312 fn test_validate_upgrade_request_wrong_version() {
1313 let headers = vec![
1314 ("Upgrade".into(), b"websocket".to_vec()),
1315 ("Connection".into(), b"upgrade".to_vec()),
1316 (
1317 "Sec-WebSocket-Key".into(),
1318 b"dGhlIHNhbXBsZSBub25jZQ==".to_vec(),
1319 ),
1320 ("Sec-WebSocket-Version".into(), b"8".to_vec()),
1321 ];
1322 assert!(validate_upgrade_request("GET", &headers).is_err());
1323 }
1324
1325 #[test]
1326 fn test_validate_upgrade_request_invalid_key_base64() {
1327 let headers = vec![
1328 ("Upgrade".into(), b"websocket".to_vec()),
1329 ("Connection".into(), b"upgrade".to_vec()),
1330 ("Sec-WebSocket-Key".into(), b"not-base64".to_vec()),
1331 ("Sec-WebSocket-Version".into(), b"13".to_vec()),
1332 ];
1333 assert!(validate_upgrade_request("GET", &headers).is_err());
1334 }
1335
1336 #[test]
1337 fn test_validate_upgrade_request_invalid_key_length() {
1338 let headers = vec![
1339 ("Upgrade".into(), b"websocket".to_vec()),
1340 ("Connection".into(), b"upgrade".to_vec()),
1341 ("Sec-WebSocket-Key".into(), b"Zm9v".to_vec()),
1342 ("Sec-WebSocket-Version".into(), b"13".to_vec()),
1343 ];
1344 assert!(validate_upgrade_request("GET", &headers).is_err());
1345 }
1346
1347 #[test]
1348 fn test_close_payload_roundtrip() {
1349 let payload = build_close_payload(CloseCode::Normal, Some("goodbye")).unwrap();
1350 let (code, reason) = parse_close_payload(&payload).unwrap();
1351 assert_eq!(code, Some(CloseCode::Normal));
1352 assert_eq!(reason, Some("goodbye".into()));
1353 }
1354
1355 #[test]
1356 fn test_close_payload_no_reason() {
1357 let payload = build_close_payload(CloseCode::GoingAway, None).unwrap();
1358 let (code, reason) = parse_close_payload(&payload).unwrap();
1359 assert_eq!(code, Some(CloseCode::GoingAway));
1360 assert_eq!(reason, None);
1361 }
1362
1363 #[test]
1364 fn test_close_payload_empty() {
1365 let (code, reason) = parse_close_payload(&[]).unwrap();
1366 assert_eq!(code, None);
1367 assert_eq!(reason, None);
1368 }
1369
1370 #[test]
1371 fn test_close_payload_len_one_is_invalid() {
1372 let err = parse_close_payload(&[0x03]).expect_err("len=1 close payload must fail");
1373 assert!(matches!(err, WebSocketError::Protocol(_)));
1374 }
1375
1376 #[test]
1377 fn test_close_payload_invalid_code_is_rejected() {
1378 let err = parse_close_payload(&[0x03, 0xEE]).expect_err("1006 must be rejected");
1379 assert!(matches!(err, WebSocketError::Protocol(_)));
1380 }
1381
1382 #[test]
1383 fn test_build_close_payload_rejects_unsendable_code() {
1384 let err = build_close_payload(CloseCode::NoStatusReceived, None)
1385 .expect_err("1005 must not be sent");
1386 assert!(matches!(err, WebSocketError::Protocol(_)));
1387 }
1388
1389 #[test]
1390 fn test_build_close_payload_truncates_on_utf8_boundary() {
1391 let reason = "é".repeat(100); let payload = build_close_payload(CloseCode::Normal, Some(&reason)).unwrap();
1393 assert!(payload.len() <= 125);
1394 let parsed =
1395 std::str::from_utf8(&payload[2..]).expect("reason bytes must stay valid UTF-8");
1396 assert!(!parsed.is_empty());
1397 }
1398
1399 #[test]
1400 fn test_message_equality() {
1401 assert_eq!(Message::Text("hello".into()), Message::Text("hello".into()));
1402 assert_eq!(
1403 Message::Binary(vec![1, 2, 3]),
1404 Message::Binary(vec![1, 2, 3])
1405 );
1406 assert_ne!(
1407 Message::Text("hello".into()),
1408 Message::Binary(b"hello".to_vec())
1409 );
1410 }
1411
1412 #[test]
1413 fn test_websocket_config_default() {
1414 let config = WebSocketConfig::default();
1415 assert_eq!(config.max_frame_size, DEFAULT_MAX_FRAME_SIZE);
1416 assert_eq!(config.max_message_size, DEFAULT_MAX_MESSAGE_SIZE);
1417 }
1418}