use std::{
future::poll_fn,
io, mem,
pin::Pin,
task::{ready, Context, Poll},
};
use actix_http::ws::{CloseReason, Item, Message, ProtocolError};
use actix_web::web::{Bytes, BytesMut};
use bytestring::ByteString;
use futures_core::Stream;
use crate::MessageStream;
pub(crate) enum ContinuationKind {
Text,
Binary,
}
#[derive(Debug, PartialEq, Eq)]
pub enum AggregatedMessage {
Text(ByteString),
Binary(Bytes),
Ping(Bytes),
Pong(Bytes),
Close(Option<CloseReason>),
}
pub struct AggregatedMessageStream {
stream: MessageStream,
current_size: usize,
max_size: usize,
continuations: Vec<Bytes>,
continuation_kind: ContinuationKind,
overflowed: bool,
}
impl AggregatedMessageStream {
#[must_use]
pub(crate) fn new(stream: MessageStream) -> Self {
AggregatedMessageStream {
stream,
current_size: 0,
max_size: 1024 * 1024,
continuations: Vec::new(),
continuation_kind: ContinuationKind::Binary,
overflowed: false,
}
}
#[must_use]
pub fn max_continuation_size(mut self, max_size: usize) -> Self {
self.max_size = max_size;
self
}
#[must_use]
pub async fn recv(&mut self) -> Option<<Self as Stream>::Item> {
poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
}
}
fn size_error() -> Poll<Option<Result<AggregatedMessage, ProtocolError>>> {
Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other(
"Exceeded maximum continuation size",
)))))
}
impl Stream for AggregatedMessageStream {
type Item = Result<AggregatedMessage, ProtocolError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
let Some(msg) = ready!(Pin::new(&mut this.stream).poll_next(cx)?) else {
return Poll::Ready(None);
};
match msg {
Message::Continuation(item) => match item {
Item::FirstText(bytes) => {
if this.overflowed {
continue;
}
this.continuation_kind = ContinuationKind::Text;
this.current_size += bytes.len();
if this.current_size > this.max_size {
this.current_size = 0;
this.continuations.clear();
this.overflowed = true;
return size_error();
}
if !bytes.is_empty() {
this.continuations.push(bytes);
}
continue;
}
Item::FirstBinary(bytes) => {
if this.overflowed {
continue;
}
this.continuation_kind = ContinuationKind::Binary;
this.current_size += bytes.len();
if this.current_size > this.max_size {
this.current_size = 0;
this.continuations.clear();
this.overflowed = true;
return size_error();
}
if !bytes.is_empty() {
this.continuations.push(bytes);
}
continue;
}
Item::Continue(bytes) => {
if this.overflowed {
continue;
}
this.current_size += bytes.len();
if this.current_size > this.max_size {
this.current_size = 0;
this.continuations.clear();
this.overflowed = true;
return size_error();
}
if !bytes.is_empty() {
this.continuations.push(bytes);
}
continue;
}
Item::Last(bytes) => {
if this.overflowed {
this.current_size = 0;
this.continuations.clear();
this.overflowed = false;
continue;
}
this.current_size += bytes.len();
if this.current_size > this.max_size {
this.current_size = 0;
this.continuations.clear();
return size_error();
}
if !bytes.is_empty() {
this.continuations.push(bytes);
}
let bytes = collect(&mut this.continuations, this.current_size);
this.current_size = 0;
match this.continuation_kind {
ContinuationKind::Text => {
return Poll::Ready(Some(match ByteString::try_from(bytes) {
Ok(bytestring) => Ok(AggregatedMessage::Text(bytestring)),
Err(err) => Err(ProtocolError::Io(io::Error::new(
io::ErrorKind::InvalidData,
err.to_string(),
))),
}))
}
ContinuationKind::Binary => {
return Poll::Ready(Some(Ok(AggregatedMessage::Binary(bytes))))
}
}
}
},
Message::Text(text) => return Poll::Ready(Some(Ok(AggregatedMessage::Text(text)))),
Message::Binary(binary) => {
return Poll::Ready(Some(Ok(AggregatedMessage::Binary(binary))))
}
Message::Ping(ping) => return Poll::Ready(Some(Ok(AggregatedMessage::Ping(ping)))),
Message::Pong(pong) => return Poll::Ready(Some(Ok(AggregatedMessage::Pong(pong)))),
Message::Close(close) => {
return Poll::Ready(Some(Ok(AggregatedMessage::Close(close))))
}
Message::Nop => unreachable!("MessageStream should not produce no-ops"),
}
}
}
}
fn collect(continuations: &mut Vec<Bytes>, total_len: usize) -> Bytes {
let continuations = mem::take(continuations);
let mut buf = BytesMut::with_capacity(total_len);
for chunk in continuations {
buf.extend_from_slice(&chunk);
}
buf.freeze()
}
#[cfg(test)]
mod tests {
use std::{future::Future, task::Poll};
use futures_core::Stream;
use super::{AggregatedMessage, Bytes, Item, Message, MessageStream};
use crate::stream::tests::payload_pair;
#[tokio::test]
async fn aggregates_continuations() {
std::future::poll_fn(move |cx| {
let (mut tx, rx) = payload_pair(8);
let message_stream = MessageStream::new(rx).aggregate_continuations();
let mut stream = std::pin::pin!(message_stream);
let messages = [
Message::Continuation(Item::FirstText(Bytes::from(b"first".to_vec()))),
Message::Continuation(Item::Continue(Bytes::from(b"second".to_vec()))),
Message::Continuation(Item::Last(Bytes::from(b"third".to_vec()))),
];
let len = messages.len();
for (idx, msg) in messages.into_iter().enumerate() {
let poll = stream.as_mut().poll_next(cx);
assert!(
poll.is_pending(),
"Stream should be pending when no messages are present {poll:?}"
);
let fut = tx.send(msg);
let fut = std::pin::pin!(fut);
assert!(fut.poll(cx).is_ready(), "Sending should not yield");
if idx == len - 1 {
assert!(
stream.as_mut().poll_next(cx).is_ready(),
"Stream should be ready"
);
} else {
assert!(
stream.as_mut().poll_next(cx).is_pending(),
"Stream shouldn't be ready until continuations complete"
);
}
}
assert!(
stream.as_mut().poll_next(cx).is_pending(),
"Stream should be pending after processing messages"
);
Poll::Ready(())
})
.await
}
#[tokio::test]
async fn aggregates_consecutive_continuations() {
std::future::poll_fn(move |cx| {
let (mut tx, rx) = payload_pair(8);
let message_stream = MessageStream::new(rx).aggregate_continuations();
let mut stream = std::pin::pin!(message_stream);
let messages = vec![
Message::Continuation(Item::FirstText(Bytes::from(b"first".to_vec()))),
Message::Continuation(Item::Continue(Bytes::from(b"second".to_vec()))),
Message::Continuation(Item::Last(Bytes::from(b"third".to_vec()))),
];
let poll = stream.as_mut().poll_next(cx);
assert!(
poll.is_pending(),
"Stream should be pending when no messages are present {poll:?}"
);
let fut = tx.send_many(messages);
let fut = std::pin::pin!(fut);
assert!(fut.poll(cx).is_ready(), "Sending should not yield");
assert!(
stream.as_mut().poll_next(cx).is_ready(),
"Stream should be ready when all continuations have been sent"
);
assert!(
stream.as_mut().poll_next(cx).is_pending(),
"Stream should be pending after processing messages"
);
Poll::Ready(())
})
.await
}
#[tokio::test]
async fn ignores_empty_continuation_chunks() {
std::future::poll_fn(move |cx| {
let (mut tx, rx) = payload_pair(8);
let message_stream = MessageStream::new(rx).aggregate_continuations();
let mut stream = std::pin::pin!(message_stream);
let poll = stream.as_mut().poll_next(cx);
assert!(
poll.is_pending(),
"Stream should be pending when no messages are present {poll:?}"
);
let messages = std::iter::once(Message::Continuation(Item::FirstText(Bytes::new())))
.chain((0..128).map(|_| Message::Continuation(Item::Continue(Bytes::new()))))
.collect::<Vec<_>>();
{
let fut = tx.send_many(messages);
let fut = std::pin::pin!(fut);
assert!(fut.poll(cx).is_ready(), "Sending should not yield");
}
assert!(
stream.as_mut().poll_next(cx).is_pending(),
"Stream shouldn't be ready until continuations complete"
);
assert_eq!(stream.as_mut().get_mut().continuations.len(), 0);
{
let fut = tx.send(Message::Continuation(Item::Last(Bytes::new())));
let fut = std::pin::pin!(fut);
assert!(fut.poll(cx).is_ready(), "Sending should not yield");
}
match stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(AggregatedMessage::Text(text)))) => assert!(text.is_empty()),
poll => panic!("expected empty text message; got {poll:?}"),
}
assert_eq!(stream.as_mut().get_mut().continuations.len(), 0);
Poll::Ready(())
})
.await
}
#[tokio::test]
async fn stream_closes() {
std::future::poll_fn(move |cx| {
let (tx, rx) = payload_pair(8);
drop(tx);
let message_stream = MessageStream::new(rx).aggregate_continuations();
let mut stream = std::pin::pin!(message_stream);
let poll = stream.as_mut().poll_next(cx);
assert!(
matches!(poll, Poll::Ready(None)),
"Stream should be ready when all continuations have been sent"
);
Poll::Ready(())
})
.await
}
#[tokio::test]
async fn continuation_overflow_errors_once_and_recovers() {
std::future::poll_fn(move |cx| {
let (mut tx, rx) = payload_pair(8);
let message_stream = MessageStream::new(rx)
.aggregate_continuations()
.max_continuation_size(4);
let mut stream = std::pin::pin!(message_stream);
let poll = stream.as_mut().poll_next(cx);
assert!(
poll.is_pending(),
"Stream should be pending when no messages are present {poll:?}"
);
let messages = vec![
Message::Continuation(Item::FirstText(Bytes::from(b"1234".to_vec()))),
Message::Continuation(Item::Continue(Bytes::from(b"5".to_vec()))),
Message::Ping(Bytes::from(b"p".to_vec())),
Message::Continuation(Item::Last(Bytes::from(b"6".to_vec()))),
Message::Text("ok".into()),
];
{
let fut = tx.send_many(messages);
let fut = std::pin::pin!(fut);
assert!(fut.poll(cx).is_ready(), "Sending should not yield");
}
assert!(
matches!(stream.as_mut().poll_next(cx), Poll::Ready(Some(Err(_)))),
"expected one overflow error"
);
assert!(
matches!(
stream.as_mut().poll_next(cx),
Poll::Ready(Some(Ok(AggregatedMessage::Ping(_))))
),
"expected ping frame after overflow"
);
assert!(
matches!(
stream.as_mut().poll_next(cx),
Poll::Ready(Some(Ok(AggregatedMessage::Text(text)))) if &text[..] == "ok"
),
"expected text message after overflow continuation is terminated"
);
assert!(
stream.as_mut().poll_next(cx).is_pending(),
"Stream should be pending after processing messages"
);
Poll::Ready(())
})
.await
}
}