actix_ws/
aggregated.rs

1//! WebSocket stream for aggregating continuation frames.
2
3use std::{
4    future::poll_fn,
5    io, mem,
6    pin::Pin,
7    task::{ready, Context, Poll},
8};
9
10use actix_http::ws::{CloseReason, Item, Message, ProtocolError};
11use actix_web::web::{Bytes, BytesMut};
12use bytestring::ByteString;
13use futures_core::Stream;
14
15use crate::MessageStream;
16
17pub(crate) enum ContinuationKind {
18    Text,
19    Binary,
20}
21
22/// WebSocket message with any continuations aggregated together.
23#[derive(Debug, PartialEq, Eq)]
24pub enum AggregatedMessage {
25    /// Text message.
26    Text(ByteString),
27
28    /// Binary message.
29    Binary(Bytes),
30
31    /// Ping message.
32    Ping(Bytes),
33
34    /// Pong message.
35    Pong(Bytes),
36
37    /// Close message with optional reason.
38    Close(Option<CloseReason>),
39}
40
41/// Stream of messages from a WebSocket client, with continuations aggregated.
42pub struct AggregatedMessageStream {
43    stream: MessageStream,
44    current_size: usize,
45    max_size: usize,
46    continuations: Vec<Bytes>,
47    continuation_kind: ContinuationKind,
48}
49
50impl AggregatedMessageStream {
51    #[must_use]
52    pub(crate) fn new(stream: MessageStream) -> Self {
53        AggregatedMessageStream {
54            stream,
55            current_size: 0,
56            max_size: 1024 * 1024,
57            continuations: Vec::new(),
58            continuation_kind: ContinuationKind::Binary,
59        }
60    }
61
62    /// Sets the maximum allowed size for aggregated continuations, in bytes.
63    ///
64    /// By default, up to 1 MiB is allowed.
65    ///
66    /// ```no_run
67    /// # use actix_ws::AggregatedMessageStream;
68    /// # async fn test(stream: AggregatedMessageStream) {
69    /// // increase the allowed size from 1MB to 8MB
70    /// let mut stream = stream.max_continuation_size(8 * 1024 * 1024);
71    ///
72    /// while let Some(Ok(msg)) = stream.recv().await {
73    ///     // handle message
74    /// }
75    /// # }
76    /// ```
77    #[must_use]
78    pub fn max_continuation_size(mut self, max_size: usize) -> Self {
79        self.max_size = max_size;
80        self
81    }
82
83    /// Waits for the next item from the aggregated message stream.
84    ///
85    /// This is a convenience for calling the [`Stream`](Stream::poll_next()) implementation.
86    ///
87    /// ```no_run
88    /// # use actix_ws::AggregatedMessageStream;
89    /// # async fn test(mut stream: AggregatedMessageStream) {
90    /// while let Some(Ok(msg)) = stream.recv().await {
91    ///     // handle message
92    /// }
93    /// # }
94    /// ```
95    #[must_use]
96    pub async fn recv(&mut self) -> Option<<Self as Stream>::Item> {
97        poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
98    }
99}
100
101fn size_error() -> Poll<Option<Result<AggregatedMessage, ProtocolError>>> {
102    Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other(
103        "Exceeded maximum continuation size",
104    )))))
105}
106
107impl Stream for AggregatedMessageStream {
108    type Item = Result<AggregatedMessage, ProtocolError>;
109
110    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
111        let this = self.get_mut();
112
113        loop {
114            let Some(msg) = ready!(Pin::new(&mut this.stream).poll_next(cx)?) else {
115                return Poll::Ready(None);
116            };
117
118            match msg {
119                Message::Continuation(item) => match item {
120                    Item::FirstText(bytes) => {
121                        this.continuation_kind = ContinuationKind::Text;
122                        this.current_size += bytes.len();
123
124                        if this.current_size > this.max_size {
125                            this.continuations.clear();
126                            return size_error();
127                        }
128
129                        this.continuations.push(bytes);
130
131                        continue;
132                    }
133
134                    Item::FirstBinary(bytes) => {
135                        this.continuation_kind = ContinuationKind::Binary;
136                        this.current_size += bytes.len();
137
138                        if this.current_size > this.max_size {
139                            this.continuations.clear();
140                            return size_error();
141                        }
142
143                        this.continuations.push(bytes);
144
145                        continue;
146                    }
147
148                    Item::Continue(bytes) => {
149                        this.current_size += bytes.len();
150
151                        if this.current_size > this.max_size {
152                            this.continuations.clear();
153                            return size_error();
154                        }
155
156                        this.continuations.push(bytes);
157
158                        continue;
159                    }
160
161                    Item::Last(bytes) => {
162                        this.current_size += bytes.len();
163
164                        if this.current_size > this.max_size {
165                            // reset current_size, as this is the last message for
166                            // the current continuation
167                            this.current_size = 0;
168                            this.continuations.clear();
169
170                            return size_error();
171                        }
172
173                        this.continuations.push(bytes);
174                        let bytes = collect(&mut this.continuations);
175
176                        this.current_size = 0;
177
178                        match this.continuation_kind {
179                            ContinuationKind::Text => {
180                                return Poll::Ready(Some(match ByteString::try_from(bytes) {
181                                    Ok(bytestring) => Ok(AggregatedMessage::Text(bytestring)),
182                                    Err(err) => Err(ProtocolError::Io(io::Error::new(
183                                        io::ErrorKind::InvalidData,
184                                        err.to_string(),
185                                    ))),
186                                }))
187                            }
188                            ContinuationKind::Binary => {
189                                return Poll::Ready(Some(Ok(AggregatedMessage::Binary(bytes))))
190                            }
191                        }
192                    }
193                },
194
195                Message::Text(text) => return Poll::Ready(Some(Ok(AggregatedMessage::Text(text)))),
196                Message::Binary(binary) => {
197                    return Poll::Ready(Some(Ok(AggregatedMessage::Binary(binary))))
198                }
199                Message::Ping(ping) => return Poll::Ready(Some(Ok(AggregatedMessage::Ping(ping)))),
200                Message::Pong(pong) => return Poll::Ready(Some(Ok(AggregatedMessage::Pong(pong)))),
201                Message::Close(close) => {
202                    return Poll::Ready(Some(Ok(AggregatedMessage::Close(close))))
203                }
204
205                Message::Nop => unreachable!("MessageStream should not produce no-ops"),
206            }
207        }
208    }
209}
210
211fn collect(continuations: &mut Vec<Bytes>) -> Bytes {
212    let continuations = mem::take(continuations);
213    let total_len = continuations.iter().map(|b| b.len()).sum();
214
215    let mut buf = BytesMut::with_capacity(total_len);
216
217    for chunk in continuations {
218        buf.extend(chunk);
219    }
220
221    buf.freeze()
222}
223
224#[cfg(test)]
225mod tests {
226    use std::{future::Future, task::Poll};
227
228    use futures_core::Stream;
229
230    use super::{Bytes, Item, Message, MessageStream};
231    use crate::stream::tests::payload_pair;
232
233    #[tokio::test]
234    async fn aggregates_continuations() {
235        std::future::poll_fn(move |cx| {
236            let (mut tx, rx) = payload_pair(8);
237            let message_stream = MessageStream::new(rx).aggregate_continuations();
238            let mut stream = std::pin::pin!(message_stream);
239
240            let messages = [
241                Message::Continuation(Item::FirstText(Bytes::from(b"first".to_vec()))),
242                Message::Continuation(Item::Continue(Bytes::from(b"second".to_vec()))),
243                Message::Continuation(Item::Last(Bytes::from(b"third".to_vec()))),
244            ];
245
246            let len = messages.len();
247
248            for (idx, msg) in messages.into_iter().enumerate() {
249                let poll = stream.as_mut().poll_next(cx);
250                assert!(
251                    poll.is_pending(),
252                    "Stream should be pending when no messages are present {poll:?}"
253                );
254
255                let fut = tx.send(msg);
256                let fut = std::pin::pin!(fut);
257
258                assert!(fut.poll(cx).is_ready(), "Sending should not yield");
259
260                if idx == len - 1 {
261                    assert!(
262                        stream.as_mut().poll_next(cx).is_ready(),
263                        "Stream should be ready"
264                    );
265                } else {
266                    assert!(
267                        stream.as_mut().poll_next(cx).is_pending(),
268                        "Stream shouldn't be ready until continuations complete"
269                    );
270                }
271            }
272
273            assert!(
274                stream.as_mut().poll_next(cx).is_pending(),
275                "Stream should be pending after processing messages"
276            );
277
278            Poll::Ready(())
279        })
280        .await
281    }
282
283    #[tokio::test]
284    async fn aggregates_consecutive_continuations() {
285        std::future::poll_fn(move |cx| {
286            let (mut tx, rx) = payload_pair(8);
287            let message_stream = MessageStream::new(rx).aggregate_continuations();
288            let mut stream = std::pin::pin!(message_stream);
289
290            let messages = vec![
291                Message::Continuation(Item::FirstText(Bytes::from(b"first".to_vec()))),
292                Message::Continuation(Item::Continue(Bytes::from(b"second".to_vec()))),
293                Message::Continuation(Item::Last(Bytes::from(b"third".to_vec()))),
294            ];
295
296            let poll = stream.as_mut().poll_next(cx);
297            assert!(
298                poll.is_pending(),
299                "Stream should be pending when no messages are present {poll:?}"
300            );
301
302            let fut = tx.send_many(messages);
303            let fut = std::pin::pin!(fut);
304
305            assert!(fut.poll(cx).is_ready(), "Sending should not yield");
306
307            assert!(
308                stream.as_mut().poll_next(cx).is_ready(),
309                "Stream should be ready when all continuations have been sent"
310            );
311
312            assert!(
313                stream.as_mut().poll_next(cx).is_pending(),
314                "Stream should be pending after processing messages"
315            );
316
317            Poll::Ready(())
318        })
319        .await
320    }
321
322    #[tokio::test]
323    async fn stream_closes() {
324        std::future::poll_fn(move |cx| {
325            let (tx, rx) = payload_pair(8);
326            drop(tx);
327            let message_stream = MessageStream::new(rx).aggregate_continuations();
328            let mut stream = std::pin::pin!(message_stream);
329
330            let poll = stream.as_mut().poll_next(cx);
331            assert!(
332                matches!(poll, Poll::Ready(None)),
333                "Stream should be ready when all continuations have been sent"
334            );
335
336            Poll::Ready(())
337        })
338        .await
339    }
340}