1#![cfg_attr(not(feature = "std"), no_std)]
14
15use byteorder::{BigEndian, ByteOrder};
16use core::{cmp, result, str};
17use heapless::{String, Vec};
18use rand_core::RngCore;
19use sha1::{Digest, Sha1};
20
21mod http;
22pub mod random;
23pub use self::http::{read_http_header, WebSocketContext};
24pub use self::random::EmptyRng;
25
26pub mod framer;
29pub mod framer_async;
30const MASK_KEY_LEN: usize = 4;
31
32pub type Result<T> = result::Result<T, Error>;
34
35pub type WebSocketKey = String<24>;
37
38pub type WebSocketSubProtocol = String<24>;
40
41#[derive(PartialEq, Eq, Debug, Copy, Clone)]
43pub enum WebSocketSendMessageType {
44 Text = 1,
46 Binary = 2,
48 Ping = 9,
50 Pong = 10,
52 CloseReply = 11,
55}
56
57impl WebSocketSendMessageType {
58 fn to_op_code(self) -> WebSocketOpCode {
59 match self {
60 WebSocketSendMessageType::Text => WebSocketOpCode::TextFrame,
61 WebSocketSendMessageType::Binary => WebSocketOpCode::BinaryFrame,
62 WebSocketSendMessageType::Ping => WebSocketOpCode::Ping,
63 WebSocketSendMessageType::Pong => WebSocketOpCode::Pong,
64 WebSocketSendMessageType::CloseReply => WebSocketOpCode::ConnectionClose,
65 }
66 }
67}
68
69#[derive(PartialEq, Eq, Debug, Copy, Clone)]
71pub enum WebSocketReceiveMessageType {
72 Text = 1,
74 Binary = 2,
76 CloseCompleted = 7,
78 CloseMustReply = 8,
82 Ping = 9,
85 Pong = 10,
87}
88
89#[derive(PartialEq, Eq, Debug, Copy, Clone)]
91pub enum WebSocketCloseStatusCode {
92 NormalClosure,
95 EndpointUnavailable,
98 ProtocolError,
101 InvalidMessageType,
105 Reserved,
107 Empty,
109 InvalidPayloadData,
113 PolicyViolation,
117 MessageTooBig,
120 MandatoryExtension,
124 InternalServerError,
127 TlsHandshake,
129 Custom(u16),
131}
132
133impl WebSocketCloseStatusCode {
134 fn from_u16(value: u16) -> WebSocketCloseStatusCode {
135 match value {
136 1000 => WebSocketCloseStatusCode::NormalClosure,
137 1001 => WebSocketCloseStatusCode::EndpointUnavailable,
138 1002 => WebSocketCloseStatusCode::ProtocolError,
139 1003 => WebSocketCloseStatusCode::InvalidMessageType,
140 1004 => WebSocketCloseStatusCode::Reserved,
141 1005 => WebSocketCloseStatusCode::Empty,
142 1007 => WebSocketCloseStatusCode::InvalidPayloadData,
143 1008 => WebSocketCloseStatusCode::PolicyViolation,
144 1009 => WebSocketCloseStatusCode::MessageTooBig,
145 1010 => WebSocketCloseStatusCode::MandatoryExtension,
146 1011 => WebSocketCloseStatusCode::InternalServerError,
147 1015 => WebSocketCloseStatusCode::TlsHandshake,
148 _ => WebSocketCloseStatusCode::Custom(value),
149 }
150 }
151
152 fn to_u16(self) -> u16 {
153 match self {
154 WebSocketCloseStatusCode::NormalClosure => 1000,
155 WebSocketCloseStatusCode::EndpointUnavailable => 1001,
156 WebSocketCloseStatusCode::ProtocolError => 1002,
157 WebSocketCloseStatusCode::InvalidMessageType => 1003,
158 WebSocketCloseStatusCode::Reserved => 1004,
159 WebSocketCloseStatusCode::Empty => 1005,
160 WebSocketCloseStatusCode::InvalidPayloadData => 1007,
161 WebSocketCloseStatusCode::PolicyViolation => 1008,
162 WebSocketCloseStatusCode::MessageTooBig => 1009,
163 WebSocketCloseStatusCode::MandatoryExtension => 1010,
164 WebSocketCloseStatusCode::InternalServerError => 1011,
165 WebSocketCloseStatusCode::TlsHandshake => 1015,
166 WebSocketCloseStatusCode::Custom(value) => value,
167 }
168 }
169}
170
171#[derive(PartialEq, Eq, Copy, Clone, Debug)]
173pub enum WebSocketState {
174 None = 0,
176 Connecting = 1,
178 Open = 2,
182 CloseSent = 3,
184 CloseReceived = 4,
186 Closed = 5,
188 Aborted = 6,
190}
191
192#[derive(PartialEq, Eq, Debug)]
194pub enum Error {
195 InvalidOpCode,
197 InvalidFrameLength,
198 InvalidCloseStatusCode,
199 WebSocketNotOpen,
200 WebsocketAlreadyOpen,
201 Utf8Error,
202 Unknown,
203 HttpHeader(httparse::Error),
204 HttpHeaderNoPath,
205 HttpHeaderIncomplete,
206 WriteToBufferTooSmall,
207 ReadFrameIncomplete,
208 HttpResponseCodeInvalid(Option<u16>),
209 AcceptStringInvalid,
210 ConvertInfallible,
211 RandCore,
212 UnexpectedContinuationFrame,
213}
214
215impl From<httparse::Error> for Error {
216 fn from(err: httparse::Error) -> Error {
217 Error::HttpHeader(err)
218 }
219}
220
221impl From<str::Utf8Error> for Error {
222 fn from(_: str::Utf8Error) -> Error {
223 Error::Utf8Error
224 }
225}
226
227impl From<core::convert::Infallible> for Error {
228 fn from(_: core::convert::Infallible) -> Error {
229 Error::ConvertInfallible
230 }
231}
232
233impl From<()> for Error {
234 fn from(_: ()) -> Error {
235 Error::Unknown
236 }
237}
238
239impl core::fmt::Display for Error {
240 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
241 match self {
242 Error::HttpHeader(error) => write!(f, "bad http header {error}"),
243 Error::HttpResponseCodeInvalid(Some(code)) => write!(f, "bad http response ({code})"),
244 _ => write!(f, "{:?}", self),
245 }
246 }
247}
248
249#[cfg(feature = "std")]
250impl std::error::Error for Error {
251 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
252 if let Self::HttpHeader(error) = self {
253 Some(error)
254 } else {
255 None
256 }
257 }
258}
259
260#[derive(Copy, Clone, Debug, PartialEq, Eq)]
261enum WebSocketOpCode {
262 ContinuationFrame = 0,
263 TextFrame = 1,
264 BinaryFrame = 2,
265 ConnectionClose = 8,
266 Ping = 9,
267 Pong = 10,
268}
269
270impl WebSocketOpCode {
271 fn to_message_type(self) -> Result<WebSocketReceiveMessageType> {
272 match self {
273 WebSocketOpCode::TextFrame => Ok(WebSocketReceiveMessageType::Text),
274 WebSocketOpCode::BinaryFrame => Ok(WebSocketReceiveMessageType::Binary),
275 _ => Err(Error::InvalidOpCode),
276 }
277 }
278}
279
280#[derive(Debug)]
282pub struct WebSocketReadResult {
283 pub len_from: usize,
285 pub len_to: usize,
287 pub end_of_message: bool,
291 pub close_status: Option<WebSocketCloseStatusCode>,
296 pub message_type: WebSocketReceiveMessageType,
298}
299
300pub struct WebSocketOptions<'a> {
303 pub path: &'a str,
307 pub host: &'a str,
310 pub origin: &'a str,
315 pub sub_protocols: Option<&'a [&'a str]>,
319 pub additional_headers: Option<&'a [&'a str]>,
322}
323
324pub type WebSocketServer = WebSocket<EmptyRng, Server>;
326
327pub type WebSocketClient<T> = WebSocket<T, Client>;
329
330pub enum Server {}
333pub enum Client {}
334
335pub trait WebSocketType {}
336impl WebSocketType for Server {}
337impl WebSocketType for Client {}
338
339pub struct WebSocket<T, S: WebSocketType>
341where
342 T: RngCore,
343{
344 is_client: bool,
345 rng: T,
346 continuation_frame_op_code: Option<WebSocketOpCode>,
347 is_write_continuation: bool,
348 pub state: WebSocketState,
349 continuation_read: Option<ContinuationRead>,
350 marker: core::marker::PhantomData<S>,
351}
352
353impl<T, Type> WebSocket<T, Type>
354where
355 T: RngCore,
356 Type: WebSocketType,
357{
358 pub fn new_client(rng: T) -> WebSocketClient<T> {
369 WebSocket {
370 is_client: true,
371 rng,
372 continuation_frame_op_code: None,
373 is_write_continuation: false,
374 state: WebSocketState::None,
375 continuation_read: None,
376 marker: core::marker::PhantomData::<Client>,
377 }
378 }
379
380 pub fn new_server() -> WebSocketServer {
392 let rng = EmptyRng::new();
393 WebSocket {
394 is_client: false,
395 rng,
396 continuation_frame_op_code: None,
397 is_write_continuation: false,
398 state: WebSocketState::None,
399 continuation_read: None,
400 marker: core::marker::PhantomData::<Server>,
401 }
402 }
403}
404
405impl<T> WebSocket<T, Server>
406where
407 T: RngCore,
408{
409 pub fn server_accept(
440 &mut self,
441 sec_websocket_key: &WebSocketKey,
442 sec_websocket_protocol: Option<&WebSocketSubProtocol>,
443 to: &mut [u8],
444 ) -> Result<usize> {
445 if self.state == WebSocketState::Open {
446 return Err(Error::WebsocketAlreadyOpen);
447 }
448
449 match http::build_connect_handshake_response(sec_websocket_key, sec_websocket_protocol, to)
450 {
451 Ok(http_response_len) => {
452 self.state = WebSocketState::Open;
453 Ok(http_response_len)
454 }
455 Err(e) => {
456 self.state = WebSocketState::Aborted;
457 Err(e)
458 }
459 }
460 }
461}
462
463impl<T> WebSocket<T, Client>
464where
465 T: RngCore,
466{
467 pub fn client_connect(
501 &mut self,
502 websocket_options: &WebSocketOptions,
503 to: &mut [u8],
504 ) -> Result<(usize, WebSocketKey)> {
505 if self.state == WebSocketState::Open {
506 return Err(Error::WebsocketAlreadyOpen);
507 }
508
509 match http::build_connect_handshake_request(websocket_options, &mut self.rng, to) {
510 Ok((request_len, sec_websocket_key)) => {
511 self.state = WebSocketState::Connecting;
512 Ok((request_len, sec_websocket_key))
513 }
514 Err(e) => Err(e),
515 }
516 }
517
518 pub fn client_accept(
539 &mut self,
540 sec_websocket_key: &WebSocketKey,
541 from: &[u8],
542 ) -> Result<(usize, Option<WebSocketSubProtocol>)> {
543 if self.state == WebSocketState::Open {
544 return Err(Error::WebsocketAlreadyOpen);
545 }
546
547 match http::read_server_connect_handshake_response(sec_websocket_key, from) {
548 Ok((len, sec_websocket_protocol)) => {
549 self.state = WebSocketState::Open;
550 Ok((len, sec_websocket_protocol))
551 }
552 Err(Error::HttpHeaderIncomplete) => Err(Error::HttpHeaderIncomplete),
553 Err(e) => {
554 self.state = WebSocketState::Aborted;
555 Err(e)
556 }
557 }
558 }
559}
560
561impl<T, Type> WebSocket<T, Type>
562where
563 T: RngCore,
564 Type: WebSocketType,
565{
566 pub fn read(&mut self, from: &[u8], to: &mut [u8]) -> Result<WebSocketReadResult> {
606 if self.state == WebSocketState::Open || self.state == WebSocketState::CloseSent {
607 let frame = self.read_frame(from, to)?;
608
609 match frame.op_code {
610 WebSocketOpCode::Ping => Ok(frame.to_readresult(WebSocketReceiveMessageType::Ping)),
611 WebSocketOpCode::Pong => Ok(frame.to_readresult(WebSocketReceiveMessageType::Pong)),
612 WebSocketOpCode::TextFrame => {
613 Ok(frame.to_readresult(WebSocketReceiveMessageType::Text))
614 }
615 WebSocketOpCode::BinaryFrame => {
616 Ok(frame.to_readresult(WebSocketReceiveMessageType::Binary))
617 }
618 WebSocketOpCode::ConnectionClose => match self.state {
619 WebSocketState::CloseSent => {
620 self.state = WebSocketState::Closed;
621 Ok(frame.to_readresult(WebSocketReceiveMessageType::CloseCompleted))
622 }
623 _ => {
624 self.state = WebSocketState::CloseReceived;
625 Ok(frame.to_readresult(WebSocketReceiveMessageType::CloseMustReply))
626 }
627 },
628 WebSocketOpCode::ContinuationFrame => match self.continuation_frame_op_code {
629 Some(cf_op_code) => Ok(frame.to_readresult(cf_op_code.to_message_type()?)),
630 None => Err(Error::UnexpectedContinuationFrame),
631 },
632 }
633 } else {
634 Err(Error::WebSocketNotOpen)
635 }
636 }
637
638 pub fn write(
665 &mut self,
666 message_type: WebSocketSendMessageType,
667 end_of_message: bool,
668 from: &[u8],
669 to: &mut [u8],
670 ) -> Result<usize> {
671 if self.state == WebSocketState::Open || self.state == WebSocketState::CloseReceived {
672 let mut op_code = message_type.to_op_code();
673 if op_code == WebSocketOpCode::ConnectionClose {
674 self.state = WebSocketState::Closed
675 } else if self.is_write_continuation {
676 op_code = WebSocketOpCode::ContinuationFrame;
677 }
678
679 self.is_write_continuation = !end_of_message;
680 self.write_frame(from, to, op_code, end_of_message)
681 } else {
682 Err(Error::WebSocketNotOpen)
683 }
684 }
685
686 pub fn close(
694 &mut self,
695 close_status: WebSocketCloseStatusCode,
696 status_description: Option<&str>,
697 to: &mut [u8],
698 ) -> Result<usize> {
699 if self.state == WebSocketState::Open {
700 self.state = WebSocketState::CloseSent;
701 if let Some(status_description) = status_description {
702 let mut from_buffer: Vec<u8, 256> = Vec::new();
703 from_buffer.extend_from_slice(&close_status.to_u16().to_be_bytes())?;
704
705 let len = if status_description.len() < 254 {
707 status_description.len()
708 } else {
709 254
710 };
711
712 from_buffer.extend_from_slice(&status_description.as_bytes()[..len])?;
713 self.write_frame(&from_buffer, to, WebSocketOpCode::ConnectionClose, true)
714 } else {
715 let mut from_buffer: [u8; 2] = [0; 2];
716 BigEndian::write_u16(&mut from_buffer, close_status.to_u16());
717 self.write_frame(&from_buffer, to, WebSocketOpCode::ConnectionClose, true)
718 }
719 } else {
720 Err(Error::WebSocketNotOpen)
721 }
722 }
723
724 fn read_frame(&mut self, from_buffer: &[u8], to_buffer: &mut [u8]) -> Result<WebSocketFrame> {
725 match &mut self.continuation_read {
726 Some(continuation_read) => {
727 let result = read_continuation(continuation_read, from_buffer, to_buffer);
728 if result.is_fin_bit_set {
729 self.continuation_read = None;
730 self.continuation_frame_op_code = None;
731 }
732 Ok(result)
733 }
734 None => {
735 let (mut result, continuation_read) = read_frame(from_buffer, to_buffer)?;
736
737 match result.op_code {
739 WebSocketOpCode::BinaryFrame | WebSocketOpCode::TextFrame => {
740 self.continuation_frame_op_code = if result.is_fin_bit_set {
742 None
743 } else {
744 Some(result.op_code)
745 };
746 }
747 WebSocketOpCode::ContinuationFrame => {
748 if let Some(continuation_frame_op_code) = self.continuation_frame_op_code {
750 result.op_code = continuation_frame_op_code;
751 }
752 }
753 _ => {
754 }
756 }
757
758 self.continuation_read = continuation_read;
759 Ok(result)
760 }
761 }
762 }
763
764 fn write_frame(
765 &mut self,
766 from_buffer: &[u8],
767 to_buffer: &mut [u8],
768 op_code: WebSocketOpCode,
769 end_of_message: bool,
770 ) -> Result<usize> {
771 let fin_bit_set_as_byte: u8 = if end_of_message { 0x80 } else { 0x00 };
772 let byte1: u8 = fin_bit_set_as_byte | op_code as u8;
773 let count = from_buffer.len();
774 const BYTE_HEADER_SIZE: usize = 2;
775 const SHORT_HEADER_SIZE: usize = 4;
776 const LONG_HEADER_SIZE: usize = 10;
777 const MASK_KEY_SIZE: usize = 4;
778 let header_size;
779 let mask_bit_set_as_byte = if self.is_client { 0x80 } else { 0x00 };
780 let payload_len = from_buffer.len() + if self.is_client { MASK_KEY_SIZE } else { 0 };
781
782 if count < 126 {
785 if payload_len + BYTE_HEADER_SIZE > to_buffer.len() {
786 return Err(Error::WriteToBufferTooSmall);
787 }
788 to_buffer[0] = byte1;
789 to_buffer[1] = mask_bit_set_as_byte | count as u8;
790 header_size = BYTE_HEADER_SIZE;
791 } else if count < 65535 {
792 if payload_len + SHORT_HEADER_SIZE > to_buffer.len() {
793 return Err(Error::WriteToBufferTooSmall);
794 }
795 to_buffer[0] = byte1;
796 to_buffer[1] = mask_bit_set_as_byte | 126;
797 BigEndian::write_u16(&mut to_buffer[2..], count as u16);
798 header_size = SHORT_HEADER_SIZE;
799 } else {
800 if payload_len + LONG_HEADER_SIZE > to_buffer.len() {
801 return Err(Error::WriteToBufferTooSmall);
802 }
803 to_buffer[0] = byte1;
804 to_buffer[1] = mask_bit_set_as_byte | 127;
805 BigEndian::write_u64(&mut to_buffer[2..], count as u64);
806 header_size = LONG_HEADER_SIZE;
807 }
808
809 if self.is_client {
812 let mut mask_key = [0; MASK_KEY_SIZE];
813 self.rng.fill_bytes(&mut mask_key); to_buffer[header_size..header_size + MASK_KEY_SIZE].copy_from_slice(&mask_key);
815 let to_buffer_start = header_size + MASK_KEY_SIZE;
816
817 for (i, (from, to)) in from_buffer[..count]
819 .iter()
820 .zip(&mut to_buffer[to_buffer_start..to_buffer_start + count])
821 .enumerate()
822 {
823 *to = *from ^ mask_key[i % MASK_KEY_SIZE];
824 }
825
826 Ok(to_buffer_start + count)
827 } else {
828 to_buffer[header_size..header_size + count].copy_from_slice(&from_buffer[..count]);
829 Ok(header_size + count)
830 }
831 }
832}
833
834struct ContinuationRead {
836 op_code: WebSocketOpCode,
837 count: usize,
838 is_fin_bit_set: bool,
839 mask_key: Option<[u8; 4]>,
840}
841
842struct WebSocketFrame {
843 is_fin_bit_set: bool,
844 op_code: WebSocketOpCode,
845 num_bytes_to: usize,
846 num_bytes_from: usize,
847 close_status: Option<WebSocketCloseStatusCode>,
848}
849
850impl WebSocketFrame {
851 fn to_readresult(&self, message_type: WebSocketReceiveMessageType) -> WebSocketReadResult {
852 WebSocketReadResult {
853 len_from: self.num_bytes_from,
854 len_to: self.num_bytes_to,
855 end_of_message: self.is_fin_bit_set,
856 close_status: self.close_status,
857 message_type,
858 }
859 }
860}
861
862fn min(num1: usize, num2: usize, num3: usize) -> usize {
863 cmp::min(cmp::min(num1, num2), num3)
864}
865
866fn read_into_buffer(
867 mask_key: &mut Option<[u8; 4]>,
868 from_buffer: &[u8],
869 to_buffer: &mut [u8],
870 len: usize,
871) -> usize {
872 let len_to_read = min(len, to_buffer.len(), from_buffer.len());
874
875 match mask_key {
876 Some(mask_key) => {
877 for (i, (from, to)) in from_buffer[..len_to_read].iter().zip(to_buffer).enumerate() {
879 *to = *from ^ mask_key[i % MASK_KEY_LEN];
880 }
881 mask_key.rotate_left(len_to_read % MASK_KEY_LEN);
882 }
883 None => {
884 to_buffer[..len_to_read].copy_from_slice(&from_buffer[..len_to_read]);
885 }
886 }
887
888 len_to_read
889}
890
891fn read_continuation(
892 continuation_read: &mut ContinuationRead,
893 from_buffer: &[u8],
894 to_buffer: &mut [u8],
895) -> WebSocketFrame {
896 let len_read = read_into_buffer(
897 &mut continuation_read.mask_key,
898 from_buffer,
899 to_buffer,
900 continuation_read.count,
901 );
902
903 let is_complete = len_read == continuation_read.count;
904
905 let frame = match continuation_read.op_code {
906 WebSocketOpCode::ConnectionClose => decode_close_frame(to_buffer, len_read, len_read),
907 _ => WebSocketFrame {
908 num_bytes_from: len_read,
909 num_bytes_to: len_read,
910 op_code: continuation_read.op_code,
911 close_status: None,
912 is_fin_bit_set: if is_complete {
913 continuation_read.is_fin_bit_set
914 } else {
915 false
916 },
917 },
918 };
919
920 continuation_read.count -= len_read;
921 frame
922}
923
924fn read_frame(
925 from_buffer: &[u8],
926 to_buffer: &mut [u8],
927) -> Result<(WebSocketFrame, Option<ContinuationRead>)> {
928 if from_buffer.len() < 2 {
929 return Err(Error::ReadFrameIncomplete);
930 }
931
932 let byte1 = from_buffer[0];
933 let byte2 = from_buffer[1];
934
935 const FIN_BIT_FLAG: u8 = 0x80;
937 const OP_CODE_FLAG: u8 = 0x0F;
938 let is_fin_bit_set = (byte1 & FIN_BIT_FLAG) == FIN_BIT_FLAG;
939 let op_code = get_op_code(byte1 & OP_CODE_FLAG)?;
940
941 const MASK_FLAG: u8 = 0x80;
943 let is_mask_bit_set = (byte2 & MASK_FLAG) == MASK_FLAG;
944 let (len, mut num_bytes_read) = read_length(byte2, &from_buffer[2..])?;
945
946 num_bytes_read += 2;
947 let from_buffer = &from_buffer[num_bytes_read..];
948
949 let mut mask_key = if is_mask_bit_set {
951 if from_buffer.len() < MASK_KEY_LEN {
952 return Err(Error::ReadFrameIncomplete);
953 }
954 let mut mask_key: [u8; MASK_KEY_LEN] = [0; MASK_KEY_LEN];
955 mask_key.copy_from_slice(&from_buffer[..MASK_KEY_LEN]);
956 num_bytes_read += MASK_KEY_LEN;
957 Some(mask_key)
958 } else {
959 None
960 };
961
962 let len_read = if is_mask_bit_set {
963 let from_buffer = &from_buffer[MASK_KEY_LEN..];
965 read_into_buffer(&mut mask_key, from_buffer, to_buffer, len)
966 } else {
967 read_into_buffer(&mut mask_key, from_buffer, to_buffer, len)
968 };
969
970 let has_continuation = len_read < len;
971 num_bytes_read += len_read;
972
973 let frame = match op_code {
974 WebSocketOpCode::ConnectionClose => decode_close_frame(to_buffer, num_bytes_read, len_read),
975 _ => WebSocketFrame {
976 num_bytes_from: num_bytes_read,
977 num_bytes_to: len_read,
978 op_code,
979 close_status: None,
980 is_fin_bit_set: if has_continuation {
981 false
982 } else {
983 is_fin_bit_set
984 },
985 },
986 };
987
988 if has_continuation {
989 let continuation_read = Some(ContinuationRead {
990 op_code,
991 count: len - len_read,
992 is_fin_bit_set,
993 mask_key,
994 });
995 Ok((frame, continuation_read))
996 } else {
997 Ok((frame, None))
998 }
999}
1000
1001fn get_op_code(val: u8) -> Result<WebSocketOpCode> {
1002 match val {
1003 0 => Ok(WebSocketOpCode::ContinuationFrame),
1004 1 => Ok(WebSocketOpCode::TextFrame),
1005 2 => Ok(WebSocketOpCode::BinaryFrame),
1006 8 => Ok(WebSocketOpCode::ConnectionClose),
1007 9 => Ok(WebSocketOpCode::Ping),
1008 10 => Ok(WebSocketOpCode::Pong),
1009 _ => Err(Error::InvalidOpCode),
1010 }
1011}
1012
1013fn read_length(byte2: u8, from_buffer: &[u8]) -> Result<(usize, usize)> {
1015 let len = byte2 & 0x7F;
1016
1017 if len < 126 {
1018 return Ok((len as usize, 0));
1020 } else if len == 126 {
1021 if from_buffer.len() < 2 {
1023 return Err(Error::ReadFrameIncomplete);
1024 }
1025 let mut buf: [u8; 2] = [0; 2];
1026 buf.copy_from_slice(&from_buffer[..2]);
1027 return Ok((BigEndian::read_u16(&buf) as usize, 2));
1028 } else if len == 127 {
1029 if from_buffer.len() < 8 {
1031 return Err(Error::ReadFrameIncomplete);
1032 }
1033 let mut buf: [u8; 8] = [0; 8];
1034 buf.copy_from_slice(&from_buffer[..8]);
1035 return Ok((BigEndian::read_u64(&buf) as usize, 8));
1036 }
1037
1038 Err(Error::InvalidFrameLength)
1039}
1040
1041fn decode_close_frame(buffer: &mut [u8], num_bytes_read: usize, len: usize) -> WebSocketFrame {
1042 if len >= 2 {
1043 let code = BigEndian::read_u16(buffer);
1045 let close_status_code = WebSocketCloseStatusCode::from_u16(code);
1046
1047 return WebSocketFrame {
1048 num_bytes_from: num_bytes_read,
1049 num_bytes_to: len,
1050 op_code: WebSocketOpCode::ConnectionClose,
1051 close_status: Some(close_status_code),
1052 is_fin_bit_set: true,
1053 };
1054 }
1055
1056 build_client_disconnected_frame(num_bytes_read)
1057}
1058
1059fn build_client_disconnected_frame(num_bytes_from: usize) -> WebSocketFrame {
1060 WebSocketFrame {
1061 num_bytes_from,
1062 num_bytes_to: 0,
1063 op_code: WebSocketOpCode::ConnectionClose,
1064 close_status: Some(WebSocketCloseStatusCode::InternalServerError),
1065 is_fin_bit_set: true,
1066 }
1067}
1068
1069#[cfg(test)]
1074mod tests {
1075 extern crate std;
1076 use super::*;
1077
1078 #[test]
1079 fn opening_handshake() {
1080 let client_request = "GET /chat HTTP/1.1
1081Host: localhost:5000
1082User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:62.0) Gecko/20100101 Firefox/62.0
1083Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
1084Accept-Language: en-US,en;q=0.5
1085Accept-Encoding: gzip, deflate
1086Sec-WebSocket-Version: 13
1087Origin: http://localhost:5000
1088Sec-WebSocket-Extensions: permessage-deflate
1089Sec-WebSocket-Key: Z7OY1UwHOx/nkSz38kfPwg==
1090Sec-WebSocket-Protocol: chat
1091DNT: 1
1092Connection: keep-alive, Upgrade
1093Pragma: no-cache
1094Cache-Control: no-cache
1095Upgrade: websocket
1096
1097";
1098
1099 let mut headers = [httparse::EMPTY_HEADER; 16];
1100 let mut request = httparse::Request::new(&mut headers);
1101 request.parse(client_request.as_bytes()).unwrap();
1102 let headers = headers.iter().map(|f| (f.name, f.value));
1103 let web_socket_context = read_http_header(headers).unwrap().unwrap();
1104 assert_eq!(
1105 "Z7OY1UwHOx/nkSz38kfPwg==",
1106 web_socket_context.sec_websocket_key
1107 );
1108 assert_eq!(
1109 "chat",
1110 web_socket_context
1111 .sec_websocket_protocol_list
1112 .get(0)
1113 .unwrap()
1114 .as_str()
1115 );
1116 let mut web_socket = WebSocketServer::new_server();
1117
1118 let mut ws_buffer: [u8; 3000] = [0; 3000];
1119 let size = web_socket
1120 .server_accept(&web_socket_context.sec_websocket_key, None, &mut ws_buffer)
1121 .unwrap();
1122 let response = std::str::from_utf8(&ws_buffer[..size]).unwrap();
1123 let client_response_expected = "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n";
1124 assert_eq!(client_response_expected, response);
1125 }
1126
1127 #[test]
1128 fn server_write_frame() {
1129 let mut buffer: [u8; 1000] = [0; 1000];
1130 let mut ws_server = WebSocketServer::new_server();
1131 let len = ws_server
1132 .write_frame(
1133 "hello".as_bytes(),
1134 &mut buffer,
1135 WebSocketOpCode::TextFrame,
1136 true,
1137 )
1138 .unwrap();
1139 let expected = [129, 5, 104, 101, 108, 108, 111];
1140 assert_eq!(&expected, &buffer[..len]);
1141 }
1142
1143 #[test]
1144 fn server_accept_should_write_sub_protocol() {
1145 let mut buffer: [u8; 1000] = [0; 1000];
1146 let mut ws_server = WebSocketServer::new_server();
1147 let ws_key = WebSocketKey::from("Z7OY1UwHOx/nkSz38kfPwg==");
1148 let sub_protocol = WebSocketSubProtocol::from("chat");
1149 let size = ws_server
1150 .server_accept(&ws_key, Some(&sub_protocol), &mut buffer)
1151 .unwrap();
1152 let response = std::str::from_utf8(&buffer[..size]).unwrap();
1153 assert_eq!("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Protocol: chat\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n", response);
1154 }
1155
1156 #[test]
1157 fn closing_handshake() {
1158 let mut buffer1: [u8; 500] = [0; 500];
1159 let mut buffer2: [u8; 500] = [0; 500];
1160
1161 let mut rng = rand::thread_rng();
1162
1163 let mut ws_client = WebSocketClient::new_client(&mut rng);
1164 ws_client.state = WebSocketState::Open;
1165
1166 let mut ws_server = WebSocketServer::new_server();
1167 ws_server.state = WebSocketState::Open;
1168
1169 ws_client
1171 .close(WebSocketCloseStatusCode::NormalClosure, None, &mut buffer1)
1172 .unwrap();
1173
1174 let ws_result = ws_server.read(&buffer1, &mut buffer2).unwrap();
1176 assert_eq!(
1177 WebSocketReceiveMessageType::CloseMustReply,
1178 ws_result.message_type
1179 );
1180
1181 ws_server
1183 .write(
1184 WebSocketSendMessageType::CloseReply,
1185 true,
1186 &buffer2[..ws_result.len_to],
1187 &mut buffer1,
1188 )
1189 .unwrap();
1190 assert_eq!(WebSocketState::Closed, ws_server.state);
1191
1192 let ws_result = ws_client.read(&buffer1, &mut buffer2).unwrap();
1194 assert_eq!(WebSocketState::Closed, ws_client.state);
1195
1196 assert_eq!(
1197 WebSocketReceiveMessageType::CloseCompleted,
1198 ws_result.message_type
1199 );
1200 }
1201
1202 #[test]
1203 fn send_message_from_client_to_server() {
1204 let mut buffer1: [u8; 1000] = [0; 1000];
1205 let mut buffer2: [u8; 1000] = [0; 1000];
1206
1207 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1209
1210 ws_client.state = WebSocketState::Open;
1211 let mut ws_server = WebSocketServer::new_server();
1212 ws_server.state = WebSocketState::Open;
1213
1214 let hello = "hello";
1216 let num_bytes = ws_client
1217 .write(
1218 WebSocketSendMessageType::Text,
1219 true,
1220 &hello.as_bytes(),
1221 &mut buffer1,
1222 )
1223 .unwrap();
1224
1225 let ws_result = ws_server.read(&buffer1[..num_bytes], &mut buffer2).unwrap();
1227 assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1228 let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1229 assert_eq!(hello, received);
1230 }
1231
1232 #[test]
1233 fn send_message_from_server_to_client() {
1234 let mut buffer1: [u8; 1000] = [0; 1000];
1235 let mut buffer2: [u8; 1000] = [0; 1000];
1236
1237 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1238 ws_client.state = WebSocketState::Open;
1239 let mut ws_server = WebSocketServer::new_server();
1240 ws_server.state = WebSocketState::Open;
1241
1242 let hello = "hello";
1244 let num_bytes = ws_server
1245 .write(
1246 WebSocketSendMessageType::Text,
1247 true,
1248 &hello.as_bytes(),
1249 &mut buffer1,
1250 )
1251 .unwrap();
1252
1253 let ws_result = ws_client.read(&buffer1[..num_bytes], &mut buffer2).unwrap();
1255 assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1256 let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1257 assert_eq!(hello, received);
1258 }
1259
1260 #[test]
1261 fn receive_buffer_too_small() {
1262 let mut buffer1: [u8; 1000] = [0; 1000];
1263 let mut buffer2: [u8; 1000] = [0; 1000];
1264
1265 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1266 ws_client.state = WebSocketState::Open;
1267 let mut ws_server = WebSocketServer::new_server();
1268 ws_server.state = WebSocketState::Open;
1269
1270 let hello = "hello";
1271 ws_server
1272 .write(
1273 WebSocketSendMessageType::Text,
1274 true,
1275 &hello.as_bytes(),
1276 &mut buffer1,
1277 )
1278 .unwrap();
1279
1280 match ws_client.read(&buffer1[..1], &mut buffer2) {
1281 Err(Error::ReadFrameIncomplete) => {
1282 }
1284 _ => {
1285 assert_eq!(true, false);
1286 }
1287 }
1288 }
1289
1290 #[test]
1291 fn receive_large_frame_with_small_receive_buffer() {
1292 let mut buffer1: [u8; 1000] = [0; 1000];
1293 let mut buffer2: [u8; 1000] = [0; 1000];
1294
1295 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1296 ws_client.state = WebSocketState::Open;
1297 let mut ws_server = WebSocketServer::new_server();
1298 ws_server.state = WebSocketState::Open;
1299
1300 let hello = "hello";
1301 ws_server
1302 .write(
1303 WebSocketSendMessageType::Text,
1304 true,
1305 &hello.as_bytes(),
1306 &mut buffer1,
1307 )
1308 .unwrap();
1309
1310 let ws_result = ws_client.read(&buffer1[..2], &mut buffer2).unwrap();
1311 assert_eq!(0, ws_result.len_to);
1312 assert_eq!(false, ws_result.end_of_message);
1313 let ws_result = ws_client.read(&buffer1[2..3], &mut buffer2).unwrap();
1314 assert_eq!(1, ws_result.len_to);
1315 assert_eq!(
1316 "h",
1317 std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1318 );
1319 assert_eq!(false, ws_result.end_of_message);
1320 let ws_result = ws_client.read(&buffer1[3..], &mut buffer2).unwrap();
1321 assert_eq!(4, ws_result.len_to);
1322 assert_eq!(
1323 "ello",
1324 std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1325 );
1326 assert_eq!(true, ws_result.end_of_message);
1327 }
1328
1329 #[test]
1330 fn send_large_frame() {
1331 let buffer1 = [0u8; 15944];
1332 let mut buffer2 = [0u8; 64000];
1333 let mut buffer3 = [0u8; 64000];
1334
1335 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1336 ws_client.state = WebSocketState::Open;
1337
1338 ws_client
1339 .write(
1340 WebSocketSendMessageType::Binary,
1341 true,
1342 &buffer1,
1343 &mut buffer2,
1344 )
1345 .unwrap();
1346
1347 let ws_result = ws_client.read(&buffer2, &mut buffer3).unwrap();
1348 assert_eq!(true, ws_result.end_of_message);
1349 assert_eq!(buffer1.len(), ws_result.len_to);
1350 }
1351
1352 #[test]
1353 fn receive_large_frame_multi_read() {
1354 let mut buffer1 = [0_u8; 1000];
1355 let mut buffer2 = [0_u8; 1000];
1356
1357 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1358 ws_client.state = WebSocketState::Open;
1359 let mut ws_server = WebSocketServer::new_server();
1360 ws_server.state = WebSocketState::Open;
1361
1362 let message = "Hello, world. This is a long message that takes multiple reads";
1363 ws_server
1364 .write(
1365 WebSocketSendMessageType::Text,
1366 true,
1367 &message.as_bytes(),
1368 &mut buffer1,
1369 )
1370 .unwrap();
1371
1372 let mut buffer2_cursor = 0;
1373 let ws_result = ws_client.read(&buffer1[..40], &mut buffer2).unwrap();
1374 assert_eq!(false, ws_result.end_of_message);
1375 buffer2_cursor += ws_result.len_to;
1376 let ws_result = ws_client
1377 .read(
1378 &buffer1[ws_result.len_from..],
1379 &mut buffer2[buffer2_cursor..],
1380 )
1381 .unwrap();
1382 assert_eq!(true, ws_result.end_of_message);
1383 buffer2_cursor += ws_result.len_to;
1384
1385 assert_eq!(
1386 message,
1387 std::str::from_utf8(&buffer2[..buffer2_cursor]).unwrap()
1388 );
1389 }
1390
1391 #[test]
1392 fn multiple_messages_in_receive_buffer() {
1393 let mut buffer1 = [0_u8; 1000];
1394 let mut buffer2 = [0_u8; 1000];
1395
1396 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1397 ws_client.state = WebSocketState::Open;
1398 let mut ws_server = WebSocketServer::new_server();
1399 ws_server.state = WebSocketState::Open;
1400
1401 let message1 = "Hello, world.";
1402 let len = ws_client
1403 .write(
1404 WebSocketSendMessageType::Text,
1405 true,
1406 &message1.as_bytes(),
1407 &mut buffer1,
1408 )
1409 .unwrap();
1410 let message2 = "This is another message.";
1411 ws_client
1412 .write(
1413 WebSocketSendMessageType::Text,
1414 true,
1415 &message2.as_bytes(),
1416 &mut buffer1[len..],
1417 )
1418 .unwrap();
1419
1420 let mut buffer1_cursor = 0;
1421 let mut buffer2_cursor = 0;
1422 let ws_result = ws_server
1423 .read(&buffer1[buffer1_cursor..], &mut buffer2)
1424 .unwrap();
1425 assert_eq!(true, ws_result.end_of_message);
1426 buffer1_cursor += ws_result.len_from;
1427 buffer2_cursor += ws_result.len_to;
1428 let ws_result = ws_server
1429 .read(&buffer1[buffer1_cursor..], &mut buffer2[buffer2_cursor..])
1430 .unwrap();
1431 assert_eq!(true, ws_result.end_of_message);
1432 assert_eq!(
1433 message1,
1434 std::str::from_utf8(&buffer2[..buffer2_cursor]).unwrap()
1435 );
1436
1437 assert_eq!(
1438 message2,
1439 std::str::from_utf8(&buffer2[buffer2_cursor..buffer2_cursor + ws_result.len_to])
1440 .unwrap()
1441 );
1442 }
1443
1444 #[test]
1445 fn receive_large_frame_with_small_send_buffer() {
1446 let mut buffer1: [u8; 1000] = [0; 1000];
1447 let mut buffer2: [u8; 1000] = [0; 1000];
1448
1449 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1450 ws_client.state = WebSocketState::Open;
1451 let mut ws_server = WebSocketServer::new_server();
1452 ws_server.state = WebSocketState::Open;
1453
1454 let hello = "hello";
1455 ws_server
1456 .write(
1457 WebSocketSendMessageType::Text,
1458 true,
1459 &hello.as_bytes(),
1460 &mut buffer1,
1461 )
1462 .unwrap();
1463
1464 let ws_result = ws_client.read(&buffer1, &mut buffer2[..1]).unwrap();
1465 assert_eq!(1, ws_result.len_to);
1466 assert_eq!(
1467 "h",
1468 std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1469 );
1470 assert_eq!(false, ws_result.end_of_message);
1471 let ws_result = ws_client
1472 .read(&buffer1[ws_result.len_from..], &mut buffer2[..4])
1473 .unwrap();
1474 assert_eq!(4, ws_result.len_to);
1475 assert_eq!(
1476 "ello",
1477 std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1478 );
1479 assert_eq!(true, ws_result.end_of_message);
1480 }
1481
1482 #[test]
1483 fn send_two_frame_message() {
1484 let mut buffer1: [u8; 1000] = [0; 1000];
1485 let mut buffer2: [u8; 1000] = [0; 1000];
1486 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1489 ws_client.state = WebSocketState::Open;
1490 let mut ws_server = WebSocketServer::new_server();
1491 ws_server.state = WebSocketState::Open;
1492
1493 let hello = "Hello, ";
1495 let num_bytes_hello = ws_server
1496 .write(
1497 WebSocketSendMessageType::Text,
1498 false,
1499 &hello.as_bytes(),
1500 &mut buffer1,
1501 )
1502 .unwrap();
1503
1504 let world = "World!";
1506 let num_bytes_world = ws_server
1507 .write(
1508 WebSocketSendMessageType::Text,
1509 true,
1510 &world.as_bytes(),
1511 &mut buffer1[num_bytes_hello..],
1512 )
1513 .unwrap();
1514
1515 let ws_result1 = ws_client
1517 .read(&buffer1[..num_bytes_hello], &mut buffer2)
1518 .unwrap();
1519 assert_eq!(WebSocketReceiveMessageType::Text, ws_result1.message_type);
1520 assert_eq!(false, ws_result1.end_of_message);
1521 let ws_result2 = ws_client
1522 .read(
1523 &buffer1[num_bytes_hello..num_bytes_hello + num_bytes_world],
1524 &mut buffer2[ws_result1.len_to..],
1525 )
1526 .unwrap();
1527 assert_eq!(WebSocketReceiveMessageType::Text, ws_result2.message_type);
1528 assert_eq!(true, ws_result2.end_of_message);
1529
1530 let received =
1531 std::str::from_utf8(&buffer2[..ws_result1.len_to + ws_result2.len_to]).unwrap();
1532 assert_eq!("Hello, World!", received);
1533 }
1534
1535 #[test]
1536 fn send_multi_frame_message() {
1537 let mut buffer1: [u8; 1000] = [0; 1000];
1538 let mut buffer2: [u8; 1000] = [0; 1000];
1539
1540 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1541 ws_client.state = WebSocketState::Open;
1542 let mut ws_server = WebSocketServer::new_server();
1543 ws_server.state = WebSocketState::Open;
1544
1545 let fragment1 = "fragment1";
1547 let fragment1_num_bytes = ws_server
1548 .write(
1549 WebSocketSendMessageType::Text,
1550 false,
1551 &fragment1.as_bytes(),
1552 &mut buffer1,
1553 )
1554 .unwrap();
1555
1556 let fragment2 = "fragment2";
1558 let fragment2_num_bytes = ws_server
1559 .write(
1560 WebSocketSendMessageType::Text,
1561 false,
1562 &fragment2.as_bytes(),
1563 &mut buffer1[fragment1_num_bytes..],
1564 )
1565 .unwrap();
1566
1567 let fragment3 = "fragment3";
1569 let _fragment3_num_bytes = ws_server
1570 .write(
1571 WebSocketSendMessageType::Text,
1572 true,
1573 &fragment3.as_bytes(),
1574 &mut buffer1[fragment1_num_bytes + fragment2_num_bytes..],
1575 )
1576 .unwrap();
1577
1578 let ws_result = ws_client.read(&buffer1, &mut buffer2).unwrap();
1580 assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1581 let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1582 assert_eq!(fragment1, received);
1583 assert_eq!(ws_result.end_of_message, false);
1584 let mut read_cursor = ws_result.len_from;
1585
1586 let ws_result = ws_client
1588 .read(&buffer1[read_cursor..], &mut buffer2)
1589 .unwrap();
1590 assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1591 let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1592 assert_eq!(fragment2, received);
1593 assert_eq!(ws_result.end_of_message, false);
1594 read_cursor += ws_result.len_from;
1595
1596 let ws_result = ws_client
1598 .read(&buffer1[read_cursor..], &mut buffer2)
1599 .unwrap();
1600 assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1601 let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1602 assert_eq!(fragment3, received);
1603 assert_eq!(ws_result.end_of_message, true);
1604
1605 let (is_fin_bit_set, op_code) = read_first_byte(buffer1[0]);
1607 assert_eq!(is_fin_bit_set, false);
1608 assert_eq!(op_code, WebSocketOpCode::TextFrame);
1609
1610 let (is_fin_bit_set, op_code) = read_first_byte(buffer1[fragment1_num_bytes]);
1612 assert_eq!(is_fin_bit_set, false);
1613 assert_eq!(op_code, WebSocketOpCode::ContinuationFrame);
1614
1615 let (is_fin_bit_set, op_code) =
1617 read_first_byte(buffer1[fragment1_num_bytes + fragment2_num_bytes]);
1618 assert_eq!(is_fin_bit_set, true);
1619 assert_eq!(op_code, WebSocketOpCode::ContinuationFrame);
1620 }
1621
1622 fn read_first_byte(byte: u8) -> (bool, WebSocketOpCode) {
1623 const FIN_BIT_FLAG: u8 = 0x80;
1624 const OP_CODE_FLAG: u8 = 0x0F;
1625 let is_fin_bit_set = (byte & FIN_BIT_FLAG) == FIN_BIT_FLAG;
1626 let op_code = get_op_code(byte & OP_CODE_FLAG).unwrap();
1627 (is_fin_bit_set, op_code)
1628 }
1629}