embedded_websocket/
framer_async.rs

1use core::{fmt::Debug, ops::Deref, str::Utf8Error};
2
3use futures::{Sink, SinkExt, Stream, StreamExt};
4use rand_core::RngCore;
5
6use crate::{
7    WebSocket, WebSocketCloseStatusCode, WebSocketOptions, WebSocketReceiveMessageType,
8    WebSocketSendMessageType, WebSocketSubProtocol, WebSocketType,
9};
10
11pub struct CloseMessage<'a> {
12    pub status_code: WebSocketCloseStatusCode,
13    pub reason: &'a [u8],
14}
15
16pub enum ReadResult<'a> {
17    Binary(&'a [u8]),
18    Text(&'a str),
19    /// We received a pong message in response to a ping message we sent earlier.
20    /// This should contain the same data as was sent in the ping message
21    Pong(&'a [u8]),
22    /// We received a ping message from the other end
23    /// A pong message with the same content will automatically be sent back as per the websocket spec
24    /// However, this message is exposed to you so that you can see what is in the ping message
25    Ping(&'a [u8]),
26    /// The other end initiated a close handshake. The contents of this message is usually the reason for the close
27    /// A close response message will automatically be sent back to the other end to complete the close handshake
28    /// However, this message is exposed to you so that you can see why
29    Close(CloseMessage<'a>),
30}
31
32#[derive(Debug)]
33pub enum FramerError<E> {
34    Io(E),
35    FrameTooLarge(usize),
36    Utf8(Utf8Error),
37    HttpHeader(httparse::Error),
38    WebSocket(crate::Error),
39    Disconnected,
40    RxBufferTooSmall(usize),
41}
42
43pub struct Framer<TRng, TWebSocketType>
44where
45    TRng: RngCore,
46    TWebSocketType: WebSocketType,
47{
48    websocket: WebSocket<TRng, TWebSocketType>,
49    frame_cursor: usize,
50    rx_remainder_len: usize,
51}
52
53impl<TRng> Framer<TRng, crate::Client>
54where
55    TRng: RngCore,
56{
57    pub async fn connect<'a, B, E>(
58        &mut self,
59        stream: &mut (impl Stream<Item = Result<B, E>> + Sink<&'a [u8], Error = E> + Unpin),
60        buffer: &'a mut [u8],
61        websocket_options: &WebSocketOptions<'_>,
62    ) -> Result<Option<WebSocketSubProtocol>, FramerError<E>>
63    where
64        B: AsRef<[u8]>,
65    {
66        let (tx_len, web_socket_key) = self
67            .websocket
68            .client_connect(websocket_options, buffer)
69            .map_err(FramerError::WebSocket)?;
70
71        let (tx_buf, rx_buf) = buffer.split_at_mut(tx_len);
72        stream.send(tx_buf).await.map_err(FramerError::Io)?;
73        stream.flush().await.map_err(FramerError::Io)?;
74
75        loop {
76            match stream.next().await {
77                Some(buf) => {
78                    let buf = buf.map_err(FramerError::Io)?;
79                    let buf = buf.as_ref();
80
81                    match self.websocket.client_accept(&web_socket_key, buf) {
82                        Ok((len, sub_protocol)) => {
83                            // "consume" the HTTP header that we have read from the stream
84                            // read_cursor would be 0 if we exactly read the HTTP header from the stream and nothing else
85
86                            // copy the remaining bytes to the end of the rx_buf (which is also the end of the buffer) because they are the contents of the next websocket frame(s)
87                            let from = len;
88                            let to = buf.len();
89                            let remaining_len = to - from;
90
91                            if remaining_len > 0 {
92                                let rx_start = rx_buf.len() - remaining_len;
93                                rx_buf[rx_start..].copy_from_slice(&buf[from..to]);
94                                self.rx_remainder_len = remaining_len;
95                            }
96
97                            return Ok(sub_protocol);
98                        }
99                        Err(crate::Error::HttpHeaderIncomplete) => {
100                            // TODO: continue reading HTTP header in loop
101                            panic!("oh no");
102                        }
103                        Err(e) => {
104                            return Err(FramerError::WebSocket(e));
105                        }
106                    }
107                }
108                None => return Err(FramerError::Disconnected),
109            }
110        }
111    }
112}
113
114impl<TRng, TWebSocketType> Framer<TRng, TWebSocketType>
115where
116    TRng: RngCore,
117    TWebSocketType: WebSocketType,
118{
119    pub fn new(websocket: WebSocket<TRng, TWebSocketType>) -> Self {
120        Self {
121            websocket,
122            frame_cursor: 0,
123            rx_remainder_len: 0,
124        }
125    }
126
127    pub fn encode<E>(
128        &mut self,
129        message_type: WebSocketSendMessageType,
130        end_of_message: bool,
131        from: &[u8],
132        to: &mut [u8],
133    ) -> Result<usize, FramerError<E>> {
134        let len = self
135            .websocket
136            .write(message_type, end_of_message, from, to)
137            .map_err(FramerError::WebSocket)?;
138
139        Ok(len)
140    }
141
142    pub async fn write<'b, E>(
143        &mut self,
144        tx: &mut (impl Sink<&'b [u8], Error = E> + Unpin),
145        tx_buf: &'b mut [u8],
146        message_type: WebSocketSendMessageType,
147        end_of_message: bool,
148        frame_buf: &[u8],
149    ) -> Result<(), FramerError<E>>
150    where
151        E: Debug,
152    {
153        let len = self
154            .websocket
155            .write(message_type, end_of_message, frame_buf, tx_buf)
156            .map_err(FramerError::WebSocket)?;
157
158        tx.send(&tx_buf[..len]).await.map_err(FramerError::Io)?;
159        tx.flush().await.map_err(FramerError::Io)?;
160        Ok(())
161    }
162
163    pub async fn close<'b, E>(
164        &mut self,
165        tx: &mut (impl Sink<&'b [u8], Error = E> + Unpin),
166        tx_buf: &'b mut [u8],
167        close_status: WebSocketCloseStatusCode,
168        status_description: Option<&str>,
169    ) -> Result<(), FramerError<E>>
170    where
171        E: Debug,
172    {
173        let len = self
174            .websocket
175            .close(close_status, status_description, tx_buf)
176            .map_err(FramerError::WebSocket)?;
177
178        tx.send(&tx_buf[..len]).await.map_err(FramerError::Io)?;
179        tx.flush().await.map_err(FramerError::Io)?;
180        Ok(())
181    }
182
183    // NOTE: any unused bytes read from the stream but not decoded are stored at the end
184    // of the buffer to be used next time this read function is called. This also applies to
185    // any unused bytes read when the connect handshake was made. Therefore it is important that
186    // the caller does not clear this buffer between calls or use it for anthing other than reads.
187    pub async fn read<'a, B: Deref<Target = [u8]>, E>(
188        &mut self,
189        stream: &mut (impl Stream<Item = Result<B, E>> + Sink<&'a [u8], Error = E> + Unpin),
190        buffer: &'a mut [u8],
191    ) -> Option<Result<ReadResult<'a>, FramerError<E>>>
192    where
193        E: Debug,
194    {
195        if self.rx_remainder_len == 0 {
196            match stream.next().await {
197                Some(Ok(input)) => {
198                    if buffer.len() < input.len() {
199                        return Some(Err(FramerError::RxBufferTooSmall(input.len())));
200                    }
201
202                    let rx_start = buffer.len() - input.len();
203
204                    // copy to end of buffer
205                    buffer[rx_start..].copy_from_slice(&input);
206                    self.rx_remainder_len = input.len()
207                }
208                Some(Err(e)) => {
209                    return Some(Err(FramerError::Io(e)));
210                }
211                None => return None,
212            }
213        }
214
215        let rx_start = buffer.len() - self.rx_remainder_len;
216        let (frame_buf, rx_buf) = buffer.split_at_mut(rx_start);
217
218        let ws_result = match self.websocket.read(rx_buf, frame_buf) {
219            Ok(ws_result) => ws_result,
220            Err(e) => return Some(Err(FramerError::WebSocket(e))),
221        };
222
223        self.rx_remainder_len -= ws_result.len_from;
224
225        match ws_result.message_type {
226            WebSocketReceiveMessageType::Binary => {
227                self.frame_cursor += ws_result.len_to;
228                if ws_result.end_of_message {
229                    let range = 0..self.frame_cursor;
230                    self.frame_cursor = 0;
231                    return Some(Ok(ReadResult::Binary(&frame_buf[range])));
232                }
233            }
234            WebSocketReceiveMessageType::Text => {
235                self.frame_cursor += ws_result.len_to;
236                if ws_result.end_of_message {
237                    let range = 0..self.frame_cursor;
238                    self.frame_cursor = 0;
239                    match core::str::from_utf8(&frame_buf[range]) {
240                        Ok(text) => return Some(Ok(ReadResult::Text(text))),
241                        Err(e) => return Some(Err(FramerError::Utf8(e))),
242                    }
243                }
244            }
245            WebSocketReceiveMessageType::CloseMustReply => {
246                let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
247
248                // create a tx_buf from the end of the frame_buf
249                let tx_buf_len = ws_result.len_to + 14; // for extra websocket header
250                let split_at = frame_buf.len() - tx_buf_len;
251                let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
252
253                match self.websocket.write(
254                    WebSocketSendMessageType::CloseReply,
255                    true,
256                    &frame_buf[range.start..range.end],
257                    tx_buf,
258                ) {
259                    Ok(len) => match stream.send(&tx_buf[..len]).await {
260                        Ok(()) => {
261                            self.frame_cursor = 0;
262                            let status_code = ws_result
263                                .close_status
264                                .expect("close message must have code");
265                            let reason = &frame_buf[range];
266                            return Some(Ok(ReadResult::Close(CloseMessage {
267                                status_code,
268                                reason,
269                            })));
270                        }
271                        Err(e) => return Some(Err(FramerError::Io(e))),
272                    },
273                    Err(e) => return Some(Err(FramerError::WebSocket(e))),
274                }
275            }
276            WebSocketReceiveMessageType::CloseCompleted => return None,
277            WebSocketReceiveMessageType::Pong => {
278                let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
279                return Some(Ok(ReadResult::Pong(&frame_buf[range])));
280            }
281            WebSocketReceiveMessageType::Ping => {
282                let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
283
284                // create a tx_buf from the end of the frame_buf
285                let tx_buf_len = ws_result.len_to + 14; // for extra websocket header
286                let split_at = frame_buf.len() - tx_buf_len;
287                let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
288
289                match self.websocket.write(
290                    WebSocketSendMessageType::Pong,
291                    true,
292                    &frame_buf[range.start..range.end],
293                    tx_buf,
294                ) {
295                    Ok(len) => match stream.send(&tx_buf[..len]).await {
296                        Ok(()) => {
297                            return Some(Ok(ReadResult::Ping(&frame_buf[range])));
298                        }
299                        Err(e) => return Some(Err(FramerError::Io(e))),
300                    },
301                    Err(e) => return Some(Err(FramerError::WebSocket(e))),
302                }
303            }
304        }
305
306        None
307    }
308}