Skip to main content

embedded_websocket/
lib.rs

1//! # Embedded Websocket
2//!
3//! `embedded_websocket` facilitates the encoding and decoding of websocket frames and can be used
4//! for both clients and servers. The library is intended to be used in constrained memory
5//! environments like embedded microcontrollers which cannot reference the rust standard library.
6//! It will work with arbitrarily small buffers regardless of websocket frame size as long as the
7//! websocket header can be read (2 - 14 bytes depending on the payload size and masking).
8//! Since the library is essentially an encoder or decoder of byte slices, the developer is free to
9//! use whatever transport mechanism they chose. The examples in the source repository use the
10//! TcpStream from the standard library.
11
12// support for websockets without using the standard library
13#![cfg_attr(not(feature = "std"), no_std)]
14
15use byteorder::{BigEndian, ByteOrder};
16use core::{cmp, result, str};
17use heapless::{String, Vec};
18use rand_core::RngCore;
19use sha1::{Digest, Sha1};
20
21mod http;
22pub mod random;
23pub use self::http::{read_http_header, WebSocketContext};
24pub use self::random::EmptyRng;
25
26// support for working with discrete websocket frames when using IO streams
27// start here!!
28pub mod framer;
29pub mod framer_async;
30const MASK_KEY_LEN: usize = 4;
31
32/// Result returning a websocket specific error if encountered
33pub type Result<T> = result::Result<T, Error>;
34
35/// A fixed length 24-character string used to hold a websocket key for the opening handshake
36pub type WebSocketKey = String<24>;
37
38/// A maximum sized 24-character string used to store a sub protocol (e.g. `chat`)
39pub type WebSocketSubProtocol = String<24>;
40
41/// Websocket send message type used when sending a websocket frame
42#[derive(PartialEq, Eq, Debug, Copy, Clone)]
43pub enum WebSocketSendMessageType {
44    /// A UTF8 encoded text string
45    Text = 1,
46    /// Binary data
47    Binary = 2,
48    /// An unsolicited ping message
49    Ping = 9,
50    /// A pong message in response to a ping message
51    Pong = 10,
52    /// A close message in response to a close message from the other party. Used to complete a
53    /// closing handshake. If initiate a close handshake use the `close` function
54    CloseReply = 11,
55}
56
57impl WebSocketSendMessageType {
58    fn to_op_code(self) -> WebSocketOpCode {
59        match self {
60            WebSocketSendMessageType::Text => WebSocketOpCode::TextFrame,
61            WebSocketSendMessageType::Binary => WebSocketOpCode::BinaryFrame,
62            WebSocketSendMessageType::Ping => WebSocketOpCode::Ping,
63            WebSocketSendMessageType::Pong => WebSocketOpCode::Pong,
64            WebSocketSendMessageType::CloseReply => WebSocketOpCode::ConnectionClose,
65        }
66    }
67}
68
69/// Websocket receive message type use when reading a websocket frame
70#[derive(PartialEq, Eq, Debug, Copy, Clone)]
71pub enum WebSocketReceiveMessageType {
72    /// A UTF8 encoded text string
73    Text = 1,
74    /// Binary data
75    Binary = 2,
76    /// Signals that the close handshake is complete
77    CloseCompleted = 7,
78    /// Signals that the other party has initiated the close handshake. If you receive this message
79    /// you should respond with a `WebSocketSendMessageType::CloseReply` with the same payload as
80    /// close message
81    CloseMustReply = 8,
82    /// A ping message that you should respond to with a `WebSocketSendMessageType::Pong` message
83    /// including the same payload as the ping
84    Ping = 9,
85    /// A pong message in response to a ping message
86    Pong = 10,
87}
88
89/// Websocket close status code as per the rfc6455 websocket spec
90#[derive(PartialEq, Eq, Debug, Copy, Clone)]
91pub enum WebSocketCloseStatusCode {
92    /// Normal closure (1000), meaning that the purpose for which the connection was established
93    /// has been fulfilled
94    NormalClosure,
95    /// Endpoint unavailable (1001) indicates that an endpoint is "going away", such as a server
96    /// going down or a browser having navigated away from a page
97    EndpointUnavailable,
98    /// Protocol error (1002) indicates that an endpoint is terminating the connection due
99    /// to a protocol error.
100    ProtocolError,
101    /// Invalid message type (1003) indicates that an endpoint is terminating the connection
102    /// because it has received a type of data it cannot accept (e.g., an endpoint that
103    /// understands only text data MAY send this if it receives a binary message)
104    InvalidMessageType,
105    /// Reserved (1004) for future use
106    Reserved,
107    /// Empty (1005) indicates that no status code was present
108    Empty,
109    /// Invalid payload data (1007) indicates that an endpoint is terminating the connection
110    /// because it has received data within a message that was not consistent with the type of
111    /// the message (e.g., non-UTF-8 data within a text message)
112    InvalidPayloadData,
113    /// Policy violation (1008) indicates that an endpoint is terminating the connection because
114    /// it has received a message that violates its policy. This is a generic status code that
115    /// can be returned when there is no other more suitable status code
116    PolicyViolation,
117    /// Message too big (1009) indicates that an endpoint is terminating the connection because
118    /// it has received a message that is too big for it to process
119    MessageTooBig,
120    /// Mandatory extension (1010) indicates that an endpoint (client) is terminating the
121    /// connection because it has expected the server to negotiate one or more extension, but
122    /// the server didn't return them in the response message of the WebSocket handshake
123    MandatoryExtension,
124    /// Internal server error (1011) indicates that a server is terminating the connection because
125    /// it encountered an unexpected condition that prevented it from fulfilling the request
126    InternalServerError,
127    /// TLS handshake (1015) connection was closed due to a failure to perform a TLS handshake
128    TlsHandshake,
129    /// Custom close code
130    Custom(u16),
131}
132
133impl WebSocketCloseStatusCode {
134    fn from_u16(value: u16) -> WebSocketCloseStatusCode {
135        match value {
136            1000 => WebSocketCloseStatusCode::NormalClosure,
137            1001 => WebSocketCloseStatusCode::EndpointUnavailable,
138            1002 => WebSocketCloseStatusCode::ProtocolError,
139            1003 => WebSocketCloseStatusCode::InvalidMessageType,
140            1004 => WebSocketCloseStatusCode::Reserved,
141            1005 => WebSocketCloseStatusCode::Empty,
142            1007 => WebSocketCloseStatusCode::InvalidPayloadData,
143            1008 => WebSocketCloseStatusCode::PolicyViolation,
144            1009 => WebSocketCloseStatusCode::MessageTooBig,
145            1010 => WebSocketCloseStatusCode::MandatoryExtension,
146            1011 => WebSocketCloseStatusCode::InternalServerError,
147            1015 => WebSocketCloseStatusCode::TlsHandshake,
148            _ => WebSocketCloseStatusCode::Custom(value),
149        }
150    }
151
152    fn to_u16(self) -> u16 {
153        match self {
154            WebSocketCloseStatusCode::NormalClosure => 1000,
155            WebSocketCloseStatusCode::EndpointUnavailable => 1001,
156            WebSocketCloseStatusCode::ProtocolError => 1002,
157            WebSocketCloseStatusCode::InvalidMessageType => 1003,
158            WebSocketCloseStatusCode::Reserved => 1004,
159            WebSocketCloseStatusCode::Empty => 1005,
160            WebSocketCloseStatusCode::InvalidPayloadData => 1007,
161            WebSocketCloseStatusCode::PolicyViolation => 1008,
162            WebSocketCloseStatusCode::MessageTooBig => 1009,
163            WebSocketCloseStatusCode::MandatoryExtension => 1010,
164            WebSocketCloseStatusCode::InternalServerError => 1011,
165            WebSocketCloseStatusCode::TlsHandshake => 1015,
166            WebSocketCloseStatusCode::Custom(value) => value,
167        }
168    }
169}
170
171/// The state of the websocket
172#[derive(PartialEq, Eq, Copy, Clone, Debug)]
173pub enum WebSocketState {
174    /// The websocket has been created with `new_client()` or `new_server()`
175    None = 0,
176    /// The client has created an opening handshake
177    Connecting = 1,
178    /// The server has completed the opening handshake via server_accept() or, likewise, the
179    /// client has completed the opening handshake via client_accept(). The user is free to call
180    /// `write()`, `read()` or `close()` on the websocket
181    Open = 2,
182    /// The `close()` function has been called
183    CloseSent = 3,
184    /// A Close websocket frame has been received
185    CloseReceived = 4,
186    /// The close handshake has been completed
187    Closed = 5,
188    /// The server or client opening handshake failed
189    Aborted = 6,
190}
191
192/// Websocket specific errors
193#[derive(PartialEq, Eq, Debug)]
194pub enum Error {
195    /// Websocket frame has an invalid opcode
196    InvalidOpCode,
197    InvalidFrameLength,
198    InvalidCloseStatusCode,
199    WebSocketNotOpen,
200    WebsocketAlreadyOpen,
201    Utf8Error,
202    Unknown,
203    HttpHeader(httparse::Error),
204    HttpHeaderNoPath,
205    HttpHeaderIncomplete,
206    WriteToBufferTooSmall,
207    ReadFrameIncomplete,
208    HttpResponseCodeInvalid(Option<u16>),
209    AcceptStringInvalid,
210    ConvertInfallible,
211    RandCore,
212    UnexpectedContinuationFrame,
213}
214
215impl From<httparse::Error> for Error {
216    fn from(err: httparse::Error) -> Error {
217        Error::HttpHeader(err)
218    }
219}
220
221impl From<str::Utf8Error> for Error {
222    fn from(_: str::Utf8Error) -> Error {
223        Error::Utf8Error
224    }
225}
226
227impl From<core::convert::Infallible> for Error {
228    fn from(_: core::convert::Infallible) -> Error {
229        Error::ConvertInfallible
230    }
231}
232
233impl From<()> for Error {
234    fn from(_: ()) -> Error {
235        Error::Unknown
236    }
237}
238
239impl core::fmt::Display for Error {
240    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
241        match self {
242            Error::HttpHeader(error) => write!(f, "bad http header {error}"),
243            Error::HttpResponseCodeInvalid(Some(code)) => write!(f, "bad http response ({code})"),
244            _ => write!(f, "{:?}", self),
245        }
246    }
247}
248
249#[cfg(feature = "std")]
250impl std::error::Error for Error {
251    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
252        if let Self::HttpHeader(error) = self {
253            Some(error)
254        } else {
255            None
256        }
257    }
258}
259
260#[derive(Copy, Clone, Debug, PartialEq, Eq)]
261enum WebSocketOpCode {
262    ContinuationFrame = 0,
263    TextFrame = 1,
264    BinaryFrame = 2,
265    ConnectionClose = 8,
266    Ping = 9,
267    Pong = 10,
268}
269
270impl WebSocketOpCode {
271    fn to_message_type(self) -> Result<WebSocketReceiveMessageType> {
272        match self {
273            WebSocketOpCode::TextFrame => Ok(WebSocketReceiveMessageType::Text),
274            WebSocketOpCode::BinaryFrame => Ok(WebSocketReceiveMessageType::Binary),
275            _ => Err(Error::InvalidOpCode),
276        }
277    }
278}
279
280/// The metadata result of a `read` function call of a websocket
281#[derive(Debug)]
282pub struct WebSocketReadResult {
283    /// Number of bytes read from the `from` buffer
284    pub len_from: usize,
285    /// Number of bytes written to the `to` buffer
286    pub len_to: usize,
287    /// End of message flag is `true` if the `to` buffer contains an entire websocket frame
288    /// payload otherwise `false` if the user must continue calling the read function to get the
289    /// rest of the payload
290    pub end_of_message: bool,
291    /// Close status code (as per the websocket spec) if the message type is `CloseMustReply` or
292    /// `CloseCompleted`. If a close status is specified then a UTF8 encoded string could also
293    /// appear in the frame payload giving more detailed information about why the websocket was
294    /// closed.
295    pub close_status: Option<WebSocketCloseStatusCode>,
296    /// The websocket frame type
297    pub message_type: WebSocketReceiveMessageType,
298}
299
300/// Websocket options used by a websocket client to initiate an opening handshake with a
301/// websocket server
302pub struct WebSocketOptions<'a> {
303    /// The request uri (e.g. `/chat?id=123`) of the GET method used to identify the endpoint of the
304    /// websocket connection. This allows multiple domains to be served by a single server.
305    /// This could also be used to send identifiable information about the client
306    pub path: &'a str,
307    /// The hostname (e.g. `server.example.com`) is used so that both the client and the server
308    /// can verify that they agree on which host is in use
309    pub host: &'a str,
310    /// The origin (e.g. `http://example.com`) is used to protect against unauthorized
311    /// cross-origin use of a WebSocket server by scripts using the WebSocket API in a web
312    /// browser. This field is usually only set by browser clients but servers may require it
313    /// so it has been exposed here.
314    pub origin: &'a str,
315    /// A list of requested sub protocols in order of preference. The server should return the
316    /// first sub protocol it supports or none at all. A sub protocol can be anything agreed
317    /// between the server and client
318    pub sub_protocols: Option<&'a [&'a str]>,
319    /// Any additional headers the server may require that are not part of the websocket
320    /// spec. These should be fully formed http headers without the `\r\n` (e.g. `MyHeader: foo`)
321    pub additional_headers: Option<&'a [&'a str]>,
322}
323
324/// Used to return a sized type from `WebSocket::new_server()`
325pub type WebSocketServer = WebSocket<EmptyRng, Server>;
326
327/// Used to return a sized type from `WebSocketClient::new_client()`
328pub type WebSocketClient<T> = WebSocket<T, Client>;
329
330// Simple Typestate pattern for preventing panics and allowing reuse of underlying
331// read/read_frame/etc..
332pub enum Server {}
333pub enum Client {}
334
335pub trait WebSocketType {}
336impl WebSocketType for Server {}
337impl WebSocketType for Client {}
338
339/// Websocket client and server implementation
340pub struct WebSocket<T, S: WebSocketType>
341where
342    T: RngCore,
343{
344    is_client: bool,
345    rng: T,
346    continuation_frame_op_code: Option<WebSocketOpCode>,
347    is_write_continuation: bool,
348    pub state: WebSocketState,
349    continuation_read: Option<ContinuationRead>,
350    marker: core::marker::PhantomData<S>,
351}
352
353impl<T, Type> WebSocket<T, Type>
354where
355    T: RngCore,
356    Type: WebSocketType,
357{
358    /// Creates a new websocket client by passing in a required random number generator
359    ///
360    /// # Examples
361    /// ```
362    /// use embedded_websocket as ws;
363    /// use rand;
364    /// let mut ws_client = ws::WebSocketClient::new_client(rand::thread_rng());
365    ///
366    /// assert_eq!(ws::WebSocketState::None, ws_client.state);
367    /// ```
368    pub fn new_client(rng: T) -> WebSocketClient<T> {
369        WebSocket {
370            is_client: true,
371            rng,
372            continuation_frame_op_code: None,
373            is_write_continuation: false,
374            state: WebSocketState::None,
375            continuation_read: None,
376            marker: core::marker::PhantomData::<Client>,
377        }
378    }
379
380    /// Creates a new websocket server. Note that you must use the `WebSocketServer` type and
381    /// not the generic `WebSocket` type for this call or you will get a `'type annotations needed'`
382    /// compilation error.
383    ///
384    /// # Examples
385    /// ```
386    /// use embedded_websocket as ws;
387    /// let mut ws_server = ws::WebSocketServer::new_server();
388    ///
389    /// assert_eq!(ws::WebSocketState::None, ws_server.state);
390    /// ```
391    pub fn new_server() -> WebSocketServer {
392        let rng = EmptyRng::new();
393        WebSocket {
394            is_client: false,
395            rng,
396            continuation_frame_op_code: None,
397            is_write_continuation: false,
398            state: WebSocketState::None,
399            continuation_read: None,
400            marker: core::marker::PhantomData::<Server>,
401        }
402    }
403}
404
405impl<T> WebSocket<T, Server>
406where
407    T: RngCore,
408{
409    /// Used by the server to accept an incoming client connection and build a websocket upgrade
410    /// http response string. The client http header should be read with the `read_http_header`
411    /// function and the result should be passed to this function.
412    /// Websocket state will change from None -> Open if successful, otherwise None -> Aborted
413    ///
414    /// # Examples
415    ///
416    /// ```
417    /// use embedded_websocket as ws;
418    /// let mut buffer: [u8; 1000] = [0; 1000];
419    /// let mut ws_server = ws::WebSocketServer::new_server();
420    /// let ws_key = ws::WebSocketKey::from("Z7OY1UwHOx/nkSz38kfPwg==");
421    /// let sub_protocol = ws::WebSocketSubProtocol::from("chat");
422    /// let len = ws_server
423    ///     .server_accept(&ws_key, Some(&sub_protocol), &mut buffer)
424    ///     .unwrap();
425    /// let response = std::str::from_utf8(&buffer[..len]).unwrap();
426    ///
427    /// assert_eq!("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Protocol: chat\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n", response);
428    /// ```
429    ///
430    /// # Errors
431    /// There should be no way for a user provided input to return the errors listed below as the
432    /// input is already constrained.
433    /// * The http response is built with a stack allocated 1KB buffer and it *should be impossible*
434    ///   for it to return an  `Unknown` error if that buffer is too small. However, this is better
435    ///   than a panic and it will do so if the response header is too large to fit in the buffer
436    /// * This function can return an `Utf8Error` if there was an error with the generation of the
437    ///   accept string. This should also be impossible but an error is preferable to a panic
438    /// * Returns `WebsocketAlreadyOpen` if called on a websocket that is already open
439    pub fn server_accept(
440        &mut self,
441        sec_websocket_key: &WebSocketKey,
442        sec_websocket_protocol: Option<&WebSocketSubProtocol>,
443        to: &mut [u8],
444    ) -> Result<usize> {
445        if self.state == WebSocketState::Open {
446            return Err(Error::WebsocketAlreadyOpen);
447        }
448
449        match http::build_connect_handshake_response(sec_websocket_key, sec_websocket_protocol, to)
450        {
451            Ok(http_response_len) => {
452                self.state = WebSocketState::Open;
453                Ok(http_response_len)
454            }
455            Err(e) => {
456                self.state = WebSocketState::Aborted;
457                Err(e)
458            }
459        }
460    }
461}
462
463impl<T> WebSocket<T, Client>
464where
465    T: RngCore,
466{
467    /// Used by the client to initiate a websocket opening handshake
468    ///
469    /// # Examples
470    /// ```
471    /// use embedded_websocket as ws;
472    /// let mut buffer: [u8; 2000] = [0; 2000];
473    /// let mut ws_client = ws::WebSocketClient::new_client(rand::thread_rng());
474    /// let sub_protocols = ["chat", "superchat"];
475    /// let websocket_options = ws::WebSocketOptions {
476    ///     path: "/chat",
477    ///     host: "localhost",
478    ///     origin: "http://localhost",
479    ///     sub_protocols: Some(&sub_protocols),
480    ///     additional_headers: None,
481    /// };
482    ///
483    /// let (len, web_socket_key) = ws_client.client_connect(&websocket_options, &mut buffer).unwrap();
484    ///
485    /// let actual_http = std::str::from_utf8(&buffer[..len]).unwrap();
486    /// let mut expected_http = String::new();
487    /// expected_http.push_str("GET /chat HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: ");
488    /// expected_http.push_str(web_socket_key.as_str());
489    /// expected_http.push_str("\r\nOrigin: http://localhost\r\nSec-WebSocket-Protocol: chat, superchat\r\nSec-WebSocket-Version: 13\r\n\r\n");
490    /// assert_eq!(expected_http.as_str(), actual_http);
491    /// ```
492    ///
493    /// # Errors
494    /// * The http response is built with a stack allocated 1KB buffer and will return an `Unknown` error
495    ///   if that buffer is too small. This would happen is the user supplied too many
496    ///   additional headers or the sub-protocol string is too large
497    /// * This function can return an `Utf8Error` if there was an error with the generation of the
498    ///   accept string. This should be impossible but an error is preferable to a panic
499    /// * Returns `WebsocketAlreadyOpen` if called on a websocket that is already open
500    pub fn client_connect(
501        &mut self,
502        websocket_options: &WebSocketOptions,
503        to: &mut [u8],
504    ) -> Result<(usize, WebSocketKey)> {
505        if self.state == WebSocketState::Open {
506            return Err(Error::WebsocketAlreadyOpen);
507        }
508
509        match http::build_connect_handshake_request(websocket_options, &mut self.rng, to) {
510            Ok((request_len, sec_websocket_key)) => {
511                self.state = WebSocketState::Connecting;
512                Ok((request_len, sec_websocket_key))
513            }
514            Err(e) => Err(e),
515        }
516    }
517
518    /// Used by a websocket client for checking the server response to an opening handshake
519    /// (sent using the client_connect function). If the client requested one or more sub protocols
520    /// the server will choose one (or none) and you get that in the result
521    /// # Examples
522    /// ```
523    /// use embedded_websocket as ws;
524    /// let mut ws_client = ws::WebSocketClient::new_client(rand::thread_rng());
525    /// let ws_key = ws::WebSocketKey::from("Z7OY1UwHOx/nkSz38kfPwg==");
526    /// let server_response_html = "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Protocol: chat\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n";    ///
527    /// let (len, sub_protocol) = ws_client.client_accept(&ws_key, server_response_html.as_bytes())
528    ///     .unwrap();
529    ///
530    /// assert_eq!(159, len);
531    /// assert_eq!("chat", sub_protocol.unwrap());
532    /// ```
533    /// # Errors
534    /// * Returns `HttpResponseCodeInvalid` if the HTTP response code is not `101 Switching Protocols`
535    /// * Returns `AcceptStringInvalid` if the web server failed to return a valid accept string
536    /// * Returns `HttpHeader(Version)` or some other varient if the HTTP response is not well formed
537    /// * Returns `WebsocketAlreadyOpen` if called on a websocket that is already open
538    pub fn client_accept(
539        &mut self,
540        sec_websocket_key: &WebSocketKey,
541        from: &[u8],
542    ) -> Result<(usize, Option<WebSocketSubProtocol>)> {
543        if self.state == WebSocketState::Open {
544            return Err(Error::WebsocketAlreadyOpen);
545        }
546
547        match http::read_server_connect_handshake_response(sec_websocket_key, from) {
548            Ok((len, sec_websocket_protocol)) => {
549                self.state = WebSocketState::Open;
550                Ok((len, sec_websocket_protocol))
551            }
552            Err(Error::HttpHeaderIncomplete) => Err(Error::HttpHeaderIncomplete),
553            Err(e) => {
554                self.state = WebSocketState::Aborted;
555                Err(e)
556            }
557        }
558    }
559}
560
561impl<T, Type> WebSocket<T, Type>
562where
563    T: RngCore,
564    Type: WebSocketType,
565{
566    /// Reads the payload from a websocket frame in buffer `from` into a buffer `to` and returns
567    /// metadata about the frame. Since this function is designed to be called in a memory
568    /// constrained system we may not read the entire payload in one go. In each of the scenarios
569    /// below the `read_result.end_of_message` flag would be `false`:
570    /// * The payload is fragmented into multiple websocket frames (as per the websocket spec)
571    /// * The `from` buffer does not hold the entire websocket frame. For example if only part of
572    ///   the frame was read or if the `from` buffer is too small to hold an entire websocket frame
573    /// * The `to` buffer is too small to hold the entire websocket frame payload
574    ///
575    /// If the function returns `read_result.end_of_message` `false` then the next
576    /// call to the function should not include data that has already been passed into the function.
577    /// The websocket *remembers* the websocket frame header and is able to process the rest of the
578    /// payload correctly. If the `from` buffer contains multiple websocket frames then only one of
579    /// them will be returned at a time and the user must make multiple calls to the function by
580    /// taking note of `read_result.len_from` which tells you how many bytes were read from the
581    /// `from` buffer
582    ///
583    /// # Examples
584    ///
585    /// ```
586    /// use embedded_websocket as ws;
587    /// //                    h   e   l   l   o
588    /// let buffer1 = [129,5,104,101,108,108,111];
589    /// let mut buffer2: [u8; 128] = [0; 128];
590    /// let mut ws_client = ws::WebSocketClient::new_client(rand::thread_rng());
591    /// ws_client.state = ws::WebSocketState::Open; // skip the opening handshake
592    /// let ws_result = ws_client.read(&buffer1, &mut buffer2).unwrap();
593    ///
594    /// assert_eq!("hello".as_bytes(), &buffer2[..ws_result.len_to]);
595    /// ```
596    /// # Errors
597    /// * Returns `WebSocketNotOpen` when the websocket is not open when this function is called
598    /// * Returns `InvalidOpCode` if the websocket frame contains an invalid opcode
599    /// * Returns `UnexpectedContinuationFrame` if we receive a continuation frame without first
600    ///   receiving a non-continuation frame with an opcode describing the payload
601    /// * Returns `ReadFrameIncomplete` if the `from` buffer does not contain a full websocket
602    ///   header (typically 2-14 bytes depending on the payload)
603    /// * Returns `InvalidFrameLength` if the frame length cannot be decoded
604    ///
605    pub fn read(&mut self, from: &[u8], to: &mut [u8]) -> Result<WebSocketReadResult> {
606        if self.state == WebSocketState::Open || self.state == WebSocketState::CloseSent {
607            let frame = self.read_frame(from, to)?;
608
609            match frame.op_code {
610                WebSocketOpCode::Ping => Ok(frame.to_readresult(WebSocketReceiveMessageType::Ping)),
611                WebSocketOpCode::Pong => Ok(frame.to_readresult(WebSocketReceiveMessageType::Pong)),
612                WebSocketOpCode::TextFrame => {
613                    Ok(frame.to_readresult(WebSocketReceiveMessageType::Text))
614                }
615                WebSocketOpCode::BinaryFrame => {
616                    Ok(frame.to_readresult(WebSocketReceiveMessageType::Binary))
617                }
618                WebSocketOpCode::ConnectionClose => match self.state {
619                    WebSocketState::CloseSent => {
620                        self.state = WebSocketState::Closed;
621                        Ok(frame.to_readresult(WebSocketReceiveMessageType::CloseCompleted))
622                    }
623                    _ => {
624                        self.state = WebSocketState::CloseReceived;
625                        Ok(frame.to_readresult(WebSocketReceiveMessageType::CloseMustReply))
626                    }
627                },
628                WebSocketOpCode::ContinuationFrame => match self.continuation_frame_op_code {
629                    Some(cf_op_code) => Ok(frame.to_readresult(cf_op_code.to_message_type()?)),
630                    None => Err(Error::UnexpectedContinuationFrame),
631                },
632            }
633        } else {
634            Err(Error::WebSocketNotOpen)
635        }
636    }
637
638    /// Writes the payload in `from` to a websocket frame in `to`
639    /// * message_type - The type of message to send: Text, Binary or CloseReply
640    /// * end_of_message - False to fragment a frame into multiple smaller frames. The last frame
641    ///   should set this to true
642    /// * from - The buffer containing the payload to encode
643    /// * to - The the buffer to save the websocket encoded payload to.
644    ///   Returns the number of bytes written to the `to` buffer
645    /// # Examples
646    ///
647    /// ```
648    /// use embedded_websocket as ws;
649    /// let mut buffer: [u8; 1000] = [0; 1000];
650    /// let mut ws_server = ws::WebSocketServer::new_server();
651    /// ws_server.state = ws::WebSocketState::Open; // skip the opening handshake
652    /// let len = ws_server.write(ws::WebSocketSendMessageType::Text, true, "hello".as_bytes(),
653    ///     &mut buffer).unwrap();
654    ///
655    /// //                     h   e   l   l   o
656    /// let expected = [129,5,104,101,108,108,111];
657    /// assert_eq!(&expected, &buffer[..len]);
658    /// ```
659    /// # Errors
660    /// * Returns `WebSocketNotOpen` when the websocket is not open when this function is called
661    /// * Returns `WriteToBufferTooSmall` when the `to` buffer is too small to fit the websocket
662    ///   frame header (2-14 bytes) plus the payload. Consider fragmenting the messages by making
663    ///   multiple write calls with `end_of_message` set to `false` and the final call set to `true`
664    pub fn write(
665        &mut self,
666        message_type: WebSocketSendMessageType,
667        end_of_message: bool,
668        from: &[u8],
669        to: &mut [u8],
670    ) -> Result<usize> {
671        if self.state == WebSocketState::Open || self.state == WebSocketState::CloseReceived {
672            let mut op_code = message_type.to_op_code();
673            if op_code == WebSocketOpCode::ConnectionClose {
674                self.state = WebSocketState::Closed
675            } else if self.is_write_continuation {
676                op_code = WebSocketOpCode::ContinuationFrame;
677            }
678
679            self.is_write_continuation = !end_of_message;
680            self.write_frame(from, to, op_code, end_of_message)
681        } else {
682            Err(Error::WebSocketNotOpen)
683        }
684    }
685
686    /// Initiates a close handshake.
687    /// Both the client and server may initiate a close handshake. If successful the function
688    /// changes the websocket state from Open -> CloseSent
689    /// # Errors
690    /// * Returns `WebSocketNotOpen` when the websocket is not open when this function is called
691    /// * Returns `WriteToBufferTooSmall` when the `to` buffer is too small to fit the websocket
692    ///   frame header (2-14 bytes) plus the payload. Consider sending a smaller status_description
693    pub fn close(
694        &mut self,
695        close_status: WebSocketCloseStatusCode,
696        status_description: Option<&str>,
697        to: &mut [u8],
698    ) -> Result<usize> {
699        if self.state == WebSocketState::Open {
700            self.state = WebSocketState::CloseSent;
701            if let Some(status_description) = status_description {
702                let mut from_buffer: Vec<u8, 256> = Vec::new();
703                from_buffer.extend_from_slice(&close_status.to_u16().to_be_bytes())?;
704
705                // restrict the max size of the status_description
706                let len = if status_description.len() < 254 {
707                    status_description.len()
708                } else {
709                    254
710                };
711
712                from_buffer.extend_from_slice(&status_description.as_bytes()[..len])?;
713                self.write_frame(&from_buffer, to, WebSocketOpCode::ConnectionClose, true)
714            } else {
715                let mut from_buffer: [u8; 2] = [0; 2];
716                BigEndian::write_u16(&mut from_buffer, close_status.to_u16());
717                self.write_frame(&from_buffer, to, WebSocketOpCode::ConnectionClose, true)
718            }
719        } else {
720            Err(Error::WebSocketNotOpen)
721        }
722    }
723
724    fn read_frame(&mut self, from_buffer: &[u8], to_buffer: &mut [u8]) -> Result<WebSocketFrame> {
725        match &mut self.continuation_read {
726            Some(continuation_read) => {
727                let result = read_continuation(continuation_read, from_buffer, to_buffer);
728                if result.is_fin_bit_set {
729                    self.continuation_read = None;
730                    self.continuation_frame_op_code = None;
731                }
732                Ok(result)
733            }
734            None => {
735                let (mut result, continuation_read) = read_frame(from_buffer, to_buffer)?;
736
737                // handle fragmented messages
738                match result.op_code {
739                    WebSocketOpCode::BinaryFrame | WebSocketOpCode::TextFrame => {
740                        // set the continuation frame op code if this is a fragmented message
741                        self.continuation_frame_op_code = if result.is_fin_bit_set {
742                            None
743                        } else {
744                            Some(result.op_code)
745                        };
746                    }
747                    WebSocketOpCode::ContinuationFrame => {
748                        // override the op code we get from the result with our continuation frame opcode if it exists
749                        if let Some(continuation_frame_op_code) = self.continuation_frame_op_code {
750                            result.op_code = continuation_frame_op_code;
751                        }
752                    }
753                    _ => {
754                        // control frames should not affect the continuation frame mechanics
755                    }
756                }
757
758                self.continuation_read = continuation_read;
759                Ok(result)
760            }
761        }
762    }
763
764    fn write_frame(
765        &mut self,
766        from_buffer: &[u8],
767        to_buffer: &mut [u8],
768        op_code: WebSocketOpCode,
769        end_of_message: bool,
770    ) -> Result<usize> {
771        let fin_bit_set_as_byte: u8 = if end_of_message { 0x80 } else { 0x00 };
772        let byte1: u8 = fin_bit_set_as_byte | op_code as u8;
773        let count = from_buffer.len();
774        const BYTE_HEADER_SIZE: usize = 2;
775        const SHORT_HEADER_SIZE: usize = 4;
776        const LONG_HEADER_SIZE: usize = 10;
777        const MASK_KEY_SIZE: usize = 4;
778        let header_size;
779        let mask_bit_set_as_byte = if self.is_client { 0x80 } else { 0x00 };
780        let payload_len = from_buffer.len() + if self.is_client { MASK_KEY_SIZE } else { 0 };
781
782        // write header followed by the payload
783        // header size depends on how large the payload is
784        if count < 126 {
785            if payload_len + BYTE_HEADER_SIZE > to_buffer.len() {
786                return Err(Error::WriteToBufferTooSmall);
787            }
788            to_buffer[0] = byte1;
789            to_buffer[1] = mask_bit_set_as_byte | count as u8;
790            header_size = BYTE_HEADER_SIZE;
791        } else if count < 65535 {
792            if payload_len + SHORT_HEADER_SIZE > to_buffer.len() {
793                return Err(Error::WriteToBufferTooSmall);
794            }
795            to_buffer[0] = byte1;
796            to_buffer[1] = mask_bit_set_as_byte | 126;
797            BigEndian::write_u16(&mut to_buffer[2..], count as u16);
798            header_size = SHORT_HEADER_SIZE;
799        } else {
800            if payload_len + LONG_HEADER_SIZE > to_buffer.len() {
801                return Err(Error::WriteToBufferTooSmall);
802            }
803            to_buffer[0] = byte1;
804            to_buffer[1] = mask_bit_set_as_byte | 127;
805            BigEndian::write_u64(&mut to_buffer[2..], count as u64);
806            header_size = LONG_HEADER_SIZE;
807        }
808
809        // sent by client - need to mask the data
810        // we need to mask the bytes to prevent web server caching
811        if self.is_client {
812            let mut mask_key = [0; MASK_KEY_SIZE];
813            self.rng.fill_bytes(&mut mask_key); // clients always have an rng instance
814            to_buffer[header_size..header_size + MASK_KEY_SIZE].copy_from_slice(&mask_key);
815            let to_buffer_start = header_size + MASK_KEY_SIZE;
816
817            // apply the mask key to every byte in the payload. This is a hot function
818            for (i, (from, to)) in from_buffer[..count]
819                .iter()
820                .zip(&mut to_buffer[to_buffer_start..to_buffer_start + count])
821                .enumerate()
822            {
823                *to = *from ^ mask_key[i % MASK_KEY_SIZE];
824            }
825
826            Ok(to_buffer_start + count)
827        } else {
828            to_buffer[header_size..header_size + count].copy_from_slice(&from_buffer[..count]);
829            Ok(header_size + count)
830        }
831    }
832}
833
834// Continuation read is used when we cannot fit the entire websocket frame into the supplied buffer
835struct ContinuationRead {
836    op_code: WebSocketOpCode,
837    count: usize,
838    is_fin_bit_set: bool,
839    mask_key: Option<[u8; 4]>,
840}
841
842struct WebSocketFrame {
843    is_fin_bit_set: bool,
844    op_code: WebSocketOpCode,
845    num_bytes_to: usize,
846    num_bytes_from: usize,
847    close_status: Option<WebSocketCloseStatusCode>,
848}
849
850impl WebSocketFrame {
851    fn to_readresult(&self, message_type: WebSocketReceiveMessageType) -> WebSocketReadResult {
852        WebSocketReadResult {
853            len_from: self.num_bytes_from,
854            len_to: self.num_bytes_to,
855            end_of_message: self.is_fin_bit_set,
856            close_status: self.close_status,
857            message_type,
858        }
859    }
860}
861
862fn min(num1: usize, num2: usize, num3: usize) -> usize {
863    cmp::min(cmp::min(num1, num2), num3)
864}
865
866fn read_into_buffer(
867    mask_key: &mut Option<[u8; 4]>,
868    from_buffer: &[u8],
869    to_buffer: &mut [u8],
870    len: usize,
871) -> usize {
872    // if we are trying to read more than number of bytes in either buffer
873    let len_to_read = min(len, to_buffer.len(), from_buffer.len());
874
875    match mask_key {
876        Some(mask_key) => {
877            // apply the mask key to every byte in the payload. This is a hot function.
878            for (i, (from, to)) in from_buffer[..len_to_read].iter().zip(to_buffer).enumerate() {
879                *to = *from ^ mask_key[i % MASK_KEY_LEN];
880            }
881            mask_key.rotate_left(len_to_read % MASK_KEY_LEN);
882        }
883        None => {
884            to_buffer[..len_to_read].copy_from_slice(&from_buffer[..len_to_read]);
885        }
886    }
887
888    len_to_read
889}
890
891fn read_continuation(
892    continuation_read: &mut ContinuationRead,
893    from_buffer: &[u8],
894    to_buffer: &mut [u8],
895) -> WebSocketFrame {
896    let len_read = read_into_buffer(
897        &mut continuation_read.mask_key,
898        from_buffer,
899        to_buffer,
900        continuation_read.count,
901    );
902
903    let is_complete = len_read == continuation_read.count;
904
905    let frame = match continuation_read.op_code {
906        WebSocketOpCode::ConnectionClose => decode_close_frame(to_buffer, len_read, len_read),
907        _ => WebSocketFrame {
908            num_bytes_from: len_read,
909            num_bytes_to: len_read,
910            op_code: continuation_read.op_code,
911            close_status: None,
912            is_fin_bit_set: if is_complete {
913                continuation_read.is_fin_bit_set
914            } else {
915                false
916            },
917        },
918    };
919
920    continuation_read.count -= len_read;
921    frame
922}
923
924fn read_frame(
925    from_buffer: &[u8],
926    to_buffer: &mut [u8],
927) -> Result<(WebSocketFrame, Option<ContinuationRead>)> {
928    if from_buffer.len() < 2 {
929        return Err(Error::ReadFrameIncomplete);
930    }
931
932    let byte1 = from_buffer[0];
933    let byte2 = from_buffer[1];
934
935    // process first byte
936    const FIN_BIT_FLAG: u8 = 0x80;
937    const OP_CODE_FLAG: u8 = 0x0F;
938    let is_fin_bit_set = (byte1 & FIN_BIT_FLAG) == FIN_BIT_FLAG;
939    let op_code = get_op_code(byte1 & OP_CODE_FLAG)?;
940
941    // process second byte
942    const MASK_FLAG: u8 = 0x80;
943    let is_mask_bit_set = (byte2 & MASK_FLAG) == MASK_FLAG;
944    let (len, mut num_bytes_read) = read_length(byte2, &from_buffer[2..])?;
945
946    num_bytes_read += 2;
947    let from_buffer = &from_buffer[num_bytes_read..];
948
949    // reads the mask key from the payload if the is_mask_bit_set flag is set
950    let mut mask_key = if is_mask_bit_set {
951        if from_buffer.len() < MASK_KEY_LEN {
952            return Err(Error::ReadFrameIncomplete);
953        }
954        let mut mask_key: [u8; MASK_KEY_LEN] = [0; MASK_KEY_LEN];
955        mask_key.copy_from_slice(&from_buffer[..MASK_KEY_LEN]);
956        num_bytes_read += MASK_KEY_LEN;
957        Some(mask_key)
958    } else {
959        None
960    };
961
962    let len_read = if is_mask_bit_set {
963        // start after the mask key
964        let from_buffer = &from_buffer[MASK_KEY_LEN..];
965        read_into_buffer(&mut mask_key, from_buffer, to_buffer, len)
966    } else {
967        read_into_buffer(&mut mask_key, from_buffer, to_buffer, len)
968    };
969
970    let has_continuation = len_read < len;
971    num_bytes_read += len_read;
972
973    let frame = match op_code {
974        WebSocketOpCode::ConnectionClose => decode_close_frame(to_buffer, num_bytes_read, len_read),
975        _ => WebSocketFrame {
976            num_bytes_from: num_bytes_read,
977            num_bytes_to: len_read,
978            op_code,
979            close_status: None,
980            is_fin_bit_set: if has_continuation {
981                false
982            } else {
983                is_fin_bit_set
984            },
985        },
986    };
987
988    if has_continuation {
989        let continuation_read = Some(ContinuationRead {
990            op_code,
991            count: len - len_read,
992            is_fin_bit_set,
993            mask_key,
994        });
995        Ok((frame, continuation_read))
996    } else {
997        Ok((frame, None))
998    }
999}
1000
1001fn get_op_code(val: u8) -> Result<WebSocketOpCode> {
1002    match val {
1003        0 => Ok(WebSocketOpCode::ContinuationFrame),
1004        1 => Ok(WebSocketOpCode::TextFrame),
1005        2 => Ok(WebSocketOpCode::BinaryFrame),
1006        8 => Ok(WebSocketOpCode::ConnectionClose),
1007        9 => Ok(WebSocketOpCode::Ping),
1008        10 => Ok(WebSocketOpCode::Pong),
1009        _ => Err(Error::InvalidOpCode),
1010    }
1011}
1012
1013// returns (len, how_many_bytes_were_read)
1014fn read_length(byte2: u8, from_buffer: &[u8]) -> Result<(usize, usize)> {
1015    let len = byte2 & 0x7F;
1016
1017    if len < 126 {
1018        // for messages smaller than 126 bytes
1019        return Ok((len as usize, 0));
1020    } else if len == 126 {
1021        // for messages smaller than 64KB
1022        if from_buffer.len() < 2 {
1023            return Err(Error::ReadFrameIncomplete);
1024        }
1025        let mut buf: [u8; 2] = [0; 2];
1026        buf.copy_from_slice(&from_buffer[..2]);
1027        return Ok((BigEndian::read_u16(&buf) as usize, 2));
1028    } else if len == 127 {
1029        // for messages larger than 64KB
1030        if from_buffer.len() < 8 {
1031            return Err(Error::ReadFrameIncomplete);
1032        }
1033        let mut buf: [u8; 8] = [0; 8];
1034        buf.copy_from_slice(&from_buffer[..8]);
1035        return Ok((BigEndian::read_u64(&buf) as usize, 8));
1036    }
1037
1038    Err(Error::InvalidFrameLength)
1039}
1040
1041fn decode_close_frame(buffer: &mut [u8], num_bytes_read: usize, len: usize) -> WebSocketFrame {
1042    if len >= 2 {
1043        // NOTE: for now, don't read the close status description
1044        let code = BigEndian::read_u16(buffer);
1045        let close_status_code = WebSocketCloseStatusCode::from_u16(code);
1046
1047        return WebSocketFrame {
1048            num_bytes_from: num_bytes_read,
1049            num_bytes_to: len,
1050            op_code: WebSocketOpCode::ConnectionClose,
1051            close_status: Some(close_status_code),
1052            is_fin_bit_set: true,
1053        };
1054    }
1055
1056    build_client_disconnected_frame(num_bytes_read)
1057}
1058
1059fn build_client_disconnected_frame(num_bytes_from: usize) -> WebSocketFrame {
1060    WebSocketFrame {
1061        num_bytes_from,
1062        num_bytes_to: 0,
1063        op_code: WebSocketOpCode::ConnectionClose,
1064        close_status: Some(WebSocketCloseStatusCode::InternalServerError),
1065        is_fin_bit_set: true,
1066    }
1067}
1068
1069// ************************************************************************************************
1070// **************************************** TESTS *************************************************
1071// ************************************************************************************************
1072
1073#[cfg(test)]
1074mod tests {
1075    extern crate std;
1076    use super::*;
1077
1078    #[test]
1079    fn opening_handshake() {
1080        let client_request = "GET /chat HTTP/1.1
1081Host: localhost:5000
1082User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:62.0) Gecko/20100101 Firefox/62.0
1083Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
1084Accept-Language: en-US,en;q=0.5
1085Accept-Encoding: gzip, deflate
1086Sec-WebSocket-Version: 13
1087Origin: http://localhost:5000
1088Sec-WebSocket-Extensions: permessage-deflate
1089Sec-WebSocket-Key: Z7OY1UwHOx/nkSz38kfPwg==
1090Sec-WebSocket-Protocol: chat
1091DNT: 1
1092Connection: keep-alive, Upgrade
1093Pragma: no-cache
1094Cache-Control: no-cache
1095Upgrade: websocket
1096
1097";
1098
1099        let mut headers = [httparse::EMPTY_HEADER; 16];
1100        let mut request = httparse::Request::new(&mut headers);
1101        request.parse(client_request.as_bytes()).unwrap();
1102        let headers = headers.iter().map(|f| (f.name, f.value));
1103        let web_socket_context = read_http_header(headers).unwrap().unwrap();
1104        assert_eq!(
1105            "Z7OY1UwHOx/nkSz38kfPwg==",
1106            web_socket_context.sec_websocket_key
1107        );
1108        assert_eq!(
1109            "chat",
1110            web_socket_context
1111                .sec_websocket_protocol_list
1112                .get(0)
1113                .unwrap()
1114                .as_str()
1115        );
1116        let mut web_socket = WebSocketServer::new_server();
1117
1118        let mut ws_buffer: [u8; 3000] = [0; 3000];
1119        let size = web_socket
1120            .server_accept(&web_socket_context.sec_websocket_key, None, &mut ws_buffer)
1121            .unwrap();
1122        let response = std::str::from_utf8(&ws_buffer[..size]).unwrap();
1123        let client_response_expected = "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n";
1124        assert_eq!(client_response_expected, response);
1125    }
1126
1127    #[test]
1128    fn server_write_frame() {
1129        let mut buffer: [u8; 1000] = [0; 1000];
1130        let mut ws_server = WebSocketServer::new_server();
1131        let len = ws_server
1132            .write_frame(
1133                "hello".as_bytes(),
1134                &mut buffer,
1135                WebSocketOpCode::TextFrame,
1136                true,
1137            )
1138            .unwrap();
1139        let expected = [129, 5, 104, 101, 108, 108, 111];
1140        assert_eq!(&expected, &buffer[..len]);
1141    }
1142
1143    #[test]
1144    fn server_accept_should_write_sub_protocol() {
1145        let mut buffer: [u8; 1000] = [0; 1000];
1146        let mut ws_server = WebSocketServer::new_server();
1147        let ws_key = WebSocketKey::from("Z7OY1UwHOx/nkSz38kfPwg==");
1148        let sub_protocol = WebSocketSubProtocol::from("chat");
1149        let size = ws_server
1150            .server_accept(&ws_key, Some(&sub_protocol), &mut buffer)
1151            .unwrap();
1152        let response = std::str::from_utf8(&buffer[..size]).unwrap();
1153        assert_eq!("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Protocol: chat\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n", response);
1154    }
1155
1156    #[test]
1157    fn closing_handshake() {
1158        let mut buffer1: [u8; 500] = [0; 500];
1159        let mut buffer2: [u8; 500] = [0; 500];
1160
1161        let mut rng = rand::thread_rng();
1162
1163        let mut ws_client = WebSocketClient::new_client(&mut rng);
1164        ws_client.state = WebSocketState::Open;
1165
1166        let mut ws_server = WebSocketServer::new_server();
1167        ws_server.state = WebSocketState::Open;
1168
1169        // client sends a close (initiates the close handshake)
1170        ws_client
1171            .close(WebSocketCloseStatusCode::NormalClosure, None, &mut buffer1)
1172            .unwrap();
1173
1174        // check that the client receives the close message
1175        let ws_result = ws_server.read(&buffer1, &mut buffer2).unwrap();
1176        assert_eq!(
1177            WebSocketReceiveMessageType::CloseMustReply,
1178            ws_result.message_type
1179        );
1180
1181        // server MUST respond to complete the handshake
1182        ws_server
1183            .write(
1184                WebSocketSendMessageType::CloseReply,
1185                true,
1186                &buffer2[..ws_result.len_to],
1187                &mut buffer1,
1188            )
1189            .unwrap();
1190        assert_eq!(WebSocketState::Closed, ws_server.state);
1191
1192        // check that the client receives the close message from the server
1193        let ws_result = ws_client.read(&buffer1, &mut buffer2).unwrap();
1194        assert_eq!(WebSocketState::Closed, ws_client.state);
1195
1196        assert_eq!(
1197            WebSocketReceiveMessageType::CloseCompleted,
1198            ws_result.message_type
1199        );
1200    }
1201
1202    #[test]
1203    fn send_message_from_client_to_server() {
1204        let mut buffer1: [u8; 1000] = [0; 1000];
1205        let mut buffer2: [u8; 1000] = [0; 1000];
1206
1207        // how to create a client
1208        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1209
1210        ws_client.state = WebSocketState::Open;
1211        let mut ws_server = WebSocketServer::new_server();
1212        ws_server.state = WebSocketState::Open;
1213
1214        // client sends a Text message
1215        let hello = "hello";
1216        let num_bytes = ws_client
1217            .write(
1218                WebSocketSendMessageType::Text,
1219                true,
1220                &hello.as_bytes(),
1221                &mut buffer1,
1222            )
1223            .unwrap();
1224
1225        // check that the Server receives the Text message
1226        let ws_result = ws_server.read(&buffer1[..num_bytes], &mut buffer2).unwrap();
1227        assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1228        let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1229        assert_eq!(hello, received);
1230    }
1231
1232    #[test]
1233    fn send_message_from_server_to_client() {
1234        let mut buffer1: [u8; 1000] = [0; 1000];
1235        let mut buffer2: [u8; 1000] = [0; 1000];
1236
1237        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1238        ws_client.state = WebSocketState::Open;
1239        let mut ws_server = WebSocketServer::new_server();
1240        ws_server.state = WebSocketState::Open;
1241
1242        // server sends a Text message
1243        let hello = "hello";
1244        let num_bytes = ws_server
1245            .write(
1246                WebSocketSendMessageType::Text,
1247                true,
1248                &hello.as_bytes(),
1249                &mut buffer1,
1250            )
1251            .unwrap();
1252
1253        // check that the client receives the Text message
1254        let ws_result = ws_client.read(&buffer1[..num_bytes], &mut buffer2).unwrap();
1255        assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1256        let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1257        assert_eq!(hello, received);
1258    }
1259
1260    #[test]
1261    fn receive_buffer_too_small() {
1262        let mut buffer1: [u8; 1000] = [0; 1000];
1263        let mut buffer2: [u8; 1000] = [0; 1000];
1264
1265        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1266        ws_client.state = WebSocketState::Open;
1267        let mut ws_server = WebSocketServer::new_server();
1268        ws_server.state = WebSocketState::Open;
1269
1270        let hello = "hello";
1271        ws_server
1272            .write(
1273                WebSocketSendMessageType::Text,
1274                true,
1275                &hello.as_bytes(),
1276                &mut buffer1,
1277            )
1278            .unwrap();
1279
1280        match ws_client.read(&buffer1[..1], &mut buffer2) {
1281            Err(Error::ReadFrameIncomplete) => {
1282                // test passes
1283            }
1284            _ => {
1285                assert_eq!(true, false);
1286            }
1287        }
1288    }
1289
1290    #[test]
1291    fn receive_large_frame_with_small_receive_buffer() {
1292        let mut buffer1: [u8; 1000] = [0; 1000];
1293        let mut buffer2: [u8; 1000] = [0; 1000];
1294
1295        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1296        ws_client.state = WebSocketState::Open;
1297        let mut ws_server = WebSocketServer::new_server();
1298        ws_server.state = WebSocketState::Open;
1299
1300        let hello = "hello";
1301        ws_server
1302            .write(
1303                WebSocketSendMessageType::Text,
1304                true,
1305                &hello.as_bytes(),
1306                &mut buffer1,
1307            )
1308            .unwrap();
1309
1310        let ws_result = ws_client.read(&buffer1[..2], &mut buffer2).unwrap();
1311        assert_eq!(0, ws_result.len_to);
1312        assert_eq!(false, ws_result.end_of_message);
1313        let ws_result = ws_client.read(&buffer1[2..3], &mut buffer2).unwrap();
1314        assert_eq!(1, ws_result.len_to);
1315        assert_eq!(
1316            "h",
1317            std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1318        );
1319        assert_eq!(false, ws_result.end_of_message);
1320        let ws_result = ws_client.read(&buffer1[3..], &mut buffer2).unwrap();
1321        assert_eq!(4, ws_result.len_to);
1322        assert_eq!(
1323            "ello",
1324            std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1325        );
1326        assert_eq!(true, ws_result.end_of_message);
1327    }
1328
1329    #[test]
1330    fn send_large_frame() {
1331        let buffer1 = [0u8; 15944];
1332        let mut buffer2 = [0u8; 64000];
1333        let mut buffer3 = [0u8; 64000];
1334
1335        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1336        ws_client.state = WebSocketState::Open;
1337
1338        ws_client
1339            .write(
1340                WebSocketSendMessageType::Binary,
1341                true,
1342                &buffer1,
1343                &mut buffer2,
1344            )
1345            .unwrap();
1346
1347        let ws_result = ws_client.read(&buffer2, &mut buffer3).unwrap();
1348        assert_eq!(true, ws_result.end_of_message);
1349        assert_eq!(buffer1.len(), ws_result.len_to);
1350    }
1351
1352    #[test]
1353    fn receive_large_frame_multi_read() {
1354        let mut buffer1 = [0_u8; 1000];
1355        let mut buffer2 = [0_u8; 1000];
1356
1357        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1358        ws_client.state = WebSocketState::Open;
1359        let mut ws_server = WebSocketServer::new_server();
1360        ws_server.state = WebSocketState::Open;
1361
1362        let message = "Hello, world. This is a long message that takes multiple reads";
1363        ws_server
1364            .write(
1365                WebSocketSendMessageType::Text,
1366                true,
1367                &message.as_bytes(),
1368                &mut buffer1,
1369            )
1370            .unwrap();
1371
1372        let mut buffer2_cursor = 0;
1373        let ws_result = ws_client.read(&buffer1[..40], &mut buffer2).unwrap();
1374        assert_eq!(false, ws_result.end_of_message);
1375        buffer2_cursor += ws_result.len_to;
1376        let ws_result = ws_client
1377            .read(
1378                &buffer1[ws_result.len_from..],
1379                &mut buffer2[buffer2_cursor..],
1380            )
1381            .unwrap();
1382        assert_eq!(true, ws_result.end_of_message);
1383        buffer2_cursor += ws_result.len_to;
1384
1385        assert_eq!(
1386            message,
1387            std::str::from_utf8(&buffer2[..buffer2_cursor]).unwrap()
1388        );
1389    }
1390
1391    #[test]
1392    fn multiple_messages_in_receive_buffer() {
1393        let mut buffer1 = [0_u8; 1000];
1394        let mut buffer2 = [0_u8; 1000];
1395
1396        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1397        ws_client.state = WebSocketState::Open;
1398        let mut ws_server = WebSocketServer::new_server();
1399        ws_server.state = WebSocketState::Open;
1400
1401        let message1 = "Hello, world.";
1402        let len = ws_client
1403            .write(
1404                WebSocketSendMessageType::Text,
1405                true,
1406                &message1.as_bytes(),
1407                &mut buffer1,
1408            )
1409            .unwrap();
1410        let message2 = "This is another message.";
1411        ws_client
1412            .write(
1413                WebSocketSendMessageType::Text,
1414                true,
1415                &message2.as_bytes(),
1416                &mut buffer1[len..],
1417            )
1418            .unwrap();
1419
1420        let mut buffer1_cursor = 0;
1421        let mut buffer2_cursor = 0;
1422        let ws_result = ws_server
1423            .read(&buffer1[buffer1_cursor..], &mut buffer2)
1424            .unwrap();
1425        assert_eq!(true, ws_result.end_of_message);
1426        buffer1_cursor += ws_result.len_from;
1427        buffer2_cursor += ws_result.len_to;
1428        let ws_result = ws_server
1429            .read(&buffer1[buffer1_cursor..], &mut buffer2[buffer2_cursor..])
1430            .unwrap();
1431        assert_eq!(true, ws_result.end_of_message);
1432        assert_eq!(
1433            message1,
1434            std::str::from_utf8(&buffer2[..buffer2_cursor]).unwrap()
1435        );
1436
1437        assert_eq!(
1438            message2,
1439            std::str::from_utf8(&buffer2[buffer2_cursor..buffer2_cursor + ws_result.len_to])
1440                .unwrap()
1441        );
1442    }
1443
1444    #[test]
1445    fn receive_large_frame_with_small_send_buffer() {
1446        let mut buffer1: [u8; 1000] = [0; 1000];
1447        let mut buffer2: [u8; 1000] = [0; 1000];
1448
1449        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1450        ws_client.state = WebSocketState::Open;
1451        let mut ws_server = WebSocketServer::new_server();
1452        ws_server.state = WebSocketState::Open;
1453
1454        let hello = "hello";
1455        ws_server
1456            .write(
1457                WebSocketSendMessageType::Text,
1458                true,
1459                &hello.as_bytes(),
1460                &mut buffer1,
1461            )
1462            .unwrap();
1463
1464        let ws_result = ws_client.read(&buffer1, &mut buffer2[..1]).unwrap();
1465        assert_eq!(1, ws_result.len_to);
1466        assert_eq!(
1467            "h",
1468            std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1469        );
1470        assert_eq!(false, ws_result.end_of_message);
1471        let ws_result = ws_client
1472            .read(&buffer1[ws_result.len_from..], &mut buffer2[..4])
1473            .unwrap();
1474        assert_eq!(4, ws_result.len_to);
1475        assert_eq!(
1476            "ello",
1477            std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1478        );
1479        assert_eq!(true, ws_result.end_of_message);
1480    }
1481
1482    #[test]
1483    fn send_two_frame_message() {
1484        let mut buffer1: [u8; 1000] = [0; 1000];
1485        let mut buffer2: [u8; 1000] = [0; 1000];
1486        // let mut rng = rand::thread_rng();
1487
1488        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1489        ws_client.state = WebSocketState::Open;
1490        let mut ws_server = WebSocketServer::new_server();
1491        ws_server.state = WebSocketState::Open;
1492
1493        // client sends a fragmented Text message
1494        let hello = "Hello, ";
1495        let num_bytes_hello = ws_server
1496            .write(
1497                WebSocketSendMessageType::Text,
1498                false,
1499                &hello.as_bytes(),
1500                &mut buffer1,
1501            )
1502            .unwrap();
1503
1504        // client sends the remaining Text message
1505        let world = "World!";
1506        let num_bytes_world = ws_server
1507            .write(
1508                WebSocketSendMessageType::Text,
1509                true,
1510                &world.as_bytes(),
1511                &mut buffer1[num_bytes_hello..],
1512            )
1513            .unwrap();
1514
1515        // check that the Server receives the entire Text message
1516        let ws_result1 = ws_client
1517            .read(&buffer1[..num_bytes_hello], &mut buffer2)
1518            .unwrap();
1519        assert_eq!(WebSocketReceiveMessageType::Text, ws_result1.message_type);
1520        assert_eq!(false, ws_result1.end_of_message);
1521        let ws_result2 = ws_client
1522            .read(
1523                &buffer1[num_bytes_hello..num_bytes_hello + num_bytes_world],
1524                &mut buffer2[ws_result1.len_to..],
1525            )
1526            .unwrap();
1527        assert_eq!(WebSocketReceiveMessageType::Text, ws_result2.message_type);
1528        assert_eq!(true, ws_result2.end_of_message);
1529
1530        let received =
1531            std::str::from_utf8(&buffer2[..ws_result1.len_to + ws_result2.len_to]).unwrap();
1532        assert_eq!("Hello, World!", received);
1533    }
1534
1535    #[test]
1536    fn send_multi_frame_message() {
1537        let mut buffer1: [u8; 1000] = [0; 1000];
1538        let mut buffer2: [u8; 1000] = [0; 1000];
1539
1540        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1541        ws_client.state = WebSocketState::Open;
1542        let mut ws_server = WebSocketServer::new_server();
1543        ws_server.state = WebSocketState::Open;
1544
1545        // server sends the first fragmented Text frame
1546        let fragment1 = "fragment1";
1547        let fragment1_num_bytes = ws_server
1548            .write(
1549                WebSocketSendMessageType::Text,
1550                false,
1551                &fragment1.as_bytes(),
1552                &mut buffer1,
1553            )
1554            .unwrap();
1555
1556        // send fragment2 as a continuation frame
1557        let fragment2 = "fragment2";
1558        let fragment2_num_bytes = ws_server
1559            .write(
1560                WebSocketSendMessageType::Text,
1561                false,
1562                &fragment2.as_bytes(),
1563                &mut buffer1[fragment1_num_bytes..],
1564            )
1565            .unwrap();
1566
1567        // send fragment3 as a continuation frame and indicate that this is the last frame
1568        let fragment3 = "fragment3";
1569        let _fragment3_num_bytes = ws_server
1570            .write(
1571                WebSocketSendMessageType::Text,
1572                true,
1573                &fragment3.as_bytes(),
1574                &mut buffer1[fragment1_num_bytes + fragment2_num_bytes..],
1575            )
1576            .unwrap();
1577
1578        // check that the client receives the "fragment1" Text message
1579        let ws_result = ws_client.read(&buffer1, &mut buffer2).unwrap();
1580        assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1581        let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1582        assert_eq!(fragment1, received);
1583        assert_eq!(ws_result.end_of_message, false);
1584        let mut read_cursor = ws_result.len_from;
1585
1586        // check that the client receives the "fragment2" Text message
1587        let ws_result = ws_client
1588            .read(&buffer1[read_cursor..], &mut buffer2)
1589            .unwrap();
1590        assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1591        let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1592        assert_eq!(fragment2, received);
1593        assert_eq!(ws_result.end_of_message, false);
1594        read_cursor += ws_result.len_from;
1595
1596        // check that the client receives the "fragment3" Text message
1597        let ws_result = ws_client
1598            .read(&buffer1[read_cursor..], &mut buffer2)
1599            .unwrap();
1600        assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1601        let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1602        assert_eq!(fragment3, received);
1603        assert_eq!(ws_result.end_of_message, true);
1604
1605        // check the actual bytes in the write buffer for fragment1
1606        let (is_fin_bit_set, op_code) = read_first_byte(buffer1[0]);
1607        assert_eq!(is_fin_bit_set, false);
1608        assert_eq!(op_code, WebSocketOpCode::TextFrame);
1609
1610        // check the actual bytes in the write buffer for fragment2
1611        let (is_fin_bit_set, op_code) = read_first_byte(buffer1[fragment1_num_bytes]);
1612        assert_eq!(is_fin_bit_set, false);
1613        assert_eq!(op_code, WebSocketOpCode::ContinuationFrame);
1614
1615        // check the actual bytes in the write buffer for fragment3
1616        let (is_fin_bit_set, op_code) =
1617            read_first_byte(buffer1[fragment1_num_bytes + fragment2_num_bytes]);
1618        assert_eq!(is_fin_bit_set, true);
1619        assert_eq!(op_code, WebSocketOpCode::ContinuationFrame);
1620    }
1621
1622    fn read_first_byte(byte: u8) -> (bool, WebSocketOpCode) {
1623        const FIN_BIT_FLAG: u8 = 0x80;
1624        const OP_CODE_FLAG: u8 = 0x0F;
1625        let is_fin_bit_set = (byte & FIN_BIT_FLAG) == FIN_BIT_FLAG;
1626        let op_code = get_op_code(byte & OP_CODE_FLAG).unwrap();
1627        (is_fin_bit_set, op_code)
1628    }
1629}