Skip to main content

http_ws/
stream.rs

1use core::{
2    fmt,
3    future::Future,
4    pin::Pin,
5    task::{ready, Context, Poll},
6};
7
8use alloc::sync::{Arc, Weak};
9
10use std::{error, io, sync::Mutex};
11
12use bytes::{Bytes, BytesMut};
13use futures_core::stream::Stream;
14use pin_project_lite::pin_project;
15use tokio::sync::mpsc::{channel, Receiver, Sender};
16
17use super::{
18    codec::{Codec, Message},
19    error::ProtocolError,
20    proto::CloseReason,
21};
22
23pin_project! {
24    /// Decode `S` type into Stream of websocket [Message].
25    /// `S` type must impl `Stream` trait and output `Result<T, E>` as `Stream::Item`
26    /// where `T` type impl `AsRef<[u8]>` trait. (`&[u8]` is needed for parsing messages)
27    ///
28    /// # Stream termination
29    ///
30    /// This stream never returns `None`. Callers should expect one of the following outcomes:
31    ///
32    /// - `Ok(Message::Close(_))`: The remote peer initiated a clean close. The caller should
33    ///   send a close frame back via [ResponseSender] and then stop polling.
34    /// - `Err(WsError::Protocol(ProtocolError::RecvClosed))`: The stream was polled after a
35    ///   close frame had already been received. The caller should have stopped polling after
36    ///   observing `Message::Close`.
37    /// - `Err(WsError::Protocol(ProtocolError::UnexpectedEof))`: The underlying transport
38    ///   ended without a close frame. This is an abnormal closure and the associated
39    ///   connection should not be reused.
40    /// - `Err(WsError::Protocol(_))`: A protocol violation occurred (e.g. bad opcode,
41    ///   continuation error). The connection should be dropped.
42    /// - `Err(WsError::Stream(_))`: The underlying stream produced an error.
43    pub struct RequestStream<S> {
44        #[pin]
45        stream: S,
46        buf: BytesMut,
47        codec: Codec,
48    }
49}
50
51impl<S, T, E> RequestStream<S>
52where
53    S: Stream<Item = Result<T, E>>,
54    T: AsRef<[u8]>,
55{
56    pub fn new(stream: S) -> Self {
57        Self::with_codec(stream, Codec::new())
58    }
59
60    pub fn with_codec(stream: S, codec: Codec) -> Self {
61        Self {
62            stream,
63            buf: BytesMut::new(),
64            codec,
65        }
66    }
67
68    #[inline]
69    pub fn inner_mut(&mut self) -> &mut S {
70        &mut self.stream
71    }
72
73    #[inline]
74    pub fn codec_mut(&mut self) -> &mut Codec {
75        &mut self.codec
76    }
77
78    /// Make a [ResponseStream] from current DecodeStream.
79    ///
80    /// This API is to share the same codec for both decode and encode stream.
81    pub fn response_stream(&self) -> (ResponseStream, ResponseSender) {
82        let codec = self.codec.duplicate();
83        let cap = codec.capacity();
84        let (tx, rx) = channel(cap);
85        (ResponseStream(rx), ResponseSender::new(tx, codec))
86    }
87}
88
89pub enum WsError<E> {
90    Protocol(ProtocolError),
91    Stream(E),
92}
93
94impl<E> fmt::Debug for WsError<E> {
95    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96        match *self {
97            Self::Protocol(ref e) => fmt::Debug::fmt(e, f),
98            Self::Stream(..) => f.write_str("Input Stream error"),
99        }
100    }
101}
102
103impl<E> fmt::Display for WsError<E> {
104    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105        match *self {
106            Self::Protocol(ref e) => fmt::Debug::fmt(e, f),
107            Self::Stream(..) => f.write_str("Input Stream error"),
108        }
109    }
110}
111
112impl<E> error::Error for WsError<E> {}
113
114impl<E> From<ProtocolError> for WsError<E> {
115    fn from(e: ProtocolError) -> Self {
116        Self::Protocol(e)
117    }
118}
119
120impl<S, T, E> Stream for RequestStream<S>
121where
122    S: Stream<Item = Result<T, E>>,
123    T: AsRef<[u8]>,
124{
125    type Item = Result<Message, WsError<E>>;
126
127    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
128        let mut this = self.project();
129
130        loop {
131            match this.codec.decode(this.buf)? {
132                Some(msg) => return Poll::Ready(Some(Ok(msg))),
133                None => match ready!(this.stream.as_mut().poll_next(cx)) {
134                    Some(res) => {
135                        let item = res.map_err(WsError::Stream)?;
136                        this.buf.extend_from_slice(item.as_ref())
137                    }
138                    None => return Poll::Ready(Some(Err(WsError::Protocol(ProtocolError::UnexpectedEof)))),
139                },
140            }
141        }
142    }
143}
144
145pub struct ResponseStream(Receiver<Item>);
146
147type Item = io::Result<Bytes>;
148
149impl Stream for ResponseStream {
150    type Item = Item;
151
152    #[inline]
153    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
154        self.get_mut().0.poll_recv(cx)
155    }
156}
157
158/// Encode [Message] into [Bytes] and send it to [ResponseStream].
159#[derive(Debug)]
160pub struct ResponseSender {
161    inner: Arc<_ResponseSender>,
162}
163
164impl ResponseSender {
165    fn new(tx: Sender<Item>, codec: Codec) -> Self {
166        let buf = BytesMut::with_capacity(codec.max_size());
167        Self {
168            inner: Arc::new(_ResponseSender {
169                encoder: Mutex::new(Encoder { codec, buf }),
170                tx,
171            }),
172        }
173    }
174
175    /// downgrade Self to a weak sender.
176    pub fn downgrade(&self) -> ResponseWeakSender {
177        ResponseWeakSender {
178            inner: Arc::downgrade(&self.inner),
179        }
180    }
181
182    /// encode [`Message::Text`] variant and add to [ResponseStream].
183    #[inline]
184    pub async fn text(&self, txt: impl Into<Bytes>) -> Result<(), ProtocolError> {
185        let bytes = txt.into();
186        core::str::from_utf8(&bytes).map_err(|_| ProtocolError::BadOpCode)?;
187        self.send(Message::Text(bytes)).await
188    }
189
190    /// encode [`Message::Binary`] variant and add to [ResponseStream].
191    #[inline]
192    pub fn binary(&self, bin: impl Into<Bytes>) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
193        self.send(Message::Binary(bin.into()))
194    }
195
196    /// encode [`Message::Continuation`] variant and add to [ResponseStream].
197    #[inline]
198    pub fn continuation(&self, item: super::codec::Item) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
199        self.send(Message::Continuation(item))
200    }
201
202    /// encode [`Message::Ping`] variant and add to [ResponseStream].
203    #[inline]
204    pub fn ping(&self, bin: impl Into<Bytes>) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
205        self.send(Message::Ping(bin.into()))
206    }
207
208    /// encode [`Message::Pong`] variant and add to [ResponseStream].
209    #[inline]
210    pub fn pong(&self, bin: impl Into<Bytes>) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
211        self.send(Message::Pong(bin.into()))
212    }
213
214    /// encode [`Message::Close`] variant and add to [ResponseStream].
215    ///
216    /// This method can only be executed once.
217    /// Concurrent callers would race for executing and at most one would succeed.
218    /// Other callers failing the race would observe [`ProtocolError::SendClosed`]
219    pub async fn close(&mut self, reason: Option<impl Into<CloseReason>>) -> Result<(), ProtocolError> {
220        self.send(Message::Close(reason.map(Into::into))).await
221    }
222
223    #[inline]
224    fn send(&self, msg: Message) -> impl Future<Output = Result<(), ProtocolError>> + '_ {
225        self.inner.send(msg)
226    }
227}
228
229/// [Weak] version of [ResponseSender].]
230///
231/// Used for duplicating sender for concurrency
232#[derive(Debug)]
233pub struct ResponseWeakSender {
234    inner: Weak<_ResponseSender>,
235}
236
237impl ResponseWeakSender {
238    /// upgrade self to strong response sender.
239    /// return None when [ResponseSender] is already dropped or session (Send side) is already closed
240    pub fn upgrade(&self) -> Option<ResponseSender> {
241        self.inner.upgrade().and_then(|inner| {
242            let closed = inner.encoder.lock().unwrap().codec.send_closed();
243            (!closed).then(|| ResponseSender { inner })
244        })
245    }
246}
247
248#[derive(Debug)]
249struct _ResponseSender {
250    encoder: Mutex<Encoder>,
251    tx: Sender<Item>,
252}
253
254#[derive(Debug)]
255struct Encoder {
256    codec: Codec,
257    buf: BytesMut,
258}
259
260impl _ResponseSender {
261    // send message to response stream. it would produce Ok(bytes) when succeed where
262    // the bytes is encoded binary websocket message ready to be sent to client.
263    async fn send(&self, msg: Message) -> Result<(), ProtocolError> {
264        let permit = self.tx.reserve().await.map_err(|_| ProtocolError::UnexpectedEof)?;
265        let buf = {
266            let mut encoder = self.encoder.lock().unwrap();
267            let Encoder { codec, buf } = &mut *encoder;
268            codec.encode(msg, buf)?;
269            buf.split().freeze()
270        };
271        permit.send(Ok(buf));
272        Ok(())
273    }
274}