http_ws/
stream.rs

1use core::{
2    fmt,
3    future::Future,
4    pin::Pin,
5    task::{ready, Context, Poll},
6};
7
8use alloc::sync::{Arc, Weak};
9
10use std::{error, io, sync::Mutex};
11
12use bytes::{Bytes, BytesMut};
13use futures_core::stream::Stream;
14use pin_project_lite::pin_project;
15use tokio::sync::mpsc::{channel, Receiver, Sender};
16
17use super::{
18    codec::{Codec, Message},
19    error::ProtocolError,
20    proto::CloseReason,
21};
22
23pin_project! {
24    /// Decode `S` type into Stream of websocket [Message].
25    /// `S` type must impl `Stream` trait and output `Result<T, E>` as `Stream::Item`
26    /// where `T` type impl `AsRef<[u8]>` trait. (`&[u8]` is needed for parsing messages)
27    pub struct RequestStream<S> {
28        #[pin]
29        stream: S,
30        buf: BytesMut,
31        codec: Codec,
32    }
33}
34
35impl<S, T, E> RequestStream<S>
36where
37    S: Stream<Item = Result<T, E>>,
38    T: AsRef<[u8]>,
39{
40    pub fn new(stream: S) -> Self {
41        Self::with_codec(stream, Codec::new())
42    }
43
44    pub fn with_codec(stream: S, codec: Codec) -> Self {
45        Self {
46            stream,
47            buf: BytesMut::new(),
48            codec,
49        }
50    }
51
52    #[inline]
53    pub fn inner_mut(&mut self) -> &mut S {
54        &mut self.stream
55    }
56
57    #[inline]
58    pub fn codec_mut(&mut self) -> &mut Codec {
59        &mut self.codec
60    }
61
62    /// Make a [ResponseStream] from current DecodeStream.
63    ///
64    /// This API is to share the same codec for both decode and encode stream.
65    pub fn response_stream(&self) -> (ResponseStream, ResponseSender) {
66        let codec = self.codec.duplicate();
67        let cap = codec.capacity();
68        let (tx, rx) = channel(cap);
69        (ResponseStream(rx), ResponseSender::new(tx, codec))
70    }
71}
72
73pub enum WsError<E> {
74    Protocol(ProtocolError),
75    Stream(E),
76}
77
78impl<E> fmt::Debug for WsError<E> {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        match *self {
81            Self::Protocol(ref e) => fmt::Debug::fmt(e, f),
82            Self::Stream(..) => f.write_str("Input Stream error"),
83        }
84    }
85}
86
87impl<E> fmt::Display for WsError<E> {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        match *self {
90            Self::Protocol(ref e) => fmt::Debug::fmt(e, f),
91            Self::Stream(..) => f.write_str("Input Stream error"),
92        }
93    }
94}
95
96impl<E> error::Error for WsError<E> {}
97
98impl<E> From<ProtocolError> for WsError<E> {
99    fn from(e: ProtocolError) -> Self {
100        Self::Protocol(e)
101    }
102}
103
104impl<S, T, E> Stream for RequestStream<S>
105where
106    S: Stream<Item = Result<T, E>>,
107    T: AsRef<[u8]>,
108{
109    type Item = Result<Message, WsError<E>>;
110
111    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
112        let mut this = self.project();
113
114        loop {
115            if let Some(msg) = this.codec.decode(this.buf)? {
116                return Poll::Ready(Some(Ok(msg)));
117            }
118            match ready!(this.stream.as_mut().poll_next(cx)) {
119                Some(res) => {
120                    let item = res.map_err(WsError::Stream)?;
121                    this.buf.extend_from_slice(item.as_ref())
122                }
123                None => return Poll::Ready(None),
124            }
125        }
126    }
127}
128
129pub struct ResponseStream(Receiver<Item>);
130
131type Item = io::Result<Bytes>;
132
133impl Stream for ResponseStream {
134    type Item = Item;
135
136    #[inline]
137    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
138        self.get_mut().0.poll_recv(cx)
139    }
140}
141
142/// Encode [Message] into [Bytes] and send it to [ResponseStream].
143#[derive(Debug)]
144pub struct ResponseSender {
145    inner: Arc<_ResponseSender>,
146}
147
148impl ResponseSender {
149    fn new(tx: Sender<Item>, codec: Codec) -> Self {
150        Self {
151            inner: Arc::new(_ResponseSender {
152                encoder: Mutex::new(Encoder {
153                    codec,
154                    buf: BytesMut::with_capacity(codec.max_size()),
155                }),
156                tx,
157            }),
158        }
159    }
160
161    /// downgrade Self to a weak sender.
162    pub fn downgrade(&self) -> ResponseWeakSender {
163        ResponseWeakSender {
164            inner: Arc::downgrade(&self.inner),
165        }
166    }
167
168    /// encode [Message] and add to [ResponseStream].
169    #[inline]
170    pub fn send(&self, msg: Message) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
171        self.inner.send(msg)
172    }
173
174    /// add [io::Error] to [ResponseStream].
175    ///
176    /// the error should be used as a signal to the TCP connection associated with `ResponseStream`
177    /// to close immediately.
178    ///
179    /// # Examples
180    /// ```rust
181    /// use std::{future::poll_fn, pin::Pin, time::Duration};
182    ///
183    /// use futures_core::Stream;
184    /// use http_ws::{CloseCode, Message, RequestStream, ResponseSender, ResponseStream};
185    /// use tokio::{io::AsyncWriteExt, time::timeout, net::TcpStream};
186    ///
187    /// // thread1:
188    /// // read and write websocket message.
189    /// async fn sender<S, T, E>(tx: ResponseSender, mut rx: Pin<&mut RequestStream<S>>)
190    /// where
191    ///     S: Stream<Item = Result<T, E>>,
192    ///     T: AsRef<[u8]>,
193    /// {
194    ///     // send close message to client
195    ///     tx.send(Message::Close(Some(CloseCode::Away.into()))).await.unwrap();
196    ///
197    ///     // the client failed to respond to close message in 5 seconds time window.
198    ///     if let Err(_) = timeout(Duration::from_secs(5), poll_fn(|cx| rx.as_mut().poll_next(cx))).await {
199    ///         // send io error to thread2
200    ///         tx.send_error(std::io::ErrorKind::UnexpectedEof.into()).await.unwrap();
201    ///     }
202    /// }
203    ///
204    /// // thread2:
205    /// // receive websocket message from thread1 and transfer it on tcp connection.
206    /// async fn io_write(conn: &mut TcpStream, mut rx: Pin<&mut ResponseStream>) {
207    ///     // the first message is the "go away" close message in Ok branch.
208    ///     let msg = poll_fn(|cx| rx.as_mut().poll_next(cx)).await.unwrap().unwrap();
209    ///
210    ///     // send msg to client
211    ///     conn.write_all(&msg).await.unwrap();
212    ///
213    ///     // the second message is the io::Error in Err branch.
214    ///     let err = poll_fn(|cx| rx.as_mut().poll_next(cx)).await.unwrap().unwrap_err();
215    ///
216    ///     // at this point we should close the tcp connection by either graceful close or
217    ///     // just return immediately and drop the TcpStream.
218    ///     let _ = conn.shutdown().await;
219    /// }
220    ///
221    /// // thread3:
222    /// // receive message from tcp connection and send it to thread1:
223    /// async fn io_read(conn: &mut TcpStream) {
224    ///     // this part is ignored as it has no relation to the send_error api.
225    /// }
226    /// ```
227    #[inline]
228    pub fn send_error(&self, err: io::Error) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
229        self.inner.send_error(err)
230    }
231
232    /// encode [Message::Text] variant and add to [ResponseStream].
233    #[inline]
234    pub fn text(&self, txt: impl Into<String>) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
235        self.send(Message::Text(Bytes::from(txt.into())))
236    }
237
238    /// encode [Message::Binary] variant and add to [ResponseStream].
239    #[inline]
240    pub fn binary(&self, bin: impl Into<Bytes>) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
241        self.send(Message::Binary(bin.into()))
242    }
243
244    /// encode [Message::Close] variant and add to [ResponseStream].
245    /// take ownership of Self as after close message no more message can be sent to client.
246    pub async fn close(self, reason: Option<impl Into<CloseReason>>) -> Result<(), ProtocolError> {
247        self.send(Message::Close(reason.map(Into::into))).await
248    }
249}
250
251/// [Weak] version of [ResponseSender].
252#[derive(Debug)]
253pub struct ResponseWeakSender {
254    inner: Weak<_ResponseSender>,
255}
256
257impl ResponseWeakSender {
258    /// upgrade self to strong response sender.
259    /// return None when [ResponseSender] is already dropped.
260    pub fn upgrade(&self) -> Option<ResponseSender> {
261        self.inner.upgrade().map(|inner| ResponseSender { inner })
262    }
263}
264
265#[derive(Debug)]
266struct _ResponseSender {
267    encoder: Mutex<Encoder>,
268    tx: Sender<Item>,
269}
270
271#[derive(Debug)]
272struct Encoder {
273    codec: Codec,
274    buf: BytesMut,
275}
276
277impl _ResponseSender {
278    // send message to response stream. it would produce Ok(bytes) when succeed where
279    // the bytes is encoded binary websocket message ready to be sent to client.
280    async fn send(&self, msg: Message) -> Result<(), ProtocolError> {
281        let permit = self.tx.reserve().await.map_err(|_| ProtocolError::Closed)?;
282        let buf = {
283            let mut encoder = self.encoder.lock().unwrap();
284            let Encoder { codec, buf } = &mut *encoder;
285            codec.encode(msg, buf)?;
286            buf.split().freeze()
287        };
288        permit.send(Ok(buf));
289        Ok(())
290    }
291
292    // send error to response stream. it would produce Err(io::Error) when succeed where
293    // the error is a representation of io error to the stream consumer. in most cases
294    // the consumer observing the error should close the stream and the tcp connection
295    // the stream belongs to.
296    async fn send_error(&self, err: io::Error) -> Result<(), ProtocolError> {
297        self.tx.send(Err(err)).await.map_err(|_| ProtocolError::Closed)
298    }
299}