1use std::{
4 future::poll_fn,
5 io, mem,
6 pin::Pin,
7 task::{ready, Context, Poll},
8};
9
10use actix_http::ws::{CloseReason, Item, Message, ProtocolError};
11use actix_web::web::{Bytes, BytesMut};
12use bytestring::ByteString;
13use futures_core::Stream;
14
15use crate::MessageStream;
16
17pub(crate) enum ContinuationKind {
18 Text,
19 Binary,
20}
21
22#[derive(Debug, PartialEq, Eq)]
24pub enum AggregatedMessage {
25 Text(ByteString),
27
28 Binary(Bytes),
30
31 Ping(Bytes),
33
34 Pong(Bytes),
36
37 Close(Option<CloseReason>),
39}
40
41pub struct AggregatedMessageStream {
43 stream: MessageStream,
44 current_size: usize,
45 max_size: usize,
46 continuations: Vec<Bytes>,
47 continuation_kind: ContinuationKind,
48}
49
50impl AggregatedMessageStream {
51 #[must_use]
52 pub(crate) fn new(stream: MessageStream) -> Self {
53 AggregatedMessageStream {
54 stream,
55 current_size: 0,
56 max_size: 1024 * 1024,
57 continuations: Vec::new(),
58 continuation_kind: ContinuationKind::Binary,
59 }
60 }
61
62 #[must_use]
78 pub fn max_continuation_size(mut self, max_size: usize) -> Self {
79 self.max_size = max_size;
80 self
81 }
82
83 #[must_use]
96 pub async fn recv(&mut self) -> Option<<Self as Stream>::Item> {
97 poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
98 }
99}
100
101fn size_error() -> Poll<Option<Result<AggregatedMessage, ProtocolError>>> {
102 Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other(
103 "Exceeded maximum continuation size",
104 )))))
105}
106
107impl Stream for AggregatedMessageStream {
108 type Item = Result<AggregatedMessage, ProtocolError>;
109
110 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
111 let this = self.get_mut();
112
113 let Some(msg) = ready!(Pin::new(&mut this.stream).poll_next(cx)?) else {
114 return Poll::Ready(None);
115 };
116
117 match msg {
118 Message::Continuation(item) => match item {
119 Item::FirstText(bytes) => {
120 this.continuation_kind = ContinuationKind::Text;
121 this.current_size += bytes.len();
122
123 if this.current_size > this.max_size {
124 this.continuations.clear();
125 return size_error();
126 }
127
128 this.continuations.push(bytes);
129
130 Poll::Pending
131 }
132
133 Item::FirstBinary(bytes) => {
134 this.continuation_kind = ContinuationKind::Binary;
135 this.current_size += bytes.len();
136
137 if this.current_size > this.max_size {
138 this.continuations.clear();
139 return size_error();
140 }
141
142 this.continuations.push(bytes);
143
144 Poll::Pending
145 }
146
147 Item::Continue(bytes) => {
148 this.current_size += bytes.len();
149
150 if this.current_size > this.max_size {
151 this.continuations.clear();
152 return size_error();
153 }
154
155 this.continuations.push(bytes);
156
157 Poll::Pending
158 }
159
160 Item::Last(bytes) => {
161 this.current_size += bytes.len();
162
163 if this.current_size > this.max_size {
164 this.current_size = 0;
167 this.continuations.clear();
168
169 return size_error();
170 }
171
172 this.continuations.push(bytes);
173 let bytes = collect(&mut this.continuations);
174
175 this.current_size = 0;
176
177 match this.continuation_kind {
178 ContinuationKind::Text => {
179 Poll::Ready(Some(match ByteString::try_from(bytes) {
180 Ok(bytestring) => Ok(AggregatedMessage::Text(bytestring)),
181 Err(err) => Err(ProtocolError::Io(io::Error::new(
182 io::ErrorKind::InvalidData,
183 err.to_string(),
184 ))),
185 }))
186 }
187 ContinuationKind::Binary => {
188 Poll::Ready(Some(Ok(AggregatedMessage::Binary(bytes))))
189 }
190 }
191 }
192 },
193
194 Message::Text(text) => Poll::Ready(Some(Ok(AggregatedMessage::Text(text)))),
195 Message::Binary(binary) => Poll::Ready(Some(Ok(AggregatedMessage::Binary(binary)))),
196 Message::Ping(ping) => Poll::Ready(Some(Ok(AggregatedMessage::Ping(ping)))),
197 Message::Pong(pong) => Poll::Ready(Some(Ok(AggregatedMessage::Pong(pong)))),
198 Message::Close(close) => Poll::Ready(Some(Ok(AggregatedMessage::Close(close)))),
199
200 Message::Nop => unreachable!("MessageStream should not produce no-ops"),
201 }
202 }
203}
204
205fn collect(continuations: &mut Vec<Bytes>) -> Bytes {
206 let continuations = mem::take(continuations);
207 let total_len = continuations.iter().map(|b| b.len()).sum();
208
209 let mut buf = BytesMut::with_capacity(total_len);
210
211 for chunk in continuations {
212 buf.extend(chunk);
213 }
214
215 buf.freeze()
216}