actix_ws/
stream.rs

1use std::{
2    collections::VecDeque,
3    future::poll_fn,
4    io, mem,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use actix_codec::{Decoder, Encoder};
10use actix_http::{
11    ws::{Codec, Frame, Message, ProtocolError},
12    Payload,
13};
14use actix_web::{
15    web::{Bytes, BytesMut},
16    Error,
17};
18use bytestring::ByteString;
19use futures_core::stream::Stream;
20use tokio::sync::mpsc::Receiver;
21
22use crate::AggregatedMessageStream;
23
24/// Response body for a WebSocket.
25pub struct StreamingBody {
26    session_rx: Receiver<Message>,
27    messages: VecDeque<Message>,
28    buf: BytesMut,
29    codec: Codec,
30    closing: bool,
31}
32
33impl StreamingBody {
34    pub(super) fn new(session_rx: Receiver<Message>) -> Self {
35        StreamingBody {
36            session_rx,
37            messages: VecDeque::new(),
38            buf: BytesMut::new(),
39            codec: Codec::new(),
40            closing: false,
41        }
42    }
43}
44
45/// Stream of messages from a WebSocket client.
46pub struct MessageStream {
47    payload: Payload,
48
49    messages: VecDeque<Message>,
50    buf: BytesMut,
51    codec: Codec,
52    closing: bool,
53}
54
55impl MessageStream {
56    pub(super) fn new(payload: Payload) -> Self {
57        MessageStream {
58            payload,
59            messages: VecDeque::new(),
60            buf: BytesMut::new(),
61            codec: Codec::new(),
62            closing: false,
63        }
64    }
65
66    /// Sets the maximum permitted size for received WebSocket frames, in bytes.
67    ///
68    /// By default, up to 64KiB is allowed.
69    ///
70    /// Any received frames larger than the permitted value will return
71    /// `Err(ProtocolError::Overflow)` instead.
72    ///
73    /// ```no_run
74    /// # use actix_ws::MessageStream;
75    /// # fn test(stream: MessageStream) {
76    /// // increase permitted frame size from 64KB to 1MB
77    /// let stream = stream.max_frame_size(1024 * 1024);
78    /// # }
79    /// ```
80    #[must_use]
81    pub fn max_frame_size(mut self, max_size: usize) -> Self {
82        self.codec = self.codec.max_size(max_size);
83        self
84    }
85
86    /// Returns a stream wrapper that collects continuation frames into their equivalent aggregated
87    /// forms, i.e., binary or text.
88    ///
89    /// By default, continuations will be aggregated up to 1MiB in size (customizable with
90    /// [`AggregatedMessageStream::max_continuation_size()`]). The stream implementation returns an
91    /// error if this size is exceeded.
92    #[must_use]
93    pub fn aggregate_continuations(self) -> AggregatedMessageStream {
94        AggregatedMessageStream::new(self)
95    }
96
97    /// Waits for the next item from the message stream
98    ///
99    /// This is a convenience for calling the [`Stream`](Stream::poll_next()) implementation.
100    ///
101    /// ```no_run
102    /// # use actix_ws::MessageStream;
103    /// # async fn test(mut stream: MessageStream) {
104    /// while let Some(Ok(msg)) = stream.recv().await {
105    ///     // handle message
106    /// }
107    /// # }
108    /// ```
109    #[must_use]
110    pub async fn recv(&mut self) -> Option<Result<Message, ProtocolError>> {
111        poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
112    }
113}
114
115impl Stream for StreamingBody {
116    type Item = Result<Bytes, Error>;
117
118    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119        let this = self.get_mut();
120
121        if this.closing {
122            return Poll::Ready(None);
123        }
124
125        loop {
126            match Pin::new(&mut this.session_rx).poll_recv(cx) {
127                Poll::Ready(Some(msg)) => {
128                    this.messages.push_back(msg);
129                }
130                Poll::Ready(None) => {
131                    this.closing = true;
132                    break;
133                }
134                Poll::Pending => break,
135            }
136        }
137
138        while let Some(msg) = this.messages.pop_front() {
139            if let Err(err) = this.codec.encode(msg, &mut this.buf) {
140                return Poll::Ready(Some(Err(err.into())));
141            }
142        }
143
144        if !this.buf.is_empty() {
145            return Poll::Ready(Some(Ok(mem::take(&mut this.buf).freeze())));
146        }
147
148        Poll::Pending
149    }
150}
151
152impl Stream for MessageStream {
153    type Item = Result<Message, ProtocolError>;
154
155    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
156        let this = self.get_mut();
157
158        // Return the first message in the queue if one exists
159        //
160        // This is faster than polling and parsing
161        if let Some(msg) = this.messages.pop_front() {
162            return Poll::Ready(Some(Ok(msg)));
163        }
164
165        if !this.closing {
166            // Read in bytes until there's nothing left to read
167            loop {
168                match Pin::new(&mut this.payload).poll_next(cx) {
169                    Poll::Ready(Some(Ok(bytes))) => {
170                        this.buf.extend_from_slice(&bytes);
171                    }
172                    Poll::Ready(Some(Err(err))) => {
173                        return Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other(err)))));
174                    }
175                    Poll::Ready(None) => {
176                        this.closing = true;
177                        break;
178                    }
179                    Poll::Pending => break,
180                }
181            }
182        }
183
184        // Create messages until there's no more bytes left
185        while let Some(frame) = this.codec.decode(&mut this.buf)? {
186            let message = match frame {
187                Frame::Text(bytes) => {
188                    ByteString::try_from(bytes)
189                        .map(Message::Text)
190                        .map_err(|err| {
191                            ProtocolError::Io(io::Error::new(io::ErrorKind::InvalidData, err))
192                        })?
193                }
194                Frame::Binary(bytes) => Message::Binary(bytes),
195                Frame::Ping(bytes) => Message::Ping(bytes),
196                Frame::Pong(bytes) => Message::Pong(bytes),
197                Frame::Close(reason) => Message::Close(reason),
198                Frame::Continuation(item) => Message::Continuation(item),
199            };
200
201            this.messages.push_back(message);
202        }
203
204        // Return the first message in the queue
205        if let Some(msg) = this.messages.pop_front() {
206            return Poll::Ready(Some(Ok(msg)));
207        }
208
209        // If we've exhausted our message queue and we're closing, close the stream
210        if this.closing {
211            return Poll::Ready(None);
212        }
213
214        Poll::Pending
215    }
216}