Skip to main content

embedded_websocket/
framer_async.rs

1use core::{fmt::Debug, ops::Deref, str::Utf8Error};
2
3use futures::{Sink, SinkExt, Stream, StreamExt};
4use rand_core::RngCore;
5
6use crate::{
7    WebSocket, WebSocketCloseStatusCode, WebSocketContext, WebSocketOptions,
8    WebSocketReceiveMessageType, WebSocketSendMessageType, WebSocketSubProtocol, WebSocketType,
9};
10
11pub struct CloseMessage<'a> {
12    pub status_code: WebSocketCloseStatusCode,
13    pub reason: &'a [u8],
14}
15
16pub enum ReadResult<'a> {
17    Binary(&'a [u8]),
18    Text(&'a str),
19    /// We received a pong message in response to a ping message we sent earlier.
20    /// This should contain the same data as was sent in the ping message
21    Pong(&'a [u8]),
22    /// We received a ping message from the other end
23    /// A pong message with the same content will automatically be sent back as per the websocket spec
24    /// However, this message is exposed to you so that you can see what is in the ping message
25    Ping(&'a [u8]),
26    /// The other end initiated a close handshake. The contents of this message is usually the reason for the close
27    /// A close response message will automatically be sent back to the other end to complete the close handshake
28    /// However, this message is exposed to you so that you can see why
29    Close(CloseMessage<'a>),
30}
31
32#[derive(Debug)]
33pub enum FramerError<E> {
34    Io(E),
35    FrameTooLarge(usize),
36    Utf8(Utf8Error),
37    HttpHeader(httparse::Error),
38    WebSocket(crate::Error),
39    Disconnected,
40    RxBufferTooSmall(usize),
41}
42
43pub struct Framer<TRng, TWebSocketType>
44where
45    TRng: RngCore,
46    TWebSocketType: WebSocketType,
47{
48    websocket: WebSocket<TRng, TWebSocketType>,
49    frame_cursor: usize,
50    rx_remainder_len: usize,
51}
52
53impl<TRng> Framer<TRng, crate::Server>
54where
55    TRng: RngCore,
56{
57    pub async fn accept<'a, B, E>(
58        &mut self,
59        stream: &mut (impl Stream<Item = Result<B, E>> + Sink<&'a [u8], Error = E> + Unpin),
60        buffer: &'a mut [u8],
61        websocket_context: &WebSocketContext,
62    ) -> Result<(), FramerError<E>> {
63        let len = self
64            .websocket
65            .server_accept(&websocket_context.sec_websocket_key, None, buffer)
66            .map_err(FramerError::WebSocket)?;
67
68        stream.send(&buffer[..len]).await.map_err(FramerError::Io)?;
69        Ok(())
70    }
71}
72
73impl<TRng> Framer<TRng, crate::Client>
74where
75    TRng: RngCore,
76{
77    pub async fn connect<'a, B, E>(
78        &mut self,
79        stream: &mut (impl Stream<Item = Result<B, E>> + Sink<&'a [u8], Error = E> + Unpin),
80        buffer: &'a mut [u8],
81        websocket_options: &WebSocketOptions<'_>,
82    ) -> Result<Option<WebSocketSubProtocol>, FramerError<E>>
83    where
84        B: AsRef<[u8]>,
85    {
86        let (tx_len, web_socket_key) = self
87            .websocket
88            .client_connect(websocket_options, buffer)
89            .map_err(FramerError::WebSocket)?;
90
91        let (tx_buf, rx_buf) = buffer.split_at_mut(tx_len);
92        stream.send(tx_buf).await.map_err(FramerError::Io)?;
93        stream.flush().await.map_err(FramerError::Io)?;
94
95        match stream.next().await {
96            Some(buf) => {
97                let buf = buf.map_err(FramerError::Io)?;
98                let buf = buf.as_ref();
99
100                match self.websocket.client_accept(&web_socket_key, buf) {
101                    Ok((len, sub_protocol)) => {
102                        // "consume" the HTTP header that we have read from the stream
103                        // read_cursor would be 0 if we exactly read the HTTP header from the stream and nothing else
104
105                        // copy the remaining bytes to the end of the rx_buf (which is also the end of the buffer) because they are the contents of the next websocket frame(s)
106                        let from = len;
107                        let to = buf.len();
108                        let remaining_len = to - from;
109
110                        if remaining_len > 0 {
111                            let rx_start = rx_buf.len() - remaining_len;
112                            rx_buf[rx_start..].copy_from_slice(&buf[from..to]);
113                            self.rx_remainder_len = remaining_len;
114                        }
115
116                        Ok(sub_protocol)
117                    }
118                    Err(crate::Error::HttpHeaderIncomplete) => {
119                        // TODO: continue reading HTTP header in loop
120                        panic!("http header not complete");
121                    }
122                    Err(e) => Err(FramerError::WebSocket(e)),
123                }
124            }
125            None => Err(FramerError::Disconnected),
126        }
127    }
128}
129
130impl<TRng, TWebSocketType> Framer<TRng, TWebSocketType>
131where
132    TRng: RngCore,
133    TWebSocketType: WebSocketType,
134{
135    pub fn new(websocket: WebSocket<TRng, TWebSocketType>) -> Self {
136        Self {
137            websocket,
138            frame_cursor: 0,
139            rx_remainder_len: 0,
140        }
141    }
142
143    pub fn encode<E>(
144        &mut self,
145        message_type: WebSocketSendMessageType,
146        end_of_message: bool,
147        from: &[u8],
148        to: &mut [u8],
149    ) -> Result<usize, FramerError<E>> {
150        let len = self
151            .websocket
152            .write(message_type, end_of_message, from, to)
153            .map_err(FramerError::WebSocket)?;
154
155        Ok(len)
156    }
157
158    pub async fn write<'b, E>(
159        &mut self,
160        tx: &mut (impl Sink<&'b [u8], Error = E> + Unpin),
161        tx_buf: &'b mut [u8],
162        message_type: WebSocketSendMessageType,
163        end_of_message: bool,
164        frame_buf: &[u8],
165    ) -> Result<(), FramerError<E>>
166    where
167        E: Debug,
168    {
169        let len = self
170            .websocket
171            .write(message_type, end_of_message, frame_buf, tx_buf)
172            .map_err(FramerError::WebSocket)?;
173
174        tx.send(&tx_buf[..len]).await.map_err(FramerError::Io)?;
175        tx.flush().await.map_err(FramerError::Io)?;
176        Ok(())
177    }
178
179    pub async fn close<'b, E>(
180        &mut self,
181        tx: &mut (impl Sink<&'b [u8], Error = E> + Unpin),
182        tx_buf: &'b mut [u8],
183        close_status: WebSocketCloseStatusCode,
184        status_description: Option<&str>,
185    ) -> Result<(), FramerError<E>>
186    where
187        E: Debug,
188    {
189        let len = self
190            .websocket
191            .close(close_status, status_description, tx_buf)
192            .map_err(FramerError::WebSocket)?;
193
194        tx.send(&tx_buf[..len]).await.map_err(FramerError::Io)?;
195        tx.flush().await.map_err(FramerError::Io)?;
196        Ok(())
197    }
198
199    // NOTE: any unused bytes read from the stream but not decoded are stored at the end
200    // of the buffer to be used next time this read function is called. This also applies to
201    // any unused bytes read when the connect handshake was made. Therefore it is important that
202    // the caller does not clear this buffer between calls or use it for anthing other than reads.
203    pub async fn read<'a, B: Deref<Target = [u8]>, E>(
204        &mut self,
205        stream: &mut (impl Stream<Item = Result<B, E>> + Sink<&'a [u8], Error = E> + Unpin),
206        buffer: &'a mut [u8],
207    ) -> Option<Result<ReadResult<'a>, FramerError<E>>>
208    where
209        E: Debug,
210    {
211        if self.rx_remainder_len == 0 {
212            match stream.next().await {
213                Some(Ok(input)) => {
214                    if buffer.len() < input.len() {
215                        return Some(Err(FramerError::RxBufferTooSmall(input.len())));
216                    }
217
218                    let rx_start = buffer.len() - input.len();
219
220                    // copy to end of buffer
221                    buffer[rx_start..].copy_from_slice(&input);
222                    self.rx_remainder_len = input.len()
223                }
224                Some(Err(e)) => {
225                    return Some(Err(FramerError::Io(e)));
226                }
227                None => return None,
228            }
229        }
230
231        let rx_start = buffer.len() - self.rx_remainder_len;
232        let (frame_buf, rx_buf) = buffer.split_at_mut(rx_start);
233
234        let ws_result = match self.websocket.read(rx_buf, frame_buf) {
235            Ok(ws_result) => ws_result,
236            Err(e) => return Some(Err(FramerError::WebSocket(e))),
237        };
238
239        self.rx_remainder_len -= ws_result.len_from;
240
241        match ws_result.message_type {
242            WebSocketReceiveMessageType::Binary => {
243                self.frame_cursor += ws_result.len_to;
244                if ws_result.end_of_message {
245                    let range = 0..self.frame_cursor;
246                    self.frame_cursor = 0;
247                    return Some(Ok(ReadResult::Binary(&frame_buf[range])));
248                }
249            }
250            WebSocketReceiveMessageType::Text => {
251                self.frame_cursor += ws_result.len_to;
252                if ws_result.end_of_message {
253                    let range = 0..self.frame_cursor;
254                    self.frame_cursor = 0;
255                    match core::str::from_utf8(&frame_buf[range]) {
256                        Ok(text) => return Some(Ok(ReadResult::Text(text))),
257                        Err(e) => return Some(Err(FramerError::Utf8(e))),
258                    }
259                }
260            }
261            WebSocketReceiveMessageType::CloseMustReply => {
262                let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
263
264                // create a tx_buf from the end of the frame_buf
265                let tx_buf_len = ws_result.len_to + 14; // for extra websocket header
266                let split_at = frame_buf.len() - tx_buf_len;
267                let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
268
269                match self.websocket.write(
270                    WebSocketSendMessageType::CloseReply,
271                    true,
272                    &frame_buf[range.start..range.end],
273                    tx_buf,
274                ) {
275                    Ok(len) => match stream.send(&tx_buf[..len]).await {
276                        Ok(()) => {
277                            self.frame_cursor = 0;
278                            let status_code = ws_result
279                                .close_status
280                                .expect("close message must have code");
281                            let reason = &frame_buf[range];
282                            return Some(Ok(ReadResult::Close(CloseMessage {
283                                status_code,
284                                reason,
285                            })));
286                        }
287                        Err(e) => return Some(Err(FramerError::Io(e))),
288                    },
289                    Err(e) => return Some(Err(FramerError::WebSocket(e))),
290                }
291            }
292            WebSocketReceiveMessageType::CloseCompleted => return None,
293            WebSocketReceiveMessageType::Pong => {
294                let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
295                return Some(Ok(ReadResult::Pong(&frame_buf[range])));
296            }
297            WebSocketReceiveMessageType::Ping => {
298                let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
299
300                // create a tx_buf from the end of the frame_buf
301                let tx_buf_len = ws_result.len_to + 14; // for extra websocket header
302                let split_at = frame_buf.len() - tx_buf_len;
303                let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
304
305                match self.websocket.write(
306                    WebSocketSendMessageType::Pong,
307                    true,
308                    &frame_buf[range.start..range.end],
309                    tx_buf,
310                ) {
311                    Ok(len) => match stream.send(&tx_buf[..len]).await {
312                        Ok(()) => {
313                            return Some(Ok(ReadResult::Ping(&frame_buf[range])));
314                        }
315                        Err(e) => return Some(Err(FramerError::Io(e))),
316                    },
317                    Err(e) => return Some(Err(FramerError::WebSocket(e))),
318                }
319            }
320        }
321
322        None
323    }
324}