actix_ws/
aggregated.rs

1//! WebSocket stream for aggregating continuation frames.
2
3use std::{
4    future::poll_fn,
5    io, mem,
6    pin::Pin,
7    task::{ready, Context, Poll},
8};
9
10use actix_http::ws::{CloseReason, Item, Message, ProtocolError};
11use actix_web::web::{Bytes, BytesMut};
12use bytestring::ByteString;
13use futures_core::Stream;
14
15use crate::MessageStream;
16
17pub(crate) enum ContinuationKind {
18    Text,
19    Binary,
20}
21
22/// WebSocket message with any continuations aggregated together.
23#[derive(Debug, PartialEq, Eq)]
24pub enum AggregatedMessage {
25    /// Text message.
26    Text(ByteString),
27
28    /// Binary message.
29    Binary(Bytes),
30
31    /// Ping message.
32    Ping(Bytes),
33
34    /// Pong message.
35    Pong(Bytes),
36
37    /// Close message with optional reason.
38    Close(Option<CloseReason>),
39}
40
41/// Stream of messages from a WebSocket client, with continuations aggregated.
42pub struct AggregatedMessageStream {
43    stream: MessageStream,
44    current_size: usize,
45    max_size: usize,
46    continuations: Vec<Bytes>,
47    continuation_kind: ContinuationKind,
48}
49
50impl AggregatedMessageStream {
51    #[must_use]
52    pub(crate) fn new(stream: MessageStream) -> Self {
53        AggregatedMessageStream {
54            stream,
55            current_size: 0,
56            max_size: 1024 * 1024,
57            continuations: Vec::new(),
58            continuation_kind: ContinuationKind::Binary,
59        }
60    }
61
62    /// Sets the maximum allowed size for aggregated continuations, in bytes.
63    ///
64    /// By default, up to 1 MiB is allowed.
65    ///
66    /// ```no_run
67    /// # use actix_ws::AggregatedMessageStream;
68    /// # async fn test(stream: AggregatedMessageStream) {
69    /// // increase the allowed size from 1MB to 8MB
70    /// let mut stream = stream.max_continuation_size(8 * 1024 * 1024);
71    ///
72    /// while let Some(Ok(msg)) = stream.recv().await {
73    ///     // handle message
74    /// }
75    /// # }
76    /// ```
77    #[must_use]
78    pub fn max_continuation_size(mut self, max_size: usize) -> Self {
79        self.max_size = max_size;
80        self
81    }
82
83    /// Waits for the next item from the aggregated message stream.
84    ///
85    /// This is a convenience for calling the [`Stream`](Stream::poll_next()) implementation.
86    ///
87    /// ```no_run
88    /// # use actix_ws::AggregatedMessageStream;
89    /// # async fn test(mut stream: AggregatedMessageStream) {
90    /// while let Some(Ok(msg)) = stream.recv().await {
91    ///     // handle message
92    /// }
93    /// # }
94    /// ```
95    #[must_use]
96    pub async fn recv(&mut self) -> Option<<Self as Stream>::Item> {
97        poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
98    }
99}
100
101fn size_error() -> Poll<Option<Result<AggregatedMessage, ProtocolError>>> {
102    Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other(
103        "Exceeded maximum continuation size",
104    )))))
105}
106
107impl Stream for AggregatedMessageStream {
108    type Item = Result<AggregatedMessage, ProtocolError>;
109
110    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
111        let this = self.get_mut();
112
113        let Some(msg) = ready!(Pin::new(&mut this.stream).poll_next(cx)?) else {
114            return Poll::Ready(None);
115        };
116
117        match msg {
118            Message::Continuation(item) => match item {
119                Item::FirstText(bytes) => {
120                    this.continuation_kind = ContinuationKind::Text;
121                    this.current_size += bytes.len();
122
123                    if this.current_size > this.max_size {
124                        this.continuations.clear();
125                        return size_error();
126                    }
127
128                    this.continuations.push(bytes);
129
130                    Poll::Pending
131                }
132
133                Item::FirstBinary(bytes) => {
134                    this.continuation_kind = ContinuationKind::Binary;
135                    this.current_size += bytes.len();
136
137                    if this.current_size > this.max_size {
138                        this.continuations.clear();
139                        return size_error();
140                    }
141
142                    this.continuations.push(bytes);
143
144                    Poll::Pending
145                }
146
147                Item::Continue(bytes) => {
148                    this.current_size += bytes.len();
149
150                    if this.current_size > this.max_size {
151                        this.continuations.clear();
152                        return size_error();
153                    }
154
155                    this.continuations.push(bytes);
156
157                    Poll::Pending
158                }
159
160                Item::Last(bytes) => {
161                    this.current_size += bytes.len();
162
163                    if this.current_size > this.max_size {
164                        // reset current_size, as this is the last message for
165                        // the current continuation
166                        this.current_size = 0;
167                        this.continuations.clear();
168
169                        return size_error();
170                    }
171
172                    this.continuations.push(bytes);
173                    let bytes = collect(&mut this.continuations);
174
175                    this.current_size = 0;
176
177                    match this.continuation_kind {
178                        ContinuationKind::Text => {
179                            Poll::Ready(Some(match ByteString::try_from(bytes) {
180                                Ok(bytestring) => Ok(AggregatedMessage::Text(bytestring)),
181                                Err(err) => Err(ProtocolError::Io(io::Error::new(
182                                    io::ErrorKind::InvalidData,
183                                    err.to_string(),
184                                ))),
185                            }))
186                        }
187                        ContinuationKind::Binary => {
188                            Poll::Ready(Some(Ok(AggregatedMessage::Binary(bytes))))
189                        }
190                    }
191                }
192            },
193
194            Message::Text(text) => Poll::Ready(Some(Ok(AggregatedMessage::Text(text)))),
195            Message::Binary(binary) => Poll::Ready(Some(Ok(AggregatedMessage::Binary(binary)))),
196            Message::Ping(ping) => Poll::Ready(Some(Ok(AggregatedMessage::Ping(ping)))),
197            Message::Pong(pong) => Poll::Ready(Some(Ok(AggregatedMessage::Pong(pong)))),
198            Message::Close(close) => Poll::Ready(Some(Ok(AggregatedMessage::Close(close)))),
199
200            Message::Nop => unreachable!("MessageStream should not produce no-ops"),
201        }
202    }
203}
204
205fn collect(continuations: &mut Vec<Bytes>) -> Bytes {
206    let continuations = mem::take(continuations);
207    let total_len = continuations.iter().map(|b| b.len()).sum();
208
209    let mut buf = BytesMut::with_capacity(total_len);
210
211    for chunk in continuations {
212        buf.extend(chunk);
213    }
214
215    buf.freeze()
216}