1use std::{
2 collections::VecDeque,
3 future::poll_fn,
4 io, mem,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use actix_codec::{Decoder, Encoder};
10use actix_http::{
11 ws::{Codec, Frame, Message, ProtocolError},
12 Payload,
13};
14use actix_web::{
15 web::{Bytes, BytesMut},
16 Error,
17};
18use bytestring::ByteString;
19use futures_core::stream::Stream;
20use tokio::sync::mpsc::Receiver;
21
22use crate::AggregatedMessageStream;
23
24pub struct StreamingBody {
26 session_rx: Receiver<Message>,
27 messages: VecDeque<Message>,
28 buf: BytesMut,
29 codec: Codec,
30 closing: bool,
31}
32
33impl StreamingBody {
34 pub(super) fn new(session_rx: Receiver<Message>) -> Self {
35 StreamingBody {
36 session_rx,
37 messages: VecDeque::new(),
38 buf: BytesMut::new(),
39 codec: Codec::new(),
40 closing: false,
41 }
42 }
43}
44
45pub struct MessageStream {
47 payload: Payload,
48
49 messages: VecDeque<Message>,
50 buf: BytesMut,
51 codec: Codec,
52 closing: bool,
53}
54
55impl MessageStream {
56 pub(super) fn new(payload: Payload) -> Self {
57 MessageStream {
58 payload,
59 messages: VecDeque::new(),
60 buf: BytesMut::new(),
61 codec: Codec::new(),
62 closing: false,
63 }
64 }
65
66 #[must_use]
81 pub fn max_frame_size(mut self, max_size: usize) -> Self {
82 self.codec = self.codec.max_size(max_size);
83 self
84 }
85
86 #[must_use]
93 pub fn aggregate_continuations(self) -> AggregatedMessageStream {
94 AggregatedMessageStream::new(self)
95 }
96
97 #[must_use]
110 pub async fn recv(&mut self) -> Option<Result<Message, ProtocolError>> {
111 poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
112 }
113}
114
115impl Stream for StreamingBody {
116 type Item = Result<Bytes, Error>;
117
118 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119 let this = self.get_mut();
120
121 if this.closing {
122 return Poll::Ready(None);
123 }
124
125 loop {
126 match Pin::new(&mut this.session_rx).poll_recv(cx) {
127 Poll::Ready(Some(msg)) => {
128 this.messages.push_back(msg);
129 }
130 Poll::Ready(None) => {
131 this.closing = true;
132 break;
133 }
134 Poll::Pending => break,
135 }
136 }
137
138 while let Some(msg) = this.messages.pop_front() {
139 if let Err(err) = this.codec.encode(msg, &mut this.buf) {
140 return Poll::Ready(Some(Err(err.into())));
141 }
142 }
143
144 if !this.buf.is_empty() {
145 return Poll::Ready(Some(Ok(mem::take(&mut this.buf).freeze())));
146 }
147
148 if this.closing {
149 return Poll::Ready(None);
150 }
151
152 Poll::Pending
153 }
154}
155
156impl Stream for MessageStream {
157 type Item = Result<Message, ProtocolError>;
158
159 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
160 let this = self.get_mut();
161
162 if let Some(msg) = this.messages.pop_front() {
166 return Poll::Ready(Some(Ok(msg)));
167 }
168
169 if !this.closing {
170 loop {
172 match Pin::new(&mut this.payload).poll_next(cx) {
173 Poll::Ready(Some(Ok(bytes))) => {
174 this.buf.extend_from_slice(&bytes);
175 }
176 Poll::Ready(Some(Err(err))) => {
177 return Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other(err)))));
178 }
179 Poll::Ready(None) => {
180 this.closing = true;
181 break;
182 }
183 Poll::Pending => break,
184 }
185 }
186 }
187
188 while let Some(frame) = this.codec.decode(&mut this.buf)? {
190 let message = match frame {
191 Frame::Text(bytes) => {
192 ByteString::try_from(bytes)
193 .map(Message::Text)
194 .map_err(|err| {
195 ProtocolError::Io(io::Error::new(io::ErrorKind::InvalidData, err))
196 })?
197 }
198 Frame::Binary(bytes) => Message::Binary(bytes),
199 Frame::Ping(bytes) => Message::Ping(bytes),
200 Frame::Pong(bytes) => Message::Pong(bytes),
201 Frame::Close(reason) => Message::Close(reason),
202 Frame::Continuation(item) => Message::Continuation(item),
203 };
204
205 this.messages.push_back(message);
206 }
207
208 if let Some(msg) = this.messages.pop_front() {
210 return Poll::Ready(Some(Ok(msg)));
211 }
212
213 if this.closing {
215 return Poll::Ready(None);
216 }
217
218 Poll::Pending
219 }
220}
221
222#[cfg(test)]
223pub(crate) mod tests {
224 use std::{
225 future::Future,
226 pin::Pin,
227 task::{ready, Context, Poll},
228 };
229
230 use actix_http::error::PayloadError;
231 use futures_core::Stream;
232 use tokio::sync::mpsc::{Receiver, Sender};
233
234 use super::{Bytes, BytesMut, Codec, Encoder, Message, MessageStream, Payload, StreamingBody};
235
236 pub(crate) struct PayloadReceiver {
237 rx: Receiver<Bytes>,
238 }
239 pub(crate) struct PayloadSender {
240 codec: Codec,
241 tx: Sender<Bytes>,
242 }
243 impl PayloadSender {
244 pub(crate) async fn send(&mut self, message: Message) {
245 self.send_many(vec![message]).await
246 }
247 pub(crate) async fn send_many(&mut self, messages: Vec<Message>) {
248 let mut buf = BytesMut::new();
249
250 for message in messages {
251 self.codec.encode(message, &mut buf).unwrap();
252 }
253
254 self.tx.send(buf.freeze()).await.unwrap()
255 }
256 }
257 impl Stream for PayloadReceiver {
258 type Item = Result<Bytes, PayloadError>;
259
260 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
261 let opt = ready!(self.get_mut().rx.poll_recv(cx));
262
263 Poll::Ready(opt.map(Ok))
264 }
265 }
266 pub(crate) fn payload_pair(capacity: usize) -> (PayloadSender, Payload) {
267 let (tx, rx) = tokio::sync::mpsc::channel(capacity);
268
269 (
270 PayloadSender {
271 codec: Codec::new().client_mode(),
272 tx,
273 },
274 Payload::Stream {
275 payload: Box::pin(PayloadReceiver { rx }),
276 },
277 )
278 }
279
280 #[tokio::test]
281 async fn message_stream_yields_messages() {
282 std::future::poll_fn(move |cx| {
283 let (mut tx, rx) = payload_pair(8);
284 let message_stream = MessageStream::new(rx);
285 let mut stream = std::pin::pin!(message_stream);
286
287 let messages = [
288 Message::Binary(Bytes::from(vec![0, 1, 2, 3])),
289 Message::Ping(Bytes::from(vec![3, 2, 1, 0])),
290 Message::Close(None),
291 ];
292
293 for msg in messages {
294 let poll = stream.as_mut().poll_next(cx);
295 assert!(
296 poll.is_pending(),
297 "Stream should be pending when no messages are present {poll:?}"
298 );
299
300 let fut = tx.send(msg);
301 let fut = std::pin::pin!(fut);
302
303 assert!(fut.poll(cx).is_ready(), "Sending should not yield");
304 assert!(
305 stream.as_mut().poll_next(cx).is_ready(),
306 "Stream should be ready"
307 );
308 }
309
310 assert!(
311 stream.as_mut().poll_next(cx).is_pending(),
312 "Stream should be pending after processing messages"
313 );
314
315 Poll::Ready(())
316 })
317 .await
318 }
319
320 #[tokio::test]
321 async fn message_stream_yields_consecutive_messages() {
322 std::future::poll_fn(move |cx| {
323 let (mut tx, rx) = payload_pair(8);
324 let message_stream = MessageStream::new(rx);
325 let mut stream = std::pin::pin!(message_stream);
326
327 let messages = vec![
328 Message::Binary(Bytes::from(vec![0, 1, 2, 3])),
329 Message::Ping(Bytes::from(vec![3, 2, 1, 0])),
330 Message::Close(None),
331 ];
332
333 let size = messages.len();
334
335 let fut = tx.send_many(messages);
336 let fut = std::pin::pin!(fut);
337 assert!(fut.poll(cx).is_ready(), "Sending should not yield");
338
339 for _ in 0..size {
340 assert!(
341 stream.as_mut().poll_next(cx).is_ready(),
342 "Stream should be ready"
343 );
344 }
345
346 assert!(
347 stream.as_mut().poll_next(cx).is_pending(),
348 "Stream should be pending after processing messages"
349 );
350
351 Poll::Ready(())
352 })
353 .await
354 }
355
356 #[tokio::test]
357 async fn message_stream_closes() {
358 std::future::poll_fn(move |cx| {
359 let (tx, rx) = payload_pair(8);
360 drop(tx);
361 let message_stream = MessageStream::new(rx);
362 let mut stream = std::pin::pin!(message_stream);
363
364 let poll = stream.as_mut().poll_next(cx);
365 assert!(
366 matches!(poll, Poll::Ready(None)),
367 "Stream should be ready when closing {poll:?}"
368 );
369
370 Poll::Ready(())
371 })
372 .await
373 }
374
375 #[tokio::test]
376 async fn stream_produces_bytes_from_messages() {
377 std::future::poll_fn(move |cx| {
378 let (tx, rx) = tokio::sync::mpsc::channel(1);
379
380 let stream = StreamingBody::new(rx);
381
382 let messages = [
383 Message::Binary(Bytes::from(vec![0, 1, 2, 3])),
384 Message::Ping(Bytes::from(vec![3, 2, 1, 0])),
385 Message::Close(None),
386 ];
387
388 let mut stream = std::pin::pin!(stream);
389
390 for msg in messages {
391 assert!(
392 stream.as_mut().poll_next(cx).is_pending(),
393 "Stream should be pending when no messages are present"
394 );
395
396 let fut = tx.send(msg);
397 let fut = std::pin::pin!(fut);
398
399 assert!(fut.poll(cx).is_ready(), "Sending should not yield");
400 assert!(
401 stream.as_mut().poll_next(cx).is_ready(),
402 "Stream should be ready"
403 );
404 }
405
406 assert!(
407 stream.as_mut().poll_next(cx).is_pending(),
408 "Stream should be pending after processing messages"
409 );
410
411 Poll::Ready(())
412 })
413 .await;
414 }
415
416 #[tokio::test]
417 async fn stream_processes_many_consecutive_messages() {
418 std::future::poll_fn(move |cx| {
419 let (tx, rx) = tokio::sync::mpsc::channel(3);
420
421 let stream = StreamingBody::new(rx);
422
423 let messages = [
424 Message::Binary(Bytes::from(vec![0, 1, 2, 3])),
425 Message::Ping(Bytes::from(vec![3, 2, 1, 0])),
426 Message::Close(None),
427 ];
428
429 let mut stream = std::pin::pin!(stream);
430
431 assert!(stream.as_mut().poll_next(cx).is_pending());
432
433 for msg in messages {
434 let fut = tx.send(msg);
435 let fut = std::pin::pin!(fut);
436 assert!(fut.poll(cx).is_ready(), "Sending should not yield");
437 }
438
439 assert!(
440 stream.as_mut().poll_next(cx).is_ready(),
441 "Stream should be ready"
442 );
443 assert!(
444 stream.as_mut().poll_next(cx).is_pending(),
445 "Stream should have only been ready once"
446 );
447
448 Poll::Ready(())
449 })
450 .await;
451 }
452
453 #[tokio::test]
454 async fn stream_closes() {
455 std::future::poll_fn(move |cx| {
456 let (tx, rx) = tokio::sync::mpsc::channel(3);
457
458 drop(tx);
459 let stream = StreamingBody::new(rx);
460
461 let mut stream = std::pin::pin!(stream);
462
463 let poll = stream.as_mut().poll_next(cx);
464
465 assert!(
466 matches!(poll, Poll::Ready(None)),
467 "stream should close after dropped tx"
468 );
469
470 Poll::Ready(())
471 })
472 .await;
473 }
474}