1use std::{
4 future::poll_fn,
5 io, mem,
6 pin::Pin,
7 task::{ready, Context, Poll},
8};
9
10use actix_http::ws::{CloseReason, Item, Message, ProtocolError};
11use actix_web::web::{Bytes, BytesMut};
12use bytestring::ByteString;
13use futures_core::Stream;
14
15use crate::MessageStream;
16
17pub(crate) enum ContinuationKind {
18 Text,
19 Binary,
20}
21
22#[derive(Debug, PartialEq, Eq)]
24pub enum AggregatedMessage {
25 Text(ByteString),
27
28 Binary(Bytes),
30
31 Ping(Bytes),
33
34 Pong(Bytes),
36
37 Close(Option<CloseReason>),
39}
40
41pub struct AggregatedMessageStream {
43 stream: MessageStream,
44 current_size: usize,
45 max_size: usize,
46 continuations: Vec<Bytes>,
47 continuation_kind: ContinuationKind,
48 overflowed: bool,
49}
50
51impl AggregatedMessageStream {
52 #[must_use]
53 pub(crate) fn new(stream: MessageStream) -> Self {
54 AggregatedMessageStream {
55 stream,
56 current_size: 0,
57 max_size: 1024 * 1024,
58 continuations: Vec::new(),
59 continuation_kind: ContinuationKind::Binary,
60 overflowed: false,
61 }
62 }
63
64 #[must_use]
80 pub fn max_continuation_size(mut self, max_size: usize) -> Self {
81 self.max_size = max_size;
82 self
83 }
84
85 #[must_use]
98 pub async fn recv(&mut self) -> Option<<Self as Stream>::Item> {
99 poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
100 }
101}
102
103fn size_error() -> Poll<Option<Result<AggregatedMessage, ProtocolError>>> {
104 Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other(
105 "Exceeded maximum continuation size",
106 )))))
107}
108
109impl Stream for AggregatedMessageStream {
110 type Item = Result<AggregatedMessage, ProtocolError>;
111
112 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
113 let this = self.get_mut();
114
115 loop {
116 let Some(msg) = ready!(Pin::new(&mut this.stream).poll_next(cx)?) else {
117 return Poll::Ready(None);
118 };
119
120 match msg {
121 Message::Continuation(item) => match item {
122 Item::FirstText(bytes) => {
123 if this.overflowed {
124 continue;
125 }
126
127 this.continuation_kind = ContinuationKind::Text;
128 this.current_size += bytes.len();
129
130 if this.current_size > this.max_size {
131 this.current_size = 0;
132 this.continuations.clear();
133 this.overflowed = true;
134 return size_error();
135 }
136
137 if !bytes.is_empty() {
139 this.continuations.push(bytes);
140 }
141
142 continue;
143 }
144
145 Item::FirstBinary(bytes) => {
146 if this.overflowed {
147 continue;
148 }
149
150 this.continuation_kind = ContinuationKind::Binary;
151 this.current_size += bytes.len();
152
153 if this.current_size > this.max_size {
154 this.current_size = 0;
155 this.continuations.clear();
156 this.overflowed = true;
157 return size_error();
158 }
159
160 if !bytes.is_empty() {
162 this.continuations.push(bytes);
163 }
164
165 continue;
166 }
167
168 Item::Continue(bytes) => {
169 if this.overflowed {
170 continue;
171 }
172
173 this.current_size += bytes.len();
174
175 if this.current_size > this.max_size {
176 this.current_size = 0;
177 this.continuations.clear();
178 this.overflowed = true;
179 return size_error();
180 }
181
182 if !bytes.is_empty() {
184 this.continuations.push(bytes);
185 }
186
187 continue;
188 }
189
190 Item::Last(bytes) => {
191 if this.overflowed {
192 this.current_size = 0;
193 this.continuations.clear();
194 this.overflowed = false;
195 continue;
196 }
197
198 this.current_size += bytes.len();
199
200 if this.current_size > this.max_size {
201 this.current_size = 0;
204 this.continuations.clear();
205
206 return size_error();
207 }
208
209 if !bytes.is_empty() {
211 this.continuations.push(bytes);
212 }
213 let bytes = collect(&mut this.continuations, this.current_size);
214
215 this.current_size = 0;
216
217 match this.continuation_kind {
218 ContinuationKind::Text => {
219 return Poll::Ready(Some(match ByteString::try_from(bytes) {
220 Ok(bytestring) => Ok(AggregatedMessage::Text(bytestring)),
221 Err(err) => Err(ProtocolError::Io(io::Error::new(
222 io::ErrorKind::InvalidData,
223 err.to_string(),
224 ))),
225 }))
226 }
227 ContinuationKind::Binary => {
228 return Poll::Ready(Some(Ok(AggregatedMessage::Binary(bytes))))
229 }
230 }
231 }
232 },
233
234 Message::Text(text) => return Poll::Ready(Some(Ok(AggregatedMessage::Text(text)))),
235 Message::Binary(binary) => {
236 return Poll::Ready(Some(Ok(AggregatedMessage::Binary(binary))))
237 }
238 Message::Ping(ping) => return Poll::Ready(Some(Ok(AggregatedMessage::Ping(ping)))),
239 Message::Pong(pong) => return Poll::Ready(Some(Ok(AggregatedMessage::Pong(pong)))),
240 Message::Close(close) => {
241 return Poll::Ready(Some(Ok(AggregatedMessage::Close(close))))
242 }
243
244 Message::Nop => unreachable!("MessageStream should not produce no-ops"),
245 }
246 }
247 }
248}
249
250fn collect(continuations: &mut Vec<Bytes>, total_len: usize) -> Bytes {
251 let continuations = mem::take(continuations);
252 let mut buf = BytesMut::with_capacity(total_len);
253
254 for chunk in continuations {
255 buf.extend_from_slice(&chunk);
256 }
257
258 buf.freeze()
259}
260
261#[cfg(test)]
262mod tests {
263 use std::{future::Future, task::Poll};
264
265 use futures_core::Stream;
266
267 use super::{AggregatedMessage, Bytes, Item, Message, MessageStream};
268 use crate::stream::tests::payload_pair;
269
270 #[tokio::test]
271 async fn aggregates_continuations() {
272 std::future::poll_fn(move |cx| {
273 let (mut tx, rx) = payload_pair(8);
274 let message_stream = MessageStream::new(rx).aggregate_continuations();
275 let mut stream = std::pin::pin!(message_stream);
276
277 let messages = [
278 Message::Continuation(Item::FirstText(Bytes::from(b"first".to_vec()))),
279 Message::Continuation(Item::Continue(Bytes::from(b"second".to_vec()))),
280 Message::Continuation(Item::Last(Bytes::from(b"third".to_vec()))),
281 ];
282
283 let len = messages.len();
284
285 for (idx, msg) in messages.into_iter().enumerate() {
286 let poll = stream.as_mut().poll_next(cx);
287 assert!(
288 poll.is_pending(),
289 "Stream should be pending when no messages are present {poll:?}"
290 );
291
292 let fut = tx.send(msg);
293 let fut = std::pin::pin!(fut);
294
295 assert!(fut.poll(cx).is_ready(), "Sending should not yield");
296
297 if idx == len - 1 {
298 assert!(
299 stream.as_mut().poll_next(cx).is_ready(),
300 "Stream should be ready"
301 );
302 } else {
303 assert!(
304 stream.as_mut().poll_next(cx).is_pending(),
305 "Stream shouldn't be ready until continuations complete"
306 );
307 }
308 }
309
310 assert!(
311 stream.as_mut().poll_next(cx).is_pending(),
312 "Stream should be pending after processing messages"
313 );
314
315 Poll::Ready(())
316 })
317 .await
318 }
319
320 #[tokio::test]
321 async fn aggregates_consecutive_continuations() {
322 std::future::poll_fn(move |cx| {
323 let (mut tx, rx) = payload_pair(8);
324 let message_stream = MessageStream::new(rx).aggregate_continuations();
325 let mut stream = std::pin::pin!(message_stream);
326
327 let messages = vec![
328 Message::Continuation(Item::FirstText(Bytes::from(b"first".to_vec()))),
329 Message::Continuation(Item::Continue(Bytes::from(b"second".to_vec()))),
330 Message::Continuation(Item::Last(Bytes::from(b"third".to_vec()))),
331 ];
332
333 let poll = stream.as_mut().poll_next(cx);
334 assert!(
335 poll.is_pending(),
336 "Stream should be pending when no messages are present {poll:?}"
337 );
338
339 let fut = tx.send_many(messages);
340 let fut = std::pin::pin!(fut);
341
342 assert!(fut.poll(cx).is_ready(), "Sending should not yield");
343
344 assert!(
345 stream.as_mut().poll_next(cx).is_ready(),
346 "Stream should be ready when all continuations have been sent"
347 );
348
349 assert!(
350 stream.as_mut().poll_next(cx).is_pending(),
351 "Stream should be pending after processing messages"
352 );
353
354 Poll::Ready(())
355 })
356 .await
357 }
358
359 #[tokio::test]
360 async fn ignores_empty_continuation_chunks() {
361 std::future::poll_fn(move |cx| {
362 let (mut tx, rx) = payload_pair(8);
363 let message_stream = MessageStream::new(rx).aggregate_continuations();
364 let mut stream = std::pin::pin!(message_stream);
365
366 let poll = stream.as_mut().poll_next(cx);
367 assert!(
368 poll.is_pending(),
369 "Stream should be pending when no messages are present {poll:?}"
370 );
371
372 let messages = std::iter::once(Message::Continuation(Item::FirstText(Bytes::new())))
375 .chain((0..128).map(|_| Message::Continuation(Item::Continue(Bytes::new()))))
376 .collect::<Vec<_>>();
377
378 {
379 let fut = tx.send_many(messages);
380 let fut = std::pin::pin!(fut);
381 assert!(fut.poll(cx).is_ready(), "Sending should not yield");
382 }
383
384 assert!(
385 stream.as_mut().poll_next(cx).is_pending(),
386 "Stream shouldn't be ready until continuations complete"
387 );
388 assert_eq!(stream.as_mut().get_mut().continuations.len(), 0);
389
390 {
392 let fut = tx.send(Message::Continuation(Item::Last(Bytes::new())));
393 let fut = std::pin::pin!(fut);
394 assert!(fut.poll(cx).is_ready(), "Sending should not yield");
395 }
396
397 match stream.as_mut().poll_next(cx) {
398 Poll::Ready(Some(Ok(AggregatedMessage::Text(text)))) => assert!(text.is_empty()),
399 poll => panic!("expected empty text message; got {poll:?}"),
400 }
401
402 assert_eq!(stream.as_mut().get_mut().continuations.len(), 0);
403
404 Poll::Ready(())
405 })
406 .await
407 }
408
409 #[tokio::test]
410 async fn stream_closes() {
411 std::future::poll_fn(move |cx| {
412 let (tx, rx) = payload_pair(8);
413 drop(tx);
414 let message_stream = MessageStream::new(rx).aggregate_continuations();
415 let mut stream = std::pin::pin!(message_stream);
416
417 let poll = stream.as_mut().poll_next(cx);
418 assert!(
419 matches!(poll, Poll::Ready(None)),
420 "Stream should be ready when all continuations have been sent"
421 );
422
423 Poll::Ready(())
424 })
425 .await
426 }
427
428 #[tokio::test]
429 async fn continuation_overflow_errors_once_and_recovers() {
430 std::future::poll_fn(move |cx| {
431 let (mut tx, rx) = payload_pair(8);
432 let message_stream = MessageStream::new(rx)
433 .aggregate_continuations()
434 .max_continuation_size(4);
435 let mut stream = std::pin::pin!(message_stream);
436
437 let poll = stream.as_mut().poll_next(cx);
438 assert!(
439 poll.is_pending(),
440 "Stream should be pending when no messages are present {poll:?}"
441 );
442
443 let messages = vec![
444 Message::Continuation(Item::FirstText(Bytes::from(b"1234".to_vec()))),
445 Message::Continuation(Item::Continue(Bytes::from(b"5".to_vec()))),
446 Message::Ping(Bytes::from(b"p".to_vec())),
447 Message::Continuation(Item::Last(Bytes::from(b"6".to_vec()))),
448 Message::Text("ok".into()),
449 ];
450
451 {
452 let fut = tx.send_many(messages);
453 let fut = std::pin::pin!(fut);
454 assert!(fut.poll(cx).is_ready(), "Sending should not yield");
455 }
456
457 assert!(
458 matches!(stream.as_mut().poll_next(cx), Poll::Ready(Some(Err(_)))),
459 "expected one overflow error"
460 );
461
462 assert!(
463 matches!(
464 stream.as_mut().poll_next(cx),
465 Poll::Ready(Some(Ok(AggregatedMessage::Ping(_))))
466 ),
467 "expected ping frame after overflow"
468 );
469
470 assert!(
471 matches!(
472 stream.as_mut().poll_next(cx),
473 Poll::Ready(Some(Ok(AggregatedMessage::Text(text)))) if &text[..] == "ok"
474 ),
475 "expected text message after overflow continuation is terminated"
476 );
477
478 assert!(
479 stream.as_mut().poll_next(cx).is_pending(),
480 "Stream should be pending after processing messages"
481 );
482
483 Poll::Ready(())
484 })
485 .await
486 }
487}