1use std::{
4 future::poll_fn,
5 io, mem,
6 pin::Pin,
7 task::{ready, Context, Poll},
8};
9
10use actix_http::ws::{CloseReason, Item, Message, ProtocolError};
11use actix_web::web::{Bytes, BytesMut};
12use bytestring::ByteString;
13use futures_core::Stream;
14
15use crate::MessageStream;
16
17pub(crate) enum ContinuationKind {
18 Text,
19 Binary,
20}
21
22#[derive(Debug, PartialEq, Eq)]
24pub enum AggregatedMessage {
25 Text(ByteString),
27
28 Binary(Bytes),
30
31 Ping(Bytes),
33
34 Pong(Bytes),
36
37 Close(Option<CloseReason>),
39}
40
41pub struct AggregatedMessageStream {
43 stream: MessageStream,
44 current_size: usize,
45 max_size: usize,
46 continuations: Vec<Bytes>,
47 continuation_kind: ContinuationKind,
48}
49
50impl AggregatedMessageStream {
51 #[must_use]
52 pub(crate) fn new(stream: MessageStream) -> Self {
53 AggregatedMessageStream {
54 stream,
55 current_size: 0,
56 max_size: 1024 * 1024,
57 continuations: Vec::new(),
58 continuation_kind: ContinuationKind::Binary,
59 }
60 }
61
62 #[must_use]
78 pub fn max_continuation_size(mut self, max_size: usize) -> Self {
79 self.max_size = max_size;
80 self
81 }
82
83 #[must_use]
96 pub async fn recv(&mut self) -> Option<<Self as Stream>::Item> {
97 poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
98 }
99}
100
101fn size_error() -> Poll<Option<Result<AggregatedMessage, ProtocolError>>> {
102 Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other(
103 "Exceeded maximum continuation size",
104 )))))
105}
106
107impl Stream for AggregatedMessageStream {
108 type Item = Result<AggregatedMessage, ProtocolError>;
109
110 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
111 let this = self.get_mut();
112
113 loop {
114 let Some(msg) = ready!(Pin::new(&mut this.stream).poll_next(cx)?) else {
115 return Poll::Ready(None);
116 };
117
118 match msg {
119 Message::Continuation(item) => match item {
120 Item::FirstText(bytes) => {
121 this.continuation_kind = ContinuationKind::Text;
122 this.current_size += bytes.len();
123
124 if this.current_size > this.max_size {
125 this.continuations.clear();
126 return size_error();
127 }
128
129 this.continuations.push(bytes);
130
131 continue;
132 }
133
134 Item::FirstBinary(bytes) => {
135 this.continuation_kind = ContinuationKind::Binary;
136 this.current_size += bytes.len();
137
138 if this.current_size > this.max_size {
139 this.continuations.clear();
140 return size_error();
141 }
142
143 this.continuations.push(bytes);
144
145 continue;
146 }
147
148 Item::Continue(bytes) => {
149 this.current_size += bytes.len();
150
151 if this.current_size > this.max_size {
152 this.continuations.clear();
153 return size_error();
154 }
155
156 this.continuations.push(bytes);
157
158 continue;
159 }
160
161 Item::Last(bytes) => {
162 this.current_size += bytes.len();
163
164 if this.current_size > this.max_size {
165 this.current_size = 0;
168 this.continuations.clear();
169
170 return size_error();
171 }
172
173 this.continuations.push(bytes);
174 let bytes = collect(&mut this.continuations);
175
176 this.current_size = 0;
177
178 match this.continuation_kind {
179 ContinuationKind::Text => {
180 return Poll::Ready(Some(match ByteString::try_from(bytes) {
181 Ok(bytestring) => Ok(AggregatedMessage::Text(bytestring)),
182 Err(err) => Err(ProtocolError::Io(io::Error::new(
183 io::ErrorKind::InvalidData,
184 err.to_string(),
185 ))),
186 }))
187 }
188 ContinuationKind::Binary => {
189 return Poll::Ready(Some(Ok(AggregatedMessage::Binary(bytes))))
190 }
191 }
192 }
193 },
194
195 Message::Text(text) => return Poll::Ready(Some(Ok(AggregatedMessage::Text(text)))),
196 Message::Binary(binary) => {
197 return Poll::Ready(Some(Ok(AggregatedMessage::Binary(binary))))
198 }
199 Message::Ping(ping) => return Poll::Ready(Some(Ok(AggregatedMessage::Ping(ping)))),
200 Message::Pong(pong) => return Poll::Ready(Some(Ok(AggregatedMessage::Pong(pong)))),
201 Message::Close(close) => {
202 return Poll::Ready(Some(Ok(AggregatedMessage::Close(close))))
203 }
204
205 Message::Nop => unreachable!("MessageStream should not produce no-ops"),
206 }
207 }
208 }
209}
210
211fn collect(continuations: &mut Vec<Bytes>) -> Bytes {
212 let continuations = mem::take(continuations);
213 let total_len = continuations.iter().map(|b| b.len()).sum();
214
215 let mut buf = BytesMut::with_capacity(total_len);
216
217 for chunk in continuations {
218 buf.extend(chunk);
219 }
220
221 buf.freeze()
222}
223
224#[cfg(test)]
225mod tests {
226 use std::{future::Future, task::Poll};
227
228 use futures_core::Stream;
229
230 use super::{Bytes, Item, Message, MessageStream};
231 use crate::stream::tests::payload_pair;
232
233 #[tokio::test]
234 async fn aggregates_continuations() {
235 std::future::poll_fn(move |cx| {
236 let (mut tx, rx) = payload_pair(8);
237 let message_stream = MessageStream::new(rx).aggregate_continuations();
238 let mut stream = std::pin::pin!(message_stream);
239
240 let messages = [
241 Message::Continuation(Item::FirstText(Bytes::from(b"first".to_vec()))),
242 Message::Continuation(Item::Continue(Bytes::from(b"second".to_vec()))),
243 Message::Continuation(Item::Last(Bytes::from(b"third".to_vec()))),
244 ];
245
246 let len = messages.len();
247
248 for (idx, msg) in messages.into_iter().enumerate() {
249 let poll = stream.as_mut().poll_next(cx);
250 assert!(
251 poll.is_pending(),
252 "Stream should be pending when no messages are present {poll:?}"
253 );
254
255 let fut = tx.send(msg);
256 let fut = std::pin::pin!(fut);
257
258 assert!(fut.poll(cx).is_ready(), "Sending should not yield");
259
260 if idx == len - 1 {
261 assert!(
262 stream.as_mut().poll_next(cx).is_ready(),
263 "Stream should be ready"
264 );
265 } else {
266 assert!(
267 stream.as_mut().poll_next(cx).is_pending(),
268 "Stream shouldn't be ready until continuations complete"
269 );
270 }
271 }
272
273 assert!(
274 stream.as_mut().poll_next(cx).is_pending(),
275 "Stream should be pending after processing messages"
276 );
277
278 Poll::Ready(())
279 })
280 .await
281 }
282
283 #[tokio::test]
284 async fn aggregates_consecutive_continuations() {
285 std::future::poll_fn(move |cx| {
286 let (mut tx, rx) = payload_pair(8);
287 let message_stream = MessageStream::new(rx).aggregate_continuations();
288 let mut stream = std::pin::pin!(message_stream);
289
290 let messages = vec![
291 Message::Continuation(Item::FirstText(Bytes::from(b"first".to_vec()))),
292 Message::Continuation(Item::Continue(Bytes::from(b"second".to_vec()))),
293 Message::Continuation(Item::Last(Bytes::from(b"third".to_vec()))),
294 ];
295
296 let poll = stream.as_mut().poll_next(cx);
297 assert!(
298 poll.is_pending(),
299 "Stream should be pending when no messages are present {poll:?}"
300 );
301
302 let fut = tx.send_many(messages);
303 let fut = std::pin::pin!(fut);
304
305 assert!(fut.poll(cx).is_ready(), "Sending should not yield");
306
307 assert!(
308 stream.as_mut().poll_next(cx).is_ready(),
309 "Stream should be ready when all continuations have been sent"
310 );
311
312 assert!(
313 stream.as_mut().poll_next(cx).is_pending(),
314 "Stream should be pending after processing messages"
315 );
316
317 Poll::Ready(())
318 })
319 .await
320 }
321
322 #[tokio::test]
323 async fn stream_closes() {
324 std::future::poll_fn(move |cx| {
325 let (tx, rx) = payload_pair(8);
326 drop(tx);
327 let message_stream = MessageStream::new(rx).aggregate_continuations();
328 let mut stream = std::pin::pin!(message_stream);
329
330 let poll = stream.as_mut().poll_next(cx);
331 assert!(
332 matches!(poll, Poll::Ready(None)),
333 "Stream should be ready when all continuations have been sent"
334 );
335
336 Poll::Ready(())
337 })
338 .await
339 }
340}