1#![cfg_attr(not(feature = "std"), no_std)]
14
15use base64::EncodeSliceError;
16use byteorder::{BigEndian, ByteOrder};
17use core::{cmp, result, str};
18use heapless::{String, Vec};
19use rand_core::RngCore;
20use sha1::{Digest, Sha1};
21
22mod http;
23pub mod random;
24pub use self::http::{read_http_header, WebSocketContext};
25pub use self::random::EmptyRng;
26
27pub mod framer;
30pub mod framer_async;
31const MASK_KEY_LEN: usize = 4;
32
33pub type Result<T> = result::Result<T, Error>;
35
36pub type WebSocketKey = String<24>;
38
39pub type WebSocketSubProtocol = String<24>;
41
42#[derive(PartialEq, Eq, Debug, Copy, Clone)]
44pub enum WebSocketSendMessageType {
45 Text = 1,
47 Binary = 2,
49 Ping = 9,
51 Pong = 10,
53 CloseReply = 11,
56}
57
58impl WebSocketSendMessageType {
59 fn to_op_code(self) -> WebSocketOpCode {
60 match self {
61 WebSocketSendMessageType::Text => WebSocketOpCode::TextFrame,
62 WebSocketSendMessageType::Binary => WebSocketOpCode::BinaryFrame,
63 WebSocketSendMessageType::Ping => WebSocketOpCode::Ping,
64 WebSocketSendMessageType::Pong => WebSocketOpCode::Pong,
65 WebSocketSendMessageType::CloseReply => WebSocketOpCode::ConnectionClose,
66 }
67 }
68}
69
70#[derive(PartialEq, Eq, Debug, Copy, Clone)]
72pub enum WebSocketReceiveMessageType {
73 Text = 1,
75 Binary = 2,
77 CloseCompleted = 7,
79 CloseMustReply = 8,
83 Ping = 9,
86 Pong = 10,
88}
89
90#[derive(PartialEq, Eq, Debug, Copy, Clone)]
92pub enum WebSocketCloseStatusCode {
93 NormalClosure,
96 EndpointUnavailable,
99 ProtocolError,
102 InvalidMessageType,
106 Reserved,
108 Empty,
110 InvalidPayloadData,
114 PolicyViolation,
118 MessageTooBig,
121 MandatoryExtension,
125 InternalServerError,
128 TlsHandshake,
130 Custom(u16),
132}
133
134impl WebSocketCloseStatusCode {
135 fn from_u16(value: u16) -> WebSocketCloseStatusCode {
136 match value {
137 1000 => WebSocketCloseStatusCode::NormalClosure,
138 1001 => WebSocketCloseStatusCode::EndpointUnavailable,
139 1002 => WebSocketCloseStatusCode::ProtocolError,
140 1003 => WebSocketCloseStatusCode::InvalidMessageType,
141 1004 => WebSocketCloseStatusCode::Reserved,
142 1005 => WebSocketCloseStatusCode::Empty,
143 1007 => WebSocketCloseStatusCode::InvalidPayloadData,
144 1008 => WebSocketCloseStatusCode::PolicyViolation,
145 1009 => WebSocketCloseStatusCode::MessageTooBig,
146 1010 => WebSocketCloseStatusCode::MandatoryExtension,
147 1011 => WebSocketCloseStatusCode::InternalServerError,
148 1015 => WebSocketCloseStatusCode::TlsHandshake,
149 _ => WebSocketCloseStatusCode::Custom(value),
150 }
151 }
152
153 fn to_u16(self) -> u16 {
154 match self {
155 WebSocketCloseStatusCode::NormalClosure => 1000,
156 WebSocketCloseStatusCode::EndpointUnavailable => 1001,
157 WebSocketCloseStatusCode::ProtocolError => 1002,
158 WebSocketCloseStatusCode::InvalidMessageType => 1003,
159 WebSocketCloseStatusCode::Reserved => 1004,
160 WebSocketCloseStatusCode::Empty => 1005,
161 WebSocketCloseStatusCode::InvalidPayloadData => 1007,
162 WebSocketCloseStatusCode::PolicyViolation => 1008,
163 WebSocketCloseStatusCode::MessageTooBig => 1009,
164 WebSocketCloseStatusCode::MandatoryExtension => 1010,
165 WebSocketCloseStatusCode::InternalServerError => 1011,
166 WebSocketCloseStatusCode::TlsHandshake => 1015,
167 WebSocketCloseStatusCode::Custom(value) => value,
168 }
169 }
170}
171
172#[derive(PartialEq, Eq, Copy, Clone, Debug)]
174pub enum WebSocketState {
175 None = 0,
177 Connecting = 1,
179 Open = 2,
183 CloseSent = 3,
185 CloseReceived = 4,
187 Closed = 5,
189 Aborted = 6,
191}
192
193#[derive(PartialEq, Eq, Debug)]
195pub enum Error {
196 InvalidOpCode,
198 InvalidFrameLength,
199 InvalidCloseStatusCode,
200 WebSocketNotOpen,
201 WebsocketAlreadyOpen,
202 Utf8Error,
203 Unknown,
204 HttpHeader(httparse::Error),
205 HttpHeaderNoPath,
206 HttpHeaderIncomplete,
207 WriteToBufferTooSmall,
208 ReadFrameIncomplete,
209 HttpResponseCodeInvalid(Option<u16>),
210 AcceptStringInvalid,
211 ConvertInfallible,
212 RandCore,
213 UnexpectedContinuationFrame,
214 EncodeSliceError(EncodeSliceError),
215}
216
217impl From<EncodeSliceError> for Error {
218 fn from(err: EncodeSliceError) -> Error {
219 Error::EncodeSliceError(err)
220 }
221}
222
223impl From<httparse::Error> for Error {
224 fn from(err: httparse::Error) -> Error {
225 Error::HttpHeader(err)
226 }
227}
228
229impl From<str::Utf8Error> for Error {
230 fn from(_: str::Utf8Error) -> Error {
231 Error::Utf8Error
232 }
233}
234
235impl From<core::convert::Infallible> for Error {
236 fn from(_: core::convert::Infallible) -> Error {
237 Error::ConvertInfallible
238 }
239}
240
241impl From<()> for Error {
242 fn from(_: ()) -> Error {
243 Error::Unknown
244 }
245}
246
247impl core::fmt::Display for Error {
248 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
249 match self {
250 Error::HttpHeader(error) => write!(f, "bad http header {error}"),
251 Error::HttpResponseCodeInvalid(Some(code)) => write!(f, "bad http response ({code})"),
252 _ => write!(f, "{:?}", self),
253 }
254 }
255}
256
257#[cfg(feature = "std")]
258impl std::error::Error for Error {
259 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
260 if let Self::HttpHeader(error) = self {
261 Some(error)
262 } else {
263 None
264 }
265 }
266}
267
268#[derive(Copy, Clone, Debug, PartialEq, Eq)]
269enum WebSocketOpCode {
270 ContinuationFrame = 0,
271 TextFrame = 1,
272 BinaryFrame = 2,
273 ConnectionClose = 8,
274 Ping = 9,
275 Pong = 10,
276}
277
278impl WebSocketOpCode {
279 fn to_message_type(self) -> Result<WebSocketReceiveMessageType> {
280 match self {
281 WebSocketOpCode::TextFrame => Ok(WebSocketReceiveMessageType::Text),
282 WebSocketOpCode::BinaryFrame => Ok(WebSocketReceiveMessageType::Binary),
283 _ => Err(Error::InvalidOpCode),
284 }
285 }
286}
287
288#[derive(Debug)]
290pub struct WebSocketReadResult {
291 pub len_from: usize,
293 pub len_to: usize,
295 pub end_of_message: bool,
299 pub close_status: Option<WebSocketCloseStatusCode>,
304 pub message_type: WebSocketReceiveMessageType,
306}
307
308pub struct WebSocketOptions<'a> {
311 pub path: &'a str,
315 pub host: &'a str,
318 pub origin: &'a str,
323 pub sub_protocols: Option<&'a [&'a str]>,
327 pub additional_headers: Option<&'a [&'a str]>,
330}
331
332pub type WebSocketServer = WebSocket<EmptyRng, Server>;
334
335pub type WebSocketClient<T> = WebSocket<T, Client>;
337
338pub enum Server {}
341pub enum Client {}
342
343pub trait WebSocketType {}
344impl WebSocketType for Server {}
345impl WebSocketType for Client {}
346
347pub struct WebSocket<T, S: WebSocketType>
349where
350 T: RngCore,
351{
352 is_client: bool,
353 rng: T,
354 continuation_frame_op_code: Option<WebSocketOpCode>,
355 is_write_continuation: bool,
356 pub state: WebSocketState,
357 continuation_read: Option<ContinuationRead>,
358 marker: core::marker::PhantomData<S>,
359}
360
361impl<T, Type> WebSocket<T, Type>
362where
363 T: RngCore,
364 Type: WebSocketType,
365{
366 pub fn new_client(rng: T) -> WebSocketClient<T> {
377 WebSocket {
378 is_client: true,
379 rng,
380 continuation_frame_op_code: None,
381 is_write_continuation: false,
382 state: WebSocketState::None,
383 continuation_read: None,
384 marker: core::marker::PhantomData::<Client>,
385 }
386 }
387
388 pub fn new_server() -> WebSocketServer {
400 let rng = EmptyRng::new();
401 WebSocket {
402 is_client: false,
403 rng,
404 continuation_frame_op_code: None,
405 is_write_continuation: false,
406 state: WebSocketState::None,
407 continuation_read: None,
408 marker: core::marker::PhantomData::<Server>,
409 }
410 }
411}
412
413impl<T> WebSocket<T, Server>
414where
415 T: RngCore,
416{
417 pub fn server_accept(
450 &mut self,
451 sec_websocket_key: &WebSocketKey,
452 sec_websocket_protocol: Option<&WebSocketSubProtocol>,
453 to: &mut [u8],
454 ) -> Result<usize> {
455 if self.state == WebSocketState::Open {
456 return Err(Error::WebsocketAlreadyOpen);
457 }
458
459 match http::build_connect_handshake_response(sec_websocket_key, sec_websocket_protocol, to)
460 {
461 Ok(http_response_len) => {
462 self.state = WebSocketState::Open;
463 Ok(http_response_len)
464 }
465 Err(e) => {
466 self.state = WebSocketState::Aborted;
467 Err(e)
468 }
469 }
470 }
471}
472
473impl<T> WebSocket<T, Client>
474where
475 T: RngCore,
476{
477 pub fn client_connect(
511 &mut self,
512 websocket_options: &WebSocketOptions,
513 to: &mut [u8],
514 ) -> Result<(usize, WebSocketKey)> {
515 if self.state == WebSocketState::Open {
516 return Err(Error::WebsocketAlreadyOpen);
517 }
518
519 match http::build_connect_handshake_request(websocket_options, &mut self.rng, to) {
520 Ok((request_len, sec_websocket_key)) => {
521 self.state = WebSocketState::Connecting;
522 Ok((request_len, sec_websocket_key))
523 }
524 Err(e) => Err(e),
525 }
526 }
527
528 pub fn client_accept(
551 &mut self,
552 sec_websocket_key: &WebSocketKey,
553 from: &[u8],
554 ) -> Result<(usize, Option<WebSocketSubProtocol>)> {
555 if self.state == WebSocketState::Open {
556 return Err(Error::WebsocketAlreadyOpen);
557 }
558
559 match http::read_server_connect_handshake_response(sec_websocket_key, from) {
560 Ok((len, sec_websocket_protocol)) => {
561 self.state = WebSocketState::Open;
562 Ok((len, sec_websocket_protocol))
563 }
564 Err(Error::HttpHeaderIncomplete) => Err(Error::HttpHeaderIncomplete),
565 Err(e) => {
566 self.state = WebSocketState::Aborted;
567 Err(e)
568 }
569 }
570 }
571}
572
573impl<T, Type> WebSocket<T, Type>
574where
575 T: RngCore,
576 Type: WebSocketType,
577{
578 pub fn read(&mut self, from: &[u8], to: &mut [u8]) -> Result<WebSocketReadResult> {
618 if self.state == WebSocketState::Open || self.state == WebSocketState::CloseSent {
619 let frame = self.read_frame(from, to)?;
620
621 match frame.op_code {
622 WebSocketOpCode::Ping => Ok(frame.to_readresult(WebSocketReceiveMessageType::Ping)),
623 WebSocketOpCode::Pong => Ok(frame.to_readresult(WebSocketReceiveMessageType::Pong)),
624 WebSocketOpCode::TextFrame => {
625 Ok(frame.to_readresult(WebSocketReceiveMessageType::Text))
626 }
627 WebSocketOpCode::BinaryFrame => {
628 Ok(frame.to_readresult(WebSocketReceiveMessageType::Binary))
629 }
630 WebSocketOpCode::ConnectionClose => match self.state {
631 WebSocketState::CloseSent => {
632 self.state = WebSocketState::Closed;
633 Ok(frame.to_readresult(WebSocketReceiveMessageType::CloseCompleted))
634 }
635 _ => {
636 self.state = WebSocketState::CloseReceived;
637 Ok(frame.to_readresult(WebSocketReceiveMessageType::CloseMustReply))
638 }
639 },
640 WebSocketOpCode::ContinuationFrame => match self.continuation_frame_op_code {
641 Some(cf_op_code) => Ok(frame.to_readresult(cf_op_code.to_message_type()?)),
642 None => Err(Error::UnexpectedContinuationFrame),
643 },
644 }
645 } else {
646 Err(Error::WebSocketNotOpen)
647 }
648 }
649
650 pub fn write(
677 &mut self,
678 message_type: WebSocketSendMessageType,
679 end_of_message: bool,
680 from: &[u8],
681 to: &mut [u8],
682 ) -> Result<usize> {
683 if self.state == WebSocketState::Open || self.state == WebSocketState::CloseReceived {
684 let mut op_code = message_type.to_op_code();
685 if op_code == WebSocketOpCode::ConnectionClose {
686 self.state = WebSocketState::Closed
687 } else if self.is_write_continuation {
688 op_code = WebSocketOpCode::ContinuationFrame;
689 }
690
691 self.is_write_continuation = !end_of_message;
692 self.write_frame(from, to, op_code, end_of_message)
693 } else {
694 Err(Error::WebSocketNotOpen)
695 }
696 }
697
698 pub fn close(
706 &mut self,
707 close_status: WebSocketCloseStatusCode,
708 status_description: Option<&str>,
709 to: &mut [u8],
710 ) -> Result<usize> {
711 if self.state == WebSocketState::Open {
712 self.state = WebSocketState::CloseSent;
713 if let Some(status_description) = status_description {
714 let mut from_buffer: Vec<u8, 256> = Vec::new();
715 from_buffer.extend_from_slice(&close_status.to_u16().to_be_bytes())?;
716
717 let len = if status_description.len() < 254 {
719 status_description.len()
720 } else {
721 254
722 };
723
724 from_buffer.extend_from_slice(status_description[..len].as_bytes())?;
725 self.write_frame(&from_buffer, to, WebSocketOpCode::ConnectionClose, true)
726 } else {
727 let mut from_buffer: [u8; 2] = [0; 2];
728 BigEndian::write_u16(&mut from_buffer, close_status.to_u16());
729 self.write_frame(&from_buffer, to, WebSocketOpCode::ConnectionClose, true)
730 }
731 } else {
732 Err(Error::WebSocketNotOpen)
733 }
734 }
735
736 fn read_frame(&mut self, from_buffer: &[u8], to_buffer: &mut [u8]) -> Result<WebSocketFrame> {
737 match &mut self.continuation_read {
738 Some(continuation_read) => {
739 let result = read_continuation(continuation_read, from_buffer, to_buffer);
740 if result.is_fin_bit_set {
741 self.continuation_read = None;
742 }
743 Ok(result)
744 }
745 None => {
746 let (mut result, continuation_read) = read_frame(from_buffer, to_buffer)?;
747
748 if let Some(continuation_frame_op_code) = self.continuation_frame_op_code {
750 result.op_code = continuation_frame_op_code;
751 }
752
753 self.continuation_frame_op_code = if result.is_fin_bit_set {
755 None
756 } else {
757 Some(result.op_code)
758 };
759
760 self.continuation_read = continuation_read;
761 Ok(result)
762 }
763 }
764 }
765
766 fn write_frame(
767 &mut self,
768 from_buffer: &[u8],
769 to_buffer: &mut [u8],
770 op_code: WebSocketOpCode,
771 end_of_message: bool,
772 ) -> Result<usize> {
773 let fin_bit_set_as_byte: u8 = if end_of_message { 0x80 } else { 0x00 };
774 let byte1: u8 = fin_bit_set_as_byte | op_code as u8;
775 let count = from_buffer.len();
776 const BYTE_HEADER_SIZE: usize = 2;
777 const SHORT_HEADER_SIZE: usize = 4;
778 const LONG_HEADER_SIZE: usize = 10;
779 const MASK_KEY_SIZE: usize = 4;
780 let header_size;
781 let mask_bit_set_as_byte = if self.is_client { 0x80 } else { 0x00 };
782 let payload_len = from_buffer.len() + if self.is_client { MASK_KEY_SIZE } else { 0 };
783
784 if count < 126 {
787 if payload_len + BYTE_HEADER_SIZE > to_buffer.len() {
788 return Err(Error::WriteToBufferTooSmall);
789 }
790 to_buffer[0] = byte1;
791 to_buffer[1] = mask_bit_set_as_byte | count as u8;
792 header_size = BYTE_HEADER_SIZE;
793 } else if count < 65535 {
794 if payload_len + SHORT_HEADER_SIZE > to_buffer.len() {
795 return Err(Error::WriteToBufferTooSmall);
796 }
797 to_buffer[0] = byte1;
798 to_buffer[1] = mask_bit_set_as_byte | 126;
799 BigEndian::write_u16(&mut to_buffer[2..], count as u16);
800 header_size = SHORT_HEADER_SIZE;
801 } else {
802 if payload_len + LONG_HEADER_SIZE > to_buffer.len() {
803 return Err(Error::WriteToBufferTooSmall);
804 }
805 to_buffer[0] = byte1;
806 to_buffer[1] = mask_bit_set_as_byte | 127;
807 BigEndian::write_u64(&mut to_buffer[2..], count as u64);
808 header_size = LONG_HEADER_SIZE;
809 }
810
811 if self.is_client {
814 let mut mask_key = [0; MASK_KEY_SIZE];
815 self.rng.fill_bytes(&mut mask_key); to_buffer[header_size..header_size + MASK_KEY_SIZE].copy_from_slice(&mask_key);
817 let to_buffer_start = header_size + MASK_KEY_SIZE;
818
819 for (i, (from, to)) in from_buffer[..count]
821 .iter()
822 .zip(&mut to_buffer[to_buffer_start..to_buffer_start + count])
823 .enumerate()
824 {
825 *to = *from ^ mask_key[i % MASK_KEY_SIZE];
826 }
827
828 Ok(to_buffer_start + count)
829 } else {
830 to_buffer[header_size..header_size + count].copy_from_slice(&from_buffer[..count]);
831 Ok(header_size + count)
832 }
833 }
834}
835
836struct ContinuationRead {
838 op_code: WebSocketOpCode,
839 count: usize,
840 is_fin_bit_set: bool,
841 mask_key: Option<[u8; 4]>,
842}
843
844struct WebSocketFrame {
845 is_fin_bit_set: bool,
846 op_code: WebSocketOpCode,
847 num_bytes_to: usize,
848 num_bytes_from: usize,
849 close_status: Option<WebSocketCloseStatusCode>,
850}
851
852impl WebSocketFrame {
853 fn to_readresult(&self, message_type: WebSocketReceiveMessageType) -> WebSocketReadResult {
854 WebSocketReadResult {
855 len_from: self.num_bytes_from,
856 len_to: self.num_bytes_to,
857 end_of_message: self.is_fin_bit_set,
858 close_status: self.close_status,
859 message_type,
860 }
861 }
862}
863
864fn min(num1: usize, num2: usize, num3: usize) -> usize {
865 cmp::min(cmp::min(num1, num2), num3)
866}
867
868fn read_into_buffer(
869 mask_key: &mut Option<[u8; 4]>,
870 from_buffer: &[u8],
871 to_buffer: &mut [u8],
872 len: usize,
873) -> usize {
874 let len_to_read = min(len, to_buffer.len(), from_buffer.len());
876
877 match mask_key {
878 Some(mask_key) => {
879 for (i, (from, to)) in from_buffer[..len_to_read].iter().zip(to_buffer).enumerate() {
881 *to = *from ^ mask_key[i % MASK_KEY_LEN];
882 }
883 mask_key.rotate_left(len_to_read % MASK_KEY_LEN);
884 }
885 None => {
886 to_buffer[..len_to_read].copy_from_slice(&from_buffer[..len_to_read]);
887 }
888 }
889
890 len_to_read
891}
892
893fn read_continuation(
894 continuation_read: &mut ContinuationRead,
895 from_buffer: &[u8],
896 to_buffer: &mut [u8],
897) -> WebSocketFrame {
898 let len_read = read_into_buffer(
899 &mut continuation_read.mask_key,
900 from_buffer,
901 to_buffer,
902 continuation_read.count,
903 );
904
905 let is_complete = len_read == continuation_read.count;
906
907 let frame = match continuation_read.op_code {
908 WebSocketOpCode::ConnectionClose => decode_close_frame(to_buffer, len_read, len_read),
909 _ => WebSocketFrame {
910 num_bytes_from: len_read,
911 num_bytes_to: len_read,
912 op_code: continuation_read.op_code,
913 close_status: None,
914 is_fin_bit_set: if is_complete {
915 continuation_read.is_fin_bit_set
916 } else {
917 false
918 },
919 },
920 };
921
922 continuation_read.count -= len_read;
923 frame
924}
925
926fn read_frame(
927 from_buffer: &[u8],
928 to_buffer: &mut [u8],
929) -> Result<(WebSocketFrame, Option<ContinuationRead>)> {
930 if from_buffer.len() < 2 {
931 return Err(Error::ReadFrameIncomplete);
932 }
933
934 let byte1 = from_buffer[0];
935 let byte2 = from_buffer[1];
936
937 const FIN_BIT_FLAG: u8 = 0x80;
939 const OP_CODE_FLAG: u8 = 0x0F;
940 let is_fin_bit_set = (byte1 & FIN_BIT_FLAG) == FIN_BIT_FLAG;
941 let op_code = get_op_code(byte1 & OP_CODE_FLAG)?;
942
943 const MASK_FLAG: u8 = 0x80;
945 let is_mask_bit_set = (byte2 & MASK_FLAG) == MASK_FLAG;
946 let (len, mut num_bytes_read) = read_length(byte2, &from_buffer[2..])?;
947
948 num_bytes_read += 2;
949 let from_buffer = &from_buffer[num_bytes_read..];
950
951 let mut mask_key = if is_mask_bit_set {
953 if from_buffer.len() < MASK_KEY_LEN {
954 return Err(Error::ReadFrameIncomplete);
955 }
956 let mut mask_key: [u8; MASK_KEY_LEN] = [0; MASK_KEY_LEN];
957 mask_key.copy_from_slice(&from_buffer[..MASK_KEY_LEN]);
958 num_bytes_read += MASK_KEY_LEN;
959 Some(mask_key)
960 } else {
961 None
962 };
963
964 let len_read = if is_mask_bit_set {
965 let from_buffer = &from_buffer[MASK_KEY_LEN..];
967 read_into_buffer(&mut mask_key, from_buffer, to_buffer, len)
968 } else {
969 read_into_buffer(&mut mask_key, from_buffer, to_buffer, len)
970 };
971
972 let has_continuation = len_read < len;
973 num_bytes_read += len_read;
974
975 let frame = match op_code {
976 WebSocketOpCode::ConnectionClose => decode_close_frame(to_buffer, num_bytes_read, len_read),
977 _ => WebSocketFrame {
978 num_bytes_from: num_bytes_read,
979 num_bytes_to: len_read,
980 op_code,
981 close_status: None,
982 is_fin_bit_set: if has_continuation {
983 false
984 } else {
985 is_fin_bit_set
986 },
987 },
988 };
989
990 if has_continuation {
991 let continuation_read = Some(ContinuationRead {
992 op_code,
993 count: len - len_read,
994 is_fin_bit_set,
995 mask_key,
996 });
997 Ok((frame, continuation_read))
998 } else {
999 Ok((frame, None))
1000 }
1001}
1002
1003fn get_op_code(val: u8) -> Result<WebSocketOpCode> {
1004 match val {
1005 0 => Ok(WebSocketOpCode::ContinuationFrame),
1006 1 => Ok(WebSocketOpCode::TextFrame),
1007 2 => Ok(WebSocketOpCode::BinaryFrame),
1008 8 => Ok(WebSocketOpCode::ConnectionClose),
1009 9 => Ok(WebSocketOpCode::Ping),
1010 10 => Ok(WebSocketOpCode::Pong),
1011 _ => Err(Error::InvalidOpCode),
1012 }
1013}
1014
1015fn read_length(byte2: u8, from_buffer: &[u8]) -> Result<(usize, usize)> {
1017 let len = byte2 & 0x7F;
1018
1019 if len < 126 {
1020 return Ok((len as usize, 0));
1022 } else if len == 126 {
1023 if from_buffer.len() < 2 {
1025 return Err(Error::ReadFrameIncomplete);
1026 }
1027 let mut buf: [u8; 2] = [0; 2];
1028 buf.copy_from_slice(&from_buffer[..2]);
1029 return Ok((BigEndian::read_u16(&buf) as usize, 2));
1030 } else if len == 127 {
1031 if from_buffer.len() < 8 {
1033 return Err(Error::ReadFrameIncomplete);
1034 }
1035 let mut buf: [u8; 8] = [0; 8];
1036 buf.copy_from_slice(&from_buffer[..8]);
1037 return Ok((BigEndian::read_u64(&buf) as usize, 8));
1038 }
1039
1040 Err(Error::InvalidFrameLength)
1041}
1042
1043fn decode_close_frame(buffer: &mut [u8], num_bytes_read: usize, len: usize) -> WebSocketFrame {
1044 if len >= 2 {
1045 let code = BigEndian::read_u16(buffer);
1047 let close_status_code = WebSocketCloseStatusCode::from_u16(code);
1048
1049 return WebSocketFrame {
1050 num_bytes_from: num_bytes_read,
1051 num_bytes_to: len,
1052 op_code: WebSocketOpCode::ConnectionClose,
1053 close_status: Some(close_status_code),
1054 is_fin_bit_set: true,
1055 };
1056 }
1057
1058 build_client_disconnected_frame(num_bytes_read)
1059}
1060
1061fn build_client_disconnected_frame(num_bytes_from: usize) -> WebSocketFrame {
1062 WebSocketFrame {
1063 num_bytes_from,
1064 num_bytes_to: 0,
1065 op_code: WebSocketOpCode::ConnectionClose,
1066 close_status: Some(WebSocketCloseStatusCode::InternalServerError),
1067 is_fin_bit_set: true,
1068 }
1069}
1070
1071#[cfg(test)]
1076mod tests {
1077 extern crate std;
1078 use core::str::FromStr;
1079
1080 use super::*;
1081
1082 #[test]
1083 fn opening_handshake() {
1084 let client_request = "GET /chat HTTP/1.1
1085Host: localhost:5000
1086User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:62.0) Gecko/20100101 Firefox/62.0
1087Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
1088Accept-Language: en-US,en;q=0.5
1089Accept-Encoding: gzip, deflate
1090Sec-WebSocket-Version: 13
1091Origin: http://localhost:5000
1092Sec-WebSocket-Extensions: permessage-deflate
1093Sec-WebSocket-Key: Z7OY1UwHOx/nkSz38kfPwg==
1094Sec-WebSocket-Protocol: chat
1095DNT: 1
1096Connection: keep-alive, Upgrade
1097Pragma: no-cache
1098Cache-Control: no-cache
1099Upgrade: websocket
1100
1101";
1102
1103 let mut headers = [httparse::EMPTY_HEADER; 16];
1104 let mut request = httparse::Request::new(&mut headers);
1105 request.parse(client_request.as_bytes()).unwrap();
1106 let headers = headers.iter().map(|f| (f.name, f.value));
1107 let web_socket_context = read_http_header(headers).unwrap().unwrap();
1108 assert_eq!(
1109 "Z7OY1UwHOx/nkSz38kfPwg==",
1110 web_socket_context.sec_websocket_key
1111 );
1112 assert_eq!(
1113 "chat",
1114 web_socket_context
1115 .sec_websocket_protocol_list
1116 .get(0)
1117 .unwrap()
1118 .as_str()
1119 );
1120 let mut web_socket = WebSocketServer::new_server();
1121
1122 let mut ws_buffer: [u8; 3000] = [0; 3000];
1123 let size = web_socket
1124 .server_accept(&web_socket_context.sec_websocket_key, None, &mut ws_buffer)
1125 .unwrap();
1126 let response = std::str::from_utf8(&ws_buffer[..size]).unwrap();
1127 let client_response_expected = "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n";
1128 assert_eq!(client_response_expected, response);
1129 }
1130
1131 #[test]
1132 fn server_write_frame() {
1133 let mut buffer: [u8; 1000] = [0; 1000];
1134 let mut ws_server = WebSocketServer::new_server();
1135 let len = ws_server
1136 .write_frame(
1137 "hello".as_bytes(),
1138 &mut buffer,
1139 WebSocketOpCode::TextFrame,
1140 true,
1141 )
1142 .unwrap();
1143 let expected = [129, 5, 104, 101, 108, 108, 111];
1144 assert_eq!(&expected, &buffer[..len]);
1145 }
1146
1147 #[test]
1148 fn server_accept_should_write_sub_protocol() {
1149 let mut buffer: [u8; 1000] = [0; 1000];
1150 let mut ws_server = WebSocketServer::new_server();
1151 let ws_key: WebSocketKey = String::from_str("Z7OY1UwHOx/nkSz38kfPwg==").unwrap();
1152 let sub_protocol: WebSocketSubProtocol = String::from_str("chat").unwrap();
1153 let size = ws_server
1154 .server_accept(&ws_key, Some(&sub_protocol), &mut buffer)
1155 .unwrap();
1156 let response = std::str::from_utf8(&buffer[..size]).unwrap();
1157 assert_eq!("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Protocol: chat\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n", response);
1158 }
1159
1160 #[test]
1161 fn closing_handshake() {
1162 let mut buffer1: [u8; 500] = [0; 500];
1163 let mut buffer2: [u8; 500] = [0; 500];
1164
1165 let mut rng = rand::thread_rng();
1166
1167 let mut ws_client = WebSocketClient::new_client(&mut rng);
1168 ws_client.state = WebSocketState::Open;
1169
1170 let mut ws_server = WebSocketServer::new_server();
1171 ws_server.state = WebSocketState::Open;
1172
1173 ws_client
1175 .close(WebSocketCloseStatusCode::NormalClosure, None, &mut buffer1)
1176 .unwrap();
1177
1178 let ws_result = ws_server.read(&buffer1, &mut buffer2).unwrap();
1180 assert_eq!(
1181 WebSocketReceiveMessageType::CloseMustReply,
1182 ws_result.message_type
1183 );
1184
1185 ws_server
1187 .write(
1188 WebSocketSendMessageType::CloseReply,
1189 true,
1190 &buffer2[..ws_result.len_to],
1191 &mut buffer1,
1192 )
1193 .unwrap();
1194 assert_eq!(WebSocketState::Closed, ws_server.state);
1195
1196 let ws_result = ws_client.read(&buffer1, &mut buffer2).unwrap();
1198 assert_eq!(WebSocketState::Closed, ws_client.state);
1199
1200 assert_eq!(
1201 WebSocketReceiveMessageType::CloseCompleted,
1202 ws_result.message_type
1203 );
1204 }
1205
1206 #[test]
1207 fn send_message_from_client_to_server() {
1208 let mut buffer1: [u8; 1000] = [0; 1000];
1209 let mut buffer2: [u8; 1000] = [0; 1000];
1210
1211 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1213
1214 ws_client.state = WebSocketState::Open;
1215 let mut ws_server = WebSocketServer::new_server();
1216 ws_server.state = WebSocketState::Open;
1217
1218 let hello = "hello";
1220 let num_bytes = ws_client
1221 .write(
1222 WebSocketSendMessageType::Text,
1223 true,
1224 &hello.as_bytes(),
1225 &mut buffer1,
1226 )
1227 .unwrap();
1228
1229 let ws_result = ws_server.read(&buffer1[..num_bytes], &mut buffer2).unwrap();
1231 assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1232 let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1233 assert_eq!(hello, received);
1234 }
1235
1236 #[test]
1237 fn send_message_from_server_to_client() {
1238 let mut buffer1: [u8; 1000] = [0; 1000];
1239 let mut buffer2: [u8; 1000] = [0; 1000];
1240
1241 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1242 ws_client.state = WebSocketState::Open;
1243 let mut ws_server = WebSocketServer::new_server();
1244 ws_server.state = WebSocketState::Open;
1245
1246 let hello = "hello";
1248 let num_bytes = ws_server
1249 .write(
1250 WebSocketSendMessageType::Text,
1251 true,
1252 &hello.as_bytes(),
1253 &mut buffer1,
1254 )
1255 .unwrap();
1256
1257 let ws_result = ws_client.read(&buffer1[..num_bytes], &mut buffer2).unwrap();
1259 assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1260 let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1261 assert_eq!(hello, received);
1262 }
1263
1264 #[test]
1265 fn receive_buffer_too_small() {
1266 let mut buffer1: [u8; 1000] = [0; 1000];
1267 let mut buffer2: [u8; 1000] = [0; 1000];
1268
1269 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1270 ws_client.state = WebSocketState::Open;
1271 let mut ws_server = WebSocketServer::new_server();
1272 ws_server.state = WebSocketState::Open;
1273
1274 let hello = "hello";
1275 ws_server
1276 .write(
1277 WebSocketSendMessageType::Text,
1278 true,
1279 &hello.as_bytes(),
1280 &mut buffer1,
1281 )
1282 .unwrap();
1283
1284 match ws_client.read(&buffer1[..1], &mut buffer2) {
1285 Err(Error::ReadFrameIncomplete) => {
1286 }
1288 _ => {
1289 assert_eq!(true, false);
1290 }
1291 }
1292 }
1293
1294 #[test]
1295 fn receive_large_frame_with_small_receive_buffer() {
1296 let mut buffer1: [u8; 1000] = [0; 1000];
1297 let mut buffer2: [u8; 1000] = [0; 1000];
1298
1299 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1300 ws_client.state = WebSocketState::Open;
1301 let mut ws_server = WebSocketServer::new_server();
1302 ws_server.state = WebSocketState::Open;
1303
1304 let hello = "hello";
1305 ws_server
1306 .write(
1307 WebSocketSendMessageType::Text,
1308 true,
1309 &hello.as_bytes(),
1310 &mut buffer1,
1311 )
1312 .unwrap();
1313
1314 let ws_result = ws_client.read(&buffer1[..2], &mut buffer2).unwrap();
1315 assert_eq!(0, ws_result.len_to);
1316 assert_eq!(false, ws_result.end_of_message);
1317 let ws_result = ws_client.read(&buffer1[2..3], &mut buffer2).unwrap();
1318 assert_eq!(1, ws_result.len_to);
1319 assert_eq!(
1320 "h",
1321 std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1322 );
1323 assert_eq!(false, ws_result.end_of_message);
1324 let ws_result = ws_client.read(&buffer1[3..], &mut buffer2).unwrap();
1325 assert_eq!(4, ws_result.len_to);
1326 assert_eq!(
1327 "ello",
1328 std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1329 );
1330 assert_eq!(true, ws_result.end_of_message);
1331 }
1332
1333 #[test]
1334 fn send_large_frame() {
1335 let buffer1 = [0u8; 15944];
1336 let mut buffer2 = [0u8; 64000];
1337 let mut buffer3 = [0u8; 64000];
1338
1339 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1340 ws_client.state = WebSocketState::Open;
1341 let mut ws_server = WebSocketServer::new_server();
1342 ws_server.state = WebSocketState::Open;
1343
1344 ws_client
1345 .write(
1346 WebSocketSendMessageType::Binary,
1347 true,
1348 &buffer1,
1349 &mut buffer2,
1350 )
1351 .unwrap();
1352
1353 let ws_result = ws_client.read(&buffer2, &mut buffer3).unwrap();
1354 assert_eq!(true, ws_result.end_of_message);
1355 assert_eq!(buffer1.len(), ws_result.len_to);
1356 }
1357
1358 #[test]
1359 fn receive_large_frame_multi_read() {
1360 let mut buffer1 = [0_u8; 1000];
1361 let mut buffer2 = [0_u8; 1000];
1362
1363 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1364 ws_client.state = WebSocketState::Open;
1365 let mut ws_server = WebSocketServer::new_server();
1366 ws_server.state = WebSocketState::Open;
1367
1368 let message = "Hello, world. This is a long message that takes multiple reads";
1369 ws_server
1370 .write(
1371 WebSocketSendMessageType::Text,
1372 true,
1373 &message.as_bytes(),
1374 &mut buffer1,
1375 )
1376 .unwrap();
1377
1378 let mut buffer2_cursor = 0;
1379 let ws_result = ws_client.read(&buffer1[..40], &mut buffer2).unwrap();
1380 assert_eq!(false, ws_result.end_of_message);
1381 buffer2_cursor += ws_result.len_to;
1382 let ws_result = ws_client
1383 .read(
1384 &buffer1[ws_result.len_from..],
1385 &mut buffer2[buffer2_cursor..],
1386 )
1387 .unwrap();
1388 assert_eq!(true, ws_result.end_of_message);
1389 buffer2_cursor += ws_result.len_to;
1390
1391 assert_eq!(
1392 message,
1393 std::str::from_utf8(&buffer2[..buffer2_cursor]).unwrap()
1394 );
1395 }
1396
1397 #[test]
1398 fn multiple_messages_in_receive_buffer() {
1399 let mut buffer1 = [0_u8; 1000];
1400 let mut buffer2 = [0_u8; 1000];
1401
1402 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1403 ws_client.state = WebSocketState::Open;
1404 let mut ws_server = WebSocketServer::new_server();
1405 ws_server.state = WebSocketState::Open;
1406
1407 let message1 = "Hello, world.";
1408 let len = ws_client
1409 .write(
1410 WebSocketSendMessageType::Text,
1411 true,
1412 &message1.as_bytes(),
1413 &mut buffer1,
1414 )
1415 .unwrap();
1416 let message2 = "This is another message.";
1417 ws_client
1418 .write(
1419 WebSocketSendMessageType::Text,
1420 true,
1421 &message2.as_bytes(),
1422 &mut buffer1[len..],
1423 )
1424 .unwrap();
1425
1426 let mut buffer1_cursor = 0;
1427 let mut buffer2_cursor = 0;
1428 let ws_result = ws_server
1429 .read(&buffer1[buffer1_cursor..], &mut buffer2)
1430 .unwrap();
1431 assert_eq!(true, ws_result.end_of_message);
1432 buffer1_cursor += ws_result.len_from;
1433 buffer2_cursor += ws_result.len_to;
1434 let ws_result = ws_server
1435 .read(&buffer1[buffer1_cursor..], &mut buffer2[buffer2_cursor..])
1436 .unwrap();
1437 assert_eq!(true, ws_result.end_of_message);
1438 assert_eq!(
1439 message1,
1440 std::str::from_utf8(&buffer2[..buffer2_cursor]).unwrap()
1441 );
1442
1443 assert_eq!(
1444 message2,
1445 std::str::from_utf8(&buffer2[buffer2_cursor..buffer2_cursor + ws_result.len_to])
1446 .unwrap()
1447 );
1448 }
1449
1450 #[test]
1451 fn receive_large_frame_with_small_send_buffer() {
1452 let mut buffer1: [u8; 1000] = [0; 1000];
1453 let mut buffer2: [u8; 1000] = [0; 1000];
1454
1455 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1456 ws_client.state = WebSocketState::Open;
1457 let mut ws_server = WebSocketServer::new_server();
1458 ws_server.state = WebSocketState::Open;
1459
1460 let hello = "hello";
1461 ws_server
1462 .write(
1463 WebSocketSendMessageType::Text,
1464 true,
1465 &hello.as_bytes(),
1466 &mut buffer1,
1467 )
1468 .unwrap();
1469
1470 let ws_result = ws_client.read(&buffer1, &mut buffer2[..1]).unwrap();
1471 assert_eq!(1, ws_result.len_to);
1472 assert_eq!(
1473 "h",
1474 std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1475 );
1476 assert_eq!(false, ws_result.end_of_message);
1477 let ws_result = ws_client
1478 .read(&buffer1[ws_result.len_from..], &mut buffer2[..4])
1479 .unwrap();
1480 assert_eq!(4, ws_result.len_to);
1481 assert_eq!(
1482 "ello",
1483 std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1484 );
1485 assert_eq!(true, ws_result.end_of_message);
1486 }
1487
1488 #[test]
1489 fn send_two_frame_message() {
1490 let mut buffer1: [u8; 1000] = [0; 1000];
1491 let mut buffer2: [u8; 1000] = [0; 1000];
1492 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1495 ws_client.state = WebSocketState::Open;
1496 let mut ws_server = WebSocketServer::new_server();
1497 ws_server.state = WebSocketState::Open;
1498
1499 let hello = "Hello, ";
1501 let num_bytes_hello = ws_server
1502 .write(
1503 WebSocketSendMessageType::Text,
1504 false,
1505 &hello.as_bytes(),
1506 &mut buffer1,
1507 )
1508 .unwrap();
1509
1510 let world = "World!";
1512 let num_bytes_world = ws_server
1513 .write(
1514 WebSocketSendMessageType::Text,
1515 true,
1516 &world.as_bytes(),
1517 &mut buffer1[num_bytes_hello..],
1518 )
1519 .unwrap();
1520
1521 let ws_result1 = ws_client
1523 .read(&buffer1[..num_bytes_hello], &mut buffer2)
1524 .unwrap();
1525 assert_eq!(WebSocketReceiveMessageType::Text, ws_result1.message_type);
1526 assert_eq!(false, ws_result1.end_of_message);
1527 let ws_result2 = ws_client
1528 .read(
1529 &buffer1[num_bytes_hello..num_bytes_hello + num_bytes_world],
1530 &mut buffer2[ws_result1.len_to..],
1531 )
1532 .unwrap();
1533 assert_eq!(WebSocketReceiveMessageType::Text, ws_result2.message_type);
1534 assert_eq!(true, ws_result2.end_of_message);
1535
1536 let received =
1537 std::str::from_utf8(&buffer2[..ws_result1.len_to + ws_result2.len_to]).unwrap();
1538 assert_eq!("Hello, World!", received);
1539 }
1540
1541 #[test]
1542 fn send_multi_frame_message() {
1543 let mut buffer1: [u8; 1000] = [0; 1000];
1544 let mut buffer2: [u8; 1000] = [0; 1000];
1545
1546 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1547 ws_client.state = WebSocketState::Open;
1548 let mut ws_server = WebSocketServer::new_server();
1549 ws_server.state = WebSocketState::Open;
1550
1551 let fragment1 = "fragment1";
1553 let fragment1_num_bytes = ws_server
1554 .write(
1555 WebSocketSendMessageType::Text,
1556 false,
1557 &fragment1.as_bytes(),
1558 &mut buffer1,
1559 )
1560 .unwrap();
1561
1562 let fragment2 = "fragment2";
1564 let fragment2_num_bytes = ws_server
1565 .write(
1566 WebSocketSendMessageType::Text,
1567 false,
1568 &fragment2.as_bytes(),
1569 &mut buffer1[fragment1_num_bytes..],
1570 )
1571 .unwrap();
1572
1573 let fragment3 = "fragment3";
1575 let _fragment3_num_bytes = ws_server
1576 .write(
1577 WebSocketSendMessageType::Text,
1578 true,
1579 &fragment3.as_bytes(),
1580 &mut buffer1[fragment1_num_bytes + fragment2_num_bytes..],
1581 )
1582 .unwrap();
1583
1584 let ws_result = ws_client.read(&buffer1, &mut buffer2).unwrap();
1586 assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1587 let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1588 assert_eq!(fragment1, received);
1589 assert_eq!(ws_result.end_of_message, false);
1590 let mut read_cursor = ws_result.len_from;
1591
1592 let ws_result = ws_client
1594 .read(&buffer1[read_cursor..], &mut buffer2)
1595 .unwrap();
1596 assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1597 let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1598 assert_eq!(fragment2, received);
1599 assert_eq!(ws_result.end_of_message, false);
1600 read_cursor += ws_result.len_from;
1601
1602 let ws_result = ws_client
1604 .read(&buffer1[read_cursor..], &mut buffer2)
1605 .unwrap();
1606 assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1607 let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1608 assert_eq!(fragment3, received);
1609 assert_eq!(ws_result.end_of_message, true);
1610
1611 let (is_fin_bit_set, op_code) = read_first_byte(buffer1[0]);
1613 assert_eq!(is_fin_bit_set, false);
1614 assert_eq!(op_code, WebSocketOpCode::TextFrame);
1615
1616 let (is_fin_bit_set, op_code) = read_first_byte(buffer1[fragment1_num_bytes]);
1618 assert_eq!(is_fin_bit_set, false);
1619 assert_eq!(op_code, WebSocketOpCode::ContinuationFrame);
1620
1621 let (is_fin_bit_set, op_code) =
1623 read_first_byte(buffer1[fragment1_num_bytes + fragment2_num_bytes]);
1624 assert_eq!(is_fin_bit_set, true);
1625 assert_eq!(op_code, WebSocketOpCode::ContinuationFrame);
1626 }
1627
1628 fn read_first_byte(byte: u8) -> (bool, WebSocketOpCode) {
1629 const FIN_BIT_FLAG: u8 = 0x80;
1630 const OP_CODE_FLAG: u8 = 0x0F;
1631 let is_fin_bit_set = (byte & FIN_BIT_FLAG) == FIN_BIT_FLAG;
1632 let op_code = get_op_code(byte & OP_CODE_FLAG).unwrap();
1633 (is_fin_bit_set, op_code)
1634 }
1635}