embedded_websocket_embedded_io/
framer_async.rs

1#[cfg(not(feature = "embedded-io-async"))]
2use core::ops::Deref;
3use core::{fmt::Debug, str::Utf8Error};
4#[cfg(feature = "embedded-io-async")]
5use embedded_io_async::{ErrorType, Read, Write};
6#[cfg(not(feature = "embedded-io-async"))]
7use futures::{Sink, SinkExt, Stream, StreamExt};
8use rand_core::RngCore;
9
10use crate::{
11    WebSocket, WebSocketCloseStatusCode, WebSocketOptions, WebSocketReceiveMessageType,
12    WebSocketSendMessageType, WebSocketSubProtocol, WebSocketType,
13};
14
15pub struct CloseMessage<'a> {
16    pub status_code: WebSocketCloseStatusCode,
17    pub reason: &'a [u8],
18}
19
20pub enum ReadResult<'a> {
21    Binary(&'a [u8]),
22    Text(&'a str),
23    /// We received a pong message in response to a ping message we sent earlier.
24    /// This should contain the same data as was sent in the ping message
25    Pong(&'a [u8]),
26    /// We received a ping message from the other end
27    /// A pong message with the same content will automatically be sent back as per the websocket spec
28    /// However, this message is exposed to you so that you can see what is in the ping message
29    Ping(&'a [u8]),
30    /// The other end initiated a close handshake. The contents of this message is usually the reason for the close
31    /// A close response message will automatically be sent back to the other end to complete the close handshake
32    /// However, this message is exposed to you so that you can see why
33    Close(CloseMessage<'a>),
34}
35
36#[derive(Debug)]
37pub enum FramerError<E> {
38    Io(E),
39    FrameTooLarge(usize),
40    Utf8(Utf8Error),
41    HttpHeader(httparse::Error),
42    WebSocket(crate::Error),
43    Disconnected,
44    RxBufferTooSmall(usize),
45}
46
47pub struct Framer<TRng, TWebSocketType>
48where
49    TRng: RngCore,
50    TWebSocketType: WebSocketType,
51{
52    websocket: WebSocket<TRng, TWebSocketType>,
53    frame_cursor: usize,
54    rx_remainder_len: usize,
55}
56
57#[cfg(not(feature = "embedded-io-async"))]
58impl<TRng> Framer<TRng, crate::Client>
59where
60    TRng: RngCore,
61{
62    pub async fn connect<'a, B, E>(
63        &mut self,
64        stream: &mut (impl Stream<Item = Result<B, E>> + Sink<&'a [u8], Error = E> + Unpin),
65        buffer: &'a mut [u8],
66        websocket_options: &WebSocketOptions<'_>,
67    ) -> Result<Option<WebSocketSubProtocol>, FramerError<E>>
68    where
69        B: AsRef<[u8]>,
70    {
71        let (tx_len, web_socket_key) = self
72            .websocket
73            .client_connect(websocket_options, buffer)
74            .map_err(FramerError::WebSocket)?;
75
76        let (tx_buf, rx_buf) = buffer.split_at_mut(tx_len);
77        stream.send(tx_buf).await.map_err(FramerError::Io)?;
78        stream.flush().await.map_err(FramerError::Io)?;
79
80        loop {
81            match stream.next().await {
82                Some(buf) => {
83                    let buf = buf.map_err(FramerError::Io)?;
84                    let buf = buf.as_ref();
85
86                    match self.websocket.client_accept(&web_socket_key, buf) {
87                        Ok((len, sub_protocol)) => {
88                            // "consume" the HTTP header that we have read from the stream
89                            // read_cursor would be 0 if we exactly read the HTTP header from the stream and nothing else
90
91                            // copy the remaining bytes to the end of the rx_buf (which is also the end of the buffer) because they are the contents of the next websocket frame(s)
92                            let from = len;
93                            let to = buf.len();
94                            let remaining_len = to - from;
95
96                            if remaining_len > 0 {
97                                let rx_start = rx_buf.len() - remaining_len;
98                                rx_buf[rx_start..].copy_from_slice(&buf[from..to]);
99                                self.rx_remainder_len = remaining_len;
100                            }
101
102                            return Ok(sub_protocol);
103                        }
104                        Err(crate::Error::HttpHeaderIncomplete) => {
105                            // TODO: continue reading HTTP header in loop
106                            panic!("oh no");
107                        }
108                        Err(e) => {
109                            return Err(FramerError::WebSocket(e));
110                        }
111                    }
112                }
113                None => return Err(FramerError::Disconnected),
114            }
115        }
116    }
117}
118
119#[cfg(not(feature = "embedded-io-async"))]
120impl<TRng, TWebSocketType> Framer<TRng, TWebSocketType>
121where
122    TRng: RngCore,
123    TWebSocketType: WebSocketType,
124{
125    pub fn new(websocket: WebSocket<TRng, TWebSocketType>) -> Self {
126        Self {
127            websocket,
128            frame_cursor: 0,
129            rx_remainder_len: 0,
130        }
131    }
132
133    pub fn encode<E>(
134        &mut self,
135        message_type: WebSocketSendMessageType,
136        end_of_message: bool,
137        from: &[u8],
138        to: &mut [u8],
139    ) -> Result<usize, FramerError<E>> {
140        let len = self
141            .websocket
142            .write(message_type, end_of_message, from, to)
143            .map_err(FramerError::WebSocket)?;
144
145        Ok(len)
146    }
147
148    pub async fn write<'b, E>(
149        &mut self,
150        tx: &mut (impl Sink<&'b [u8], Error = E> + Unpin),
151        tx_buf: &'b mut [u8],
152        message_type: WebSocketSendMessageType,
153        end_of_message: bool,
154        frame_buf: &[u8],
155    ) -> Result<(), FramerError<E>>
156    where
157        E: Debug,
158    {
159        let len = self
160            .websocket
161            .write(message_type, end_of_message, frame_buf, tx_buf)
162            .map_err(FramerError::WebSocket)?;
163
164        tx.send(&tx_buf[..len]).await.map_err(FramerError::Io)?;
165        tx.flush().await.map_err(FramerError::Io)?;
166        Ok(())
167    }
168
169    pub async fn close<'b, E>(
170        &mut self,
171        tx: &mut (impl Sink<&'b [u8], Error = E> + Unpin),
172        tx_buf: &'b mut [u8],
173        close_status: WebSocketCloseStatusCode,
174        status_description: Option<&str>,
175    ) -> Result<(), FramerError<E>>
176    where
177        E: Debug,
178    {
179        let len = self
180            .websocket
181            .close(close_status, status_description, tx_buf)
182            .map_err(FramerError::WebSocket)?;
183
184        tx.send(&tx_buf[..len]).await.map_err(FramerError::Io)?;
185        tx.flush().await.map_err(FramerError::Io)?;
186        Ok(())
187    }
188
189    // NOTE: any unused bytes read from the stream but not decoded are stored at the end
190    // of the buffer to be used next time this read function is called. This also applies to
191    // any unused bytes read when the connect handshake was made. Therefore it is important that
192    // the caller does not clear this buffer between calls or use it for anthing other than reads.
193    pub async fn read<'a, B: Deref<Target = [u8]>, E>(
194        &mut self,
195        stream: &mut (impl Stream<Item = Result<B, E>> + Sink<&'a [u8], Error = E> + Unpin),
196        buffer: &'a mut [u8],
197    ) -> Option<Result<ReadResult<'a>, FramerError<E>>>
198    where
199        E: Debug,
200    {
201        if self.rx_remainder_len == 0 {
202            match stream.next().await {
203                Some(Ok(input)) => {
204                    if buffer.len() < input.len() {
205                        return Some(Err(FramerError::RxBufferTooSmall(input.len())));
206                    }
207
208                    let rx_start = buffer.len() - input.len();
209
210                    // copy to end of buffer
211                    buffer[rx_start..].copy_from_slice(&input);
212                    self.rx_remainder_len = input.len()
213                }
214                Some(Err(e)) => {
215                    return Some(Err(FramerError::Io(e)));
216                }
217                None => return None,
218            }
219        }
220
221        let rx_start = buffer.len() - self.rx_remainder_len;
222        let (frame_buf, rx_buf) = buffer.split_at_mut(rx_start);
223
224        let ws_result = match self.websocket.read(rx_buf, frame_buf) {
225            Ok(ws_result) => ws_result,
226            Err(e) => return Some(Err(FramerError::WebSocket(e))),
227        };
228
229        self.rx_remainder_len -= ws_result.len_from;
230
231        match ws_result.message_type {
232            WebSocketReceiveMessageType::Binary => {
233                self.frame_cursor += ws_result.len_to;
234                if ws_result.end_of_message {
235                    let range = 0..self.frame_cursor;
236                    self.frame_cursor = 0;
237                    return Some(Ok(ReadResult::Binary(&frame_buf[range])));
238                }
239            }
240            WebSocketReceiveMessageType::Text => {
241                self.frame_cursor += ws_result.len_to;
242                if ws_result.end_of_message {
243                    let range = 0..self.frame_cursor;
244                    self.frame_cursor = 0;
245                    match core::str::from_utf8(&frame_buf[range]) {
246                        Ok(text) => return Some(Ok(ReadResult::Text(text))),
247                        Err(e) => return Some(Err(FramerError::Utf8(e))),
248                    }
249                }
250            }
251            WebSocketReceiveMessageType::CloseMustReply => {
252                let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
253
254                // create a tx_buf from the end of the frame_buf
255                let tx_buf_len = ws_result.len_to + 14; // for extra websocket header
256                let split_at = frame_buf.len() - tx_buf_len;
257                let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
258
259                match self.websocket.write(
260                    WebSocketSendMessageType::CloseReply,
261                    true,
262                    &frame_buf[range.start..range.end],
263                    tx_buf,
264                ) {
265                    Ok(len) => match stream.send(&tx_buf[..len]).await {
266                        Ok(()) => {
267                            self.frame_cursor = 0;
268                            let status_code = ws_result
269                                .close_status
270                                .expect("close message must have code");
271                            let reason = &frame_buf[range];
272                            return Some(Ok(ReadResult::Close(CloseMessage {
273                                status_code,
274                                reason,
275                            })));
276                        }
277                        Err(e) => return Some(Err(FramerError::Io(e))),
278                    },
279                    Err(e) => return Some(Err(FramerError::WebSocket(e))),
280                }
281            }
282            WebSocketReceiveMessageType::CloseCompleted => return None,
283            WebSocketReceiveMessageType::Pong => {
284                let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
285                return Some(Ok(ReadResult::Pong(&frame_buf[range])));
286            }
287            WebSocketReceiveMessageType::Ping => {
288                let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
289
290                // create a tx_buf from the end of the frame_buf
291                let tx_buf_len = ws_result.len_to + 14; // for extra websocket header
292                let split_at = frame_buf.len() - tx_buf_len;
293                let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
294
295                match self.websocket.write(
296                    WebSocketSendMessageType::Pong,
297                    true,
298                    &frame_buf[range.start..range.end],
299                    tx_buf,
300                ) {
301                    Ok(len) => match stream.send(&tx_buf[..len]).await {
302                        Ok(()) => {
303                            return Some(Ok(ReadResult::Ping(&frame_buf[range])));
304                        }
305                        Err(e) => return Some(Err(FramerError::Io(e))),
306                    },
307                    Err(e) => return Some(Err(FramerError::WebSocket(e))),
308                }
309            }
310        }
311
312        None
313    }
314}
315
316#[cfg(feature = "embedded-io-async")]
317impl<TRng> Framer<TRng, crate::Client>
318where
319    TRng: RngCore,
320{
321    pub async fn connect<'a, S>(
322        &mut self,
323        stream: &mut S,
324        buffer: &'a mut [u8],
325        websocket_options: &WebSocketOptions<'_>,
326    ) -> Result<Option<WebSocketSubProtocol>, FramerError<<S as ErrorType>::Error>>
327    where
328        S: Read + Write + Unpin,
329    {
330        let (tx_len, web_socket_key) = self
331            .websocket
332            .client_connect(websocket_options, buffer)
333            .map_err(FramerError::WebSocket)?;
334
335        let (tx_buf, _rx_buf) = buffer.split_at_mut(tx_len);
336        stream.write(tx_buf).await.map_err(FramerError::Io)?;
337        stream.flush().await.map_err(FramerError::Io)?;
338
339        loop {
340            let read_len = stream.read(buffer).await.map_err(FramerError::Io)?;
341
342            match self.websocket.client_accept(&web_socket_key, buffer) {
343                Ok((len, sub_protocol)) => {
344                    // "consume" the HTTP header that we have read from the stream
345                    // read_cursor would be 0 if we exactly read the HTTP header from the stream and nothing else
346
347                    // copy the remaining bytes to the end of the rx_buf (which is also the end of the buffer) because they are the contents of the next websocket frame(s)
348                    let from = len;
349                    let to = read_len;
350                    let remaining_len = to - from;
351
352                    if remaining_len > 0 {
353                        // let rx_start = read_len - remaining_len;
354                        // rx_buf[rx_start..].copy_from_slice(&buf[from..to]);
355                        self.rx_remainder_len = remaining_len;
356                    }
357
358                    return Ok(sub_protocol);
359                }
360                Err(crate::Error::HttpHeaderIncomplete) => {
361                    // TODO: continue reading HTTP header in loop
362                    panic!("oh no");
363                }
364                Err(e) => {
365                    return Err(FramerError::WebSocket(e));
366                }
367            }
368        }
369    }
370}
371
372#[cfg(feature = "embedded-io-async")]
373impl<TRng, TWebSocketType> Framer<TRng, TWebSocketType>
374where
375    TRng: RngCore,
376    TWebSocketType: WebSocketType,
377{
378    pub fn new(websocket: WebSocket<TRng, TWebSocketType>) -> Self {
379        Self {
380            websocket,
381            frame_cursor: 0,
382            rx_remainder_len: 0,
383        }
384    }
385
386    pub fn encode<E>(
387        &mut self,
388        message_type: WebSocketSendMessageType,
389        end_of_message: bool,
390        from: &[u8],
391        to: &mut [u8],
392    ) -> Result<usize, FramerError<E>> {
393        let len = self
394            .websocket
395            .write(message_type, end_of_message, from, to)
396            .map_err(FramerError::WebSocket)?;
397
398        Ok(len)
399    }
400
401    pub async fn write<'b, T>(
402        &mut self,
403        tx: &mut T,
404        tx_buf: &'b mut [u8],
405        message_type: WebSocketSendMessageType,
406        end_of_message: bool,
407        frame_buf: &[u8],
408    ) -> Result<(), FramerError<<T as ErrorType>::Error>>
409    where
410        T: Write + Unpin,
411    {
412        let len = self
413            .websocket
414            .write(message_type, end_of_message, frame_buf, tx_buf)
415            .map_err(FramerError::WebSocket)?;
416
417        tx.write(&tx_buf[..len])
418            .await
419            .map_err(FramerError::Io)
420            .unwrap();
421        tx.flush().await.map_err(FramerError::Io).unwrap();
422        Ok(())
423    }
424
425    pub async fn close<'b, T>(
426        &mut self,
427        tx: &mut T,
428        tx_buf: &'b mut [u8],
429        close_status: WebSocketCloseStatusCode,
430        status_description: Option<&str>,
431    ) -> Result<(), FramerError<<T as ErrorType>::Error>>
432    where
433        T: Write + Unpin,
434    {
435        let len = self
436            .websocket
437            .close(close_status, status_description, tx_buf)
438            .map_err(FramerError::WebSocket)?;
439
440        tx.write(&tx_buf[..len])
441            .await
442            .map_err(FramerError::Io)
443            .unwrap();
444        tx.flush().await.map_err(FramerError::Io).unwrap();
445        Ok(())
446    }
447
448    // NOTE: any unused bytes read from the stream but not decoded are stored at the end
449    // of the buffer to be used next time this read function is called. This also applies to
450    // any unused bytes read when the connect handshake was made. Therefore it is important that
451    // the caller does not clear this buffer between calls or use it for anthing other than reads.
452    pub async fn read<'a, S>(
453        &mut self,
454        stream: &mut S,
455        buffer: &'a mut [u8],
456    ) -> Option<Result<ReadResult<'a>, FramerError<<S as ErrorType>::Error>>>
457    where
458        S: Read + Write + Unpin,
459    {
460        if self.rx_remainder_len == 0 {
461            match stream.read(buffer).await {
462                Ok(read_len) => {
463                    if buffer.len() < read_len {
464                        return Some(Err(FramerError::RxBufferTooSmall(read_len)));
465                    }
466
467                    self.rx_remainder_len = read_len
468                }
469                Err(error) => {
470                    return Some(Err(FramerError::Io(error)));
471                }
472            }
473        }
474
475        let (rx_buf, frame_buf) = buffer.split_at_mut(self.rx_remainder_len);
476        let ws_result = match self.websocket.read(rx_buf, frame_buf) {
477            Ok(ws_result) => ws_result,
478            Err(e) => return Some(Err(FramerError::WebSocket(e))),
479        };
480
481        self.rx_remainder_len -= ws_result.len_from;
482
483        match ws_result.message_type {
484            WebSocketReceiveMessageType::Binary => {
485                self.frame_cursor += ws_result.len_to;
486                if ws_result.end_of_message {
487                    let range = 0..self.frame_cursor;
488                    self.frame_cursor = 0;
489                    return Some(Ok(ReadResult::Binary(&frame_buf[range])));
490                }
491            }
492            WebSocketReceiveMessageType::Text => {
493                self.frame_cursor += ws_result.len_to;
494                if ws_result.end_of_message {
495                    let range = 0..self.frame_cursor;
496                    self.frame_cursor = 0;
497                    match core::str::from_utf8(&frame_buf[range]) {
498                        Ok(text) => return Some(Ok(ReadResult::Text(text))),
499                        Err(e) => return Some(Err(FramerError::Utf8(e))),
500                    }
501                }
502            }
503            WebSocketReceiveMessageType::CloseMustReply => {
504                let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
505
506                // create a tx_buf from the end of the frame_buf
507                let tx_buf_len = ws_result.len_to + 14; // for extra websocket header
508                let split_at = frame_buf.len() - tx_buf_len;
509                let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
510
511                match self.websocket.write(
512                    WebSocketSendMessageType::CloseReply,
513                    true,
514                    &frame_buf[range.start..range.end],
515                    tx_buf,
516                ) {
517                    Ok(len) => match stream.write(&tx_buf[..len]).await {
518                        Ok(_write_len) => {
519                            self.frame_cursor = 0;
520                            let status_code = ws_result
521                                .close_status
522                                .expect("close message must have code");
523                            let reason = &frame_buf[range];
524                            return Some(Ok(ReadResult::Close(CloseMessage {
525                                status_code,
526                                reason,
527                            })));
528                        }
529                        Err(e) => return Some(Err(FramerError::Io(e))),
530                    },
531                    Err(e) => return Some(Err(FramerError::WebSocket(e))),
532                }
533            }
534            WebSocketReceiveMessageType::CloseCompleted => return None,
535            WebSocketReceiveMessageType::Pong => {
536                let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
537                return Some(Ok(ReadResult::Pong(&frame_buf[range])));
538            }
539            WebSocketReceiveMessageType::Ping => {
540                let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
541
542                // create a tx_buf from the end of the frame_buf
543                let tx_buf_len = ws_result.len_to + 14; // for extra websocket header
544                let split_at = frame_buf.len() - tx_buf_len;
545                let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
546
547                match self.websocket.write(
548                    WebSocketSendMessageType::Pong,
549                    true,
550                    &frame_buf[range.start..range.end],
551                    tx_buf,
552                ) {
553                    Ok(len) => match stream.write(&tx_buf[..len]).await {
554                        Ok(_write_len) => {
555                            return Some(Ok(ReadResult::Ping(&frame_buf[range])));
556                        }
557                        Err(e) => return Some(Err(FramerError::Io(e))),
558                    },
559                    Err(e) => return Some(Err(FramerError::WebSocket(e))),
560                }
561            }
562        }
563
564        None
565    }
566}