1#![cfg_attr(not(feature = "std"), no_std)]
14
15use byteorder::{BigEndian, ByteOrder};
16use core::{cmp, result, str};
17use heapless::{String, Vec};
18use rand_core::RngCore;
19use sha1::{Digest, Sha1};
20
21mod http;
22pub mod random;
23pub use self::http::{read_http_header, WebSocketContext};
24pub use self::random::EmptyRng;
25
26pub mod framer;
29pub mod framer_async;
30const MASK_KEY_LEN: usize = 4;
31
32pub type Result<T> = result::Result<T, Error>;
34
35pub type WebSocketKey = String<24>;
37
38pub type WebSocketSubProtocol = String<24>;
40
41#[derive(PartialEq, Eq, Debug, Copy, Clone)]
43pub enum WebSocketSendMessageType {
44 Text = 1,
46 Binary = 2,
48 Ping = 9,
50 Pong = 10,
52 CloseReply = 11,
55}
56
57impl WebSocketSendMessageType {
58 fn to_op_code(self) -> WebSocketOpCode {
59 match self {
60 WebSocketSendMessageType::Text => WebSocketOpCode::TextFrame,
61 WebSocketSendMessageType::Binary => WebSocketOpCode::BinaryFrame,
62 WebSocketSendMessageType::Ping => WebSocketOpCode::Ping,
63 WebSocketSendMessageType::Pong => WebSocketOpCode::Pong,
64 WebSocketSendMessageType::CloseReply => WebSocketOpCode::ConnectionClose,
65 }
66 }
67}
68
69#[derive(PartialEq, Eq, Debug, Copy, Clone)]
71pub enum WebSocketReceiveMessageType {
72 Text = 1,
74 Binary = 2,
76 CloseCompleted = 7,
78 CloseMustReply = 8,
82 Ping = 9,
85 Pong = 10,
87}
88
89#[derive(PartialEq, Eq, Debug, Copy, Clone)]
91pub enum WebSocketCloseStatusCode {
92 NormalClosure,
95 EndpointUnavailable,
98 ProtocolError,
101 InvalidMessageType,
105 Reserved,
107 Empty,
109 InvalidPayloadData,
113 PolicyViolation,
117 MessageTooBig,
120 MandatoryExtension,
124 InternalServerError,
127 TlsHandshake,
129 Custom(u16),
131}
132
133impl WebSocketCloseStatusCode {
134 fn from_u16(value: u16) -> WebSocketCloseStatusCode {
135 match value {
136 1000 => WebSocketCloseStatusCode::NormalClosure,
137 1001 => WebSocketCloseStatusCode::EndpointUnavailable,
138 1002 => WebSocketCloseStatusCode::ProtocolError,
139 1003 => WebSocketCloseStatusCode::InvalidMessageType,
140 1004 => WebSocketCloseStatusCode::Reserved,
141 1005 => WebSocketCloseStatusCode::Empty,
142 1007 => WebSocketCloseStatusCode::InvalidPayloadData,
143 1008 => WebSocketCloseStatusCode::PolicyViolation,
144 1009 => WebSocketCloseStatusCode::MessageTooBig,
145 1010 => WebSocketCloseStatusCode::MandatoryExtension,
146 1011 => WebSocketCloseStatusCode::InternalServerError,
147 1015 => WebSocketCloseStatusCode::TlsHandshake,
148 _ => WebSocketCloseStatusCode::Custom(value),
149 }
150 }
151
152 fn to_u16(self) -> u16 {
153 match self {
154 WebSocketCloseStatusCode::NormalClosure => 1000,
155 WebSocketCloseStatusCode::EndpointUnavailable => 1001,
156 WebSocketCloseStatusCode::ProtocolError => 1002,
157 WebSocketCloseStatusCode::InvalidMessageType => 1003,
158 WebSocketCloseStatusCode::Reserved => 1004,
159 WebSocketCloseStatusCode::Empty => 1005,
160 WebSocketCloseStatusCode::InvalidPayloadData => 1007,
161 WebSocketCloseStatusCode::PolicyViolation => 1008,
162 WebSocketCloseStatusCode::MessageTooBig => 1009,
163 WebSocketCloseStatusCode::MandatoryExtension => 1010,
164 WebSocketCloseStatusCode::InternalServerError => 1011,
165 WebSocketCloseStatusCode::TlsHandshake => 1015,
166 WebSocketCloseStatusCode::Custom(value) => value,
167 }
168 }
169}
170
171#[derive(PartialEq, Eq, Copy, Clone, Debug)]
173pub enum WebSocketState {
174 None = 0,
176 Connecting = 1,
178 Open = 2,
182 CloseSent = 3,
184 CloseReceived = 4,
186 Closed = 5,
188 Aborted = 6,
190}
191
192#[derive(PartialEq, Eq, Debug)]
194pub enum Error {
195 InvalidOpCode,
197 InvalidFrameLength,
198 InvalidCloseStatusCode,
199 WebSocketNotOpen,
200 WebsocketAlreadyOpen,
201 Utf8Error,
202 Unknown,
203 HttpHeader(httparse::Error),
204 HttpHeaderNoPath,
205 HttpHeaderIncomplete,
206 WriteToBufferTooSmall,
207 ReadFrameIncomplete,
208 HttpResponseCodeInvalid(Option<u16>),
209 AcceptStringInvalid,
210 ConvertInfallible,
211 RandCore,
212 UnexpectedContinuationFrame,
213}
214
215impl From<httparse::Error> for Error {
216 fn from(err: httparse::Error) -> Error {
217 Error::HttpHeader(err)
218 }
219}
220
221impl From<str::Utf8Error> for Error {
222 fn from(_: str::Utf8Error) -> Error {
223 Error::Utf8Error
224 }
225}
226
227impl From<core::convert::Infallible> for Error {
228 fn from(_: core::convert::Infallible) -> Error {
229 Error::ConvertInfallible
230 }
231}
232
233impl From<()> for Error {
234 fn from(_: ()) -> Error {
235 Error::Unknown
236 }
237}
238
239impl core::fmt::Display for Error {
240 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
241 match self {
242 Error::HttpHeader(error) => write!(f, "bad http header {error}"),
243 Error::HttpResponseCodeInvalid(Some(code)) => write!(f, "bad http response ({code})"),
244 _ => write!(f, "{:?}", self),
245 }
246 }
247}
248
249#[cfg(feature = "std")]
250impl std::error::Error for Error {
251 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
252 if let Self::HttpHeader(error) = self {
253 Some(error)
254 } else {
255 None
256 }
257 }
258}
259
260#[derive(Copy, Clone, Debug, PartialEq, Eq)]
261enum WebSocketOpCode {
262 ContinuationFrame = 0,
263 TextFrame = 1,
264 BinaryFrame = 2,
265 ConnectionClose = 8,
266 Ping = 9,
267 Pong = 10,
268}
269
270impl WebSocketOpCode {
271 fn to_message_type(self) -> Result<WebSocketReceiveMessageType> {
272 match self {
273 WebSocketOpCode::TextFrame => Ok(WebSocketReceiveMessageType::Text),
274 WebSocketOpCode::BinaryFrame => Ok(WebSocketReceiveMessageType::Binary),
275 _ => Err(Error::InvalidOpCode),
276 }
277 }
278}
279
280#[derive(Debug)]
282pub struct WebSocketReadResult {
283 pub len_from: usize,
285 pub len_to: usize,
287 pub end_of_message: bool,
291 pub close_status: Option<WebSocketCloseStatusCode>,
296 pub message_type: WebSocketReceiveMessageType,
298}
299
300pub struct WebSocketOptions<'a> {
303 pub path: &'a str,
307 pub host: &'a str,
310 pub origin: &'a str,
315 pub sub_protocols: Option<&'a [&'a str]>,
319 pub additional_headers: Option<&'a [&'a str]>,
322}
323
324pub type WebSocketServer = WebSocket<EmptyRng, Server>;
326
327pub type WebSocketClient<T> = WebSocket<T, Client>;
329
330pub enum Server {}
333pub enum Client {}
334
335pub trait WebSocketType {}
336impl WebSocketType for Server {}
337impl WebSocketType for Client {}
338
339pub struct WebSocket<T, S: WebSocketType>
341where
342 T: RngCore,
343{
344 is_client: bool,
345 rng: T,
346 continuation_frame_op_code: Option<WebSocketOpCode>,
347 is_write_continuation: bool,
348 pub state: WebSocketState,
349 continuation_read: Option<ContinuationRead>,
350 marker: core::marker::PhantomData<S>,
351}
352
353impl<T, Type> WebSocket<T, Type>
354where
355 T: RngCore,
356 Type: WebSocketType,
357{
358 pub fn new_client(rng: T) -> WebSocketClient<T> {
369 WebSocket {
370 is_client: true,
371 rng,
372 continuation_frame_op_code: None,
373 is_write_continuation: false,
374 state: WebSocketState::None,
375 continuation_read: None,
376 marker: core::marker::PhantomData::<Client>,
377 }
378 }
379
380 pub fn new_server() -> WebSocketServer {
392 let rng = EmptyRng::new();
393 WebSocket {
394 is_client: false,
395 rng,
396 continuation_frame_op_code: None,
397 is_write_continuation: false,
398 state: WebSocketState::None,
399 continuation_read: None,
400 marker: core::marker::PhantomData::<Server>,
401 }
402 }
403}
404
405impl<T> WebSocket<T, Server>
406where
407 T: RngCore,
408{
409 pub fn server_accept(
440 &mut self,
441 sec_websocket_key: &WebSocketKey,
442 sec_websocket_protocol: Option<&WebSocketSubProtocol>,
443 to: &mut [u8],
444 ) -> Result<usize> {
445 if self.state == WebSocketState::Open {
446 return Err(Error::WebsocketAlreadyOpen);
447 }
448
449 match http::build_connect_handshake_response(sec_websocket_key, sec_websocket_protocol, to)
450 {
451 Ok(http_response_len) => {
452 self.state = WebSocketState::Open;
453 Ok(http_response_len)
454 }
455 Err(e) => {
456 self.state = WebSocketState::Aborted;
457 Err(e)
458 }
459 }
460 }
461}
462
463impl<T> WebSocket<T, Client>
464where
465 T: RngCore,
466{
467 pub fn client_connect(
501 &mut self,
502 websocket_options: &WebSocketOptions,
503 to: &mut [u8],
504 ) -> Result<(usize, WebSocketKey)> {
505 if self.state == WebSocketState::Open {
506 return Err(Error::WebsocketAlreadyOpen);
507 }
508
509 match http::build_connect_handshake_request(websocket_options, &mut self.rng, to) {
510 Ok((request_len, sec_websocket_key)) => {
511 self.state = WebSocketState::Connecting;
512 Ok((request_len, sec_websocket_key))
513 }
514 Err(e) => Err(e),
515 }
516 }
517
518 pub fn client_accept(
539 &mut self,
540 sec_websocket_key: &WebSocketKey,
541 from: &[u8],
542 ) -> Result<(usize, Option<WebSocketSubProtocol>)> {
543 if self.state == WebSocketState::Open {
544 return Err(Error::WebsocketAlreadyOpen);
545 }
546
547 match http::read_server_connect_handshake_response(sec_websocket_key, from) {
548 Ok((len, sec_websocket_protocol)) => {
549 self.state = WebSocketState::Open;
550 Ok((len, sec_websocket_protocol))
551 }
552 Err(Error::HttpHeaderIncomplete) => Err(Error::HttpHeaderIncomplete),
553 Err(e) => {
554 self.state = WebSocketState::Aborted;
555 Err(e)
556 }
557 }
558 }
559}
560
561impl<T, Type> WebSocket<T, Type>
562where
563 T: RngCore,
564 Type: WebSocketType,
565{
566 pub fn read(&mut self, from: &[u8], to: &mut [u8]) -> Result<WebSocketReadResult> {
606 if self.state == WebSocketState::Open || self.state == WebSocketState::CloseSent {
607 let frame = self.read_frame(from, to)?;
608
609 match frame.op_code {
610 WebSocketOpCode::Ping => Ok(frame.to_readresult(WebSocketReceiveMessageType::Ping)),
611 WebSocketOpCode::Pong => Ok(frame.to_readresult(WebSocketReceiveMessageType::Pong)),
612 WebSocketOpCode::TextFrame => {
613 Ok(frame.to_readresult(WebSocketReceiveMessageType::Text))
614 }
615 WebSocketOpCode::BinaryFrame => {
616 Ok(frame.to_readresult(WebSocketReceiveMessageType::Binary))
617 }
618 WebSocketOpCode::ConnectionClose => match self.state {
619 WebSocketState::CloseSent => {
620 self.state = WebSocketState::Closed;
621 Ok(frame.to_readresult(WebSocketReceiveMessageType::CloseCompleted))
622 }
623 _ => {
624 self.state = WebSocketState::CloseReceived;
625 Ok(frame.to_readresult(WebSocketReceiveMessageType::CloseMustReply))
626 }
627 },
628 WebSocketOpCode::ContinuationFrame => match self.continuation_frame_op_code {
629 Some(cf_op_code) => Ok(frame.to_readresult(cf_op_code.to_message_type()?)),
630 None => Err(Error::UnexpectedContinuationFrame),
631 },
632 }
633 } else {
634 Err(Error::WebSocketNotOpen)
635 }
636 }
637
638 pub fn write(
665 &mut self,
666 message_type: WebSocketSendMessageType,
667 end_of_message: bool,
668 from: &[u8],
669 to: &mut [u8],
670 ) -> Result<usize> {
671 if self.state == WebSocketState::Open || self.state == WebSocketState::CloseReceived {
672 let mut op_code = message_type.to_op_code();
673 if op_code == WebSocketOpCode::ConnectionClose {
674 self.state = WebSocketState::Closed
675 } else if self.is_write_continuation {
676 op_code = WebSocketOpCode::ContinuationFrame;
677 }
678
679 self.is_write_continuation = !end_of_message;
680 self.write_frame(from, to, op_code, end_of_message)
681 } else {
682 Err(Error::WebSocketNotOpen)
683 }
684 }
685
686 pub fn close(
694 &mut self,
695 close_status: WebSocketCloseStatusCode,
696 status_description: Option<&str>,
697 to: &mut [u8],
698 ) -> Result<usize> {
699 if self.state == WebSocketState::Open {
700 self.state = WebSocketState::CloseSent;
701 if let Some(status_description) = status_description {
702 let mut from_buffer: Vec<u8, 256> = Vec::new();
703 from_buffer.extend_from_slice(&close_status.to_u16().to_be_bytes())?;
704
705 let len = if status_description.len() < 254 {
707 status_description.len()
708 } else {
709 254
710 };
711
712 from_buffer.extend_from_slice(status_description[..len].as_bytes())?;
713 self.write_frame(&from_buffer, to, WebSocketOpCode::ConnectionClose, true)
714 } else {
715 let mut from_buffer: [u8; 2] = [0; 2];
716 BigEndian::write_u16(&mut from_buffer, close_status.to_u16());
717 self.write_frame(&from_buffer, to, WebSocketOpCode::ConnectionClose, true)
718 }
719 } else {
720 Err(Error::WebSocketNotOpen)
721 }
722 }
723
724 fn read_frame(&mut self, from_buffer: &[u8], to_buffer: &mut [u8]) -> Result<WebSocketFrame> {
725 match &mut self.continuation_read {
726 Some(continuation_read) => {
727 let result = read_continuation(continuation_read, from_buffer, to_buffer);
728 if result.is_fin_bit_set {
729 self.continuation_read = None;
730 self.continuation_frame_op_code = None;
731 }
732 Ok(result)
733 }
734 None => {
735 let (mut result, continuation_read) = read_frame(from_buffer, to_buffer)?;
736
737 if let Some(continuation_frame_op_code) = self.continuation_frame_op_code {
739 result.op_code = continuation_frame_op_code;
740 }
741
742 self.continuation_frame_op_code = if result.is_fin_bit_set {
744 None
745 } else {
746 Some(result.op_code)
747 };
748
749 self.continuation_read = continuation_read;
750 Ok(result)
751 }
752 }
753 }
754
755 fn write_frame(
756 &mut self,
757 from_buffer: &[u8],
758 to_buffer: &mut [u8],
759 op_code: WebSocketOpCode,
760 end_of_message: bool,
761 ) -> Result<usize> {
762 let fin_bit_set_as_byte: u8 = if end_of_message { 0x80 } else { 0x00 };
763 let byte1: u8 = fin_bit_set_as_byte | op_code as u8;
764 let count = from_buffer.len();
765 const BYTE_HEADER_SIZE: usize = 2;
766 const SHORT_HEADER_SIZE: usize = 4;
767 const LONG_HEADER_SIZE: usize = 10;
768 const MASK_KEY_SIZE: usize = 4;
769 let header_size;
770 let mask_bit_set_as_byte = if self.is_client { 0x80 } else { 0x00 };
771 let payload_len = from_buffer.len() + if self.is_client { MASK_KEY_SIZE } else { 0 };
772
773 if count < 126 {
776 if payload_len + BYTE_HEADER_SIZE > to_buffer.len() {
777 return Err(Error::WriteToBufferTooSmall);
778 }
779 to_buffer[0] = byte1;
780 to_buffer[1] = mask_bit_set_as_byte | count as u8;
781 header_size = BYTE_HEADER_SIZE;
782 } else if count < 65535 {
783 if payload_len + SHORT_HEADER_SIZE > to_buffer.len() {
784 return Err(Error::WriteToBufferTooSmall);
785 }
786 to_buffer[0] = byte1;
787 to_buffer[1] = mask_bit_set_as_byte | 126;
788 BigEndian::write_u16(&mut to_buffer[2..], count as u16);
789 header_size = SHORT_HEADER_SIZE;
790 } else {
791 if payload_len + LONG_HEADER_SIZE > to_buffer.len() {
792 return Err(Error::WriteToBufferTooSmall);
793 }
794 to_buffer[0] = byte1;
795 to_buffer[1] = mask_bit_set_as_byte | 127;
796 BigEndian::write_u64(&mut to_buffer[2..], count as u64);
797 header_size = LONG_HEADER_SIZE;
798 }
799
800 if self.is_client {
803 let mut mask_key = [0; MASK_KEY_SIZE];
804 self.rng.fill_bytes(&mut mask_key); to_buffer[header_size..header_size + MASK_KEY_SIZE].copy_from_slice(&mask_key);
806 let to_buffer_start = header_size + MASK_KEY_SIZE;
807
808 for (i, (from, to)) in from_buffer[..count]
810 .iter()
811 .zip(&mut to_buffer[to_buffer_start..to_buffer_start + count])
812 .enumerate()
813 {
814 *to = *from ^ mask_key[i % MASK_KEY_SIZE];
815 }
816
817 Ok(to_buffer_start + count)
818 } else {
819 to_buffer[header_size..header_size + count].copy_from_slice(&from_buffer[..count]);
820 Ok(header_size + count)
821 }
822 }
823}
824
825struct ContinuationRead {
827 op_code: WebSocketOpCode,
828 count: usize,
829 is_fin_bit_set: bool,
830 mask_key: Option<[u8; 4]>,
831}
832
833struct WebSocketFrame {
834 is_fin_bit_set: bool,
835 op_code: WebSocketOpCode,
836 num_bytes_to: usize,
837 num_bytes_from: usize,
838 close_status: Option<WebSocketCloseStatusCode>,
839}
840
841impl WebSocketFrame {
842 fn to_readresult(&self, message_type: WebSocketReceiveMessageType) -> WebSocketReadResult {
843 WebSocketReadResult {
844 len_from: self.num_bytes_from,
845 len_to: self.num_bytes_to,
846 end_of_message: self.is_fin_bit_set,
847 close_status: self.close_status,
848 message_type,
849 }
850 }
851}
852
853fn min(num1: usize, num2: usize, num3: usize) -> usize {
854 cmp::min(cmp::min(num1, num2), num3)
855}
856
857fn read_into_buffer(
858 mask_key: &mut Option<[u8; 4]>,
859 from_buffer: &[u8],
860 to_buffer: &mut [u8],
861 len: usize,
862) -> usize {
863 let len_to_read = min(len, to_buffer.len(), from_buffer.len());
865
866 match mask_key {
867 Some(mask_key) => {
868 for (i, (from, to)) in from_buffer[..len_to_read].iter().zip(to_buffer).enumerate() {
870 *to = *from ^ mask_key[i % MASK_KEY_LEN];
871 }
872 mask_key.rotate_left(len_to_read % MASK_KEY_LEN);
873 }
874 None => {
875 to_buffer[..len_to_read].copy_from_slice(&from_buffer[..len_to_read]);
876 }
877 }
878
879 len_to_read
880}
881
882fn read_continuation(
883 continuation_read: &mut ContinuationRead,
884 from_buffer: &[u8],
885 to_buffer: &mut [u8],
886) -> WebSocketFrame {
887 let len_read = read_into_buffer(
888 &mut continuation_read.mask_key,
889 from_buffer,
890 to_buffer,
891 continuation_read.count,
892 );
893
894 let is_complete = len_read == continuation_read.count;
895
896 let frame = match continuation_read.op_code {
897 WebSocketOpCode::ConnectionClose => decode_close_frame(to_buffer, len_read, len_read),
898 _ => WebSocketFrame {
899 num_bytes_from: len_read,
900 num_bytes_to: len_read,
901 op_code: continuation_read.op_code,
902 close_status: None,
903 is_fin_bit_set: if is_complete {
904 continuation_read.is_fin_bit_set
905 } else {
906 false
907 },
908 },
909 };
910
911 continuation_read.count -= len_read;
912 frame
913}
914
915fn read_frame(
916 from_buffer: &[u8],
917 to_buffer: &mut [u8],
918) -> Result<(WebSocketFrame, Option<ContinuationRead>)> {
919 if from_buffer.len() < 2 {
920 return Err(Error::ReadFrameIncomplete);
921 }
922
923 let byte1 = from_buffer[0];
924 let byte2 = from_buffer[1];
925
926 const FIN_BIT_FLAG: u8 = 0x80;
928 const OP_CODE_FLAG: u8 = 0x0F;
929 let is_fin_bit_set = (byte1 & FIN_BIT_FLAG) == FIN_BIT_FLAG;
930 let op_code = get_op_code(byte1 & OP_CODE_FLAG)?;
931
932 const MASK_FLAG: u8 = 0x80;
934 let is_mask_bit_set = (byte2 & MASK_FLAG) == MASK_FLAG;
935 let (len, mut num_bytes_read) = read_length(byte2, &from_buffer[2..])?;
936
937 num_bytes_read += 2;
938 let from_buffer = &from_buffer[num_bytes_read..];
939
940 let mut mask_key = if is_mask_bit_set {
942 if from_buffer.len() < MASK_KEY_LEN {
943 return Err(Error::ReadFrameIncomplete);
944 }
945 let mut mask_key: [u8; MASK_KEY_LEN] = [0; MASK_KEY_LEN];
946 mask_key.copy_from_slice(&from_buffer[..MASK_KEY_LEN]);
947 num_bytes_read += MASK_KEY_LEN;
948 Some(mask_key)
949 } else {
950 None
951 };
952
953 let len_read = if is_mask_bit_set {
954 let from_buffer = &from_buffer[MASK_KEY_LEN..];
956 read_into_buffer(&mut mask_key, from_buffer, to_buffer, len)
957 } else {
958 read_into_buffer(&mut mask_key, from_buffer, to_buffer, len)
959 };
960
961 let has_continuation = len_read < len;
962 num_bytes_read += len_read;
963
964 let frame = match op_code {
965 WebSocketOpCode::ConnectionClose => decode_close_frame(to_buffer, num_bytes_read, len_read),
966 _ => WebSocketFrame {
967 num_bytes_from: num_bytes_read,
968 num_bytes_to: len_read,
969 op_code,
970 close_status: None,
971 is_fin_bit_set: if has_continuation {
972 false
973 } else {
974 is_fin_bit_set
975 },
976 },
977 };
978
979 if has_continuation {
980 let continuation_read = Some(ContinuationRead {
981 op_code,
982 count: len - len_read,
983 is_fin_bit_set,
984 mask_key,
985 });
986 Ok((frame, continuation_read))
987 } else {
988 Ok((frame, None))
989 }
990}
991
992fn get_op_code(val: u8) -> Result<WebSocketOpCode> {
993 match val {
994 0 => Ok(WebSocketOpCode::ContinuationFrame),
995 1 => Ok(WebSocketOpCode::TextFrame),
996 2 => Ok(WebSocketOpCode::BinaryFrame),
997 8 => Ok(WebSocketOpCode::ConnectionClose),
998 9 => Ok(WebSocketOpCode::Ping),
999 10 => Ok(WebSocketOpCode::Pong),
1000 _ => Err(Error::InvalidOpCode),
1001 }
1002}
1003
1004fn read_length(byte2: u8, from_buffer: &[u8]) -> Result<(usize, usize)> {
1006 let len = byte2 & 0x7F;
1007
1008 if len < 126 {
1009 return Ok((len as usize, 0));
1011 } else if len == 126 {
1012 if from_buffer.len() < 2 {
1014 return Err(Error::ReadFrameIncomplete);
1015 }
1016 let mut buf: [u8; 2] = [0; 2];
1017 buf.copy_from_slice(&from_buffer[..2]);
1018 return Ok((BigEndian::read_u16(&buf) as usize, 2));
1019 } else if len == 127 {
1020 if from_buffer.len() < 8 {
1022 return Err(Error::ReadFrameIncomplete);
1023 }
1024 let mut buf: [u8; 8] = [0; 8];
1025 buf.copy_from_slice(&from_buffer[..8]);
1026 return Ok((BigEndian::read_u64(&buf) as usize, 8));
1027 }
1028
1029 Err(Error::InvalidFrameLength)
1030}
1031
1032fn decode_close_frame(buffer: &mut [u8], num_bytes_read: usize, len: usize) -> WebSocketFrame {
1033 if len >= 2 {
1034 let code = BigEndian::read_u16(buffer);
1036 let close_status_code = WebSocketCloseStatusCode::from_u16(code);
1037
1038 return WebSocketFrame {
1039 num_bytes_from: num_bytes_read,
1040 num_bytes_to: len,
1041 op_code: WebSocketOpCode::ConnectionClose,
1042 close_status: Some(close_status_code),
1043 is_fin_bit_set: true,
1044 };
1045 }
1046
1047 build_client_disconnected_frame(num_bytes_read)
1048}
1049
1050fn build_client_disconnected_frame(num_bytes_from: usize) -> WebSocketFrame {
1051 WebSocketFrame {
1052 num_bytes_from,
1053 num_bytes_to: 0,
1054 op_code: WebSocketOpCode::ConnectionClose,
1055 close_status: Some(WebSocketCloseStatusCode::InternalServerError),
1056 is_fin_bit_set: true,
1057 }
1058}
1059
1060#[cfg(test)]
1065mod tests {
1066 extern crate std;
1067 use super::*;
1068
1069 #[test]
1070 fn opening_handshake() {
1071 let client_request = "GET /chat HTTP/1.1
1072Host: localhost:5000
1073User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:62.0) Gecko/20100101 Firefox/62.0
1074Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
1075Accept-Language: en-US,en;q=0.5
1076Accept-Encoding: gzip, deflate
1077Sec-WebSocket-Version: 13
1078Origin: http://localhost:5000
1079Sec-WebSocket-Extensions: permessage-deflate
1080Sec-WebSocket-Key: Z7OY1UwHOx/nkSz38kfPwg==
1081Sec-WebSocket-Protocol: chat
1082DNT: 1
1083Connection: keep-alive, Upgrade
1084Pragma: no-cache
1085Cache-Control: no-cache
1086Upgrade: websocket
1087
1088";
1089
1090 let mut headers = [httparse::EMPTY_HEADER; 16];
1091 let mut request = httparse::Request::new(&mut headers);
1092 request.parse(client_request.as_bytes()).unwrap();
1093 let headers = headers.iter().map(|f| (f.name, f.value));
1094 let web_socket_context = read_http_header(headers).unwrap().unwrap();
1095 assert_eq!(
1096 "Z7OY1UwHOx/nkSz38kfPwg==",
1097 web_socket_context.sec_websocket_key
1098 );
1099 assert_eq!(
1100 "chat",
1101 web_socket_context
1102 .sec_websocket_protocol_list
1103 .get(0)
1104 .unwrap()
1105 .as_str()
1106 );
1107 let mut web_socket = WebSocketServer::new_server();
1108
1109 let mut ws_buffer: [u8; 3000] = [0; 3000];
1110 let size = web_socket
1111 .server_accept(&web_socket_context.sec_websocket_key, None, &mut ws_buffer)
1112 .unwrap();
1113 let response = std::str::from_utf8(&ws_buffer[..size]).unwrap();
1114 let client_response_expected = "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n";
1115 assert_eq!(client_response_expected, response);
1116 }
1117
1118 #[test]
1119 fn server_write_frame() {
1120 let mut buffer: [u8; 1000] = [0; 1000];
1121 let mut ws_server = WebSocketServer::new_server();
1122 let len = ws_server
1123 .write_frame(
1124 "hello".as_bytes(),
1125 &mut buffer,
1126 WebSocketOpCode::TextFrame,
1127 true,
1128 )
1129 .unwrap();
1130 let expected = [129, 5, 104, 101, 108, 108, 111];
1131 assert_eq!(&expected, &buffer[..len]);
1132 }
1133
1134 #[test]
1135 fn server_accept_should_write_sub_protocol() {
1136 let mut buffer: [u8; 1000] = [0; 1000];
1137 let mut ws_server = WebSocketServer::new_server();
1138 let ws_key = WebSocketKey::from("Z7OY1UwHOx/nkSz38kfPwg==");
1139 let sub_protocol = WebSocketSubProtocol::from("chat");
1140 let size = ws_server
1141 .server_accept(&ws_key, Some(&sub_protocol), &mut buffer)
1142 .unwrap();
1143 let response = std::str::from_utf8(&buffer[..size]).unwrap();
1144 assert_eq!("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Protocol: chat\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n", response);
1145 }
1146
1147 #[test]
1148 fn closing_handshake() {
1149 let mut buffer1: [u8; 500] = [0; 500];
1150 let mut buffer2: [u8; 500] = [0; 500];
1151
1152 let mut rng = rand::thread_rng();
1153
1154 let mut ws_client = WebSocketClient::new_client(&mut rng);
1155 ws_client.state = WebSocketState::Open;
1156
1157 let mut ws_server = WebSocketServer::new_server();
1158 ws_server.state = WebSocketState::Open;
1159
1160 ws_client
1162 .close(WebSocketCloseStatusCode::NormalClosure, None, &mut buffer1)
1163 .unwrap();
1164
1165 let ws_result = ws_server.read(&buffer1, &mut buffer2).unwrap();
1167 assert_eq!(
1168 WebSocketReceiveMessageType::CloseMustReply,
1169 ws_result.message_type
1170 );
1171
1172 ws_server
1174 .write(
1175 WebSocketSendMessageType::CloseReply,
1176 true,
1177 &buffer2[..ws_result.len_to],
1178 &mut buffer1,
1179 )
1180 .unwrap();
1181 assert_eq!(WebSocketState::Closed, ws_server.state);
1182
1183 let ws_result = ws_client.read(&buffer1, &mut buffer2).unwrap();
1185 assert_eq!(WebSocketState::Closed, ws_client.state);
1186
1187 assert_eq!(
1188 WebSocketReceiveMessageType::CloseCompleted,
1189 ws_result.message_type
1190 );
1191 }
1192
1193 #[test]
1194 fn send_message_from_client_to_server() {
1195 let mut buffer1: [u8; 1000] = [0; 1000];
1196 let mut buffer2: [u8; 1000] = [0; 1000];
1197
1198 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1200
1201 ws_client.state = WebSocketState::Open;
1202 let mut ws_server = WebSocketServer::new_server();
1203 ws_server.state = WebSocketState::Open;
1204
1205 let hello = "hello";
1207 let num_bytes = ws_client
1208 .write(
1209 WebSocketSendMessageType::Text,
1210 true,
1211 &hello.as_bytes(),
1212 &mut buffer1,
1213 )
1214 .unwrap();
1215
1216 let ws_result = ws_server.read(&buffer1[..num_bytes], &mut buffer2).unwrap();
1218 assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1219 let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1220 assert_eq!(hello, received);
1221 }
1222
1223 #[test]
1224 fn send_message_from_server_to_client() {
1225 let mut buffer1: [u8; 1000] = [0; 1000];
1226 let mut buffer2: [u8; 1000] = [0; 1000];
1227
1228 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1229 ws_client.state = WebSocketState::Open;
1230 let mut ws_server = WebSocketServer::new_server();
1231 ws_server.state = WebSocketState::Open;
1232
1233 let hello = "hello";
1235 let num_bytes = ws_server
1236 .write(
1237 WebSocketSendMessageType::Text,
1238 true,
1239 &hello.as_bytes(),
1240 &mut buffer1,
1241 )
1242 .unwrap();
1243
1244 let ws_result = ws_client.read(&buffer1[..num_bytes], &mut buffer2).unwrap();
1246 assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1247 let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1248 assert_eq!(hello, received);
1249 }
1250
1251 #[test]
1252 fn receive_buffer_too_small() {
1253 let mut buffer1: [u8; 1000] = [0; 1000];
1254 let mut buffer2: [u8; 1000] = [0; 1000];
1255
1256 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1257 ws_client.state = WebSocketState::Open;
1258 let mut ws_server = WebSocketServer::new_server();
1259 ws_server.state = WebSocketState::Open;
1260
1261 let hello = "hello";
1262 ws_server
1263 .write(
1264 WebSocketSendMessageType::Text,
1265 true,
1266 &hello.as_bytes(),
1267 &mut buffer1,
1268 )
1269 .unwrap();
1270
1271 match ws_client.read(&buffer1[..1], &mut buffer2) {
1272 Err(Error::ReadFrameIncomplete) => {
1273 }
1275 _ => {
1276 assert_eq!(true, false);
1277 }
1278 }
1279 }
1280
1281 #[test]
1282 fn receive_large_frame_with_small_receive_buffer() {
1283 let mut buffer1: [u8; 1000] = [0; 1000];
1284 let mut buffer2: [u8; 1000] = [0; 1000];
1285
1286 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1287 ws_client.state = WebSocketState::Open;
1288 let mut ws_server = WebSocketServer::new_server();
1289 ws_server.state = WebSocketState::Open;
1290
1291 let hello = "hello";
1292 ws_server
1293 .write(
1294 WebSocketSendMessageType::Text,
1295 true,
1296 &hello.as_bytes(),
1297 &mut buffer1,
1298 )
1299 .unwrap();
1300
1301 let ws_result = ws_client.read(&buffer1[..2], &mut buffer2).unwrap();
1302 assert_eq!(0, ws_result.len_to);
1303 assert_eq!(false, ws_result.end_of_message);
1304 let ws_result = ws_client.read(&buffer1[2..3], &mut buffer2).unwrap();
1305 assert_eq!(1, ws_result.len_to);
1306 assert_eq!(
1307 "h",
1308 std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1309 );
1310 assert_eq!(false, ws_result.end_of_message);
1311 let ws_result = ws_client.read(&buffer1[3..], &mut buffer2).unwrap();
1312 assert_eq!(4, ws_result.len_to);
1313 assert_eq!(
1314 "ello",
1315 std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1316 );
1317 assert_eq!(true, ws_result.end_of_message);
1318 }
1319
1320 #[test]
1321 fn send_large_frame() {
1322 let buffer1 = [0u8; 15944];
1323 let mut buffer2 = [0u8; 64000];
1324 let mut buffer3 = [0u8; 64000];
1325
1326 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1327 ws_client.state = WebSocketState::Open;
1328 let mut ws_server = WebSocketServer::new_server();
1329 ws_server.state = WebSocketState::Open;
1330
1331 ws_client
1332 .write(
1333 WebSocketSendMessageType::Binary,
1334 true,
1335 &buffer1,
1336 &mut buffer2,
1337 )
1338 .unwrap();
1339
1340 let ws_result = ws_client.read(&buffer2, &mut buffer3).unwrap();
1341 assert_eq!(true, ws_result.end_of_message);
1342 assert_eq!(buffer1.len(), ws_result.len_to);
1343 }
1344
1345 #[test]
1346 fn receive_large_frame_multi_read() {
1347 let mut buffer1 = [0_u8; 1000];
1348 let mut buffer2 = [0_u8; 1000];
1349
1350 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1351 ws_client.state = WebSocketState::Open;
1352 let mut ws_server = WebSocketServer::new_server();
1353 ws_server.state = WebSocketState::Open;
1354
1355 let message = "Hello, world. This is a long message that takes multiple reads";
1356 ws_server
1357 .write(
1358 WebSocketSendMessageType::Text,
1359 true,
1360 &message.as_bytes(),
1361 &mut buffer1,
1362 )
1363 .unwrap();
1364
1365 let mut buffer2_cursor = 0;
1366 let ws_result = ws_client.read(&buffer1[..40], &mut buffer2).unwrap();
1367 assert_eq!(false, ws_result.end_of_message);
1368 buffer2_cursor += ws_result.len_to;
1369 let ws_result = ws_client
1370 .read(
1371 &buffer1[ws_result.len_from..],
1372 &mut buffer2[buffer2_cursor..],
1373 )
1374 .unwrap();
1375 assert_eq!(true, ws_result.end_of_message);
1376 buffer2_cursor += ws_result.len_to;
1377
1378 assert_eq!(
1379 message,
1380 std::str::from_utf8(&buffer2[..buffer2_cursor]).unwrap()
1381 );
1382 }
1383
1384 #[test]
1385 fn multiple_messages_in_receive_buffer() {
1386 let mut buffer1 = [0_u8; 1000];
1387 let mut buffer2 = [0_u8; 1000];
1388
1389 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1390 ws_client.state = WebSocketState::Open;
1391 let mut ws_server = WebSocketServer::new_server();
1392 ws_server.state = WebSocketState::Open;
1393
1394 let message1 = "Hello, world.";
1395 let len = ws_client
1396 .write(
1397 WebSocketSendMessageType::Text,
1398 true,
1399 &message1.as_bytes(),
1400 &mut buffer1,
1401 )
1402 .unwrap();
1403 let message2 = "This is another message.";
1404 ws_client
1405 .write(
1406 WebSocketSendMessageType::Text,
1407 true,
1408 &message2.as_bytes(),
1409 &mut buffer1[len..],
1410 )
1411 .unwrap();
1412
1413 let mut buffer1_cursor = 0;
1414 let mut buffer2_cursor = 0;
1415 let ws_result = ws_server
1416 .read(&buffer1[buffer1_cursor..], &mut buffer2)
1417 .unwrap();
1418 assert_eq!(true, ws_result.end_of_message);
1419 buffer1_cursor += ws_result.len_from;
1420 buffer2_cursor += ws_result.len_to;
1421 let ws_result = ws_server
1422 .read(&buffer1[buffer1_cursor..], &mut buffer2[buffer2_cursor..])
1423 .unwrap();
1424 assert_eq!(true, ws_result.end_of_message);
1425 assert_eq!(
1426 message1,
1427 std::str::from_utf8(&buffer2[..buffer2_cursor]).unwrap()
1428 );
1429
1430 assert_eq!(
1431 message2,
1432 std::str::from_utf8(&buffer2[buffer2_cursor..buffer2_cursor + ws_result.len_to])
1433 .unwrap()
1434 );
1435 }
1436
1437 #[test]
1438 fn receive_large_frame_with_small_send_buffer() {
1439 let mut buffer1: [u8; 1000] = [0; 1000];
1440 let mut buffer2: [u8; 1000] = [0; 1000];
1441
1442 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1443 ws_client.state = WebSocketState::Open;
1444 let mut ws_server = WebSocketServer::new_server();
1445 ws_server.state = WebSocketState::Open;
1446
1447 let hello = "hello";
1448 ws_server
1449 .write(
1450 WebSocketSendMessageType::Text,
1451 true,
1452 &hello.as_bytes(),
1453 &mut buffer1,
1454 )
1455 .unwrap();
1456
1457 let ws_result = ws_client.read(&buffer1, &mut buffer2[..1]).unwrap();
1458 assert_eq!(1, ws_result.len_to);
1459 assert_eq!(
1460 "h",
1461 std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1462 );
1463 assert_eq!(false, ws_result.end_of_message);
1464 let ws_result = ws_client
1465 .read(&buffer1[ws_result.len_from..], &mut buffer2[..4])
1466 .unwrap();
1467 assert_eq!(4, ws_result.len_to);
1468 assert_eq!(
1469 "ello",
1470 std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1471 );
1472 assert_eq!(true, ws_result.end_of_message);
1473 }
1474
1475 #[test]
1476 fn send_two_frame_message() {
1477 let mut buffer1: [u8; 1000] = [0; 1000];
1478 let mut buffer2: [u8; 1000] = [0; 1000];
1479 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1482 ws_client.state = WebSocketState::Open;
1483 let mut ws_server = WebSocketServer::new_server();
1484 ws_server.state = WebSocketState::Open;
1485
1486 let hello = "Hello, ";
1488 let num_bytes_hello = ws_server
1489 .write(
1490 WebSocketSendMessageType::Text,
1491 false,
1492 &hello.as_bytes(),
1493 &mut buffer1,
1494 )
1495 .unwrap();
1496
1497 let world = "World!";
1499 let num_bytes_world = ws_server
1500 .write(
1501 WebSocketSendMessageType::Text,
1502 true,
1503 &world.as_bytes(),
1504 &mut buffer1[num_bytes_hello..],
1505 )
1506 .unwrap();
1507
1508 let ws_result1 = ws_client
1510 .read(&buffer1[..num_bytes_hello], &mut buffer2)
1511 .unwrap();
1512 assert_eq!(WebSocketReceiveMessageType::Text, ws_result1.message_type);
1513 assert_eq!(false, ws_result1.end_of_message);
1514 let ws_result2 = ws_client
1515 .read(
1516 &buffer1[num_bytes_hello..num_bytes_hello + num_bytes_world],
1517 &mut buffer2[ws_result1.len_to..],
1518 )
1519 .unwrap();
1520 assert_eq!(WebSocketReceiveMessageType::Text, ws_result2.message_type);
1521 assert_eq!(true, ws_result2.end_of_message);
1522
1523 let received =
1524 std::str::from_utf8(&buffer2[..ws_result1.len_to + ws_result2.len_to]).unwrap();
1525 assert_eq!("Hello, World!", received);
1526 }
1527
1528 #[test]
1529 fn send_multi_frame_message() {
1530 let mut buffer1: [u8; 1000] = [0; 1000];
1531 let mut buffer2: [u8; 1000] = [0; 1000];
1532
1533 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1534 ws_client.state = WebSocketState::Open;
1535 let mut ws_server = WebSocketServer::new_server();
1536 ws_server.state = WebSocketState::Open;
1537
1538 let fragment1 = "fragment1";
1540 let fragment1_num_bytes = ws_server
1541 .write(
1542 WebSocketSendMessageType::Text,
1543 false,
1544 &fragment1.as_bytes(),
1545 &mut buffer1,
1546 )
1547 .unwrap();
1548
1549 let fragment2 = "fragment2";
1551 let fragment2_num_bytes = ws_server
1552 .write(
1553 WebSocketSendMessageType::Text,
1554 false,
1555 &fragment2.as_bytes(),
1556 &mut buffer1[fragment1_num_bytes..],
1557 )
1558 .unwrap();
1559
1560 let fragment3 = "fragment3";
1562 let _fragment3_num_bytes = ws_server
1563 .write(
1564 WebSocketSendMessageType::Text,
1565 true,
1566 &fragment3.as_bytes(),
1567 &mut buffer1[fragment1_num_bytes + fragment2_num_bytes..],
1568 )
1569 .unwrap();
1570
1571 let ws_result = ws_client.read(&buffer1, &mut buffer2).unwrap();
1573 assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1574 let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1575 assert_eq!(fragment1, received);
1576 assert_eq!(ws_result.end_of_message, false);
1577 let mut read_cursor = ws_result.len_from;
1578
1579 let ws_result = ws_client
1581 .read(&buffer1[read_cursor..], &mut buffer2)
1582 .unwrap();
1583 assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1584 let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1585 assert_eq!(fragment2, received);
1586 assert_eq!(ws_result.end_of_message, false);
1587 read_cursor += ws_result.len_from;
1588
1589 let ws_result = ws_client
1591 .read(&buffer1[read_cursor..], &mut buffer2)
1592 .unwrap();
1593 assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1594 let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1595 assert_eq!(fragment3, received);
1596 assert_eq!(ws_result.end_of_message, true);
1597
1598 let (is_fin_bit_set, op_code) = read_first_byte(buffer1[0]);
1600 assert_eq!(is_fin_bit_set, false);
1601 assert_eq!(op_code, WebSocketOpCode::TextFrame);
1602
1603 let (is_fin_bit_set, op_code) = read_first_byte(buffer1[fragment1_num_bytes]);
1605 assert_eq!(is_fin_bit_set, false);
1606 assert_eq!(op_code, WebSocketOpCode::ContinuationFrame);
1607
1608 let (is_fin_bit_set, op_code) =
1610 read_first_byte(buffer1[fragment1_num_bytes + fragment2_num_bytes]);
1611 assert_eq!(is_fin_bit_set, true);
1612 assert_eq!(op_code, WebSocketOpCode::ContinuationFrame);
1613 }
1614
1615 fn read_first_byte(byte: u8) -> (bool, WebSocketOpCode) {
1616 const FIN_BIT_FLAG: u8 = 0x80;
1617 const OP_CODE_FLAG: u8 = 0x0F;
1618 let is_fin_bit_set = (byte & FIN_BIT_FLAG) == FIN_BIT_FLAG;
1619 let op_code = get_op_code(byte & OP_CODE_FLAG).unwrap();
1620 (is_fin_bit_set, op_code)
1621 }
1622}