actix_ws/
stream.rs

1use std::{
2    collections::VecDeque,
3    future::poll_fn,
4    io, mem,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use actix_codec::{Decoder, Encoder};
10use actix_http::{
11    ws::{Codec, Frame, Message, ProtocolError},
12    Payload,
13};
14use actix_web::{
15    web::{Bytes, BytesMut},
16    Error,
17};
18use bytestring::ByteString;
19use futures_core::stream::Stream;
20use tokio::sync::mpsc::Receiver;
21
22use crate::AggregatedMessageStream;
23
24/// Response body for a WebSocket.
25pub struct StreamingBody {
26    session_rx: Receiver<Message>,
27    messages: VecDeque<Message>,
28    buf: BytesMut,
29    codec: Codec,
30    closing: bool,
31}
32
33impl StreamingBody {
34    pub(super) fn new(session_rx: Receiver<Message>) -> Self {
35        StreamingBody {
36            session_rx,
37            messages: VecDeque::new(),
38            buf: BytesMut::new(),
39            codec: Codec::new(),
40            closing: false,
41        }
42    }
43}
44
45/// Stream of messages from a WebSocket client.
46pub struct MessageStream {
47    payload: Payload,
48
49    messages: VecDeque<Message>,
50    buf: BytesMut,
51    codec: Codec,
52    closing: bool,
53}
54
55impl MessageStream {
56    pub(super) fn new(payload: Payload) -> Self {
57        MessageStream {
58            payload,
59            messages: VecDeque::new(),
60            buf: BytesMut::new(),
61            codec: Codec::new(),
62            closing: false,
63        }
64    }
65
66    /// Sets the maximum permitted size for received WebSocket frames, in bytes.
67    ///
68    /// By default, up to 64KiB is allowed.
69    ///
70    /// Any received frames larger than the permitted value will return
71    /// `Err(ProtocolError::Overflow)` instead.
72    ///
73    /// ```no_run
74    /// # use actix_ws::MessageStream;
75    /// # fn test(stream: MessageStream) {
76    /// // increase permitted frame size from 64KB to 1MB
77    /// let stream = stream.max_frame_size(1024 * 1024);
78    /// # }
79    /// ```
80    #[must_use]
81    pub fn max_frame_size(mut self, max_size: usize) -> Self {
82        self.codec = self.codec.max_size(max_size);
83        self
84    }
85
86    /// Returns a stream wrapper that collects continuation frames into their equivalent aggregated
87    /// forms, i.e., binary or text.
88    ///
89    /// By default, continuations will be aggregated up to 1MiB in size (customizable with
90    /// [`AggregatedMessageStream::max_continuation_size()`]). The stream implementation returns an
91    /// error if this size is exceeded.
92    #[must_use]
93    pub fn aggregate_continuations(self) -> AggregatedMessageStream {
94        AggregatedMessageStream::new(self)
95    }
96
97    /// Waits for the next item from the message stream
98    ///
99    /// This is a convenience for calling the [`Stream`](Stream::poll_next()) implementation.
100    ///
101    /// ```no_run
102    /// # use actix_ws::MessageStream;
103    /// # async fn test(mut stream: MessageStream) {
104    /// while let Some(Ok(msg)) = stream.recv().await {
105    ///     // handle message
106    /// }
107    /// # }
108    /// ```
109    #[must_use]
110    pub async fn recv(&mut self) -> Option<Result<Message, ProtocolError>> {
111        poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
112    }
113}
114
115impl Stream for StreamingBody {
116    type Item = Result<Bytes, Error>;
117
118    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119        let this = self.get_mut();
120
121        if this.closing {
122            return Poll::Ready(None);
123        }
124
125        loop {
126            match Pin::new(&mut this.session_rx).poll_recv(cx) {
127                Poll::Ready(Some(msg)) => {
128                    this.messages.push_back(msg);
129                }
130                Poll::Ready(None) => {
131                    this.closing = true;
132                    break;
133                }
134                Poll::Pending => break,
135            }
136        }
137
138        while let Some(msg) = this.messages.pop_front() {
139            if let Err(err) = this.codec.encode(msg, &mut this.buf) {
140                return Poll::Ready(Some(Err(err.into())));
141            }
142        }
143
144        if !this.buf.is_empty() {
145            return Poll::Ready(Some(Ok(mem::take(&mut this.buf).freeze())));
146        }
147
148        if this.closing {
149            return Poll::Ready(None);
150        }
151
152        Poll::Pending
153    }
154}
155
156impl Stream for MessageStream {
157    type Item = Result<Message, ProtocolError>;
158
159    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
160        let this = self.get_mut();
161
162        // Return the first message in the queue if one exists
163        //
164        // This is faster than polling and parsing
165        if let Some(msg) = this.messages.pop_front() {
166            return Poll::Ready(Some(Ok(msg)));
167        }
168
169        if !this.closing {
170            // Read in bytes until there's nothing left to read
171            loop {
172                match Pin::new(&mut this.payload).poll_next(cx) {
173                    Poll::Ready(Some(Ok(bytes))) => {
174                        this.buf.extend_from_slice(&bytes);
175                    }
176                    Poll::Ready(Some(Err(err))) => {
177                        return Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other(err)))));
178                    }
179                    Poll::Ready(None) => {
180                        this.closing = true;
181                        break;
182                    }
183                    Poll::Pending => break,
184                }
185            }
186        }
187
188        // Create messages until there's no more bytes left
189        while let Some(frame) = this.codec.decode(&mut this.buf)? {
190            let message = match frame {
191                Frame::Text(bytes) => {
192                    ByteString::try_from(bytes)
193                        .map(Message::Text)
194                        .map_err(|err| {
195                            ProtocolError::Io(io::Error::new(io::ErrorKind::InvalidData, err))
196                        })?
197                }
198                Frame::Binary(bytes) => Message::Binary(bytes),
199                Frame::Ping(bytes) => Message::Ping(bytes),
200                Frame::Pong(bytes) => Message::Pong(bytes),
201                Frame::Close(reason) => Message::Close(reason),
202                Frame::Continuation(item) => Message::Continuation(item),
203            };
204
205            this.messages.push_back(message);
206        }
207
208        // Return the first message in the queue
209        if let Some(msg) = this.messages.pop_front() {
210            return Poll::Ready(Some(Ok(msg)));
211        }
212
213        // If we've exhausted our message queue and we're closing, close the stream
214        if this.closing {
215            return Poll::Ready(None);
216        }
217
218        Poll::Pending
219    }
220}
221
222#[cfg(test)]
223pub(crate) mod tests {
224    use std::{
225        future::Future,
226        pin::Pin,
227        task::{ready, Context, Poll},
228    };
229
230    use actix_http::error::PayloadError;
231    use futures_core::Stream;
232    use tokio::sync::mpsc::{Receiver, Sender};
233
234    use super::{Bytes, BytesMut, Codec, Encoder, Message, MessageStream, Payload, StreamingBody};
235
236    pub(crate) struct PayloadReceiver {
237        rx: Receiver<Bytes>,
238    }
239    pub(crate) struct PayloadSender {
240        codec: Codec,
241        tx: Sender<Bytes>,
242    }
243    impl PayloadSender {
244        pub(crate) async fn send(&mut self, message: Message) {
245            self.send_many(vec![message]).await
246        }
247        pub(crate) async fn send_many(&mut self, messages: Vec<Message>) {
248            let mut buf = BytesMut::new();
249
250            for message in messages {
251                self.codec.encode(message, &mut buf).unwrap();
252            }
253
254            self.tx.send(buf.freeze()).await.unwrap()
255        }
256    }
257    impl Stream for PayloadReceiver {
258        type Item = Result<Bytes, PayloadError>;
259
260        fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
261            let opt = ready!(self.get_mut().rx.poll_recv(cx));
262
263            Poll::Ready(opt.map(Ok))
264        }
265    }
266    pub(crate) fn payload_pair(capacity: usize) -> (PayloadSender, Payload) {
267        let (tx, rx) = tokio::sync::mpsc::channel(capacity);
268
269        (
270            PayloadSender {
271                codec: Codec::new().client_mode(),
272                tx,
273            },
274            Payload::Stream {
275                payload: Box::pin(PayloadReceiver { rx }),
276            },
277        )
278    }
279
280    #[tokio::test]
281    async fn message_stream_yields_messages() {
282        std::future::poll_fn(move |cx| {
283            let (mut tx, rx) = payload_pair(8);
284            let message_stream = MessageStream::new(rx);
285            let mut stream = std::pin::pin!(message_stream);
286
287            let messages = [
288                Message::Binary(Bytes::from(vec![0, 1, 2, 3])),
289                Message::Ping(Bytes::from(vec![3, 2, 1, 0])),
290                Message::Close(None),
291            ];
292
293            for msg in messages {
294                let poll = stream.as_mut().poll_next(cx);
295                assert!(
296                    poll.is_pending(),
297                    "Stream should be pending when no messages are present {poll:?}"
298                );
299
300                let fut = tx.send(msg);
301                let fut = std::pin::pin!(fut);
302
303                assert!(fut.poll(cx).is_ready(), "Sending should not yield");
304                assert!(
305                    stream.as_mut().poll_next(cx).is_ready(),
306                    "Stream should be ready"
307                );
308            }
309
310            assert!(
311                stream.as_mut().poll_next(cx).is_pending(),
312                "Stream should be pending after processing messages"
313            );
314
315            Poll::Ready(())
316        })
317        .await
318    }
319
320    #[tokio::test]
321    async fn message_stream_yields_consecutive_messages() {
322        std::future::poll_fn(move |cx| {
323            let (mut tx, rx) = payload_pair(8);
324            let message_stream = MessageStream::new(rx);
325            let mut stream = std::pin::pin!(message_stream);
326
327            let messages = vec![
328                Message::Binary(Bytes::from(vec![0, 1, 2, 3])),
329                Message::Ping(Bytes::from(vec![3, 2, 1, 0])),
330                Message::Close(None),
331            ];
332
333            let size = messages.len();
334
335            let fut = tx.send_many(messages);
336            let fut = std::pin::pin!(fut);
337            assert!(fut.poll(cx).is_ready(), "Sending should not yield");
338
339            for _ in 0..size {
340                assert!(
341                    stream.as_mut().poll_next(cx).is_ready(),
342                    "Stream should be ready"
343                );
344            }
345
346            assert!(
347                stream.as_mut().poll_next(cx).is_pending(),
348                "Stream should be pending after processing messages"
349            );
350
351            Poll::Ready(())
352        })
353        .await
354    }
355
356    #[tokio::test]
357    async fn message_stream_closes() {
358        std::future::poll_fn(move |cx| {
359            let (tx, rx) = payload_pair(8);
360            drop(tx);
361            let message_stream = MessageStream::new(rx);
362            let mut stream = std::pin::pin!(message_stream);
363
364            let poll = stream.as_mut().poll_next(cx);
365            assert!(
366                matches!(poll, Poll::Ready(None)),
367                "Stream should be ready when closing {poll:?}"
368            );
369
370            Poll::Ready(())
371        })
372        .await
373    }
374
375    #[tokio::test]
376    async fn stream_produces_bytes_from_messages() {
377        std::future::poll_fn(move |cx| {
378            let (tx, rx) = tokio::sync::mpsc::channel(1);
379
380            let stream = StreamingBody::new(rx);
381
382            let messages = [
383                Message::Binary(Bytes::from(vec![0, 1, 2, 3])),
384                Message::Ping(Bytes::from(vec![3, 2, 1, 0])),
385                Message::Close(None),
386            ];
387
388            let mut stream = std::pin::pin!(stream);
389
390            for msg in messages {
391                assert!(
392                    stream.as_mut().poll_next(cx).is_pending(),
393                    "Stream should be pending when no messages are present"
394                );
395
396                let fut = tx.send(msg);
397                let fut = std::pin::pin!(fut);
398
399                assert!(fut.poll(cx).is_ready(), "Sending should not yield");
400                assert!(
401                    stream.as_mut().poll_next(cx).is_ready(),
402                    "Stream should be ready"
403                );
404            }
405
406            assert!(
407                stream.as_mut().poll_next(cx).is_pending(),
408                "Stream should be pending after processing messages"
409            );
410
411            Poll::Ready(())
412        })
413        .await;
414    }
415
416    #[tokio::test]
417    async fn stream_processes_many_consecutive_messages() {
418        std::future::poll_fn(move |cx| {
419            let (tx, rx) = tokio::sync::mpsc::channel(3);
420
421            let stream = StreamingBody::new(rx);
422
423            let messages = [
424                Message::Binary(Bytes::from(vec![0, 1, 2, 3])),
425                Message::Ping(Bytes::from(vec![3, 2, 1, 0])),
426                Message::Close(None),
427            ];
428
429            let mut stream = std::pin::pin!(stream);
430
431            assert!(stream.as_mut().poll_next(cx).is_pending());
432
433            for msg in messages {
434                let fut = tx.send(msg);
435                let fut = std::pin::pin!(fut);
436                assert!(fut.poll(cx).is_ready(), "Sending should not yield");
437            }
438
439            assert!(
440                stream.as_mut().poll_next(cx).is_ready(),
441                "Stream should be ready"
442            );
443            assert!(
444                stream.as_mut().poll_next(cx).is_pending(),
445                "Stream should have only been ready once"
446            );
447
448            Poll::Ready(())
449        })
450        .await;
451    }
452
453    #[tokio::test]
454    async fn stream_closes() {
455        std::future::poll_fn(move |cx| {
456            let (tx, rx) = tokio::sync::mpsc::channel(3);
457
458            drop(tx);
459            let stream = StreamingBody::new(rx);
460
461            let mut stream = std::pin::pin!(stream);
462
463            let poll = stream.as_mut().poll_next(cx);
464
465            assert!(
466                matches!(poll, Poll::Ready(None)),
467                "stream should close after dropped tx"
468            );
469
470            Poll::Ready(())
471        })
472        .await;
473    }
474}