embedded_websocket_embedded_io/
lib.rs

1//! # Embedded Websocket
2//!
3//! `embedded_websocket` facilitates the encoding and decoding of websocket frames and can be used
4//! for both clients and servers. The library is intended to be used in constrained memory
5//! environments like embedded microcontrollers which cannot reference the rust standard library.
6//! It will work with arbitrarily small buffers regardless of websocket frame size as long as the
7//! websocket header can be read (2 - 14 bytes depending on the payload size and masking).
8//! Since the library is essentially an encoder or decoder of byte slices, the developer is free to
9//! use whatever transport mechanism they chose. The examples in the source repository use the
10//! TcpStream from the standard library.
11
12// support for websockets without using the standard library
13#![cfg_attr(not(feature = "std"), no_std)]
14
15use base64::EncodeSliceError;
16use byteorder::{BigEndian, ByteOrder};
17use core::{cmp, result, str};
18use heapless::{String, Vec};
19use rand_core::RngCore;
20use sha1::{Digest, Sha1};
21
22mod http;
23pub mod random;
24pub use self::http::{read_http_header, WebSocketContext};
25pub use self::random::EmptyRng;
26
27// support for working with discrete websocket frames when using IO streams
28// start here!!
29pub mod framer;
30pub mod framer_async;
31const MASK_KEY_LEN: usize = 4;
32
33/// Result returning a websocket specific error if encountered
34pub type Result<T> = result::Result<T, Error>;
35
36/// A fixed length 24-character string used to hold a websocket key for the opening handshake
37pub type WebSocketKey = String<24>;
38
39/// A maximum sized 24-character string used to store a sub protocol (e.g. `chat`)
40pub type WebSocketSubProtocol = String<24>;
41
42/// Websocket send message type used when sending a websocket frame
43#[derive(PartialEq, Eq, Debug, Copy, Clone)]
44pub enum WebSocketSendMessageType {
45    /// A UTF8 encoded text string
46    Text = 1,
47    /// Binary data
48    Binary = 2,
49    /// An unsolicited ping message
50    Ping = 9,
51    /// A pong message in response to a ping message
52    Pong = 10,
53    /// A close message in response to a close message from the other party. Used to complete a
54    /// closing handshake. If initiate a close handshake use the `close` function
55    CloseReply = 11,
56}
57
58impl WebSocketSendMessageType {
59    fn to_op_code(self) -> WebSocketOpCode {
60        match self {
61            WebSocketSendMessageType::Text => WebSocketOpCode::TextFrame,
62            WebSocketSendMessageType::Binary => WebSocketOpCode::BinaryFrame,
63            WebSocketSendMessageType::Ping => WebSocketOpCode::Ping,
64            WebSocketSendMessageType::Pong => WebSocketOpCode::Pong,
65            WebSocketSendMessageType::CloseReply => WebSocketOpCode::ConnectionClose,
66        }
67    }
68}
69
70/// Websocket receive message type use when reading a websocket frame
71#[derive(PartialEq, Eq, Debug, Copy, Clone)]
72pub enum WebSocketReceiveMessageType {
73    /// A UTF8 encoded text string
74    Text = 1,
75    /// Binary data
76    Binary = 2,
77    /// Signals that the close handshake is complete
78    CloseCompleted = 7,
79    /// Signals that the other party has initiated the close handshake. If you receive this message
80    /// you should respond with a `WebSocketSendMessageType::CloseReply` with the same payload as
81    /// close message
82    CloseMustReply = 8,
83    /// A ping message that you should respond to with a `WebSocketSendMessageType::Pong` message
84    /// including the same payload as the ping
85    Ping = 9,
86    /// A pong message in response to a ping message
87    Pong = 10,
88}
89
90/// Websocket close status code as per the rfc6455 websocket spec
91#[derive(PartialEq, Eq, Debug, Copy, Clone)]
92pub enum WebSocketCloseStatusCode {
93    /// Normal closure (1000), meaning that the purpose for which the connection was established
94    /// has been fulfilled
95    NormalClosure,
96    /// Endpoint unavailable (1001) indicates that an endpoint is "going away", such as a server
97    /// going down or a browser having navigated away from a page
98    EndpointUnavailable,
99    /// Protocol error (1002) indicates that an endpoint is terminating the connection due
100    /// to a protocol error.
101    ProtocolError,
102    /// Invalid message type (1003) indicates that an endpoint is terminating the connection
103    /// because it has received a type of data it cannot accept (e.g., an endpoint that
104    /// understands only text data MAY send this if it receives a binary message)
105    InvalidMessageType,
106    /// Reserved (1004) for future use
107    Reserved,
108    /// Empty (1005) indicates that no status code was present
109    Empty,
110    /// Invalid payload data (1007) indicates that an endpoint is terminating the connection
111    /// because it has received data within a message that was not consistent with the type of
112    /// the message (e.g., non-UTF-8 data within a text message)
113    InvalidPayloadData,
114    /// Policy violation (1008) indicates that an endpoint is terminating the connection because
115    /// it has received a message that violates its policy. This is a generic status code that
116    /// can be returned when there is no other more suitable status code
117    PolicyViolation,
118    /// Message too big (1009) indicates that an endpoint is terminating the connection because
119    /// it has received a message that is too big for it to process
120    MessageTooBig,
121    /// Mandatory extension (1010) indicates that an endpoint (client) is terminating the
122    /// connection because it has expected the server to negotiate one or more extension, but
123    /// the server didn't return them in the response message of the WebSocket handshake
124    MandatoryExtension,
125    /// Internal server error (1011) indicates that a server is terminating the connection because
126    /// it encountered an unexpected condition that prevented it from fulfilling the request
127    InternalServerError,
128    /// TLS handshake (1015) connection was closed due to a failure to perform a TLS handshake
129    TlsHandshake,
130    /// Custom close code
131    Custom(u16),
132}
133
134impl WebSocketCloseStatusCode {
135    fn from_u16(value: u16) -> WebSocketCloseStatusCode {
136        match value {
137            1000 => WebSocketCloseStatusCode::NormalClosure,
138            1001 => WebSocketCloseStatusCode::EndpointUnavailable,
139            1002 => WebSocketCloseStatusCode::ProtocolError,
140            1003 => WebSocketCloseStatusCode::InvalidMessageType,
141            1004 => WebSocketCloseStatusCode::Reserved,
142            1005 => WebSocketCloseStatusCode::Empty,
143            1007 => WebSocketCloseStatusCode::InvalidPayloadData,
144            1008 => WebSocketCloseStatusCode::PolicyViolation,
145            1009 => WebSocketCloseStatusCode::MessageTooBig,
146            1010 => WebSocketCloseStatusCode::MandatoryExtension,
147            1011 => WebSocketCloseStatusCode::InternalServerError,
148            1015 => WebSocketCloseStatusCode::TlsHandshake,
149            _ => WebSocketCloseStatusCode::Custom(value),
150        }
151    }
152
153    fn to_u16(self) -> u16 {
154        match self {
155            WebSocketCloseStatusCode::NormalClosure => 1000,
156            WebSocketCloseStatusCode::EndpointUnavailable => 1001,
157            WebSocketCloseStatusCode::ProtocolError => 1002,
158            WebSocketCloseStatusCode::InvalidMessageType => 1003,
159            WebSocketCloseStatusCode::Reserved => 1004,
160            WebSocketCloseStatusCode::Empty => 1005,
161            WebSocketCloseStatusCode::InvalidPayloadData => 1007,
162            WebSocketCloseStatusCode::PolicyViolation => 1008,
163            WebSocketCloseStatusCode::MessageTooBig => 1009,
164            WebSocketCloseStatusCode::MandatoryExtension => 1010,
165            WebSocketCloseStatusCode::InternalServerError => 1011,
166            WebSocketCloseStatusCode::TlsHandshake => 1015,
167            WebSocketCloseStatusCode::Custom(value) => value,
168        }
169    }
170}
171
172/// The state of the websocket
173#[derive(PartialEq, Eq, Copy, Clone, Debug)]
174pub enum WebSocketState {
175    /// The websocket has been created with `new_client()` or `new_server()`
176    None = 0,
177    /// The client has created an opening handshake
178    Connecting = 1,
179    /// The server has completed the opening handshake via server_accept() or, likewise, the
180    /// client has completed the opening handshake via client_accept(). The user is free to call
181    /// `write()`, `read()` or `close()` on the websocket
182    Open = 2,
183    /// The `close()` function has been called
184    CloseSent = 3,
185    /// A Close websocket frame has been received
186    CloseReceived = 4,
187    /// The close handshake has been completed
188    Closed = 5,
189    /// The server or client opening handshake failed
190    Aborted = 6,
191}
192
193/// Websocket specific errors
194#[derive(PartialEq, Eq, Debug)]
195pub enum Error {
196    /// Websocket frame has an invalid opcode
197    InvalidOpCode,
198    InvalidFrameLength,
199    InvalidCloseStatusCode,
200    WebSocketNotOpen,
201    WebsocketAlreadyOpen,
202    Utf8Error,
203    Unknown,
204    HttpHeader(httparse::Error),
205    HttpHeaderNoPath,
206    HttpHeaderIncomplete,
207    WriteToBufferTooSmall,
208    ReadFrameIncomplete,
209    HttpResponseCodeInvalid(Option<u16>),
210    AcceptStringInvalid,
211    ConvertInfallible,
212    RandCore,
213    UnexpectedContinuationFrame,
214    EncodeSliceError(EncodeSliceError),
215}
216
217impl From<EncodeSliceError> for Error {
218    fn from(err: EncodeSliceError) -> Error {
219        Error::EncodeSliceError(err)
220    }
221}
222
223impl From<httparse::Error> for Error {
224    fn from(err: httparse::Error) -> Error {
225        Error::HttpHeader(err)
226    }
227}
228
229impl From<str::Utf8Error> for Error {
230    fn from(_: str::Utf8Error) -> Error {
231        Error::Utf8Error
232    }
233}
234
235impl From<core::convert::Infallible> for Error {
236    fn from(_: core::convert::Infallible) -> Error {
237        Error::ConvertInfallible
238    }
239}
240
241impl From<()> for Error {
242    fn from(_: ()) -> Error {
243        Error::Unknown
244    }
245}
246
247impl core::fmt::Display for Error {
248    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
249        match self {
250            Error::HttpHeader(error) => write!(f, "bad http header {error}"),
251            Error::HttpResponseCodeInvalid(Some(code)) => write!(f, "bad http response ({code})"),
252            _ => write!(f, "{:?}", self),
253        }
254    }
255}
256
257#[cfg(feature = "std")]
258impl std::error::Error for Error {
259    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
260        if let Self::HttpHeader(error) = self {
261            Some(error)
262        } else {
263            None
264        }
265    }
266}
267
268#[derive(Copy, Clone, Debug, PartialEq, Eq)]
269enum WebSocketOpCode {
270    ContinuationFrame = 0,
271    TextFrame = 1,
272    BinaryFrame = 2,
273    ConnectionClose = 8,
274    Ping = 9,
275    Pong = 10,
276}
277
278impl WebSocketOpCode {
279    fn to_message_type(self) -> Result<WebSocketReceiveMessageType> {
280        match self {
281            WebSocketOpCode::TextFrame => Ok(WebSocketReceiveMessageType::Text),
282            WebSocketOpCode::BinaryFrame => Ok(WebSocketReceiveMessageType::Binary),
283            _ => Err(Error::InvalidOpCode),
284        }
285    }
286}
287
288/// The metadata result of a `read` function call of a websocket
289#[derive(Debug)]
290pub struct WebSocketReadResult {
291    /// Number of bytes read from the `from` buffer
292    pub len_from: usize,
293    /// Number of bytes written to the `to` buffer
294    pub len_to: usize,
295    /// End of message flag is `true` if the `to` buffer contains an entire websocket frame
296    /// payload otherwise `false` if the user must continue calling the read function to get the
297    /// rest of the payload
298    pub end_of_message: bool,
299    /// Close status code (as per the websocket spec) if the message type is `CloseMustReply` or
300    /// `CloseCompleted`. If a close status is specified then a UTF8 encoded string could also
301    /// appear in the frame payload giving more detailed information about why the websocket was
302    /// closed.
303    pub close_status: Option<WebSocketCloseStatusCode>,
304    /// The websocket frame type
305    pub message_type: WebSocketReceiveMessageType,
306}
307
308/// Websocket options used by a websocket client to initiate an opening handshake with a
309/// websocket server
310pub struct WebSocketOptions<'a> {
311    /// The request uri (e.g. `/chat?id=123`) of the GET method used to identify the endpoint of the
312    /// websocket connection. This allows multiple domains to be served by a single server.
313    /// This could also be used to send identifiable information about the client
314    pub path: &'a str,
315    /// The hostname (e.g. `server.example.com`) is used so that both the client and the server
316    /// can verify that they agree on which host is in use
317    pub host: &'a str,
318    /// The origin (e.g. `http://example.com`) is used to protect against unauthorized
319    /// cross-origin use of a WebSocket server by scripts using the WebSocket API in a web
320    /// browser. This field is usually only set by browser clients but servers may require it
321    /// so it has been exposed here.
322    pub origin: &'a str,
323    /// A list of requested sub protocols in order of preference. The server should return the
324    /// first sub protocol it supports or none at all. A sub protocol can be anything agreed
325    /// between the server and client
326    pub sub_protocols: Option<&'a [&'a str]>,
327    /// Any additional headers the server may require that are not part of the websocket
328    /// spec. These should be fully formed http headers without the `\r\n` (e.g. `MyHeader: foo`)
329    pub additional_headers: Option<&'a [&'a str]>,
330}
331
332/// Used to return a sized type from `WebSocket::new_server()`
333pub type WebSocketServer = WebSocket<EmptyRng, Server>;
334
335/// Used to return a sized type from `WebSocketClient::new_client()`
336pub type WebSocketClient<T> = WebSocket<T, Client>;
337
338// Simple Typestate pattern for preventing panics and allowing reuse of underlying
339// read/read_frame/etc..
340pub enum Server {}
341pub enum Client {}
342
343pub trait WebSocketType {}
344impl WebSocketType for Server {}
345impl WebSocketType for Client {}
346
347/// Websocket client and server implementation
348pub struct WebSocket<T, S: WebSocketType>
349where
350    T: RngCore,
351{
352    is_client: bool,
353    rng: T,
354    continuation_frame_op_code: Option<WebSocketOpCode>,
355    is_write_continuation: bool,
356    pub state: WebSocketState,
357    continuation_read: Option<ContinuationRead>,
358    marker: core::marker::PhantomData<S>,
359}
360
361impl<T, Type> WebSocket<T, Type>
362where
363    T: RngCore,
364    Type: WebSocketType,
365{
366    /// Creates a new websocket client by passing in a required random number generator
367    ///
368    /// # Examples
369    /// ```
370    /// use embedded_websocket as ws;
371    /// use rand;
372    /// let mut ws_client = ws::WebSocketClient::new_client(rand::thread_rng());
373    ///
374    /// assert_eq!(ws::WebSocketState::None, ws_client.state);
375    /// ```
376    pub fn new_client(rng: T) -> WebSocketClient<T> {
377        WebSocket {
378            is_client: true,
379            rng,
380            continuation_frame_op_code: None,
381            is_write_continuation: false,
382            state: WebSocketState::None,
383            continuation_read: None,
384            marker: core::marker::PhantomData::<Client>,
385        }
386    }
387
388    /// Creates a new websocket server. Note that you must use the `WebSocketServer` type and
389    /// not the generic `WebSocket` type for this call or you will get a `'type annotations needed'`
390    /// compilation error.
391    ///
392    /// # Examples
393    /// ```
394    /// use embedded_websocket as ws;
395    /// let mut ws_server = ws::WebSocketServer::new_server();
396    ///
397    /// assert_eq!(ws::WebSocketState::None, ws_server.state);
398    /// ```
399    pub fn new_server() -> WebSocketServer {
400        let rng = EmptyRng::new();
401        WebSocket {
402            is_client: false,
403            rng,
404            continuation_frame_op_code: None,
405            is_write_continuation: false,
406            state: WebSocketState::None,
407            continuation_read: None,
408            marker: core::marker::PhantomData::<Server>,
409        }
410    }
411}
412
413impl<T> WebSocket<T, Server>
414where
415    T: RngCore,
416{
417    /// Used by the server to accept an incoming client connection and build a websocket upgrade
418    /// http response string. The client http header should be read with the `read_http_header`
419    /// function and the result should be passed to this function.
420    /// Websocket state will change from None -> Open if successful, otherwise None -> Aborted
421    ///
422    /// # Examples
423    ///
424    /// ```
425    /// use embedded_websocket as ws;
426    /// use core::str::FromStr;
427    /// use heapless::String;
428    /// let mut buffer: [u8; 1000] = [0; 1000];
429    /// let mut ws_server = ws::WebSocketServer::new_server();
430    /// let ws_key: ws::WebSocketKey = String::from_str("Z7OY1UwHOx/nkSz38kfPwg==").unwrap();
431    /// let sub_protocol: ws::WebSocketSubProtocol = String::from_str("chat").unwrap();
432    /// let len = ws_server
433    ///     .server_accept(&ws_key, Some(&sub_protocol), &mut buffer)
434    ///     .unwrap();
435    /// let response = std::str::from_utf8(&buffer[..len]).unwrap();
436    ///
437    /// assert_eq!("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Protocol: chat\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n", response);
438    /// ```
439    ///
440    /// # Errors
441    /// There should be no way for a user provided input to return the errors listed below as the
442    /// input is already constrained.
443    /// * The http response is built with a stack allocated 1KB buffer and it *should be impossible*
444    /// for it to return an  `Unknown` error if that buffer is too small. However, this is better
445    /// than a panic and it will do so if the response header is too large to fit in the buffer
446    /// * This function can return an `Utf8Error` if there was an error with the generation of the
447    /// accept string. This should also be impossible but an error is preferable to a panic
448    /// * Returns `WebsocketAlreadyOpen` if called on a websocket that is already open
449    pub fn server_accept(
450        &mut self,
451        sec_websocket_key: &WebSocketKey,
452        sec_websocket_protocol: Option<&WebSocketSubProtocol>,
453        to: &mut [u8],
454    ) -> Result<usize> {
455        if self.state == WebSocketState::Open {
456            return Err(Error::WebsocketAlreadyOpen);
457        }
458
459        match http::build_connect_handshake_response(sec_websocket_key, sec_websocket_protocol, to)
460        {
461            Ok(http_response_len) => {
462                self.state = WebSocketState::Open;
463                Ok(http_response_len)
464            }
465            Err(e) => {
466                self.state = WebSocketState::Aborted;
467                Err(e)
468            }
469        }
470    }
471}
472
473impl<T> WebSocket<T, Client>
474where
475    T: RngCore,
476{
477    /// Used by the client to initiate a websocket opening handshake
478    ///
479    /// # Examples
480    /// ```
481    /// use embedded_websocket as ws;
482    /// let mut buffer: [u8; 2000] = [0; 2000];
483    /// let mut ws_client = ws::WebSocketClient::new_client(rand::thread_rng());
484    /// let sub_protocols = ["chat", "superchat"];
485    /// let websocket_options = ws::WebSocketOptions {
486    ///     path: "/chat",
487    ///     host: "localhost",
488    ///     origin: "http://localhost",
489    ///     sub_protocols: Some(&sub_protocols),
490    ///     additional_headers: None,
491    /// };
492    ///
493    /// let (len, web_socket_key) = ws_client.client_connect(&websocket_options, &mut buffer).unwrap();
494    ///
495    /// let actual_http = std::str::from_utf8(&buffer[..len]).unwrap();
496    /// let mut expected_http = String::new();
497    /// expected_http.push_str("GET /chat HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: ");
498    /// expected_http.push_str(web_socket_key.as_str());
499    /// expected_http.push_str("\r\nOrigin: http://localhost\r\nSec-WebSocket-Protocol: chat, superchat\r\nSec-WebSocket-Version: 13\r\n\r\n");
500    /// assert_eq!(expected_http.as_str(), actual_http);
501    /// ```
502    ///
503    /// # Errors
504    /// * The http response is built with a stack allocated 1KB buffer and will return an
505    /// `Unknown` error if that buffer is too small. This would happen is the user supplied too many
506    /// additional headers or the sub-protocol string is too large
507    /// * This function can return an `Utf8Error` if there was an error with the generation of the
508    /// accept string. This should be impossible but an error is preferable to a panic
509    /// * Returns `WebsocketAlreadyOpen` if called on a websocket that is already open
510    pub fn client_connect(
511        &mut self,
512        websocket_options: &WebSocketOptions,
513        to: &mut [u8],
514    ) -> Result<(usize, WebSocketKey)> {
515        if self.state == WebSocketState::Open {
516            return Err(Error::WebsocketAlreadyOpen);
517        }
518
519        match http::build_connect_handshake_request(websocket_options, &mut self.rng, to) {
520            Ok((request_len, sec_websocket_key)) => {
521                self.state = WebSocketState::Connecting;
522                Ok((request_len, sec_websocket_key))
523            }
524            Err(e) => Err(e),
525        }
526    }
527
528    /// Used by a websocket client for checking the server response to an opening handshake
529    /// (sent using the client_connect function). If the client requested one or more sub protocols
530    /// the server will choose one (or none) and you get that in the result
531    /// # Examples
532    /// ```
533    /// use embedded_websocket as ws;
534    /// use core::str::FromStr;
535    /// use heapless::String;
536    /// let mut ws_client = ws::WebSocketClient::new_client(rand::thread_rng());
537    /// let ws_key: ws::WebSocketKey = String::from_str("Z7OY1UwHOx/nkSz38kfPwg==").unwrap();
538    /// let server_response_html = "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Protocol: chat\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n";    ///
539    /// let (len, sub_protocol) = ws_client.client_accept(&ws_key, server_response_html.as_bytes())
540    ///     .unwrap();
541    ///
542    /// assert_eq!(159, len);
543    /// assert_eq!("chat", sub_protocol.unwrap());
544    /// ```
545    /// # Errors
546    /// * Returns `HttpResponseCodeInvalid` if the HTTP response code is not `101 Switching Protocols`
547    /// * Returns `AcceptStringInvalid` if the web server failed to return a valid accept string
548    /// * Returns `HttpHeader(Version)` or some other varient if the HTTP response is not well formed
549    /// * Returns `WebsocketAlreadyOpen` if called on a websocket that is already open
550    pub fn client_accept(
551        &mut self,
552        sec_websocket_key: &WebSocketKey,
553        from: &[u8],
554    ) -> Result<(usize, Option<WebSocketSubProtocol>)> {
555        if self.state == WebSocketState::Open {
556            return Err(Error::WebsocketAlreadyOpen);
557        }
558
559        match http::read_server_connect_handshake_response(sec_websocket_key, from) {
560            Ok((len, sec_websocket_protocol)) => {
561                self.state = WebSocketState::Open;
562                Ok((len, sec_websocket_protocol))
563            }
564            Err(Error::HttpHeaderIncomplete) => Err(Error::HttpHeaderIncomplete),
565            Err(e) => {
566                self.state = WebSocketState::Aborted;
567                Err(e)
568            }
569        }
570    }
571}
572
573impl<T, Type> WebSocket<T, Type>
574where
575    T: RngCore,
576    Type: WebSocketType,
577{
578    /// Reads the payload from a websocket frame in buffer `from` into a buffer `to` and returns
579    /// metadata about the frame. Since this function is designed to be called in a memory
580    /// constrained system we may not read the entire payload in one go. In each of the scenarios
581    /// below the `read_result.end_of_message` flag would be `false`:
582    /// * The payload is fragmented into multiple websocket frames (as per the websocket spec)
583    /// * The `from` buffer does not hold the entire websocket frame. For example if only part of
584    /// the frame was read or if the `from` buffer is too small to hold an entire websocket frame
585    /// * The `to` buffer is too small to hold the entire websocket frame payload
586    ///
587    /// If the function returns `read_result.end_of_message` `false` then the next
588    /// call to the function should not include data that has already been passed into the function.
589    /// The websocket *remembers* the websocket frame header and is able to process the rest of the
590    /// payload correctly. If the `from` buffer contains multiple websocket frames then only one of
591    /// them will be returned at a time and the user must make multiple calls to the function by
592    /// taking note of `read_result.len_from` which tells you how many bytes were read from the
593    /// `from` buffer
594    ///
595    /// # Examples
596    ///
597    /// ```
598    /// use embedded_websocket as ws;
599    /// //                    h   e   l   l   o
600    /// let buffer1 = [129,5,104,101,108,108,111];
601    /// let mut buffer2: [u8; 128] = [0; 128];
602    /// let mut ws_client = ws::WebSocketClient::new_client(rand::thread_rng());
603    /// ws_client.state = ws::WebSocketState::Open; // skip the opening handshake
604    /// let ws_result = ws_client.read(&buffer1, &mut buffer2).unwrap();
605    ///
606    /// assert_eq!("hello".as_bytes(), &buffer2[..ws_result.len_to]);
607    /// ```
608    /// # Errors
609    /// * Returns `WebSocketNotOpen` when the websocket is not open when this function is called
610    /// * Returns `InvalidOpCode` if the websocket frame contains an invalid opcode
611    /// * Returns `UnexpectedContinuationFrame` if we receive a continuation frame without first
612    /// receiving a non-continuation frame with an opcode describing the payload
613    /// * Returns `ReadFrameIncomplete` if the `from` buffer does not contain a full websocket
614    /// header (typically 2-14 bytes depending on the payload)
615    /// * Returns `InvalidFrameLength` if the frame length cannot be decoded
616    ///
617    pub fn read(&mut self, from: &[u8], to: &mut [u8]) -> Result<WebSocketReadResult> {
618        if self.state == WebSocketState::Open || self.state == WebSocketState::CloseSent {
619            let frame = self.read_frame(from, to)?;
620
621            match frame.op_code {
622                WebSocketOpCode::Ping => Ok(frame.to_readresult(WebSocketReceiveMessageType::Ping)),
623                WebSocketOpCode::Pong => Ok(frame.to_readresult(WebSocketReceiveMessageType::Pong)),
624                WebSocketOpCode::TextFrame => {
625                    Ok(frame.to_readresult(WebSocketReceiveMessageType::Text))
626                }
627                WebSocketOpCode::BinaryFrame => {
628                    Ok(frame.to_readresult(WebSocketReceiveMessageType::Binary))
629                }
630                WebSocketOpCode::ConnectionClose => match self.state {
631                    WebSocketState::CloseSent => {
632                        self.state = WebSocketState::Closed;
633                        Ok(frame.to_readresult(WebSocketReceiveMessageType::CloseCompleted))
634                    }
635                    _ => {
636                        self.state = WebSocketState::CloseReceived;
637                        Ok(frame.to_readresult(WebSocketReceiveMessageType::CloseMustReply))
638                    }
639                },
640                WebSocketOpCode::ContinuationFrame => match self.continuation_frame_op_code {
641                    Some(cf_op_code) => Ok(frame.to_readresult(cf_op_code.to_message_type()?)),
642                    None => Err(Error::UnexpectedContinuationFrame),
643                },
644            }
645        } else {
646            Err(Error::WebSocketNotOpen)
647        }
648    }
649
650    /// Writes the payload in `from` to a websocket frame in `to`
651    /// * message_type - The type of message to send: Text, Binary or CloseReply
652    /// * end_of_message - False to fragment a frame into multiple smaller frames. The last frame
653    /// should set this to true
654    /// * from - The buffer containing the payload to encode
655    /// * to - The the buffer to save the websocket encoded payload to.
656    /// Returns the number of bytes written to the `to` buffer
657    /// # Examples
658    ///
659    /// ```
660    /// use embedded_websocket as ws;
661    /// let mut buffer: [u8; 1000] = [0; 1000];
662    /// let mut ws_server = ws::WebSocketServer::new_server();
663    /// ws_server.state = ws::WebSocketState::Open; // skip the opening handshake
664    /// let len = ws_server.write(ws::WebSocketSendMessageType::Text, true, "hello".as_bytes(),
665    ///     &mut buffer).unwrap();
666    ///
667    /// //                     h   e   l   l   o
668    /// let expected = [129,5,104,101,108,108,111];
669    /// assert_eq!(&expected, &buffer[..len]);
670    /// ```
671    /// # Errors
672    /// * Returns `WebSocketNotOpen` when the websocket is not open when this function is called
673    /// * Returns `WriteToBufferTooSmall` when the `to` buffer is too small to fit the websocket
674    /// frame header (2-14 bytes) plus the payload. Consider fragmenting the messages by making
675    /// multiple write calls with `end_of_message` set to `false` and the final call set to `true`
676    pub fn write(
677        &mut self,
678        message_type: WebSocketSendMessageType,
679        end_of_message: bool,
680        from: &[u8],
681        to: &mut [u8],
682    ) -> Result<usize> {
683        if self.state == WebSocketState::Open || self.state == WebSocketState::CloseReceived {
684            let mut op_code = message_type.to_op_code();
685            if op_code == WebSocketOpCode::ConnectionClose {
686                self.state = WebSocketState::Closed
687            } else if self.is_write_continuation {
688                op_code = WebSocketOpCode::ContinuationFrame;
689            }
690
691            self.is_write_continuation = !end_of_message;
692            self.write_frame(from, to, op_code, end_of_message)
693        } else {
694            Err(Error::WebSocketNotOpen)
695        }
696    }
697
698    /// Initiates a close handshake.
699    /// Both the client and server may initiate a close handshake. If successful the function
700    /// changes the websocket state from Open -> CloseSent
701    /// # Errors
702    /// * Returns `WebSocketNotOpen` when the websocket is not open when this function is called
703    /// * Returns `WriteToBufferTooSmall` when the `to` buffer is too small to fit the websocket
704    /// frame header (2-14 bytes) plus the payload. Consider sending a smaller status_description
705    pub fn close(
706        &mut self,
707        close_status: WebSocketCloseStatusCode,
708        status_description: Option<&str>,
709        to: &mut [u8],
710    ) -> Result<usize> {
711        if self.state == WebSocketState::Open {
712            self.state = WebSocketState::CloseSent;
713            if let Some(status_description) = status_description {
714                let mut from_buffer: Vec<u8, 256> = Vec::new();
715                from_buffer.extend_from_slice(&close_status.to_u16().to_be_bytes())?;
716
717                // restrict the max size of the status_description
718                let len = if status_description.len() < 254 {
719                    status_description.len()
720                } else {
721                    254
722                };
723
724                from_buffer.extend_from_slice(status_description[..len].as_bytes())?;
725                self.write_frame(&from_buffer, to, WebSocketOpCode::ConnectionClose, true)
726            } else {
727                let mut from_buffer: [u8; 2] = [0; 2];
728                BigEndian::write_u16(&mut from_buffer, close_status.to_u16());
729                self.write_frame(&from_buffer, to, WebSocketOpCode::ConnectionClose, true)
730            }
731        } else {
732            Err(Error::WebSocketNotOpen)
733        }
734    }
735
736    fn read_frame(&mut self, from_buffer: &[u8], to_buffer: &mut [u8]) -> Result<WebSocketFrame> {
737        match &mut self.continuation_read {
738            Some(continuation_read) => {
739                let result = read_continuation(continuation_read, from_buffer, to_buffer);
740                if result.is_fin_bit_set {
741                    self.continuation_read = None;
742                }
743                Ok(result)
744            }
745            None => {
746                let (mut result, continuation_read) = read_frame(from_buffer, to_buffer)?;
747
748                // override the op code we get from the result with our continuation frame opcode if it exists
749                if let Some(continuation_frame_op_code) = self.continuation_frame_op_code {
750                    result.op_code = continuation_frame_op_code;
751                }
752
753                // reset the continuation frame op code to None if this is the last fragment (or there is no fragmentation)
754                self.continuation_frame_op_code = if result.is_fin_bit_set {
755                    None
756                } else {
757                    Some(result.op_code)
758                };
759
760                self.continuation_read = continuation_read;
761                Ok(result)
762            }
763        }
764    }
765
766    fn write_frame(
767        &mut self,
768        from_buffer: &[u8],
769        to_buffer: &mut [u8],
770        op_code: WebSocketOpCode,
771        end_of_message: bool,
772    ) -> Result<usize> {
773        let fin_bit_set_as_byte: u8 = if end_of_message { 0x80 } else { 0x00 };
774        let byte1: u8 = fin_bit_set_as_byte | op_code as u8;
775        let count = from_buffer.len();
776        const BYTE_HEADER_SIZE: usize = 2;
777        const SHORT_HEADER_SIZE: usize = 4;
778        const LONG_HEADER_SIZE: usize = 10;
779        const MASK_KEY_SIZE: usize = 4;
780        let header_size;
781        let mask_bit_set_as_byte = if self.is_client { 0x80 } else { 0x00 };
782        let payload_len = from_buffer.len() + if self.is_client { MASK_KEY_SIZE } else { 0 };
783
784        // write header followed by the payload
785        // header size depends on how large the payload is
786        if count < 126 {
787            if payload_len + BYTE_HEADER_SIZE > to_buffer.len() {
788                return Err(Error::WriteToBufferTooSmall);
789            }
790            to_buffer[0] = byte1;
791            to_buffer[1] = mask_bit_set_as_byte | count as u8;
792            header_size = BYTE_HEADER_SIZE;
793        } else if count < 65535 {
794            if payload_len + SHORT_HEADER_SIZE > to_buffer.len() {
795                return Err(Error::WriteToBufferTooSmall);
796            }
797            to_buffer[0] = byte1;
798            to_buffer[1] = mask_bit_set_as_byte | 126;
799            BigEndian::write_u16(&mut to_buffer[2..], count as u16);
800            header_size = SHORT_HEADER_SIZE;
801        } else {
802            if payload_len + LONG_HEADER_SIZE > to_buffer.len() {
803                return Err(Error::WriteToBufferTooSmall);
804            }
805            to_buffer[0] = byte1;
806            to_buffer[1] = mask_bit_set_as_byte | 127;
807            BigEndian::write_u64(&mut to_buffer[2..], count as u64);
808            header_size = LONG_HEADER_SIZE;
809        }
810
811        // sent by client - need to mask the data
812        // we need to mask the bytes to prevent web server caching
813        if self.is_client {
814            let mut mask_key = [0; MASK_KEY_SIZE];
815            self.rng.fill_bytes(&mut mask_key); // clients always have an rng instance
816            to_buffer[header_size..header_size + MASK_KEY_SIZE].copy_from_slice(&mask_key);
817            let to_buffer_start = header_size + MASK_KEY_SIZE;
818
819            // apply the mask key to every byte in the payload. This is a hot function
820            for (i, (from, to)) in from_buffer[..count]
821                .iter()
822                .zip(&mut to_buffer[to_buffer_start..to_buffer_start + count])
823                .enumerate()
824            {
825                *to = *from ^ mask_key[i % MASK_KEY_SIZE];
826            }
827
828            Ok(to_buffer_start + count)
829        } else {
830            to_buffer[header_size..header_size + count].copy_from_slice(&from_buffer[..count]);
831            Ok(header_size + count)
832        }
833    }
834}
835
836// Continuation read is used when we cannot fit the entire websocket frame into the supplied buffer
837struct ContinuationRead {
838    op_code: WebSocketOpCode,
839    count: usize,
840    is_fin_bit_set: bool,
841    mask_key: Option<[u8; 4]>,
842}
843
844struct WebSocketFrame {
845    is_fin_bit_set: bool,
846    op_code: WebSocketOpCode,
847    num_bytes_to: usize,
848    num_bytes_from: usize,
849    close_status: Option<WebSocketCloseStatusCode>,
850}
851
852impl WebSocketFrame {
853    fn to_readresult(&self, message_type: WebSocketReceiveMessageType) -> WebSocketReadResult {
854        WebSocketReadResult {
855            len_from: self.num_bytes_from,
856            len_to: self.num_bytes_to,
857            end_of_message: self.is_fin_bit_set,
858            close_status: self.close_status,
859            message_type,
860        }
861    }
862}
863
864fn min(num1: usize, num2: usize, num3: usize) -> usize {
865    cmp::min(cmp::min(num1, num2), num3)
866}
867
868fn read_into_buffer(
869    mask_key: &mut Option<[u8; 4]>,
870    from_buffer: &[u8],
871    to_buffer: &mut [u8],
872    len: usize,
873) -> usize {
874    // if we are trying to read more than number of bytes in either buffer
875    let len_to_read = min(len, to_buffer.len(), from_buffer.len());
876
877    match mask_key {
878        Some(mask_key) => {
879            // apply the mask key to every byte in the payload. This is a hot function.
880            for (i, (from, to)) in from_buffer[..len_to_read].iter().zip(to_buffer).enumerate() {
881                *to = *from ^ mask_key[i % MASK_KEY_LEN];
882            }
883            mask_key.rotate_left(len_to_read % MASK_KEY_LEN);
884        }
885        None => {
886            to_buffer[..len_to_read].copy_from_slice(&from_buffer[..len_to_read]);
887        }
888    }
889
890    len_to_read
891}
892
893fn read_continuation(
894    continuation_read: &mut ContinuationRead,
895    from_buffer: &[u8],
896    to_buffer: &mut [u8],
897) -> WebSocketFrame {
898    let len_read = read_into_buffer(
899        &mut continuation_read.mask_key,
900        from_buffer,
901        to_buffer,
902        continuation_read.count,
903    );
904
905    let is_complete = len_read == continuation_read.count;
906
907    let frame = match continuation_read.op_code {
908        WebSocketOpCode::ConnectionClose => decode_close_frame(to_buffer, len_read, len_read),
909        _ => WebSocketFrame {
910            num_bytes_from: len_read,
911            num_bytes_to: len_read,
912            op_code: continuation_read.op_code,
913            close_status: None,
914            is_fin_bit_set: if is_complete {
915                continuation_read.is_fin_bit_set
916            } else {
917                false
918            },
919        },
920    };
921
922    continuation_read.count -= len_read;
923    frame
924}
925
926fn read_frame(
927    from_buffer: &[u8],
928    to_buffer: &mut [u8],
929) -> Result<(WebSocketFrame, Option<ContinuationRead>)> {
930    if from_buffer.len() < 2 {
931        return Err(Error::ReadFrameIncomplete);
932    }
933
934    let byte1 = from_buffer[0];
935    let byte2 = from_buffer[1];
936
937    // process first byte
938    const FIN_BIT_FLAG: u8 = 0x80;
939    const OP_CODE_FLAG: u8 = 0x0F;
940    let is_fin_bit_set = (byte1 & FIN_BIT_FLAG) == FIN_BIT_FLAG;
941    let op_code = get_op_code(byte1 & OP_CODE_FLAG)?;
942
943    // process second byte
944    const MASK_FLAG: u8 = 0x80;
945    let is_mask_bit_set = (byte2 & MASK_FLAG) == MASK_FLAG;
946    let (len, mut num_bytes_read) = read_length(byte2, &from_buffer[2..])?;
947
948    num_bytes_read += 2;
949    let from_buffer = &from_buffer[num_bytes_read..];
950
951    // reads the mask key from the payload if the is_mask_bit_set flag is set
952    let mut mask_key = if is_mask_bit_set {
953        if from_buffer.len() < MASK_KEY_LEN {
954            return Err(Error::ReadFrameIncomplete);
955        }
956        let mut mask_key: [u8; MASK_KEY_LEN] = [0; MASK_KEY_LEN];
957        mask_key.copy_from_slice(&from_buffer[..MASK_KEY_LEN]);
958        num_bytes_read += MASK_KEY_LEN;
959        Some(mask_key)
960    } else {
961        None
962    };
963
964    let len_read = if is_mask_bit_set {
965        // start after the mask key
966        let from_buffer = &from_buffer[MASK_KEY_LEN..];
967        read_into_buffer(&mut mask_key, from_buffer, to_buffer, len)
968    } else {
969        read_into_buffer(&mut mask_key, from_buffer, to_buffer, len)
970    };
971
972    let has_continuation = len_read < len;
973    num_bytes_read += len_read;
974
975    let frame = match op_code {
976        WebSocketOpCode::ConnectionClose => decode_close_frame(to_buffer, num_bytes_read, len_read),
977        _ => WebSocketFrame {
978            num_bytes_from: num_bytes_read,
979            num_bytes_to: len_read,
980            op_code,
981            close_status: None,
982            is_fin_bit_set: if has_continuation {
983                false
984            } else {
985                is_fin_bit_set
986            },
987        },
988    };
989
990    if has_continuation {
991        let continuation_read = Some(ContinuationRead {
992            op_code,
993            count: len - len_read,
994            is_fin_bit_set,
995            mask_key,
996        });
997        Ok((frame, continuation_read))
998    } else {
999        Ok((frame, None))
1000    }
1001}
1002
1003fn get_op_code(val: u8) -> Result<WebSocketOpCode> {
1004    match val {
1005        0 => Ok(WebSocketOpCode::ContinuationFrame),
1006        1 => Ok(WebSocketOpCode::TextFrame),
1007        2 => Ok(WebSocketOpCode::BinaryFrame),
1008        8 => Ok(WebSocketOpCode::ConnectionClose),
1009        9 => Ok(WebSocketOpCode::Ping),
1010        10 => Ok(WebSocketOpCode::Pong),
1011        _ => Err(Error::InvalidOpCode),
1012    }
1013}
1014
1015// returns (len, how_many_bytes_were_read)
1016fn read_length(byte2: u8, from_buffer: &[u8]) -> Result<(usize, usize)> {
1017    let len = byte2 & 0x7F;
1018
1019    if len < 126 {
1020        // for messages smaller than 126 bytes
1021        return Ok((len as usize, 0));
1022    } else if len == 126 {
1023        // for messages smaller than 64KB
1024        if from_buffer.len() < 2 {
1025            return Err(Error::ReadFrameIncomplete);
1026        }
1027        let mut buf: [u8; 2] = [0; 2];
1028        buf.copy_from_slice(&from_buffer[..2]);
1029        return Ok((BigEndian::read_u16(&buf) as usize, 2));
1030    } else if len == 127 {
1031        // for messages larger than 64KB
1032        if from_buffer.len() < 8 {
1033            return Err(Error::ReadFrameIncomplete);
1034        }
1035        let mut buf: [u8; 8] = [0; 8];
1036        buf.copy_from_slice(&from_buffer[..8]);
1037        return Ok((BigEndian::read_u64(&buf) as usize, 8));
1038    }
1039
1040    Err(Error::InvalidFrameLength)
1041}
1042
1043fn decode_close_frame(buffer: &mut [u8], num_bytes_read: usize, len: usize) -> WebSocketFrame {
1044    if len >= 2 {
1045        // NOTE: for now, don't read the close status description
1046        let code = BigEndian::read_u16(buffer);
1047        let close_status_code = WebSocketCloseStatusCode::from_u16(code);
1048
1049        return WebSocketFrame {
1050            num_bytes_from: num_bytes_read,
1051            num_bytes_to: len,
1052            op_code: WebSocketOpCode::ConnectionClose,
1053            close_status: Some(close_status_code),
1054            is_fin_bit_set: true,
1055        };
1056    }
1057
1058    build_client_disconnected_frame(num_bytes_read)
1059}
1060
1061fn build_client_disconnected_frame(num_bytes_from: usize) -> WebSocketFrame {
1062    WebSocketFrame {
1063        num_bytes_from,
1064        num_bytes_to: 0,
1065        op_code: WebSocketOpCode::ConnectionClose,
1066        close_status: Some(WebSocketCloseStatusCode::InternalServerError),
1067        is_fin_bit_set: true,
1068    }
1069}
1070
1071// ************************************************************************************************
1072// **************************************** TESTS *************************************************
1073// ************************************************************************************************
1074
1075#[cfg(test)]
1076mod tests {
1077    extern crate std;
1078    use core::str::FromStr;
1079
1080    use super::*;
1081
1082    #[test]
1083    fn opening_handshake() {
1084        let client_request = "GET /chat HTTP/1.1
1085Host: localhost:5000
1086User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:62.0) Gecko/20100101 Firefox/62.0
1087Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
1088Accept-Language: en-US,en;q=0.5
1089Accept-Encoding: gzip, deflate
1090Sec-WebSocket-Version: 13
1091Origin: http://localhost:5000
1092Sec-WebSocket-Extensions: permessage-deflate
1093Sec-WebSocket-Key: Z7OY1UwHOx/nkSz38kfPwg==
1094Sec-WebSocket-Protocol: chat
1095DNT: 1
1096Connection: keep-alive, Upgrade
1097Pragma: no-cache
1098Cache-Control: no-cache
1099Upgrade: websocket
1100
1101";
1102
1103        let mut headers = [httparse::EMPTY_HEADER; 16];
1104        let mut request = httparse::Request::new(&mut headers);
1105        request.parse(client_request.as_bytes()).unwrap();
1106        let headers = headers.iter().map(|f| (f.name, f.value));
1107        let web_socket_context = read_http_header(headers).unwrap().unwrap();
1108        assert_eq!(
1109            "Z7OY1UwHOx/nkSz38kfPwg==",
1110            web_socket_context.sec_websocket_key
1111        );
1112        assert_eq!(
1113            "chat",
1114            web_socket_context
1115                .sec_websocket_protocol_list
1116                .get(0)
1117                .unwrap()
1118                .as_str()
1119        );
1120        let mut web_socket = WebSocketServer::new_server();
1121
1122        let mut ws_buffer: [u8; 3000] = [0; 3000];
1123        let size = web_socket
1124            .server_accept(&web_socket_context.sec_websocket_key, None, &mut ws_buffer)
1125            .unwrap();
1126        let response = std::str::from_utf8(&ws_buffer[..size]).unwrap();
1127        let client_response_expected = "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n";
1128        assert_eq!(client_response_expected, response);
1129    }
1130
1131    #[test]
1132    fn server_write_frame() {
1133        let mut buffer: [u8; 1000] = [0; 1000];
1134        let mut ws_server = WebSocketServer::new_server();
1135        let len = ws_server
1136            .write_frame(
1137                "hello".as_bytes(),
1138                &mut buffer,
1139                WebSocketOpCode::TextFrame,
1140                true,
1141            )
1142            .unwrap();
1143        let expected = [129, 5, 104, 101, 108, 108, 111];
1144        assert_eq!(&expected, &buffer[..len]);
1145    }
1146
1147    #[test]
1148    fn server_accept_should_write_sub_protocol() {
1149        let mut buffer: [u8; 1000] = [0; 1000];
1150        let mut ws_server = WebSocketServer::new_server();
1151        let ws_key: WebSocketKey = String::from_str("Z7OY1UwHOx/nkSz38kfPwg==").unwrap();
1152        let sub_protocol: WebSocketSubProtocol = String::from_str("chat").unwrap();
1153        let size = ws_server
1154            .server_accept(&ws_key, Some(&sub_protocol), &mut buffer)
1155            .unwrap();
1156        let response = std::str::from_utf8(&buffer[..size]).unwrap();
1157        assert_eq!("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Protocol: chat\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n", response);
1158    }
1159
1160    #[test]
1161    fn closing_handshake() {
1162        let mut buffer1: [u8; 500] = [0; 500];
1163        let mut buffer2: [u8; 500] = [0; 500];
1164
1165        let mut rng = rand::thread_rng();
1166
1167        let mut ws_client = WebSocketClient::new_client(&mut rng);
1168        ws_client.state = WebSocketState::Open;
1169
1170        let mut ws_server = WebSocketServer::new_server();
1171        ws_server.state = WebSocketState::Open;
1172
1173        // client sends a close (initiates the close handshake)
1174        ws_client
1175            .close(WebSocketCloseStatusCode::NormalClosure, None, &mut buffer1)
1176            .unwrap();
1177
1178        // check that the client receives the close message
1179        let ws_result = ws_server.read(&buffer1, &mut buffer2).unwrap();
1180        assert_eq!(
1181            WebSocketReceiveMessageType::CloseMustReply,
1182            ws_result.message_type
1183        );
1184
1185        // server MUST respond to complete the handshake
1186        ws_server
1187            .write(
1188                WebSocketSendMessageType::CloseReply,
1189                true,
1190                &buffer2[..ws_result.len_to],
1191                &mut buffer1,
1192            )
1193            .unwrap();
1194        assert_eq!(WebSocketState::Closed, ws_server.state);
1195
1196        // check that the client receives the close message from the server
1197        let ws_result = ws_client.read(&buffer1, &mut buffer2).unwrap();
1198        assert_eq!(WebSocketState::Closed, ws_client.state);
1199
1200        assert_eq!(
1201            WebSocketReceiveMessageType::CloseCompleted,
1202            ws_result.message_type
1203        );
1204    }
1205
1206    #[test]
1207    fn send_message_from_client_to_server() {
1208        let mut buffer1: [u8; 1000] = [0; 1000];
1209        let mut buffer2: [u8; 1000] = [0; 1000];
1210
1211        // how to create a client
1212        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1213
1214        ws_client.state = WebSocketState::Open;
1215        let mut ws_server = WebSocketServer::new_server();
1216        ws_server.state = WebSocketState::Open;
1217
1218        // client sends a Text message
1219        let hello = "hello";
1220        let num_bytes = ws_client
1221            .write(
1222                WebSocketSendMessageType::Text,
1223                true,
1224                &hello.as_bytes(),
1225                &mut buffer1,
1226            )
1227            .unwrap();
1228
1229        // check that the Server receives the Text message
1230        let ws_result = ws_server.read(&buffer1[..num_bytes], &mut buffer2).unwrap();
1231        assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1232        let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1233        assert_eq!(hello, received);
1234    }
1235
1236    #[test]
1237    fn send_message_from_server_to_client() {
1238        let mut buffer1: [u8; 1000] = [0; 1000];
1239        let mut buffer2: [u8; 1000] = [0; 1000];
1240
1241        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1242        ws_client.state = WebSocketState::Open;
1243        let mut ws_server = WebSocketServer::new_server();
1244        ws_server.state = WebSocketState::Open;
1245
1246        // server sends a Text message
1247        let hello = "hello";
1248        let num_bytes = ws_server
1249            .write(
1250                WebSocketSendMessageType::Text,
1251                true,
1252                &hello.as_bytes(),
1253                &mut buffer1,
1254            )
1255            .unwrap();
1256
1257        // check that the client receives the Text message
1258        let ws_result = ws_client.read(&buffer1[..num_bytes], &mut buffer2).unwrap();
1259        assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1260        let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1261        assert_eq!(hello, received);
1262    }
1263
1264    #[test]
1265    fn receive_buffer_too_small() {
1266        let mut buffer1: [u8; 1000] = [0; 1000];
1267        let mut buffer2: [u8; 1000] = [0; 1000];
1268
1269        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1270        ws_client.state = WebSocketState::Open;
1271        let mut ws_server = WebSocketServer::new_server();
1272        ws_server.state = WebSocketState::Open;
1273
1274        let hello = "hello";
1275        ws_server
1276            .write(
1277                WebSocketSendMessageType::Text,
1278                true,
1279                &hello.as_bytes(),
1280                &mut buffer1,
1281            )
1282            .unwrap();
1283
1284        match ws_client.read(&buffer1[..1], &mut buffer2) {
1285            Err(Error::ReadFrameIncomplete) => {
1286                // test passes
1287            }
1288            _ => {
1289                assert_eq!(true, false);
1290            }
1291        }
1292    }
1293
1294    #[test]
1295    fn receive_large_frame_with_small_receive_buffer() {
1296        let mut buffer1: [u8; 1000] = [0; 1000];
1297        let mut buffer2: [u8; 1000] = [0; 1000];
1298
1299        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1300        ws_client.state = WebSocketState::Open;
1301        let mut ws_server = WebSocketServer::new_server();
1302        ws_server.state = WebSocketState::Open;
1303
1304        let hello = "hello";
1305        ws_server
1306            .write(
1307                WebSocketSendMessageType::Text,
1308                true,
1309                &hello.as_bytes(),
1310                &mut buffer1,
1311            )
1312            .unwrap();
1313
1314        let ws_result = ws_client.read(&buffer1[..2], &mut buffer2).unwrap();
1315        assert_eq!(0, ws_result.len_to);
1316        assert_eq!(false, ws_result.end_of_message);
1317        let ws_result = ws_client.read(&buffer1[2..3], &mut buffer2).unwrap();
1318        assert_eq!(1, ws_result.len_to);
1319        assert_eq!(
1320            "h",
1321            std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1322        );
1323        assert_eq!(false, ws_result.end_of_message);
1324        let ws_result = ws_client.read(&buffer1[3..], &mut buffer2).unwrap();
1325        assert_eq!(4, ws_result.len_to);
1326        assert_eq!(
1327            "ello",
1328            std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1329        );
1330        assert_eq!(true, ws_result.end_of_message);
1331    }
1332
1333    #[test]
1334    fn send_large_frame() {
1335        let buffer1 = [0u8; 15944];
1336        let mut buffer2 = [0u8; 64000];
1337        let mut buffer3 = [0u8; 64000];
1338
1339        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1340        ws_client.state = WebSocketState::Open;
1341        let mut ws_server = WebSocketServer::new_server();
1342        ws_server.state = WebSocketState::Open;
1343
1344        ws_client
1345            .write(
1346                WebSocketSendMessageType::Binary,
1347                true,
1348                &buffer1,
1349                &mut buffer2,
1350            )
1351            .unwrap();
1352
1353        let ws_result = ws_client.read(&buffer2, &mut buffer3).unwrap();
1354        assert_eq!(true, ws_result.end_of_message);
1355        assert_eq!(buffer1.len(), ws_result.len_to);
1356    }
1357
1358    #[test]
1359    fn receive_large_frame_multi_read() {
1360        let mut buffer1 = [0_u8; 1000];
1361        let mut buffer2 = [0_u8; 1000];
1362
1363        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1364        ws_client.state = WebSocketState::Open;
1365        let mut ws_server = WebSocketServer::new_server();
1366        ws_server.state = WebSocketState::Open;
1367
1368        let message = "Hello, world. This is a long message that takes multiple reads";
1369        ws_server
1370            .write(
1371                WebSocketSendMessageType::Text,
1372                true,
1373                &message.as_bytes(),
1374                &mut buffer1,
1375            )
1376            .unwrap();
1377
1378        let mut buffer2_cursor = 0;
1379        let ws_result = ws_client.read(&buffer1[..40], &mut buffer2).unwrap();
1380        assert_eq!(false, ws_result.end_of_message);
1381        buffer2_cursor += ws_result.len_to;
1382        let ws_result = ws_client
1383            .read(
1384                &buffer1[ws_result.len_from..],
1385                &mut buffer2[buffer2_cursor..],
1386            )
1387            .unwrap();
1388        assert_eq!(true, ws_result.end_of_message);
1389        buffer2_cursor += ws_result.len_to;
1390
1391        assert_eq!(
1392            message,
1393            std::str::from_utf8(&buffer2[..buffer2_cursor]).unwrap()
1394        );
1395    }
1396
1397    #[test]
1398    fn multiple_messages_in_receive_buffer() {
1399        let mut buffer1 = [0_u8; 1000];
1400        let mut buffer2 = [0_u8; 1000];
1401
1402        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1403        ws_client.state = WebSocketState::Open;
1404        let mut ws_server = WebSocketServer::new_server();
1405        ws_server.state = WebSocketState::Open;
1406
1407        let message1 = "Hello, world.";
1408        let len = ws_client
1409            .write(
1410                WebSocketSendMessageType::Text,
1411                true,
1412                &message1.as_bytes(),
1413                &mut buffer1,
1414            )
1415            .unwrap();
1416        let message2 = "This is another message.";
1417        ws_client
1418            .write(
1419                WebSocketSendMessageType::Text,
1420                true,
1421                &message2.as_bytes(),
1422                &mut buffer1[len..],
1423            )
1424            .unwrap();
1425
1426        let mut buffer1_cursor = 0;
1427        let mut buffer2_cursor = 0;
1428        let ws_result = ws_server
1429            .read(&buffer1[buffer1_cursor..], &mut buffer2)
1430            .unwrap();
1431        assert_eq!(true, ws_result.end_of_message);
1432        buffer1_cursor += ws_result.len_from;
1433        buffer2_cursor += ws_result.len_to;
1434        let ws_result = ws_server
1435            .read(&buffer1[buffer1_cursor..], &mut buffer2[buffer2_cursor..])
1436            .unwrap();
1437        assert_eq!(true, ws_result.end_of_message);
1438        assert_eq!(
1439            message1,
1440            std::str::from_utf8(&buffer2[..buffer2_cursor]).unwrap()
1441        );
1442
1443        assert_eq!(
1444            message2,
1445            std::str::from_utf8(&buffer2[buffer2_cursor..buffer2_cursor + ws_result.len_to])
1446                .unwrap()
1447        );
1448    }
1449
1450    #[test]
1451    fn receive_large_frame_with_small_send_buffer() {
1452        let mut buffer1: [u8; 1000] = [0; 1000];
1453        let mut buffer2: [u8; 1000] = [0; 1000];
1454
1455        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1456        ws_client.state = WebSocketState::Open;
1457        let mut ws_server = WebSocketServer::new_server();
1458        ws_server.state = WebSocketState::Open;
1459
1460        let hello = "hello";
1461        ws_server
1462            .write(
1463                WebSocketSendMessageType::Text,
1464                true,
1465                &hello.as_bytes(),
1466                &mut buffer1,
1467            )
1468            .unwrap();
1469
1470        let ws_result = ws_client.read(&buffer1, &mut buffer2[..1]).unwrap();
1471        assert_eq!(1, ws_result.len_to);
1472        assert_eq!(
1473            "h",
1474            std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1475        );
1476        assert_eq!(false, ws_result.end_of_message);
1477        let ws_result = ws_client
1478            .read(&buffer1[ws_result.len_from..], &mut buffer2[..4])
1479            .unwrap();
1480        assert_eq!(4, ws_result.len_to);
1481        assert_eq!(
1482            "ello",
1483            std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1484        );
1485        assert_eq!(true, ws_result.end_of_message);
1486    }
1487
1488    #[test]
1489    fn send_two_frame_message() {
1490        let mut buffer1: [u8; 1000] = [0; 1000];
1491        let mut buffer2: [u8; 1000] = [0; 1000];
1492        // let mut rng = rand::thread_rng();
1493
1494        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1495        ws_client.state = WebSocketState::Open;
1496        let mut ws_server = WebSocketServer::new_server();
1497        ws_server.state = WebSocketState::Open;
1498
1499        // client sends a fragmented Text message
1500        let hello = "Hello, ";
1501        let num_bytes_hello = ws_server
1502            .write(
1503                WebSocketSendMessageType::Text,
1504                false,
1505                &hello.as_bytes(),
1506                &mut buffer1,
1507            )
1508            .unwrap();
1509
1510        // client sends the remaining Text message
1511        let world = "World!";
1512        let num_bytes_world = ws_server
1513            .write(
1514                WebSocketSendMessageType::Text,
1515                true,
1516                &world.as_bytes(),
1517                &mut buffer1[num_bytes_hello..],
1518            )
1519            .unwrap();
1520
1521        // check that the Server receives the entire Text message
1522        let ws_result1 = ws_client
1523            .read(&buffer1[..num_bytes_hello], &mut buffer2)
1524            .unwrap();
1525        assert_eq!(WebSocketReceiveMessageType::Text, ws_result1.message_type);
1526        assert_eq!(false, ws_result1.end_of_message);
1527        let ws_result2 = ws_client
1528            .read(
1529                &buffer1[num_bytes_hello..num_bytes_hello + num_bytes_world],
1530                &mut buffer2[ws_result1.len_to..],
1531            )
1532            .unwrap();
1533        assert_eq!(WebSocketReceiveMessageType::Text, ws_result2.message_type);
1534        assert_eq!(true, ws_result2.end_of_message);
1535
1536        let received =
1537            std::str::from_utf8(&buffer2[..ws_result1.len_to + ws_result2.len_to]).unwrap();
1538        assert_eq!("Hello, World!", received);
1539    }
1540
1541    #[test]
1542    fn send_multi_frame_message() {
1543        let mut buffer1: [u8; 1000] = [0; 1000];
1544        let mut buffer2: [u8; 1000] = [0; 1000];
1545
1546        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1547        ws_client.state = WebSocketState::Open;
1548        let mut ws_server = WebSocketServer::new_server();
1549        ws_server.state = WebSocketState::Open;
1550
1551        // server sends the first fragmented Text frame
1552        let fragment1 = "fragment1";
1553        let fragment1_num_bytes = ws_server
1554            .write(
1555                WebSocketSendMessageType::Text,
1556                false,
1557                &fragment1.as_bytes(),
1558                &mut buffer1,
1559            )
1560            .unwrap();
1561
1562        // send fragment2 as a continuation frame
1563        let fragment2 = "fragment2";
1564        let fragment2_num_bytes = ws_server
1565            .write(
1566                WebSocketSendMessageType::Text,
1567                false,
1568                &fragment2.as_bytes(),
1569                &mut buffer1[fragment1_num_bytes..],
1570            )
1571            .unwrap();
1572
1573        // send fragment3 as a continuation frame and indicate that this is the last frame
1574        let fragment3 = "fragment3";
1575        let _fragment3_num_bytes = ws_server
1576            .write(
1577                WebSocketSendMessageType::Text,
1578                true,
1579                &fragment3.as_bytes(),
1580                &mut buffer1[fragment1_num_bytes + fragment2_num_bytes..],
1581            )
1582            .unwrap();
1583
1584        // check that the client receives the "fragment1" Text message
1585        let ws_result = ws_client.read(&buffer1, &mut buffer2).unwrap();
1586        assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1587        let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1588        assert_eq!(fragment1, received);
1589        assert_eq!(ws_result.end_of_message, false);
1590        let mut read_cursor = ws_result.len_from;
1591
1592        // check that the client receives the "fragment2" Text message
1593        let ws_result = ws_client
1594            .read(&buffer1[read_cursor..], &mut buffer2)
1595            .unwrap();
1596        assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1597        let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1598        assert_eq!(fragment2, received);
1599        assert_eq!(ws_result.end_of_message, false);
1600        read_cursor += ws_result.len_from;
1601
1602        // check that the client receives the "fragment3" Text message
1603        let ws_result = ws_client
1604            .read(&buffer1[read_cursor..], &mut buffer2)
1605            .unwrap();
1606        assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1607        let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1608        assert_eq!(fragment3, received);
1609        assert_eq!(ws_result.end_of_message, true);
1610
1611        // check the actual bytes in the write buffer for fragment1
1612        let (is_fin_bit_set, op_code) = read_first_byte(buffer1[0]);
1613        assert_eq!(is_fin_bit_set, false);
1614        assert_eq!(op_code, WebSocketOpCode::TextFrame);
1615
1616        // check the actual bytes in the write buffer for fragment2
1617        let (is_fin_bit_set, op_code) = read_first_byte(buffer1[fragment1_num_bytes]);
1618        assert_eq!(is_fin_bit_set, false);
1619        assert_eq!(op_code, WebSocketOpCode::ContinuationFrame);
1620
1621        // check the actual bytes in the write buffer for fragment3
1622        let (is_fin_bit_set, op_code) =
1623            read_first_byte(buffer1[fragment1_num_bytes + fragment2_num_bytes]);
1624        assert_eq!(is_fin_bit_set, true);
1625        assert_eq!(op_code, WebSocketOpCode::ContinuationFrame);
1626    }
1627
1628    fn read_first_byte(byte: u8) -> (bool, WebSocketOpCode) {
1629        const FIN_BIT_FLAG: u8 = 0x80;
1630        const OP_CODE_FLAG: u8 = 0x0F;
1631        let is_fin_bit_set = (byte & FIN_BIT_FLAG) == FIN_BIT_FLAG;
1632        let op_code = get_op_code(byte & OP_CODE_FLAG).unwrap();
1633        (is_fin_bit_set, op_code)
1634    }
1635}