embedded_websocket/
lib.rs

1//! # Embedded Websocket
2//!
3//! `embedded_websocket` facilitates the encoding and decoding of websocket frames and can be used
4//! for both clients and servers. The library is intended to be used in constrained memory
5//! environments like embedded microcontrollers which cannot reference the rust standard library.
6//! It will work with arbitrarily small buffers regardless of websocket frame size as long as the
7//! websocket header can be read (2 - 14 bytes depending on the payload size and masking).
8//! Since the library is essentially an encoder or decoder of byte slices, the developer is free to
9//! use whatever transport mechanism they chose. The examples in the source repository use the
10//! TcpStream from the standard library.
11
12// support for websockets without using the standard library
13#![cfg_attr(not(feature = "std"), no_std)]
14
15use byteorder::{BigEndian, ByteOrder};
16use core::{cmp, result, str};
17use heapless::{String, Vec};
18use rand_core::RngCore;
19use sha1::{Digest, Sha1};
20
21mod http;
22pub mod random;
23pub use self::http::{read_http_header, WebSocketContext};
24pub use self::random::EmptyRng;
25
26// support for working with discrete websocket frames when using IO streams
27// start here!!
28pub mod framer;
29pub mod framer_async;
30const MASK_KEY_LEN: usize = 4;
31
32/// Result returning a websocket specific error if encountered
33pub type Result<T> = result::Result<T, Error>;
34
35/// A fixed length 24-character string used to hold a websocket key for the opening handshake
36pub type WebSocketKey = String<24>;
37
38/// A maximum sized 24-character string used to store a sub protocol (e.g. `chat`)
39pub type WebSocketSubProtocol = String<24>;
40
41/// Websocket send message type used when sending a websocket frame
42#[derive(PartialEq, Eq, Debug, Copy, Clone)]
43pub enum WebSocketSendMessageType {
44    /// A UTF8 encoded text string
45    Text = 1,
46    /// Binary data
47    Binary = 2,
48    /// An unsolicited ping message
49    Ping = 9,
50    /// A pong message in response to a ping message
51    Pong = 10,
52    /// A close message in response to a close message from the other party. Used to complete a
53    /// closing handshake. If initiate a close handshake use the `close` function
54    CloseReply = 11,
55}
56
57impl WebSocketSendMessageType {
58    fn to_op_code(self) -> WebSocketOpCode {
59        match self {
60            WebSocketSendMessageType::Text => WebSocketOpCode::TextFrame,
61            WebSocketSendMessageType::Binary => WebSocketOpCode::BinaryFrame,
62            WebSocketSendMessageType::Ping => WebSocketOpCode::Ping,
63            WebSocketSendMessageType::Pong => WebSocketOpCode::Pong,
64            WebSocketSendMessageType::CloseReply => WebSocketOpCode::ConnectionClose,
65        }
66    }
67}
68
69/// Websocket receive message type use when reading a websocket frame
70#[derive(PartialEq, Eq, Debug, Copy, Clone)]
71pub enum WebSocketReceiveMessageType {
72    /// A UTF8 encoded text string
73    Text = 1,
74    /// Binary data
75    Binary = 2,
76    /// Signals that the close handshake is complete
77    CloseCompleted = 7,
78    /// Signals that the other party has initiated the close handshake. If you receive this message
79    /// you should respond with a `WebSocketSendMessageType::CloseReply` with the same payload as
80    /// close message
81    CloseMustReply = 8,
82    /// A ping message that you should respond to with a `WebSocketSendMessageType::Pong` message
83    /// including the same payload as the ping
84    Ping = 9,
85    /// A pong message in response to a ping message
86    Pong = 10,
87}
88
89/// Websocket close status code as per the rfc6455 websocket spec
90#[derive(PartialEq, Eq, Debug, Copy, Clone)]
91pub enum WebSocketCloseStatusCode {
92    /// Normal closure (1000), meaning that the purpose for which the connection was established
93    /// has been fulfilled
94    NormalClosure,
95    /// Endpoint unavailable (1001) indicates that an endpoint is "going away", such as a server
96    /// going down or a browser having navigated away from a page
97    EndpointUnavailable,
98    /// Protocol error (1002) indicates that an endpoint is terminating the connection due
99    /// to a protocol error.
100    ProtocolError,
101    /// Invalid message type (1003) indicates that an endpoint is terminating the connection
102    /// because it has received a type of data it cannot accept (e.g., an endpoint that
103    /// understands only text data MAY send this if it receives a binary message)
104    InvalidMessageType,
105    /// Reserved (1004) for future use
106    Reserved,
107    /// Empty (1005) indicates that no status code was present
108    Empty,
109    /// Invalid payload data (1007) indicates that an endpoint is terminating the connection
110    /// because it has received data within a message that was not consistent with the type of
111    /// the message (e.g., non-UTF-8 data within a text message)
112    InvalidPayloadData,
113    /// Policy violation (1008) indicates that an endpoint is terminating the connection because
114    /// it has received a message that violates its policy. This is a generic status code that
115    /// can be returned when there is no other more suitable status code
116    PolicyViolation,
117    /// Message too big (1009) indicates that an endpoint is terminating the connection because
118    /// it has received a message that is too big for it to process
119    MessageTooBig,
120    /// Mandatory extension (1010) indicates that an endpoint (client) is terminating the
121    /// connection because it has expected the server to negotiate one or more extension, but
122    /// the server didn't return them in the response message of the WebSocket handshake
123    MandatoryExtension,
124    /// Internal server error (1011) indicates that a server is terminating the connection because
125    /// it encountered an unexpected condition that prevented it from fulfilling the request
126    InternalServerError,
127    /// TLS handshake (1015) connection was closed due to a failure to perform a TLS handshake
128    TlsHandshake,
129    /// Custom close code
130    Custom(u16),
131}
132
133impl WebSocketCloseStatusCode {
134    fn from_u16(value: u16) -> WebSocketCloseStatusCode {
135        match value {
136            1000 => WebSocketCloseStatusCode::NormalClosure,
137            1001 => WebSocketCloseStatusCode::EndpointUnavailable,
138            1002 => WebSocketCloseStatusCode::ProtocolError,
139            1003 => WebSocketCloseStatusCode::InvalidMessageType,
140            1004 => WebSocketCloseStatusCode::Reserved,
141            1005 => WebSocketCloseStatusCode::Empty,
142            1007 => WebSocketCloseStatusCode::InvalidPayloadData,
143            1008 => WebSocketCloseStatusCode::PolicyViolation,
144            1009 => WebSocketCloseStatusCode::MessageTooBig,
145            1010 => WebSocketCloseStatusCode::MandatoryExtension,
146            1011 => WebSocketCloseStatusCode::InternalServerError,
147            1015 => WebSocketCloseStatusCode::TlsHandshake,
148            _ => WebSocketCloseStatusCode::Custom(value),
149        }
150    }
151
152    fn to_u16(self) -> u16 {
153        match self {
154            WebSocketCloseStatusCode::NormalClosure => 1000,
155            WebSocketCloseStatusCode::EndpointUnavailable => 1001,
156            WebSocketCloseStatusCode::ProtocolError => 1002,
157            WebSocketCloseStatusCode::InvalidMessageType => 1003,
158            WebSocketCloseStatusCode::Reserved => 1004,
159            WebSocketCloseStatusCode::Empty => 1005,
160            WebSocketCloseStatusCode::InvalidPayloadData => 1007,
161            WebSocketCloseStatusCode::PolicyViolation => 1008,
162            WebSocketCloseStatusCode::MessageTooBig => 1009,
163            WebSocketCloseStatusCode::MandatoryExtension => 1010,
164            WebSocketCloseStatusCode::InternalServerError => 1011,
165            WebSocketCloseStatusCode::TlsHandshake => 1015,
166            WebSocketCloseStatusCode::Custom(value) => value,
167        }
168    }
169}
170
171/// The state of the websocket
172#[derive(PartialEq, Eq, Copy, Clone, Debug)]
173pub enum WebSocketState {
174    /// The websocket has been created with `new_client()` or `new_server()`
175    None = 0,
176    /// The client has created an opening handshake
177    Connecting = 1,
178    /// The server has completed the opening handshake via server_accept() or, likewise, the
179    /// client has completed the opening handshake via client_accept(). The user is free to call
180    /// `write()`, `read()` or `close()` on the websocket
181    Open = 2,
182    /// The `close()` function has been called
183    CloseSent = 3,
184    /// A Close websocket frame has been received
185    CloseReceived = 4,
186    /// The close handshake has been completed
187    Closed = 5,
188    /// The server or client opening handshake failed
189    Aborted = 6,
190}
191
192/// Websocket specific errors
193#[derive(PartialEq, Eq, Debug)]
194pub enum Error {
195    /// Websocket frame has an invalid opcode
196    InvalidOpCode,
197    InvalidFrameLength,
198    InvalidCloseStatusCode,
199    WebSocketNotOpen,
200    WebsocketAlreadyOpen,
201    Utf8Error,
202    Unknown,
203    HttpHeader(httparse::Error),
204    HttpHeaderNoPath,
205    HttpHeaderIncomplete,
206    WriteToBufferTooSmall,
207    ReadFrameIncomplete,
208    HttpResponseCodeInvalid(Option<u16>),
209    AcceptStringInvalid,
210    ConvertInfallible,
211    RandCore,
212    UnexpectedContinuationFrame,
213}
214
215impl From<httparse::Error> for Error {
216    fn from(err: httparse::Error) -> Error {
217        Error::HttpHeader(err)
218    }
219}
220
221impl From<str::Utf8Error> for Error {
222    fn from(_: str::Utf8Error) -> Error {
223        Error::Utf8Error
224    }
225}
226
227impl From<core::convert::Infallible> for Error {
228    fn from(_: core::convert::Infallible) -> Error {
229        Error::ConvertInfallible
230    }
231}
232
233impl From<()> for Error {
234    fn from(_: ()) -> Error {
235        Error::Unknown
236    }
237}
238
239impl core::fmt::Display for Error {
240    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
241        match self {
242            Error::HttpHeader(error) => write!(f, "bad http header {error}"),
243            Error::HttpResponseCodeInvalid(Some(code)) => write!(f, "bad http response ({code})"),
244            _ => write!(f, "{:?}", self),
245        }
246    }
247}
248
249#[cfg(feature = "std")]
250impl std::error::Error for Error {
251    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
252        if let Self::HttpHeader(error) = self {
253            Some(error)
254        } else {
255            None
256        }
257    }
258}
259
260#[derive(Copy, Clone, Debug, PartialEq, Eq)]
261enum WebSocketOpCode {
262    ContinuationFrame = 0,
263    TextFrame = 1,
264    BinaryFrame = 2,
265    ConnectionClose = 8,
266    Ping = 9,
267    Pong = 10,
268}
269
270impl WebSocketOpCode {
271    fn to_message_type(self) -> Result<WebSocketReceiveMessageType> {
272        match self {
273            WebSocketOpCode::TextFrame => Ok(WebSocketReceiveMessageType::Text),
274            WebSocketOpCode::BinaryFrame => Ok(WebSocketReceiveMessageType::Binary),
275            _ => Err(Error::InvalidOpCode),
276        }
277    }
278}
279
280/// The metadata result of a `read` function call of a websocket
281#[derive(Debug)]
282pub struct WebSocketReadResult {
283    /// Number of bytes read from the `from` buffer
284    pub len_from: usize,
285    /// Number of bytes written to the `to` buffer
286    pub len_to: usize,
287    /// End of message flag is `true` if the `to` buffer contains an entire websocket frame
288    /// payload otherwise `false` if the user must continue calling the read function to get the
289    /// rest of the payload
290    pub end_of_message: bool,
291    /// Close status code (as per the websocket spec) if the message type is `CloseMustReply` or
292    /// `CloseCompleted`. If a close status is specified then a UTF8 encoded string could also
293    /// appear in the frame payload giving more detailed information about why the websocket was
294    /// closed.
295    pub close_status: Option<WebSocketCloseStatusCode>,
296    /// The websocket frame type
297    pub message_type: WebSocketReceiveMessageType,
298}
299
300/// Websocket options used by a websocket client to initiate an opening handshake with a
301/// websocket server
302pub struct WebSocketOptions<'a> {
303    /// The request uri (e.g. `/chat?id=123`) of the GET method used to identify the endpoint of the
304    /// websocket connection. This allows multiple domains to be served by a single server.
305    /// This could also be used to send identifiable information about the client
306    pub path: &'a str,
307    /// The hostname (e.g. `server.example.com`) is used so that both the client and the server
308    /// can verify that they agree on which host is in use
309    pub host: &'a str,
310    /// The origin (e.g. `http://example.com`) is used to protect against unauthorized
311    /// cross-origin use of a WebSocket server by scripts using the WebSocket API in a web
312    /// browser. This field is usually only set by browser clients but servers may require it
313    /// so it has been exposed here.
314    pub origin: &'a str,
315    /// A list of requested sub protocols in order of preference. The server should return the
316    /// first sub protocol it supports or none at all. A sub protocol can be anything agreed
317    /// between the server and client
318    pub sub_protocols: Option<&'a [&'a str]>,
319    /// Any additional headers the server may require that are not part of the websocket
320    /// spec. These should be fully formed http headers without the `\r\n` (e.g. `MyHeader: foo`)
321    pub additional_headers: Option<&'a [&'a str]>,
322}
323
324/// Used to return a sized type from `WebSocket::new_server()`
325pub type WebSocketServer = WebSocket<EmptyRng, Server>;
326
327/// Used to return a sized type from `WebSocketClient::new_client()`
328pub type WebSocketClient<T> = WebSocket<T, Client>;
329
330// Simple Typestate pattern for preventing panics and allowing reuse of underlying
331// read/read_frame/etc..
332pub enum Server {}
333pub enum Client {}
334
335pub trait WebSocketType {}
336impl WebSocketType for Server {}
337impl WebSocketType for Client {}
338
339/// Websocket client and server implementation
340pub struct WebSocket<T, S: WebSocketType>
341where
342    T: RngCore,
343{
344    is_client: bool,
345    rng: T,
346    continuation_frame_op_code: Option<WebSocketOpCode>,
347    is_write_continuation: bool,
348    pub state: WebSocketState,
349    continuation_read: Option<ContinuationRead>,
350    marker: core::marker::PhantomData<S>,
351}
352
353impl<T, Type> WebSocket<T, Type>
354where
355    T: RngCore,
356    Type: WebSocketType,
357{
358    /// Creates a new websocket client by passing in a required random number generator
359    ///
360    /// # Examples
361    /// ```
362    /// use embedded_websocket as ws;
363    /// use rand;
364    /// let mut ws_client = ws::WebSocketClient::new_client(rand::thread_rng());
365    ///
366    /// assert_eq!(ws::WebSocketState::None, ws_client.state);
367    /// ```
368    pub fn new_client(rng: T) -> WebSocketClient<T> {
369        WebSocket {
370            is_client: true,
371            rng,
372            continuation_frame_op_code: None,
373            is_write_continuation: false,
374            state: WebSocketState::None,
375            continuation_read: None,
376            marker: core::marker::PhantomData::<Client>,
377        }
378    }
379
380    /// Creates a new websocket server. Note that you must use the `WebSocketServer` type and
381    /// not the generic `WebSocket` type for this call or you will get a `'type annotations needed'`
382    /// compilation error.
383    ///
384    /// # Examples
385    /// ```
386    /// use embedded_websocket as ws;
387    /// let mut ws_server = ws::WebSocketServer::new_server();
388    ///
389    /// assert_eq!(ws::WebSocketState::None, ws_server.state);
390    /// ```
391    pub fn new_server() -> WebSocketServer {
392        let rng = EmptyRng::new();
393        WebSocket {
394            is_client: false,
395            rng,
396            continuation_frame_op_code: None,
397            is_write_continuation: false,
398            state: WebSocketState::None,
399            continuation_read: None,
400            marker: core::marker::PhantomData::<Server>,
401        }
402    }
403}
404
405impl<T> WebSocket<T, Server>
406where
407    T: RngCore,
408{
409    /// Used by the server to accept an incoming client connection and build a websocket upgrade
410    /// http response string. The client http header should be read with the `read_http_header`
411    /// function and the result should be passed to this function.
412    /// Websocket state will change from None -> Open if successful, otherwise None -> Aborted
413    ///
414    /// # Examples
415    ///
416    /// ```
417    /// use embedded_websocket as ws;
418    /// let mut buffer: [u8; 1000] = [0; 1000];
419    /// let mut ws_server = ws::WebSocketServer::new_server();
420    /// let ws_key = ws::WebSocketKey::from("Z7OY1UwHOx/nkSz38kfPwg==");
421    /// let sub_protocol = ws::WebSocketSubProtocol::from("chat");
422    /// let len = ws_server
423    ///     .server_accept(&ws_key, Some(&sub_protocol), &mut buffer)
424    ///     .unwrap();
425    /// let response = std::str::from_utf8(&buffer[..len]).unwrap();
426    ///
427    /// assert_eq!("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Protocol: chat\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n", response);
428    /// ```
429    ///
430    /// # Errors
431    /// There should be no way for a user provided input to return the errors listed below as the
432    /// input is already constrained.
433    /// * The http response is built with a stack allocated 1KB buffer and it *should be impossible*
434    /// for it to return an  `Unknown` error if that buffer is too small. However, this is better
435    /// than a panic and it will do so if the response header is too large to fit in the buffer
436    /// * This function can return an `Utf8Error` if there was an error with the generation of the
437    /// accept string. This should also be impossible but an error is preferable to a panic
438    /// * Returns `WebsocketAlreadyOpen` if called on a websocket that is already open
439    pub fn server_accept(
440        &mut self,
441        sec_websocket_key: &WebSocketKey,
442        sec_websocket_protocol: Option<&WebSocketSubProtocol>,
443        to: &mut [u8],
444    ) -> Result<usize> {
445        if self.state == WebSocketState::Open {
446            return Err(Error::WebsocketAlreadyOpen);
447        }
448
449        match http::build_connect_handshake_response(sec_websocket_key, sec_websocket_protocol, to)
450        {
451            Ok(http_response_len) => {
452                self.state = WebSocketState::Open;
453                Ok(http_response_len)
454            }
455            Err(e) => {
456                self.state = WebSocketState::Aborted;
457                Err(e)
458            }
459        }
460    }
461}
462
463impl<T> WebSocket<T, Client>
464where
465    T: RngCore,
466{
467    /// Used by the client to initiate a websocket opening handshake
468    ///
469    /// # Examples
470    /// ```
471    /// use embedded_websocket as ws;
472    /// let mut buffer: [u8; 2000] = [0; 2000];
473    /// let mut ws_client = ws::WebSocketClient::new_client(rand::thread_rng());
474    /// let sub_protocols = ["chat", "superchat"];
475    /// let websocket_options = ws::WebSocketOptions {
476    ///     path: "/chat",
477    ///     host: "localhost",
478    ///     origin: "http://localhost",
479    ///     sub_protocols: Some(&sub_protocols),
480    ///     additional_headers: None,
481    /// };
482    ///
483    /// let (len, web_socket_key) = ws_client.client_connect(&websocket_options, &mut buffer).unwrap();
484    ///
485    /// let actual_http = std::str::from_utf8(&buffer[..len]).unwrap();
486    /// let mut expected_http = String::new();
487    /// expected_http.push_str("GET /chat HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: ");
488    /// expected_http.push_str(web_socket_key.as_str());
489    /// expected_http.push_str("\r\nOrigin: http://localhost\r\nSec-WebSocket-Protocol: chat, superchat\r\nSec-WebSocket-Version: 13\r\n\r\n");
490    /// assert_eq!(expected_http.as_str(), actual_http);
491    /// ```
492    ///
493    /// # Errors
494    /// * The http response is built with a stack allocated 1KB buffer and will return an
495    /// `Unknown` error if that buffer is too small. This would happen is the user supplied too many
496    /// additional headers or the sub-protocol string is too large
497    /// * This function can return an `Utf8Error` if there was an error with the generation of the
498    /// accept string. This should be impossible but an error is preferable to a panic
499    /// * Returns `WebsocketAlreadyOpen` if called on a websocket that is already open
500    pub fn client_connect(
501        &mut self,
502        websocket_options: &WebSocketOptions,
503        to: &mut [u8],
504    ) -> Result<(usize, WebSocketKey)> {
505        if self.state == WebSocketState::Open {
506            return Err(Error::WebsocketAlreadyOpen);
507        }
508
509        match http::build_connect_handshake_request(websocket_options, &mut self.rng, to) {
510            Ok((request_len, sec_websocket_key)) => {
511                self.state = WebSocketState::Connecting;
512                Ok((request_len, sec_websocket_key))
513            }
514            Err(e) => Err(e),
515        }
516    }
517
518    /// Used by a websocket client for checking the server response to an opening handshake
519    /// (sent using the client_connect function). If the client requested one or more sub protocols
520    /// the server will choose one (or none) and you get that in the result
521    /// # Examples
522    /// ```
523    /// use embedded_websocket as ws;
524    /// let mut ws_client = ws::WebSocketClient::new_client(rand::thread_rng());
525    /// let ws_key = ws::WebSocketKey::from("Z7OY1UwHOx/nkSz38kfPwg==");
526    /// let server_response_html = "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Protocol: chat\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n";    ///
527    /// let (len, sub_protocol) = ws_client.client_accept(&ws_key, server_response_html.as_bytes())
528    ///     .unwrap();
529    ///
530    /// assert_eq!(159, len);
531    /// assert_eq!("chat", sub_protocol.unwrap());
532    /// ```
533    /// # Errors
534    /// * Returns `HttpResponseCodeInvalid` if the HTTP response code is not `101 Switching Protocols`
535    /// * Returns `AcceptStringInvalid` if the web server failed to return a valid accept string
536    /// * Returns `HttpHeader(Version)` or some other varient if the HTTP response is not well formed
537    /// * Returns `WebsocketAlreadyOpen` if called on a websocket that is already open
538    pub fn client_accept(
539        &mut self,
540        sec_websocket_key: &WebSocketKey,
541        from: &[u8],
542    ) -> Result<(usize, Option<WebSocketSubProtocol>)> {
543        if self.state == WebSocketState::Open {
544            return Err(Error::WebsocketAlreadyOpen);
545        }
546
547        match http::read_server_connect_handshake_response(sec_websocket_key, from) {
548            Ok((len, sec_websocket_protocol)) => {
549                self.state = WebSocketState::Open;
550                Ok((len, sec_websocket_protocol))
551            }
552            Err(Error::HttpHeaderIncomplete) => Err(Error::HttpHeaderIncomplete),
553            Err(e) => {
554                self.state = WebSocketState::Aborted;
555                Err(e)
556            }
557        }
558    }
559}
560
561impl<T, Type> WebSocket<T, Type>
562where
563    T: RngCore,
564    Type: WebSocketType,
565{
566    /// Reads the payload from a websocket frame in buffer `from` into a buffer `to` and returns
567    /// metadata about the frame. Since this function is designed to be called in a memory
568    /// constrained system we may not read the entire payload in one go. In each of the scenarios
569    /// below the `read_result.end_of_message` flag would be `false`:
570    /// * The payload is fragmented into multiple websocket frames (as per the websocket spec)
571    /// * The `from` buffer does not hold the entire websocket frame. For example if only part of
572    /// the frame was read or if the `from` buffer is too small to hold an entire websocket frame
573    /// * The `to` buffer is too small to hold the entire websocket frame payload
574    ///
575    /// If the function returns `read_result.end_of_message` `false` then the next
576    /// call to the function should not include data that has already been passed into the function.
577    /// The websocket *remembers* the websocket frame header and is able to process the rest of the
578    /// payload correctly. If the `from` buffer contains multiple websocket frames then only one of
579    /// them will be returned at a time and the user must make multiple calls to the function by
580    /// taking note of `read_result.len_from` which tells you how many bytes were read from the
581    /// `from` buffer
582    ///
583    /// # Examples
584    ///
585    /// ```
586    /// use embedded_websocket as ws;
587    /// //                    h   e   l   l   o
588    /// let buffer1 = [129,5,104,101,108,108,111];
589    /// let mut buffer2: [u8; 128] = [0; 128];
590    /// let mut ws_client = ws::WebSocketClient::new_client(rand::thread_rng());
591    /// ws_client.state = ws::WebSocketState::Open; // skip the opening handshake
592    /// let ws_result = ws_client.read(&buffer1, &mut buffer2).unwrap();
593    ///
594    /// assert_eq!("hello".as_bytes(), &buffer2[..ws_result.len_to]);
595    /// ```
596    /// # Errors
597    /// * Returns `WebSocketNotOpen` when the websocket is not open when this function is called
598    /// * Returns `InvalidOpCode` if the websocket frame contains an invalid opcode
599    /// * Returns `UnexpectedContinuationFrame` if we receive a continuation frame without first
600    /// receiving a non-continuation frame with an opcode describing the payload
601    /// * Returns `ReadFrameIncomplete` if the `from` buffer does not contain a full websocket
602    /// header (typically 2-14 bytes depending on the payload)
603    /// * Returns `InvalidFrameLength` if the frame length cannot be decoded
604    ///
605    pub fn read(&mut self, from: &[u8], to: &mut [u8]) -> Result<WebSocketReadResult> {
606        if self.state == WebSocketState::Open || self.state == WebSocketState::CloseSent {
607            let frame = self.read_frame(from, to)?;
608
609            match frame.op_code {
610                WebSocketOpCode::Ping => Ok(frame.to_readresult(WebSocketReceiveMessageType::Ping)),
611                WebSocketOpCode::Pong => Ok(frame.to_readresult(WebSocketReceiveMessageType::Pong)),
612                WebSocketOpCode::TextFrame => {
613                    Ok(frame.to_readresult(WebSocketReceiveMessageType::Text))
614                }
615                WebSocketOpCode::BinaryFrame => {
616                    Ok(frame.to_readresult(WebSocketReceiveMessageType::Binary))
617                }
618                WebSocketOpCode::ConnectionClose => match self.state {
619                    WebSocketState::CloseSent => {
620                        self.state = WebSocketState::Closed;
621                        Ok(frame.to_readresult(WebSocketReceiveMessageType::CloseCompleted))
622                    }
623                    _ => {
624                        self.state = WebSocketState::CloseReceived;
625                        Ok(frame.to_readresult(WebSocketReceiveMessageType::CloseMustReply))
626                    }
627                },
628                WebSocketOpCode::ContinuationFrame => match self.continuation_frame_op_code {
629                    Some(cf_op_code) => Ok(frame.to_readresult(cf_op_code.to_message_type()?)),
630                    None => Err(Error::UnexpectedContinuationFrame),
631                },
632            }
633        } else {
634            Err(Error::WebSocketNotOpen)
635        }
636    }
637
638    /// Writes the payload in `from` to a websocket frame in `to`
639    /// * message_type - The type of message to send: Text, Binary or CloseReply
640    /// * end_of_message - False to fragment a frame into multiple smaller frames. The last frame
641    /// should set this to true
642    /// * from - The buffer containing the payload to encode
643    /// * to - The the buffer to save the websocket encoded payload to.
644    /// Returns the number of bytes written to the `to` buffer
645    /// # Examples
646    ///
647    /// ```
648    /// use embedded_websocket as ws;
649    /// let mut buffer: [u8; 1000] = [0; 1000];
650    /// let mut ws_server = ws::WebSocketServer::new_server();
651    /// ws_server.state = ws::WebSocketState::Open; // skip the opening handshake
652    /// let len = ws_server.write(ws::WebSocketSendMessageType::Text, true, "hello".as_bytes(),
653    ///     &mut buffer).unwrap();
654    ///
655    /// //                     h   e   l   l   o
656    /// let expected = [129,5,104,101,108,108,111];
657    /// assert_eq!(&expected, &buffer[..len]);
658    /// ```
659    /// # Errors
660    /// * Returns `WebSocketNotOpen` when the websocket is not open when this function is called
661    /// * Returns `WriteToBufferTooSmall` when the `to` buffer is too small to fit the websocket
662    /// frame header (2-14 bytes) plus the payload. Consider fragmenting the messages by making
663    /// multiple write calls with `end_of_message` set to `false` and the final call set to `true`
664    pub fn write(
665        &mut self,
666        message_type: WebSocketSendMessageType,
667        end_of_message: bool,
668        from: &[u8],
669        to: &mut [u8],
670    ) -> Result<usize> {
671        if self.state == WebSocketState::Open || self.state == WebSocketState::CloseReceived {
672            let mut op_code = message_type.to_op_code();
673            if op_code == WebSocketOpCode::ConnectionClose {
674                self.state = WebSocketState::Closed
675            } else if self.is_write_continuation {
676                op_code = WebSocketOpCode::ContinuationFrame;
677            }
678
679            self.is_write_continuation = !end_of_message;
680            self.write_frame(from, to, op_code, end_of_message)
681        } else {
682            Err(Error::WebSocketNotOpen)
683        }
684    }
685
686    /// Initiates a close handshake.
687    /// Both the client and server may initiate a close handshake. If successful the function
688    /// changes the websocket state from Open -> CloseSent
689    /// # Errors
690    /// * Returns `WebSocketNotOpen` when the websocket is not open when this function is called
691    /// * Returns `WriteToBufferTooSmall` when the `to` buffer is too small to fit the websocket
692    /// frame header (2-14 bytes) plus the payload. Consider sending a smaller status_description
693    pub fn close(
694        &mut self,
695        close_status: WebSocketCloseStatusCode,
696        status_description: Option<&str>,
697        to: &mut [u8],
698    ) -> Result<usize> {
699        if self.state == WebSocketState::Open {
700            self.state = WebSocketState::CloseSent;
701            if let Some(status_description) = status_description {
702                let mut from_buffer: Vec<u8, 256> = Vec::new();
703                from_buffer.extend_from_slice(&close_status.to_u16().to_be_bytes())?;
704
705                // restrict the max size of the status_description
706                let len = if status_description.len() < 254 {
707                    status_description.len()
708                } else {
709                    254
710                };
711
712                from_buffer.extend_from_slice(status_description[..len].as_bytes())?;
713                self.write_frame(&from_buffer, to, WebSocketOpCode::ConnectionClose, true)
714            } else {
715                let mut from_buffer: [u8; 2] = [0; 2];
716                BigEndian::write_u16(&mut from_buffer, close_status.to_u16());
717                self.write_frame(&from_buffer, to, WebSocketOpCode::ConnectionClose, true)
718            }
719        } else {
720            Err(Error::WebSocketNotOpen)
721        }
722    }
723
724    fn read_frame(&mut self, from_buffer: &[u8], to_buffer: &mut [u8]) -> Result<WebSocketFrame> {
725        match &mut self.continuation_read {
726            Some(continuation_read) => {
727                let result = read_continuation(continuation_read, from_buffer, to_buffer);
728                if result.is_fin_bit_set {
729                    self.continuation_read = None;
730                    self.continuation_frame_op_code = None;
731                }
732                Ok(result)
733            }
734            None => {
735                let (mut result, continuation_read) = read_frame(from_buffer, to_buffer)?;
736
737                // override the op code we get from the result with our continuation frame opcode if it exists
738                if let Some(continuation_frame_op_code) = self.continuation_frame_op_code {
739                    result.op_code = continuation_frame_op_code;
740                }
741
742                // reset the continuation frame op code to None if this is the last fragment (or there is no fragmentation)
743                self.continuation_frame_op_code = if result.is_fin_bit_set {
744                    None
745                } else {
746                    Some(result.op_code)
747                };
748
749                self.continuation_read = continuation_read;
750                Ok(result)
751            }
752        }
753    }
754
755    fn write_frame(
756        &mut self,
757        from_buffer: &[u8],
758        to_buffer: &mut [u8],
759        op_code: WebSocketOpCode,
760        end_of_message: bool,
761    ) -> Result<usize> {
762        let fin_bit_set_as_byte: u8 = if end_of_message { 0x80 } else { 0x00 };
763        let byte1: u8 = fin_bit_set_as_byte | op_code as u8;
764        let count = from_buffer.len();
765        const BYTE_HEADER_SIZE: usize = 2;
766        const SHORT_HEADER_SIZE: usize = 4;
767        const LONG_HEADER_SIZE: usize = 10;
768        const MASK_KEY_SIZE: usize = 4;
769        let header_size;
770        let mask_bit_set_as_byte = if self.is_client { 0x80 } else { 0x00 };
771        let payload_len = from_buffer.len() + if self.is_client { MASK_KEY_SIZE } else { 0 };
772
773        // write header followed by the payload
774        // header size depends on how large the payload is
775        if count < 126 {
776            if payload_len + BYTE_HEADER_SIZE > to_buffer.len() {
777                return Err(Error::WriteToBufferTooSmall);
778            }
779            to_buffer[0] = byte1;
780            to_buffer[1] = mask_bit_set_as_byte | count as u8;
781            header_size = BYTE_HEADER_SIZE;
782        } else if count < 65535 {
783            if payload_len + SHORT_HEADER_SIZE > to_buffer.len() {
784                return Err(Error::WriteToBufferTooSmall);
785            }
786            to_buffer[0] = byte1;
787            to_buffer[1] = mask_bit_set_as_byte | 126;
788            BigEndian::write_u16(&mut to_buffer[2..], count as u16);
789            header_size = SHORT_HEADER_SIZE;
790        } else {
791            if payload_len + LONG_HEADER_SIZE > to_buffer.len() {
792                return Err(Error::WriteToBufferTooSmall);
793            }
794            to_buffer[0] = byte1;
795            to_buffer[1] = mask_bit_set_as_byte | 127;
796            BigEndian::write_u64(&mut to_buffer[2..], count as u64);
797            header_size = LONG_HEADER_SIZE;
798        }
799
800        // sent by client - need to mask the data
801        // we need to mask the bytes to prevent web server caching
802        if self.is_client {
803            let mut mask_key = [0; MASK_KEY_SIZE];
804            self.rng.fill_bytes(&mut mask_key); // clients always have an rng instance
805            to_buffer[header_size..header_size + MASK_KEY_SIZE].copy_from_slice(&mask_key);
806            let to_buffer_start = header_size + MASK_KEY_SIZE;
807
808            // apply the mask key to every byte in the payload. This is a hot function
809            for (i, (from, to)) in from_buffer[..count]
810                .iter()
811                .zip(&mut to_buffer[to_buffer_start..to_buffer_start + count])
812                .enumerate()
813            {
814                *to = *from ^ mask_key[i % MASK_KEY_SIZE];
815            }
816
817            Ok(to_buffer_start + count)
818        } else {
819            to_buffer[header_size..header_size + count].copy_from_slice(&from_buffer[..count]);
820            Ok(header_size + count)
821        }
822    }
823}
824
825// Continuation read is used when we cannot fit the entire websocket frame into the supplied buffer
826struct ContinuationRead {
827    op_code: WebSocketOpCode,
828    count: usize,
829    is_fin_bit_set: bool,
830    mask_key: Option<[u8; 4]>,
831}
832
833struct WebSocketFrame {
834    is_fin_bit_set: bool,
835    op_code: WebSocketOpCode,
836    num_bytes_to: usize,
837    num_bytes_from: usize,
838    close_status: Option<WebSocketCloseStatusCode>,
839}
840
841impl WebSocketFrame {
842    fn to_readresult(&self, message_type: WebSocketReceiveMessageType) -> WebSocketReadResult {
843        WebSocketReadResult {
844            len_from: self.num_bytes_from,
845            len_to: self.num_bytes_to,
846            end_of_message: self.is_fin_bit_set,
847            close_status: self.close_status,
848            message_type,
849        }
850    }
851}
852
853fn min(num1: usize, num2: usize, num3: usize) -> usize {
854    cmp::min(cmp::min(num1, num2), num3)
855}
856
857fn read_into_buffer(
858    mask_key: &mut Option<[u8; 4]>,
859    from_buffer: &[u8],
860    to_buffer: &mut [u8],
861    len: usize,
862) -> usize {
863    // if we are trying to read more than number of bytes in either buffer
864    let len_to_read = min(len, to_buffer.len(), from_buffer.len());
865
866    match mask_key {
867        Some(mask_key) => {
868            // apply the mask key to every byte in the payload. This is a hot function.
869            for (i, (from, to)) in from_buffer[..len_to_read].iter().zip(to_buffer).enumerate() {
870                *to = *from ^ mask_key[i % MASK_KEY_LEN];
871            }
872            mask_key.rotate_left(len_to_read % MASK_KEY_LEN);
873        }
874        None => {
875            to_buffer[..len_to_read].copy_from_slice(&from_buffer[..len_to_read]);
876        }
877    }
878
879    len_to_read
880}
881
882fn read_continuation(
883    continuation_read: &mut ContinuationRead,
884    from_buffer: &[u8],
885    to_buffer: &mut [u8],
886) -> WebSocketFrame {
887    let len_read = read_into_buffer(
888        &mut continuation_read.mask_key,
889        from_buffer,
890        to_buffer,
891        continuation_read.count,
892    );
893
894    let is_complete = len_read == continuation_read.count;
895
896    let frame = match continuation_read.op_code {
897        WebSocketOpCode::ConnectionClose => decode_close_frame(to_buffer, len_read, len_read),
898        _ => WebSocketFrame {
899            num_bytes_from: len_read,
900            num_bytes_to: len_read,
901            op_code: continuation_read.op_code,
902            close_status: None,
903            is_fin_bit_set: if is_complete {
904                continuation_read.is_fin_bit_set
905            } else {
906                false
907            },
908        },
909    };
910
911    continuation_read.count -= len_read;
912    frame
913}
914
915fn read_frame(
916    from_buffer: &[u8],
917    to_buffer: &mut [u8],
918) -> Result<(WebSocketFrame, Option<ContinuationRead>)> {
919    if from_buffer.len() < 2 {
920        return Err(Error::ReadFrameIncomplete);
921    }
922
923    let byte1 = from_buffer[0];
924    let byte2 = from_buffer[1];
925
926    // process first byte
927    const FIN_BIT_FLAG: u8 = 0x80;
928    const OP_CODE_FLAG: u8 = 0x0F;
929    let is_fin_bit_set = (byte1 & FIN_BIT_FLAG) == FIN_BIT_FLAG;
930    let op_code = get_op_code(byte1 & OP_CODE_FLAG)?;
931
932    // process second byte
933    const MASK_FLAG: u8 = 0x80;
934    let is_mask_bit_set = (byte2 & MASK_FLAG) == MASK_FLAG;
935    let (len, mut num_bytes_read) = read_length(byte2, &from_buffer[2..])?;
936
937    num_bytes_read += 2;
938    let from_buffer = &from_buffer[num_bytes_read..];
939
940    // reads the mask key from the payload if the is_mask_bit_set flag is set
941    let mut mask_key = if is_mask_bit_set {
942        if from_buffer.len() < MASK_KEY_LEN {
943            return Err(Error::ReadFrameIncomplete);
944        }
945        let mut mask_key: [u8; MASK_KEY_LEN] = [0; MASK_KEY_LEN];
946        mask_key.copy_from_slice(&from_buffer[..MASK_KEY_LEN]);
947        num_bytes_read += MASK_KEY_LEN;
948        Some(mask_key)
949    } else {
950        None
951    };
952
953    let len_read = if is_mask_bit_set {
954        // start after the mask key
955        let from_buffer = &from_buffer[MASK_KEY_LEN..];
956        read_into_buffer(&mut mask_key, from_buffer, to_buffer, len)
957    } else {
958        read_into_buffer(&mut mask_key, from_buffer, to_buffer, len)
959    };
960
961    let has_continuation = len_read < len;
962    num_bytes_read += len_read;
963
964    let frame = match op_code {
965        WebSocketOpCode::ConnectionClose => decode_close_frame(to_buffer, num_bytes_read, len_read),
966        _ => WebSocketFrame {
967            num_bytes_from: num_bytes_read,
968            num_bytes_to: len_read,
969            op_code,
970            close_status: None,
971            is_fin_bit_set: if has_continuation {
972                false
973            } else {
974                is_fin_bit_set
975            },
976        },
977    };
978
979    if has_continuation {
980        let continuation_read = Some(ContinuationRead {
981            op_code,
982            count: len - len_read,
983            is_fin_bit_set,
984            mask_key,
985        });
986        Ok((frame, continuation_read))
987    } else {
988        Ok((frame, None))
989    }
990}
991
992fn get_op_code(val: u8) -> Result<WebSocketOpCode> {
993    match val {
994        0 => Ok(WebSocketOpCode::ContinuationFrame),
995        1 => Ok(WebSocketOpCode::TextFrame),
996        2 => Ok(WebSocketOpCode::BinaryFrame),
997        8 => Ok(WebSocketOpCode::ConnectionClose),
998        9 => Ok(WebSocketOpCode::Ping),
999        10 => Ok(WebSocketOpCode::Pong),
1000        _ => Err(Error::InvalidOpCode),
1001    }
1002}
1003
1004// returns (len, how_many_bytes_were_read)
1005fn read_length(byte2: u8, from_buffer: &[u8]) -> Result<(usize, usize)> {
1006    let len = byte2 & 0x7F;
1007
1008    if len < 126 {
1009        // for messages smaller than 126 bytes
1010        return Ok((len as usize, 0));
1011    } else if len == 126 {
1012        // for messages smaller than 64KB
1013        if from_buffer.len() < 2 {
1014            return Err(Error::ReadFrameIncomplete);
1015        }
1016        let mut buf: [u8; 2] = [0; 2];
1017        buf.copy_from_slice(&from_buffer[..2]);
1018        return Ok((BigEndian::read_u16(&buf) as usize, 2));
1019    } else if len == 127 {
1020        // for messages larger than 64KB
1021        if from_buffer.len() < 8 {
1022            return Err(Error::ReadFrameIncomplete);
1023        }
1024        let mut buf: [u8; 8] = [0; 8];
1025        buf.copy_from_slice(&from_buffer[..8]);
1026        return Ok((BigEndian::read_u64(&buf) as usize, 8));
1027    }
1028
1029    Err(Error::InvalidFrameLength)
1030}
1031
1032fn decode_close_frame(buffer: &mut [u8], num_bytes_read: usize, len: usize) -> WebSocketFrame {
1033    if len >= 2 {
1034        // NOTE: for now, don't read the close status description
1035        let code = BigEndian::read_u16(buffer);
1036        let close_status_code = WebSocketCloseStatusCode::from_u16(code);
1037
1038        return WebSocketFrame {
1039            num_bytes_from: num_bytes_read,
1040            num_bytes_to: len,
1041            op_code: WebSocketOpCode::ConnectionClose,
1042            close_status: Some(close_status_code),
1043            is_fin_bit_set: true,
1044        };
1045    }
1046
1047    build_client_disconnected_frame(num_bytes_read)
1048}
1049
1050fn build_client_disconnected_frame(num_bytes_from: usize) -> WebSocketFrame {
1051    WebSocketFrame {
1052        num_bytes_from,
1053        num_bytes_to: 0,
1054        op_code: WebSocketOpCode::ConnectionClose,
1055        close_status: Some(WebSocketCloseStatusCode::InternalServerError),
1056        is_fin_bit_set: true,
1057    }
1058}
1059
1060// ************************************************************************************************
1061// **************************************** TESTS *************************************************
1062// ************************************************************************************************
1063
1064#[cfg(test)]
1065mod tests {
1066    extern crate std;
1067    use super::*;
1068
1069    #[test]
1070    fn opening_handshake() {
1071        let client_request = "GET /chat HTTP/1.1
1072Host: localhost:5000
1073User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:62.0) Gecko/20100101 Firefox/62.0
1074Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
1075Accept-Language: en-US,en;q=0.5
1076Accept-Encoding: gzip, deflate
1077Sec-WebSocket-Version: 13
1078Origin: http://localhost:5000
1079Sec-WebSocket-Extensions: permessage-deflate
1080Sec-WebSocket-Key: Z7OY1UwHOx/nkSz38kfPwg==
1081Sec-WebSocket-Protocol: chat
1082DNT: 1
1083Connection: keep-alive, Upgrade
1084Pragma: no-cache
1085Cache-Control: no-cache
1086Upgrade: websocket
1087
1088";
1089
1090        let mut headers = [httparse::EMPTY_HEADER; 16];
1091        let mut request = httparse::Request::new(&mut headers);
1092        request.parse(client_request.as_bytes()).unwrap();
1093        let headers = headers.iter().map(|f| (f.name, f.value));
1094        let web_socket_context = read_http_header(headers).unwrap().unwrap();
1095        assert_eq!(
1096            "Z7OY1UwHOx/nkSz38kfPwg==",
1097            web_socket_context.sec_websocket_key
1098        );
1099        assert_eq!(
1100            "chat",
1101            web_socket_context
1102                .sec_websocket_protocol_list
1103                .get(0)
1104                .unwrap()
1105                .as_str()
1106        );
1107        let mut web_socket = WebSocketServer::new_server();
1108
1109        let mut ws_buffer: [u8; 3000] = [0; 3000];
1110        let size = web_socket
1111            .server_accept(&web_socket_context.sec_websocket_key, None, &mut ws_buffer)
1112            .unwrap();
1113        let response = std::str::from_utf8(&ws_buffer[..size]).unwrap();
1114        let client_response_expected = "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n";
1115        assert_eq!(client_response_expected, response);
1116    }
1117
1118    #[test]
1119    fn server_write_frame() {
1120        let mut buffer: [u8; 1000] = [0; 1000];
1121        let mut ws_server = WebSocketServer::new_server();
1122        let len = ws_server
1123            .write_frame(
1124                "hello".as_bytes(),
1125                &mut buffer,
1126                WebSocketOpCode::TextFrame,
1127                true,
1128            )
1129            .unwrap();
1130        let expected = [129, 5, 104, 101, 108, 108, 111];
1131        assert_eq!(&expected, &buffer[..len]);
1132    }
1133
1134    #[test]
1135    fn server_accept_should_write_sub_protocol() {
1136        let mut buffer: [u8; 1000] = [0; 1000];
1137        let mut ws_server = WebSocketServer::new_server();
1138        let ws_key = WebSocketKey::from("Z7OY1UwHOx/nkSz38kfPwg==");
1139        let sub_protocol = WebSocketSubProtocol::from("chat");
1140        let size = ws_server
1141            .server_accept(&ws_key, Some(&sub_protocol), &mut buffer)
1142            .unwrap();
1143        let response = std::str::from_utf8(&buffer[..size]).unwrap();
1144        assert_eq!("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Protocol: chat\r\nSec-WebSocket-Accept: ptPnPeDOTo6khJlzmLhOZSh2tAY=\r\n\r\n", response);
1145    }
1146
1147    #[test]
1148    fn closing_handshake() {
1149        let mut buffer1: [u8; 500] = [0; 500];
1150        let mut buffer2: [u8; 500] = [0; 500];
1151
1152        let mut rng = rand::thread_rng();
1153
1154        let mut ws_client = WebSocketClient::new_client(&mut rng);
1155        ws_client.state = WebSocketState::Open;
1156
1157        let mut ws_server = WebSocketServer::new_server();
1158        ws_server.state = WebSocketState::Open;
1159
1160        // client sends a close (initiates the close handshake)
1161        ws_client
1162            .close(WebSocketCloseStatusCode::NormalClosure, None, &mut buffer1)
1163            .unwrap();
1164
1165        // check that the client receives the close message
1166        let ws_result = ws_server.read(&buffer1, &mut buffer2).unwrap();
1167        assert_eq!(
1168            WebSocketReceiveMessageType::CloseMustReply,
1169            ws_result.message_type
1170        );
1171
1172        // server MUST respond to complete the handshake
1173        ws_server
1174            .write(
1175                WebSocketSendMessageType::CloseReply,
1176                true,
1177                &buffer2[..ws_result.len_to],
1178                &mut buffer1,
1179            )
1180            .unwrap();
1181        assert_eq!(WebSocketState::Closed, ws_server.state);
1182
1183        // check that the client receives the close message from the server
1184        let ws_result = ws_client.read(&buffer1, &mut buffer2).unwrap();
1185        assert_eq!(WebSocketState::Closed, ws_client.state);
1186
1187        assert_eq!(
1188            WebSocketReceiveMessageType::CloseCompleted,
1189            ws_result.message_type
1190        );
1191    }
1192
1193    #[test]
1194    fn send_message_from_client_to_server() {
1195        let mut buffer1: [u8; 1000] = [0; 1000];
1196        let mut buffer2: [u8; 1000] = [0; 1000];
1197
1198        // how to create a client
1199        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1200
1201        ws_client.state = WebSocketState::Open;
1202        let mut ws_server = WebSocketServer::new_server();
1203        ws_server.state = WebSocketState::Open;
1204
1205        // client sends a Text message
1206        let hello = "hello";
1207        let num_bytes = ws_client
1208            .write(
1209                WebSocketSendMessageType::Text,
1210                true,
1211                &hello.as_bytes(),
1212                &mut buffer1,
1213            )
1214            .unwrap();
1215
1216        // check that the Server receives the Text message
1217        let ws_result = ws_server.read(&buffer1[..num_bytes], &mut buffer2).unwrap();
1218        assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1219        let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1220        assert_eq!(hello, received);
1221    }
1222
1223    #[test]
1224    fn send_message_from_server_to_client() {
1225        let mut buffer1: [u8; 1000] = [0; 1000];
1226        let mut buffer2: [u8; 1000] = [0; 1000];
1227
1228        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1229        ws_client.state = WebSocketState::Open;
1230        let mut ws_server = WebSocketServer::new_server();
1231        ws_server.state = WebSocketState::Open;
1232
1233        // server sends a Text message
1234        let hello = "hello";
1235        let num_bytes = ws_server
1236            .write(
1237                WebSocketSendMessageType::Text,
1238                true,
1239                &hello.as_bytes(),
1240                &mut buffer1,
1241            )
1242            .unwrap();
1243
1244        // check that the client receives the Text message
1245        let ws_result = ws_client.read(&buffer1[..num_bytes], &mut buffer2).unwrap();
1246        assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1247        let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1248        assert_eq!(hello, received);
1249    }
1250
1251    #[test]
1252    fn receive_buffer_too_small() {
1253        let mut buffer1: [u8; 1000] = [0; 1000];
1254        let mut buffer2: [u8; 1000] = [0; 1000];
1255
1256        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1257        ws_client.state = WebSocketState::Open;
1258        let mut ws_server = WebSocketServer::new_server();
1259        ws_server.state = WebSocketState::Open;
1260
1261        let hello = "hello";
1262        ws_server
1263            .write(
1264                WebSocketSendMessageType::Text,
1265                true,
1266                &hello.as_bytes(),
1267                &mut buffer1,
1268            )
1269            .unwrap();
1270
1271        match ws_client.read(&buffer1[..1], &mut buffer2) {
1272            Err(Error::ReadFrameIncomplete) => {
1273                // test passes
1274            }
1275            _ => {
1276                assert_eq!(true, false);
1277            }
1278        }
1279    }
1280
1281    #[test]
1282    fn receive_large_frame_with_small_receive_buffer() {
1283        let mut buffer1: [u8; 1000] = [0; 1000];
1284        let mut buffer2: [u8; 1000] = [0; 1000];
1285
1286        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1287        ws_client.state = WebSocketState::Open;
1288        let mut ws_server = WebSocketServer::new_server();
1289        ws_server.state = WebSocketState::Open;
1290
1291        let hello = "hello";
1292        ws_server
1293            .write(
1294                WebSocketSendMessageType::Text,
1295                true,
1296                &hello.as_bytes(),
1297                &mut buffer1,
1298            )
1299            .unwrap();
1300
1301        let ws_result = ws_client.read(&buffer1[..2], &mut buffer2).unwrap();
1302        assert_eq!(0, ws_result.len_to);
1303        assert_eq!(false, ws_result.end_of_message);
1304        let ws_result = ws_client.read(&buffer1[2..3], &mut buffer2).unwrap();
1305        assert_eq!(1, ws_result.len_to);
1306        assert_eq!(
1307            "h",
1308            std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1309        );
1310        assert_eq!(false, ws_result.end_of_message);
1311        let ws_result = ws_client.read(&buffer1[3..], &mut buffer2).unwrap();
1312        assert_eq!(4, ws_result.len_to);
1313        assert_eq!(
1314            "ello",
1315            std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1316        );
1317        assert_eq!(true, ws_result.end_of_message);
1318    }
1319
1320    #[test]
1321    fn send_large_frame() {
1322        let buffer1 = [0u8; 15944];
1323        let mut buffer2 = [0u8; 64000];
1324        let mut buffer3 = [0u8; 64000];
1325
1326        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1327        ws_client.state = WebSocketState::Open;
1328        let mut ws_server = WebSocketServer::new_server();
1329        ws_server.state = WebSocketState::Open;
1330
1331        ws_client
1332            .write(
1333                WebSocketSendMessageType::Binary,
1334                true,
1335                &buffer1,
1336                &mut buffer2,
1337            )
1338            .unwrap();
1339
1340        let ws_result = ws_client.read(&buffer2, &mut buffer3).unwrap();
1341        assert_eq!(true, ws_result.end_of_message);
1342        assert_eq!(buffer1.len(), ws_result.len_to);
1343    }
1344
1345    #[test]
1346    fn receive_large_frame_multi_read() {
1347        let mut buffer1 = [0_u8; 1000];
1348        let mut buffer2 = [0_u8; 1000];
1349
1350        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1351        ws_client.state = WebSocketState::Open;
1352        let mut ws_server = WebSocketServer::new_server();
1353        ws_server.state = WebSocketState::Open;
1354
1355        let message = "Hello, world. This is a long message that takes multiple reads";
1356        ws_server
1357            .write(
1358                WebSocketSendMessageType::Text,
1359                true,
1360                &message.as_bytes(),
1361                &mut buffer1,
1362            )
1363            .unwrap();
1364
1365        let mut buffer2_cursor = 0;
1366        let ws_result = ws_client.read(&buffer1[..40], &mut buffer2).unwrap();
1367        assert_eq!(false, ws_result.end_of_message);
1368        buffer2_cursor += ws_result.len_to;
1369        let ws_result = ws_client
1370            .read(
1371                &buffer1[ws_result.len_from..],
1372                &mut buffer2[buffer2_cursor..],
1373            )
1374            .unwrap();
1375        assert_eq!(true, ws_result.end_of_message);
1376        buffer2_cursor += ws_result.len_to;
1377
1378        assert_eq!(
1379            message,
1380            std::str::from_utf8(&buffer2[..buffer2_cursor]).unwrap()
1381        );
1382    }
1383
1384    #[test]
1385    fn multiple_messages_in_receive_buffer() {
1386        let mut buffer1 = [0_u8; 1000];
1387        let mut buffer2 = [0_u8; 1000];
1388
1389        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1390        ws_client.state = WebSocketState::Open;
1391        let mut ws_server = WebSocketServer::new_server();
1392        ws_server.state = WebSocketState::Open;
1393
1394        let message1 = "Hello, world.";
1395        let len = ws_client
1396            .write(
1397                WebSocketSendMessageType::Text,
1398                true,
1399                &message1.as_bytes(),
1400                &mut buffer1,
1401            )
1402            .unwrap();
1403        let message2 = "This is another message.";
1404        ws_client
1405            .write(
1406                WebSocketSendMessageType::Text,
1407                true,
1408                &message2.as_bytes(),
1409                &mut buffer1[len..],
1410            )
1411            .unwrap();
1412
1413        let mut buffer1_cursor = 0;
1414        let mut buffer2_cursor = 0;
1415        let ws_result = ws_server
1416            .read(&buffer1[buffer1_cursor..], &mut buffer2)
1417            .unwrap();
1418        assert_eq!(true, ws_result.end_of_message);
1419        buffer1_cursor += ws_result.len_from;
1420        buffer2_cursor += ws_result.len_to;
1421        let ws_result = ws_server
1422            .read(&buffer1[buffer1_cursor..], &mut buffer2[buffer2_cursor..])
1423            .unwrap();
1424        assert_eq!(true, ws_result.end_of_message);
1425        assert_eq!(
1426            message1,
1427            std::str::from_utf8(&buffer2[..buffer2_cursor]).unwrap()
1428        );
1429
1430        assert_eq!(
1431            message2,
1432            std::str::from_utf8(&buffer2[buffer2_cursor..buffer2_cursor + ws_result.len_to])
1433                .unwrap()
1434        );
1435    }
1436
1437    #[test]
1438    fn receive_large_frame_with_small_send_buffer() {
1439        let mut buffer1: [u8; 1000] = [0; 1000];
1440        let mut buffer2: [u8; 1000] = [0; 1000];
1441
1442        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1443        ws_client.state = WebSocketState::Open;
1444        let mut ws_server = WebSocketServer::new_server();
1445        ws_server.state = WebSocketState::Open;
1446
1447        let hello = "hello";
1448        ws_server
1449            .write(
1450                WebSocketSendMessageType::Text,
1451                true,
1452                &hello.as_bytes(),
1453                &mut buffer1,
1454            )
1455            .unwrap();
1456
1457        let ws_result = ws_client.read(&buffer1, &mut buffer2[..1]).unwrap();
1458        assert_eq!(1, ws_result.len_to);
1459        assert_eq!(
1460            "h",
1461            std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1462        );
1463        assert_eq!(false, ws_result.end_of_message);
1464        let ws_result = ws_client
1465            .read(&buffer1[ws_result.len_from..], &mut buffer2[..4])
1466            .unwrap();
1467        assert_eq!(4, ws_result.len_to);
1468        assert_eq!(
1469            "ello",
1470            std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap()
1471        );
1472        assert_eq!(true, ws_result.end_of_message);
1473    }
1474
1475    #[test]
1476    fn send_two_frame_message() {
1477        let mut buffer1: [u8; 1000] = [0; 1000];
1478        let mut buffer2: [u8; 1000] = [0; 1000];
1479        // let mut rng = rand::thread_rng();
1480
1481        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1482        ws_client.state = WebSocketState::Open;
1483        let mut ws_server = WebSocketServer::new_server();
1484        ws_server.state = WebSocketState::Open;
1485
1486        // client sends a fragmented Text message
1487        let hello = "Hello, ";
1488        let num_bytes_hello = ws_server
1489            .write(
1490                WebSocketSendMessageType::Text,
1491                false,
1492                &hello.as_bytes(),
1493                &mut buffer1,
1494            )
1495            .unwrap();
1496
1497        // client sends the remaining Text message
1498        let world = "World!";
1499        let num_bytes_world = ws_server
1500            .write(
1501                WebSocketSendMessageType::Text,
1502                true,
1503                &world.as_bytes(),
1504                &mut buffer1[num_bytes_hello..],
1505            )
1506            .unwrap();
1507
1508        // check that the Server receives the entire Text message
1509        let ws_result1 = ws_client
1510            .read(&buffer1[..num_bytes_hello], &mut buffer2)
1511            .unwrap();
1512        assert_eq!(WebSocketReceiveMessageType::Text, ws_result1.message_type);
1513        assert_eq!(false, ws_result1.end_of_message);
1514        let ws_result2 = ws_client
1515            .read(
1516                &buffer1[num_bytes_hello..num_bytes_hello + num_bytes_world],
1517                &mut buffer2[ws_result1.len_to..],
1518            )
1519            .unwrap();
1520        assert_eq!(WebSocketReceiveMessageType::Text, ws_result2.message_type);
1521        assert_eq!(true, ws_result2.end_of_message);
1522
1523        let received =
1524            std::str::from_utf8(&buffer2[..ws_result1.len_to + ws_result2.len_to]).unwrap();
1525        assert_eq!("Hello, World!", received);
1526    }
1527
1528    #[test]
1529    fn send_multi_frame_message() {
1530        let mut buffer1: [u8; 1000] = [0; 1000];
1531        let mut buffer2: [u8; 1000] = [0; 1000];
1532
1533        let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
1534        ws_client.state = WebSocketState::Open;
1535        let mut ws_server = WebSocketServer::new_server();
1536        ws_server.state = WebSocketState::Open;
1537
1538        // server sends the first fragmented Text frame
1539        let fragment1 = "fragment1";
1540        let fragment1_num_bytes = ws_server
1541            .write(
1542                WebSocketSendMessageType::Text,
1543                false,
1544                &fragment1.as_bytes(),
1545                &mut buffer1,
1546            )
1547            .unwrap();
1548
1549        // send fragment2 as a continuation frame
1550        let fragment2 = "fragment2";
1551        let fragment2_num_bytes = ws_server
1552            .write(
1553                WebSocketSendMessageType::Text,
1554                false,
1555                &fragment2.as_bytes(),
1556                &mut buffer1[fragment1_num_bytes..],
1557            )
1558            .unwrap();
1559
1560        // send fragment3 as a continuation frame and indicate that this is the last frame
1561        let fragment3 = "fragment3";
1562        let _fragment3_num_bytes = ws_server
1563            .write(
1564                WebSocketSendMessageType::Text,
1565                true,
1566                &fragment3.as_bytes(),
1567                &mut buffer1[fragment1_num_bytes + fragment2_num_bytes..],
1568            )
1569            .unwrap();
1570
1571        // check that the client receives the "fragment1" Text message
1572        let ws_result = ws_client.read(&buffer1, &mut buffer2).unwrap();
1573        assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1574        let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1575        assert_eq!(fragment1, received);
1576        assert_eq!(ws_result.end_of_message, false);
1577        let mut read_cursor = ws_result.len_from;
1578
1579        // check that the client receives the "fragment2" Text message
1580        let ws_result = ws_client
1581            .read(&buffer1[read_cursor..], &mut buffer2)
1582            .unwrap();
1583        assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1584        let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1585        assert_eq!(fragment2, received);
1586        assert_eq!(ws_result.end_of_message, false);
1587        read_cursor += ws_result.len_from;
1588
1589        // check that the client receives the "fragment3" Text message
1590        let ws_result = ws_client
1591            .read(&buffer1[read_cursor..], &mut buffer2)
1592            .unwrap();
1593        assert_eq!(WebSocketReceiveMessageType::Text, ws_result.message_type);
1594        let received = std::str::from_utf8(&buffer2[..ws_result.len_to]).unwrap();
1595        assert_eq!(fragment3, received);
1596        assert_eq!(ws_result.end_of_message, true);
1597
1598        // check the actual bytes in the write buffer for fragment1
1599        let (is_fin_bit_set, op_code) = read_first_byte(buffer1[0]);
1600        assert_eq!(is_fin_bit_set, false);
1601        assert_eq!(op_code, WebSocketOpCode::TextFrame);
1602
1603        // check the actual bytes in the write buffer for fragment2
1604        let (is_fin_bit_set, op_code) = read_first_byte(buffer1[fragment1_num_bytes]);
1605        assert_eq!(is_fin_bit_set, false);
1606        assert_eq!(op_code, WebSocketOpCode::ContinuationFrame);
1607
1608        // check the actual bytes in the write buffer for fragment3
1609        let (is_fin_bit_set, op_code) =
1610            read_first_byte(buffer1[fragment1_num_bytes + fragment2_num_bytes]);
1611        assert_eq!(is_fin_bit_set, true);
1612        assert_eq!(op_code, WebSocketOpCode::ContinuationFrame);
1613    }
1614
1615    fn read_first_byte(byte: u8) -> (bool, WebSocketOpCode) {
1616        const FIN_BIT_FLAG: u8 = 0x80;
1617        const OP_CODE_FLAG: u8 = 0x0F;
1618        let is_fin_bit_set = (byte & FIN_BIT_FLAG) == FIN_BIT_FLAG;
1619        let op_code = get_op_code(byte & OP_CODE_FLAG).unwrap();
1620        (is_fin_bit_set, op_code)
1621    }
1622}