embedded_websocket_embedded_io/
framer.rs

1// This module helps you work with the websocket library using a stream of data rather than using it as a raw codec.
2// This is the most common use case when working with websockets and is recommended due to the hand shaky nature of
3// the protocol as well as the fact that an input buffer can contain multiple websocket frames or maybe only a fragment of one.
4// This module allows you to work with discrete websocket frames rather than the multiple fragments you read off a stream.
5// NOTE: if you are using the standard library then you can use the built in Read and Write traits from std otherwise
6//       you have to implement the Read and Write traits specified below
7
8use crate::{
9    WebSocket, WebSocketCloseStatusCode, WebSocketContext, WebSocketOptions,
10    WebSocketReceiveMessageType, WebSocketSendMessageType, WebSocketState, WebSocketSubProtocol,
11    WebSocketType,
12};
13use core::{cmp::min, str::Utf8Error};
14use rand_core::RngCore;
15
16// automagically implement the Stream trait for TcpStream if we are using the standard library
17// if you were using no_std you would have to implement your own stream
18#[cfg(feature = "std")]
19impl Stream<std::io::Error> for std::net::TcpStream {
20    fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
21        std::io::Read::read(self, buf)
22    }
23
24    fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> {
25        std::io::Write::write_all(self, buf)
26    }
27}
28
29pub trait Stream<E> {
30    fn read(&mut self, buf: &mut [u8]) -> Result<usize, E>;
31    fn write_all(&mut self, buf: &[u8]) -> Result<(), E>;
32}
33
34pub enum ReadResult<'a> {
35    Binary(&'a [u8]),
36    Text(&'a str),
37    Pong(&'a [u8]),
38    Closed,
39}
40
41#[derive(Debug)]
42pub enum FramerError<E> {
43    Io(E),
44    FrameTooLarge(usize),
45    Utf8(Utf8Error),
46    HttpHeader(httparse::Error),
47    WebSocket(crate::Error),
48}
49
50pub struct Framer<'a, TRng, TWebSocketType>
51where
52    TRng: RngCore,
53    TWebSocketType: WebSocketType,
54{
55    read_buf: &'a mut [u8],
56    write_buf: &'a mut [u8],
57    read_cursor: &'a mut usize,
58    frame_cursor: usize,
59    read_len: usize,
60    websocket: &'a mut WebSocket<TRng, TWebSocketType>,
61}
62
63impl<'a, TRng> Framer<'a, TRng, crate::Client>
64where
65    TRng: RngCore,
66{
67    pub fn connect<E>(
68        &mut self,
69        stream: &mut impl Stream<E>,
70        websocket_options: &WebSocketOptions,
71    ) -> Result<Option<WebSocketSubProtocol>, FramerError<E>> {
72        let (len, web_socket_key) = self
73            .websocket
74            .client_connect(websocket_options, self.write_buf)
75            .map_err(FramerError::WebSocket)?;
76        stream
77            .write_all(&self.write_buf[..len])
78            .map_err(FramerError::Io)?;
79        *self.read_cursor = 0;
80
81        loop {
82            // read the response from the server and check it to complete the opening handshake
83            let received_size = stream
84                .read(&mut self.read_buf[*self.read_cursor..])
85                .map_err(FramerError::Io)?;
86
87            match self.websocket.client_accept(
88                &web_socket_key,
89                &self.read_buf[..*self.read_cursor + received_size],
90            ) {
91                Ok((len, sub_protocol)) => {
92                    // "consume" the HTTP header that we have read from the stream
93                    // read_cursor would be 0 if we exactly read the HTTP header from the stream and nothing else
94                    *self.read_cursor += received_size - len;
95                    return Ok(sub_protocol);
96                }
97                Err(crate::Error::HttpHeaderIncomplete) => {
98                    *self.read_cursor += received_size;
99                    // continue reading HTTP header in loop
100                }
101                Err(e) => {
102                    *self.read_cursor += received_size;
103                    return Err(FramerError::WebSocket(e));
104                }
105            }
106        }
107    }
108}
109
110impl<'a, TRng> Framer<'a, TRng, crate::Server>
111where
112    TRng: RngCore,
113{
114    pub fn accept<E>(
115        &mut self,
116        stream: &mut impl Stream<E>,
117        websocket_context: &WebSocketContext,
118    ) -> Result<(), FramerError<E>> {
119        let len = self
120            .websocket
121            .server_accept(&websocket_context.sec_websocket_key, None, self.write_buf)
122            .map_err(FramerError::WebSocket)?;
123
124        stream
125            .write_all(&self.write_buf[..len])
126            .map_err(FramerError::Io)?;
127        Ok(())
128    }
129}
130
131impl<'a, TRng, TWebSocketType> Framer<'a, TRng, TWebSocketType>
132where
133    TRng: RngCore,
134    TWebSocketType: WebSocketType,
135{
136    // read and write buffers are usually quite small (4KB) and can be smaller
137    // than the frame buffer but use whatever is is appropriate for your stream
138    pub fn new(
139        read_buf: &'a mut [u8],
140        read_cursor: &'a mut usize,
141        write_buf: &'a mut [u8],
142        websocket: &'a mut WebSocket<TRng, TWebSocketType>,
143    ) -> Self {
144        Self {
145            read_buf,
146            write_buf,
147            read_cursor,
148            frame_cursor: 0,
149            read_len: 0,
150            websocket,
151        }
152    }
153
154    pub fn state(&self) -> WebSocketState {
155        self.websocket.state
156    }
157
158    // calling close on a websocket that has already been closed by the other party has no effect
159    pub fn close<E>(
160        &mut self,
161        stream: &mut impl Stream<E>,
162        close_status: WebSocketCloseStatusCode,
163        status_description: Option<&str>,
164    ) -> Result<(), FramerError<E>> {
165        let len = self
166            .websocket
167            .close(close_status, status_description, self.write_buf)
168            .map_err(FramerError::WebSocket)?;
169        stream
170            .write_all(&self.write_buf[..len])
171            .map_err(FramerError::Io)?;
172        Ok(())
173    }
174
175    pub fn write<E>(
176        &mut self,
177        stream: &mut impl Stream<E>,
178        message_type: WebSocketSendMessageType,
179        end_of_message: bool,
180        frame_buf: &[u8],
181    ) -> Result<(), FramerError<E>> {
182        let len = self
183            .websocket
184            .write(message_type, end_of_message, frame_buf, self.write_buf)
185            .map_err(FramerError::WebSocket)?;
186        stream
187            .write_all(&self.write_buf[..len])
188            .map_err(FramerError::Io)?;
189        Ok(())
190    }
191
192    // frame_buf should be large enough to hold an entire websocket text frame
193    // this function will block until it has recieved a full websocket frame.
194    // It will wait until the last fragmented frame has arrived.
195    pub fn read<'b, E>(
196        &mut self,
197        stream: &mut impl Stream<E>,
198        frame_buf: &'b mut [u8],
199    ) -> Result<ReadResult<'b>, FramerError<E>> {
200        loop {
201            if *self.read_cursor == 0 || *self.read_cursor == self.read_len {
202                self.read_len = stream.read(self.read_buf).map_err(FramerError::Io)?;
203                *self.read_cursor = 0;
204            }
205
206            if self.read_len == 0 {
207                return Ok(ReadResult::Closed);
208            }
209
210            loop {
211                if *self.read_cursor == self.read_len {
212                    break;
213                }
214
215                if self.frame_cursor == frame_buf.len() {
216                    return Err(FramerError::FrameTooLarge(frame_buf.len()));
217                }
218
219                let ws_result = self
220                    .websocket
221                    .read(
222                        &self.read_buf[*self.read_cursor..self.read_len],
223                        &mut frame_buf[self.frame_cursor..],
224                    )
225                    .map_err(FramerError::WebSocket)?;
226
227                *self.read_cursor += ws_result.len_from;
228
229                match ws_result.message_type {
230                    WebSocketReceiveMessageType::Binary => {
231                        self.frame_cursor += ws_result.len_to;
232                        if ws_result.end_of_message {
233                            let frame = &frame_buf[..self.frame_cursor];
234                            self.frame_cursor = 0;
235                            return Ok(ReadResult::Binary(frame));
236                        }
237                    }
238                    WebSocketReceiveMessageType::Text => {
239                        self.frame_cursor += ws_result.len_to;
240                        if ws_result.end_of_message {
241                            let frame = &frame_buf[..self.frame_cursor];
242                            self.frame_cursor = 0;
243                            let text = core::str::from_utf8(frame).map_err(FramerError::Utf8)?;
244                            return Ok(ReadResult::Text(text));
245                        }
246                    }
247                    WebSocketReceiveMessageType::CloseMustReply => {
248                        self.send_back(
249                            stream,
250                            frame_buf,
251                            ws_result.len_to,
252                            WebSocketSendMessageType::CloseReply,
253                        )?;
254                        return Ok(ReadResult::Closed);
255                    }
256                    WebSocketReceiveMessageType::CloseCompleted => return Ok(ReadResult::Closed),
257                    WebSocketReceiveMessageType::Ping => {
258                        self.send_back(
259                            stream,
260                            frame_buf,
261                            ws_result.len_to,
262                            WebSocketSendMessageType::Pong,
263                        )?;
264                    }
265                    WebSocketReceiveMessageType::Pong => {
266                        let bytes =
267                            &frame_buf[self.frame_cursor..self.frame_cursor + ws_result.len_to];
268                        return Ok(ReadResult::Pong(bytes));
269                    }
270                }
271            }
272        }
273    }
274
275    fn send_back<E>(
276        &mut self,
277        stream: &mut impl Stream<E>,
278        frame_buf: &'_ mut [u8],
279        len_to: usize,
280        send_message_type: WebSocketSendMessageType,
281    ) -> Result<(), FramerError<E>> {
282        let payload_len = min(self.write_buf.len(), len_to);
283        let from = &frame_buf[self.frame_cursor..self.frame_cursor + payload_len];
284        let len = self
285            .websocket
286            .write(send_message_type, true, from, self.write_buf)
287            .map_err(FramerError::WebSocket)?;
288        stream
289            .write_all(&self.write_buf[..len])
290            .map_err(FramerError::Io)?;
291        Ok(())
292    }
293}