stream_ws/
lib.rs

1//! [![docs.rs](https://img.shields.io/docsrs/stream-ws)](https://docs.rs/crate/stream-ws/latest)
2//! [![repo](https://img.shields.io/badge/repo-stream--ws-blue?logo=github)](https://github.com/idkidknow/stream-ws)
3//! [![crates-io](https://img.shields.io/badge/crates--io-stream--ws-blue)](https://crates.io/crates/stream-ws)
4//! 
5//! A layer over WebSocket enables carrying byte stream, for both native and WebAssembly.
6//!
7//! Providing methods able to wrap any WebSocket message stream implementation,
8//! and impl [`AsyncRead`], [`AsyncBufRead`] and [`AsyncWrite`].
9//!
10//! # Usage
11//!
12//! run `cargo add stream-ws` to bring it into your crate.
13//!
14//! Examples in `examples/`.
15//!
16//! ## Tungstenite
17//!
18//! With feature `tungstenite`.
19//!
20//! For `WebSocketStream` in either crate `tokio-tungstenite` or `async-tungstenite`,
21//! use
22//! 
23//! ```rust
24//! let stream = stream_ws::tungstenite::WsByteStream::new(inner)
25//! ```
26//!
27//! ## Gloo (for WebAssembly)
28//!
29//! With feature `gloo`.
30//!
31//! use
32//! 
33//! ```rust
34//! let stream = stream_ws::gloo::WsByteStream::new(inner)
35//! ```
36//!
37//! ## Wrapping underlying stream of other WebSocket implementation
38//!
39//! Your WebSocket implementation should have a struct `S` satisfying trait bound
40//! 
41//! ```rust
42//! Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin
43//! ```
44//! 
45//! where `Msg` and `E`
46//! are message and error type of the implementation.
47//!
48//! Create a struct `Handler` and impl [`WsMessageHandle`], which is easy, and then
49//! call `Handler::wrap_stream(underlying_stream)` to get a [`WsByteStream`].
50//!
51//! # Crate features
52//!
53//! - `tokio`: impl `tokio`'s `AsyncRead`, `AsyncBufRead` and `AsyncWrite` variants
54//! - `tungstenite`: handlers for message and error type from crate `tungstenite`
55//! - `gloo`: handlers for message and error type from crate `gloo`
56
57#![cfg_attr(docsrs, feature(doc_cfg))]
58
59#[cfg(feature = "gloo")]
60#[cfg_attr(docsrs, doc(cfg(feature = "gloo")))]
61pub mod gloo;
62#[cfg(feature = "tungstenite")]
63#[cfg_attr(docsrs, doc(cfg(feature = "tungstenite")))]
64pub mod tungstenite;
65
66use futures::{ready, AsyncBufRead, AsyncRead, AsyncWrite, Sink, Stream};
67use pin_project::pin_project;
68use std::{io, marker::PhantomData, pin::Pin, task::Poll};
69#[cfg(feature = "tokio")]
70#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
71use tokio::io::{
72    AsyncBufRead as TokioAsyncBufRead, AsyncRead as TokioAsyncRead, AsyncWrite as TokioAsyncWrite,
73};
74
75#[derive(Debug)]
76pub enum WsMessageKind {
77    /// Messages considered as payload carriers
78    Bytes(Vec<u8>),
79    Close,
80    /// Messages that will be ignored
81    Other,
82}
83
84#[derive(Debug)]
85pub enum WsErrorKind {
86    Io(io::Error),
87    /// Normally closed. Won't be considered as an error.
88    Closed,
89    AlreadyClosed,
90    Other(Box<dyn std::error::Error + Send + Sync>),
91}
92
93/// Classify messages and errors of WebSocket.
94pub trait WsMessageHandle<Msg, E> {
95    fn message_into_kind(msg: Msg) -> WsMessageKind;
96    fn error_into_kind(e: E) -> WsErrorKind;
97    /// These bytes will be carried by the `Msg` and sent by the underlying stream.
98    fn message_from_bytes<T: Into<Vec<u8>>>(bytes: T) -> Msg;
99
100    fn wrap_stream<S>(inner: S) -> WsByteStream<S, Msg, E, Self>
101    where
102        S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
103    {
104        WsByteStream::new(inner)
105    }
106}
107
108/// A wrapper implements [`AsyncRead`], [`AsyncBufRead`] and [`AsyncWrite`],
109/// around a established websocket connect which
110/// implemented `Stream` and `Sink` traits.
111///
112/// Bytes are transported on binary messages over WebSocket.
113/// Other messages except close messages will be ignored.
114///
115/// # Assumption:
116///
117/// The underlying stream should automatically handle ping and pong messages.
118#[pin_project]
119pub struct WsByteStream<S, Msg, E, H>
120where
121    S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
122    H: WsMessageHandle<Msg, E> + ?Sized,
123{
124    #[pin]
125    inner: S,
126    state: State,
127    _marker: PhantomData<H>,
128}
129
130#[derive(Debug)]
131struct State {
132    read: ReadState,
133    write: WriteState,
134}
135
136#[derive(Debug)]
137enum ReadState {
138    /// Buf is empty and ready to read from WebSocket.
139    Pending,
140    /// Bytes from a message not all being read, are stored in `buf`. The amount of bytes read is `amt_read`.
141    Ready { buf: Vec<u8>, amt_read: usize },
142    /// Connection was shut down correctly.
143    Terminated,
144}
145
146#[derive(Debug)]
147enum WriteState {
148    Ready,
149    Closed,
150}
151
152impl<S, Msg, E, H> WsByteStream<S, Msg, E, H>
153where
154    S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
155    H: WsMessageHandle<Msg, E> + ?Sized,
156{
157    pub fn new(inner: S) -> Self {
158        Self {
159            inner,
160            state: State {
161                read: ReadState::Pending,
162                write: WriteState::Ready,
163            },
164            _marker: PhantomData,
165        }
166    }
167
168    /// This function tries reading next message to fill the buf (set the state to `ReadState::Ready`).
169    ///
170    /// Returns `None` if the stream has been fused. (The state is set to `ReadState::Terminated`)
171    fn fill_buf_with_next_msg(
172        self: Pin<&mut Self>,
173        cx: &mut std::task::Context<'_>,
174    ) -> Poll<Option<io::Result<()>>> {
175        let mut this = self.project();
176        loop {
177            // Try read from inner stream.
178            let res = ready!(this.inner.as_mut().poll_next(cx));
179            let Some(res) = res else {
180                // `res` is None. Stream has been fused.
181                this.state.read = ReadState::Terminated;
182                return Poll::Ready(None);
183            };
184            match res {
185                Ok(msg) => {
186                    let msg = H::message_into_kind(msg);
187                    match msg {
188                        WsMessageKind::Bytes(msg) => {
189                            this.state.read = ReadState::Ready {
190                                buf: msg,
191                                amt_read: 0,
192                            };
193                            return Poll::Ready(Some(Ok(())));
194                        }
195                        WsMessageKind::Close => {
196                            this.state.read = ReadState::Terminated;
197                            return Poll::Ready(None);
198                        }
199                        WsMessageKind::Other => {
200                            // Simply ignore it.
201                            continue;
202                        }
203                    }
204                }
205                Err(e) => {
206                    let e = H::error_into_kind(e);
207                    match e {
208                        WsErrorKind::Io(e) => {
209                            return Poll::Ready(Some(Err(e)));
210                        }
211                        WsErrorKind::Closed => {
212                            this.state.read = ReadState::Terminated;
213                            return Poll::Ready(None);
214                        }
215                        WsErrorKind::AlreadyClosed => {
216                            this.state.read = ReadState::Terminated;
217                            let e = io::Error::new(io::ErrorKind::NotConnected, "Already closed");
218                            return Poll::Ready(Some(Err(e)));
219                        }
220                        WsErrorKind::Other(e) => {
221                            return Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e))));
222                        }
223                    }
224                }
225            }
226        }
227    }
228}
229
230impl<S, Msg, E, H> AsyncRead for WsByteStream<S, Msg, E, H>
231where
232    S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
233    H: WsMessageHandle<Msg, E> + ?Sized,
234{
235    fn poll_read(
236        mut self: Pin<&mut Self>,
237        cx: &mut std::task::Context<'_>,
238        dst: &mut [u8],
239    ) -> Poll<io::Result<usize>> {
240        loop {
241            let this = self.as_mut().project();
242            match this.state.read {
243                ReadState::Pending => {
244                    let res = ready!(self.as_mut().fill_buf_with_next_msg(cx));
245                    match res {
246                        Some(Ok(())) => continue, // The state is assumed to be `Ready`
247                        Some(Err(e)) => return Poll::Ready(Err(e)),
248                        None => continue, // The state is assumed to be `Terminated`
249                    }
250                }
251                ReadState::Ready {
252                    ref buf,
253                    ref mut amt_read,
254                } => {
255                    let buf = &buf[*amt_read..];
256                    let len = std::cmp::min(dst.len(), buf.len());
257                    dst[..len].copy_from_slice(&buf[..len]);
258                    if len == buf.len() {
259                        this.state.read = ReadState::Pending;
260                    } else {
261                        *amt_read += len;
262                    }
263                    return Poll::Ready(Ok(len));
264                }
265                ReadState::Terminated => {
266                    return Poll::Ready(Ok(0));
267                }
268            }
269        }
270    }
271}
272
273impl<S, Msg, E, H> AsyncBufRead for WsByteStream<S, Msg, E, H>
274where
275    S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
276    H: WsMessageHandle<Msg, E> + ?Sized,
277{
278    fn poll_fill_buf(
279        mut self: Pin<&mut Self>,
280        cx: &mut std::task::Context<'_>,
281    ) -> Poll<io::Result<&[u8]>> {
282        loop {
283            let this = self.as_mut().project();
284            match this.state.read {
285                ReadState::Pending => {
286                    let res = ready!(self.as_mut().fill_buf_with_next_msg(cx));
287                    match res {
288                        Some(Ok(())) => continue, // The state is assumed to be `Ready`
289                        Some(Err(e)) => return Poll::Ready(Err(e)),
290                        None => continue, // The state is assumed to be `Terminated`
291                    }
292                }
293                ReadState::Ready { .. } => {
294                    // Borrow and match again to meet the lifetime requirement.
295                    let this = self.project();
296                    let ReadState::Ready { ref buf, amt_read } = this.state.read else {
297                        unreachable!()
298                    };
299                    return Poll::Ready(Ok(&buf[amt_read..]));
300                }
301                ReadState::Terminated => {
302                    return Poll::Ready(Ok(&[]));
303                }
304            }
305        }
306    }
307
308    fn consume(mut self: Pin<&mut Self>, amt: usize) {
309        let ReadState::Ready {
310            ref buf,
311            ref mut amt_read,
312        } = self.state.read
313        else {
314            return;
315        };
316        *amt_read = std::cmp::min(buf.len(), *amt_read + amt);
317        if *amt_read == buf.len() {
318            self.state.read = ReadState::Pending;
319        }
320    }
321}
322
323impl<S, Msg, E, H> AsyncWrite for WsByteStream<S, Msg, E, H>
324where
325    S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
326    H: WsMessageHandle<Msg, E> + ?Sized,
327{
328    fn poll_write(
329        mut self: Pin<&mut Self>,
330        cx: &mut std::task::Context<'_>,
331        buf: &[u8],
332    ) -> Poll<io::Result<usize>> {
333        let mut this = self.as_mut().project();
334        loop {
335            match this.state.write {
336                WriteState::Ready => {
337                    if let Err(e) = ready!(this.inner.as_mut().poll_ready(cx)) {
338                        let e = H::error_into_kind(e);
339                        match e {
340                            WsErrorKind::Io(e) => {
341                                return Poll::Ready(Err(e));
342                            }
343                            WsErrorKind::Other(e) => {
344                                return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e)));
345                            }
346                            WsErrorKind::Closed => {
347                                this.state.write = WriteState::Closed;
348                                return Poll::Ready(Ok(0));
349                            }
350                            WsErrorKind::AlreadyClosed => {
351                                this.state.write = WriteState::Closed;
352                                let e =
353                                    io::Error::new(io::ErrorKind::NotConnected, "Already closed");
354                                return Poll::Ready(Err(e));
355                            }
356                        }
357                    }
358                    // Start sending
359                    let Err(e) = this.inner.as_mut().start_send(H::message_from_bytes(buf)) else {
360                        this.state.write = WriteState::Ready;
361                        return Poll::Ready(Ok(buf.len()));
362                    };
363                    let e = H::error_into_kind(e);
364                    match e {
365                        WsErrorKind::Io(e) => {
366                            return Poll::Ready(Err(e));
367                        }
368                        WsErrorKind::Other(e) => {
369                            return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e)));
370                        }
371                        WsErrorKind::Closed => {
372                            this.state.write = WriteState::Closed;
373                            return Poll::Ready(Ok(0));
374                        }
375                        WsErrorKind::AlreadyClosed => {
376                            this.state.write = WriteState::Closed;
377                            let e = io::Error::new(io::ErrorKind::NotConnected, "Already closed");
378                            return Poll::Ready(Err(e));
379                        }
380                    }
381                }
382                WriteState::Closed => {
383                    let e = io::Error::new(io::ErrorKind::NotConnected, "Already closed");
384                    return Poll::Ready(Err(e));
385                }
386            }
387        }
388    }
389
390    fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
391        let mut this = self.project();
392        if let Err(e) = ready!(this.inner.as_mut().poll_flush(cx)) {
393            let e = H::error_into_kind(e);
394            match e {
395                WsErrorKind::Io(e) => {
396                    return Poll::Ready(Err(e));
397                }
398                WsErrorKind::Other(e) => {
399                    return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e)));
400                }
401                WsErrorKind::Closed => {
402                    this.state.write = WriteState::Closed;
403                    return Poll::Ready(Ok(()));
404                }
405                WsErrorKind::AlreadyClosed => {
406                    this.state.write = WriteState::Closed;
407                    let e = io::Error::new(io::ErrorKind::NotConnected, "Already closed");
408                    return Poll::Ready(Err(e));
409                }
410            }
411        }
412        Poll::Ready(Ok(()))
413    }
414
415    fn poll_close(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
416        let mut this = self.project();
417        this.state.write = WriteState::Closed;
418        if let Err(e) = ready!(this.inner.as_mut().poll_close(cx)) {
419            let e = H::error_into_kind(e);
420            match e {
421                WsErrorKind::Io(e) => {
422                    return Poll::Ready(Err(e));
423                }
424                WsErrorKind::Other(e) => {
425                    return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e)));
426                }
427                WsErrorKind::Closed => {
428                    return Poll::Ready(Ok(()));
429                }
430                WsErrorKind::AlreadyClosed => {
431                    let e = io::Error::new(io::ErrorKind::NotConnected, "Already closed");
432                    return Poll::Ready(Err(e));
433                }
434            }
435        }
436        Poll::Ready(Ok(()))
437    }
438}
439
440#[cfg(feature = "tokio")]
441#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
442impl<S, Msg, E, H> TokioAsyncRead for WsByteStream<S, Msg, E, H>
443where
444    S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
445    H: WsMessageHandle<Msg, E> + ?Sized,
446{
447    fn poll_read(
448        self: Pin<&mut Self>,
449        cx: &mut std::task::Context<'_>,
450        buf: &mut tokio::io::ReadBuf<'_>,
451    ) -> Poll<io::Result<()>> {
452        let slice = buf.initialize_unfilled();
453        let n = ready!(AsyncRead::poll_read(self, cx, slice))?;
454        buf.advance(n);
455        Poll::Ready(Ok(()))
456    }
457}
458
459#[cfg(feature = "tokio")]
460#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
461impl<S, Msg, E, H> TokioAsyncBufRead for WsByteStream<S, Msg, E, H>
462where
463    S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
464    H: WsMessageHandle<Msg, E> + ?Sized,
465{
466    fn poll_fill_buf(
467        self: Pin<&mut Self>,
468        cx: &mut std::task::Context<'_>,
469    ) -> Poll<io::Result<&[u8]>> {
470        AsyncBufRead::poll_fill_buf(self, cx)
471    }
472
473    fn consume(self: Pin<&mut Self>, amt: usize) {
474        AsyncBufRead::consume(self, amt)
475    }
476}
477
478#[cfg(feature = "tokio")]
479#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
480impl<S, Msg, E, H> TokioAsyncWrite for WsByteStream<S, Msg, E, H>
481where
482    S: Stream<Item = Result<Msg, E>> + Sink<Msg, Error = E> + Unpin,
483    H: WsMessageHandle<Msg, E> + ?Sized,
484{
485    fn poll_write(
486        self: Pin<&mut Self>,
487        cx: &mut std::task::Context<'_>,
488        buf: &[u8],
489    ) -> Poll<Result<usize, io::Error>> {
490        AsyncWrite::poll_write(self, cx, buf)
491    }
492
493    fn poll_flush(
494        self: Pin<&mut Self>,
495        cx: &mut std::task::Context<'_>,
496    ) -> Poll<Result<(), io::Error>> {
497        AsyncWrite::poll_flush(self, cx)
498    }
499
500    fn poll_shutdown(
501        self: Pin<&mut Self>,
502        cx: &mut std::task::Context<'_>,
503    ) -> Poll<Result<(), io::Error>> {
504        AsyncWrite::poll_close(self, cx)
505    }
506}