Skip to main content

actix_ws/
aggregated.rs

1//! WebSocket stream for aggregating continuation frames.
2
3use std::{
4    future::poll_fn,
5    io, mem,
6    pin::Pin,
7    task::{ready, Context, Poll},
8};
9
10use actix_http::ws::{CloseReason, Item, Message, ProtocolError};
11use actix_web::web::{Bytes, BytesMut};
12use bytestring::ByteString;
13use futures_core::Stream;
14
15use crate::MessageStream;
16
17pub(crate) enum ContinuationKind {
18    Text,
19    Binary,
20}
21
22/// WebSocket message with any continuations aggregated together.
23#[derive(Debug, PartialEq, Eq)]
24pub enum AggregatedMessage {
25    /// Text message.
26    Text(ByteString),
27
28    /// Binary message.
29    Binary(Bytes),
30
31    /// Ping message.
32    Ping(Bytes),
33
34    /// Pong message.
35    Pong(Bytes),
36
37    /// Close message with optional reason.
38    Close(Option<CloseReason>),
39}
40
41/// Stream of messages from a WebSocket client, with continuations aggregated.
42pub struct AggregatedMessageStream {
43    stream: MessageStream,
44    current_size: usize,
45    max_size: usize,
46    continuations: Vec<Bytes>,
47    continuation_kind: ContinuationKind,
48    overflowed: bool,
49}
50
51impl AggregatedMessageStream {
52    #[must_use]
53    pub(crate) fn new(stream: MessageStream) -> Self {
54        AggregatedMessageStream {
55            stream,
56            current_size: 0,
57            max_size: 1024 * 1024,
58            continuations: Vec::new(),
59            continuation_kind: ContinuationKind::Binary,
60            overflowed: false,
61        }
62    }
63
64    /// Sets the maximum allowed size for aggregated continuations, in bytes.
65    ///
66    /// By default, up to 1 MiB is allowed.
67    ///
68    /// ```no_run
69    /// # use actix_ws::AggregatedMessageStream;
70    /// # async fn test(stream: AggregatedMessageStream) {
71    /// // increase the allowed size from 1MB to 8MB
72    /// let mut stream = stream.max_continuation_size(8 * 1024 * 1024);
73    ///
74    /// while let Some(Ok(msg)) = stream.recv().await {
75    ///     // handle message
76    /// }
77    /// # }
78    /// ```
79    #[must_use]
80    pub fn max_continuation_size(mut self, max_size: usize) -> Self {
81        self.max_size = max_size;
82        self
83    }
84
85    /// Waits for the next item from the aggregated message stream.
86    ///
87    /// This is a convenience for calling the [`Stream`](Stream::poll_next()) implementation.
88    ///
89    /// ```no_run
90    /// # use actix_ws::AggregatedMessageStream;
91    /// # async fn test(mut stream: AggregatedMessageStream) {
92    /// while let Some(Ok(msg)) = stream.recv().await {
93    ///     // handle message
94    /// }
95    /// # }
96    /// ```
97    #[must_use]
98    pub async fn recv(&mut self) -> Option<<Self as Stream>::Item> {
99        poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
100    }
101}
102
103fn size_error() -> Poll<Option<Result<AggregatedMessage, ProtocolError>>> {
104    Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other(
105        "Exceeded maximum continuation size",
106    )))))
107}
108
109impl Stream for AggregatedMessageStream {
110    type Item = Result<AggregatedMessage, ProtocolError>;
111
112    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
113        let this = self.get_mut();
114
115        loop {
116            let Some(msg) = ready!(Pin::new(&mut this.stream).poll_next(cx)?) else {
117                return Poll::Ready(None);
118            };
119
120            match msg {
121                Message::Continuation(item) => match item {
122                    Item::FirstText(bytes) => {
123                        if this.overflowed {
124                            continue;
125                        }
126
127                        this.continuation_kind = ContinuationKind::Text;
128                        this.current_size += bytes.len();
129
130                        if this.current_size > this.max_size {
131                            this.current_size = 0;
132                            this.continuations.clear();
133                            this.overflowed = true;
134                            return size_error();
135                        }
136
137                        // Avoid unbounded growth when receiving unlimited empty continuation frames.
138                        if !bytes.is_empty() {
139                            this.continuations.push(bytes);
140                        }
141
142                        continue;
143                    }
144
145                    Item::FirstBinary(bytes) => {
146                        if this.overflowed {
147                            continue;
148                        }
149
150                        this.continuation_kind = ContinuationKind::Binary;
151                        this.current_size += bytes.len();
152
153                        if this.current_size > this.max_size {
154                            this.current_size = 0;
155                            this.continuations.clear();
156                            this.overflowed = true;
157                            return size_error();
158                        }
159
160                        // Avoid unbounded growth when receiving unlimited empty continuation frames.
161                        if !bytes.is_empty() {
162                            this.continuations.push(bytes);
163                        }
164
165                        continue;
166                    }
167
168                    Item::Continue(bytes) => {
169                        if this.overflowed {
170                            continue;
171                        }
172
173                        this.current_size += bytes.len();
174
175                        if this.current_size > this.max_size {
176                            this.current_size = 0;
177                            this.continuations.clear();
178                            this.overflowed = true;
179                            return size_error();
180                        }
181
182                        // Avoid unbounded growth when receiving unlimited empty continuation frames.
183                        if !bytes.is_empty() {
184                            this.continuations.push(bytes);
185                        }
186
187                        continue;
188                    }
189
190                    Item::Last(bytes) => {
191                        if this.overflowed {
192                            this.current_size = 0;
193                            this.continuations.clear();
194                            this.overflowed = false;
195                            continue;
196                        }
197
198                        this.current_size += bytes.len();
199
200                        if this.current_size > this.max_size {
201                            // reset current_size, as this is the last message for
202                            // the current continuation
203                            this.current_size = 0;
204                            this.continuations.clear();
205
206                            return size_error();
207                        }
208
209                        // Avoid unbounded growth when receiving unlimited empty continuation frames.
210                        if !bytes.is_empty() {
211                            this.continuations.push(bytes);
212                        }
213                        let bytes = collect(&mut this.continuations, this.current_size);
214
215                        this.current_size = 0;
216
217                        match this.continuation_kind {
218                            ContinuationKind::Text => {
219                                return Poll::Ready(Some(match ByteString::try_from(bytes) {
220                                    Ok(bytestring) => Ok(AggregatedMessage::Text(bytestring)),
221                                    Err(err) => Err(ProtocolError::Io(io::Error::new(
222                                        io::ErrorKind::InvalidData,
223                                        err.to_string(),
224                                    ))),
225                                }))
226                            }
227                            ContinuationKind::Binary => {
228                                return Poll::Ready(Some(Ok(AggregatedMessage::Binary(bytes))))
229                            }
230                        }
231                    }
232                },
233
234                Message::Text(text) => return Poll::Ready(Some(Ok(AggregatedMessage::Text(text)))),
235                Message::Binary(binary) => {
236                    return Poll::Ready(Some(Ok(AggregatedMessage::Binary(binary))))
237                }
238                Message::Ping(ping) => return Poll::Ready(Some(Ok(AggregatedMessage::Ping(ping)))),
239                Message::Pong(pong) => return Poll::Ready(Some(Ok(AggregatedMessage::Pong(pong)))),
240                Message::Close(close) => {
241                    return Poll::Ready(Some(Ok(AggregatedMessage::Close(close))))
242                }
243
244                Message::Nop => unreachable!("MessageStream should not produce no-ops"),
245            }
246        }
247    }
248}
249
250fn collect(continuations: &mut Vec<Bytes>, total_len: usize) -> Bytes {
251    let continuations = mem::take(continuations);
252    let mut buf = BytesMut::with_capacity(total_len);
253
254    for chunk in continuations {
255        buf.extend_from_slice(&chunk);
256    }
257
258    buf.freeze()
259}
260
261#[cfg(test)]
262mod tests {
263    use std::{future::Future, task::Poll};
264
265    use futures_core::Stream;
266
267    use super::{AggregatedMessage, Bytes, Item, Message, MessageStream};
268    use crate::stream::tests::payload_pair;
269
270    #[tokio::test]
271    async fn aggregates_continuations() {
272        std::future::poll_fn(move |cx| {
273            let (mut tx, rx) = payload_pair(8);
274            let message_stream = MessageStream::new(rx).aggregate_continuations();
275            let mut stream = std::pin::pin!(message_stream);
276
277            let messages = [
278                Message::Continuation(Item::FirstText(Bytes::from(b"first".to_vec()))),
279                Message::Continuation(Item::Continue(Bytes::from(b"second".to_vec()))),
280                Message::Continuation(Item::Last(Bytes::from(b"third".to_vec()))),
281            ];
282
283            let len = messages.len();
284
285            for (idx, msg) in messages.into_iter().enumerate() {
286                let poll = stream.as_mut().poll_next(cx);
287                assert!(
288                    poll.is_pending(),
289                    "Stream should be pending when no messages are present {poll:?}"
290                );
291
292                let fut = tx.send(msg);
293                let fut = std::pin::pin!(fut);
294
295                assert!(fut.poll(cx).is_ready(), "Sending should not yield");
296
297                if idx == len - 1 {
298                    assert!(
299                        stream.as_mut().poll_next(cx).is_ready(),
300                        "Stream should be ready"
301                    );
302                } else {
303                    assert!(
304                        stream.as_mut().poll_next(cx).is_pending(),
305                        "Stream shouldn't be ready until continuations complete"
306                    );
307                }
308            }
309
310            assert!(
311                stream.as_mut().poll_next(cx).is_pending(),
312                "Stream should be pending after processing messages"
313            );
314
315            Poll::Ready(())
316        })
317        .await
318    }
319
320    #[tokio::test]
321    async fn aggregates_consecutive_continuations() {
322        std::future::poll_fn(move |cx| {
323            let (mut tx, rx) = payload_pair(8);
324            let message_stream = MessageStream::new(rx).aggregate_continuations();
325            let mut stream = std::pin::pin!(message_stream);
326
327            let messages = vec![
328                Message::Continuation(Item::FirstText(Bytes::from(b"first".to_vec()))),
329                Message::Continuation(Item::Continue(Bytes::from(b"second".to_vec()))),
330                Message::Continuation(Item::Last(Bytes::from(b"third".to_vec()))),
331            ];
332
333            let poll = stream.as_mut().poll_next(cx);
334            assert!(
335                poll.is_pending(),
336                "Stream should be pending when no messages are present {poll:?}"
337            );
338
339            let fut = tx.send_many(messages);
340            let fut = std::pin::pin!(fut);
341
342            assert!(fut.poll(cx).is_ready(), "Sending should not yield");
343
344            assert!(
345                stream.as_mut().poll_next(cx).is_ready(),
346                "Stream should be ready when all continuations have been sent"
347            );
348
349            assert!(
350                stream.as_mut().poll_next(cx).is_pending(),
351                "Stream should be pending after processing messages"
352            );
353
354            Poll::Ready(())
355        })
356        .await
357    }
358
359    #[tokio::test]
360    async fn ignores_empty_continuation_chunks() {
361        std::future::poll_fn(move |cx| {
362            let (mut tx, rx) = payload_pair(8);
363            let message_stream = MessageStream::new(rx).aggregate_continuations();
364            let mut stream = std::pin::pin!(message_stream);
365
366            let poll = stream.as_mut().poll_next(cx);
367            assert!(
368                poll.is_pending(),
369                "Stream should be pending when no messages are present {poll:?}"
370            );
371
372            // start continuation with empty chunk, then send a bunch of empty continuation chunks;
373            // they should not be buffered (would otherwise cause unbounded `Vec` growth).
374            let messages = std::iter::once(Message::Continuation(Item::FirstText(Bytes::new())))
375                .chain((0..128).map(|_| Message::Continuation(Item::Continue(Bytes::new()))))
376                .collect::<Vec<_>>();
377
378            {
379                let fut = tx.send_many(messages);
380                let fut = std::pin::pin!(fut);
381                assert!(fut.poll(cx).is_ready(), "Sending should not yield");
382            }
383
384            assert!(
385                stream.as_mut().poll_next(cx).is_pending(),
386                "Stream shouldn't be ready until continuations complete"
387            );
388            assert_eq!(stream.as_mut().get_mut().continuations.len(), 0);
389
390            // end continuation; this should yield an empty text message.
391            {
392                let fut = tx.send(Message::Continuation(Item::Last(Bytes::new())));
393                let fut = std::pin::pin!(fut);
394                assert!(fut.poll(cx).is_ready(), "Sending should not yield");
395            }
396
397            match stream.as_mut().poll_next(cx) {
398                Poll::Ready(Some(Ok(AggregatedMessage::Text(text)))) => assert!(text.is_empty()),
399                poll => panic!("expected empty text message; got {poll:?}"),
400            }
401
402            assert_eq!(stream.as_mut().get_mut().continuations.len(), 0);
403
404            Poll::Ready(())
405        })
406        .await
407    }
408
409    #[tokio::test]
410    async fn stream_closes() {
411        std::future::poll_fn(move |cx| {
412            let (tx, rx) = payload_pair(8);
413            drop(tx);
414            let message_stream = MessageStream::new(rx).aggregate_continuations();
415            let mut stream = std::pin::pin!(message_stream);
416
417            let poll = stream.as_mut().poll_next(cx);
418            assert!(
419                matches!(poll, Poll::Ready(None)),
420                "Stream should be ready when all continuations have been sent"
421            );
422
423            Poll::Ready(())
424        })
425        .await
426    }
427
428    #[tokio::test]
429    async fn continuation_overflow_errors_once_and_recovers() {
430        std::future::poll_fn(move |cx| {
431            let (mut tx, rx) = payload_pair(8);
432            let message_stream = MessageStream::new(rx)
433                .aggregate_continuations()
434                .max_continuation_size(4);
435            let mut stream = std::pin::pin!(message_stream);
436
437            let poll = stream.as_mut().poll_next(cx);
438            assert!(
439                poll.is_pending(),
440                "Stream should be pending when no messages are present {poll:?}"
441            );
442
443            let messages = vec![
444                Message::Continuation(Item::FirstText(Bytes::from(b"1234".to_vec()))),
445                Message::Continuation(Item::Continue(Bytes::from(b"5".to_vec()))),
446                Message::Ping(Bytes::from(b"p".to_vec())),
447                Message::Continuation(Item::Last(Bytes::from(b"6".to_vec()))),
448                Message::Text("ok".into()),
449            ];
450
451            {
452                let fut = tx.send_many(messages);
453                let fut = std::pin::pin!(fut);
454                assert!(fut.poll(cx).is_ready(), "Sending should not yield");
455            }
456
457            assert!(
458                matches!(stream.as_mut().poll_next(cx), Poll::Ready(Some(Err(_)))),
459                "expected one overflow error"
460            );
461
462            assert!(
463                matches!(
464                    stream.as_mut().poll_next(cx),
465                    Poll::Ready(Some(Ok(AggregatedMessage::Ping(_))))
466                ),
467                "expected ping frame after overflow"
468            );
469
470            assert!(
471                matches!(
472                    stream.as_mut().poll_next(cx),
473                    Poll::Ready(Some(Ok(AggregatedMessage::Text(text)))) if &text[..] == "ok"
474                ),
475                "expected text message after overflow continuation is terminated"
476            );
477
478            assert!(
479                stream.as_mut().poll_next(cx).is_pending(),
480                "Stream should be pending after processing messages"
481            );
482
483            Poll::Ready(())
484        })
485        .await
486    }
487}