1use std::{
2 collections::VecDeque,
3 future::poll_fn,
4 io, mem,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use actix_codec::{Decoder, Encoder};
10use actix_http::{
11 ws::{Codec, Frame, Message, ProtocolError},
12 Payload,
13};
14use actix_web::{
15 web::{Bytes, BytesMut},
16 Error,
17};
18use bytestring::ByteString;
19use futures_core::stream::Stream;
20use tokio::sync::mpsc::Receiver;
21
22use crate::AggregatedMessageStream;
23
24pub struct StreamingBody {
26 session_rx: Receiver<Message>,
27 messages: VecDeque<Message>,
28 buf: BytesMut,
29 codec: Codec,
30 closing: bool,
31}
32
33impl StreamingBody {
34 pub(super) fn new(session_rx: Receiver<Message>) -> Self {
35 StreamingBody {
36 session_rx,
37 messages: VecDeque::new(),
38 buf: BytesMut::new(),
39 codec: Codec::new(),
40 closing: false,
41 }
42 }
43}
44
45pub struct MessageStream {
47 payload: Payload,
48
49 messages: VecDeque<Message>,
50 buf: BytesMut,
51 codec: Codec,
52 closing: bool,
53}
54
55impl MessageStream {
56 pub(super) fn new(payload: Payload) -> Self {
57 MessageStream {
58 payload,
59 messages: VecDeque::new(),
60 buf: BytesMut::new(),
61 codec: Codec::new(),
62 closing: false,
63 }
64 }
65
66 #[must_use]
81 pub fn max_frame_size(mut self, max_size: usize) -> Self {
82 self.codec = self.codec.max_size(max_size);
83 self
84 }
85
86 #[must_use]
93 pub fn aggregate_continuations(self) -> AggregatedMessageStream {
94 AggregatedMessageStream::new(self)
95 }
96
97 #[must_use]
110 pub async fn recv(&mut self) -> Option<Result<Message, ProtocolError>> {
111 poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
112 }
113}
114
115impl Stream for StreamingBody {
116 type Item = Result<Bytes, Error>;
117
118 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119 let this = self.get_mut();
120
121 if this.closing {
122 return Poll::Ready(None);
123 }
124
125 loop {
126 match Pin::new(&mut this.session_rx).poll_recv(cx) {
127 Poll::Ready(Some(msg)) => {
128 this.messages.push_back(msg);
129 }
130 Poll::Ready(None) => {
131 this.closing = true;
132 break;
133 }
134 Poll::Pending => break,
135 }
136 }
137
138 while let Some(msg) = this.messages.pop_front() {
139 if let Err(err) = this.codec.encode(msg, &mut this.buf) {
140 return Poll::Ready(Some(Err(err.into())));
141 }
142 }
143
144 if !this.buf.is_empty() {
145 return Poll::Ready(Some(Ok(mem::take(&mut this.buf).freeze())));
146 }
147
148 Poll::Pending
149 }
150}
151
152impl Stream for MessageStream {
153 type Item = Result<Message, ProtocolError>;
154
155 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
156 let this = self.get_mut();
157
158 if let Some(msg) = this.messages.pop_front() {
162 return Poll::Ready(Some(Ok(msg)));
163 }
164
165 if !this.closing {
166 loop {
168 match Pin::new(&mut this.payload).poll_next(cx) {
169 Poll::Ready(Some(Ok(bytes))) => {
170 this.buf.extend_from_slice(&bytes);
171 }
172 Poll::Ready(Some(Err(err))) => {
173 return Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other(err)))));
174 }
175 Poll::Ready(None) => {
176 this.closing = true;
177 break;
178 }
179 Poll::Pending => break,
180 }
181 }
182 }
183
184 while let Some(frame) = this.codec.decode(&mut this.buf)? {
186 let message = match frame {
187 Frame::Text(bytes) => {
188 ByteString::try_from(bytes)
189 .map(Message::Text)
190 .map_err(|err| {
191 ProtocolError::Io(io::Error::new(io::ErrorKind::InvalidData, err))
192 })?
193 }
194 Frame::Binary(bytes) => Message::Binary(bytes),
195 Frame::Ping(bytes) => Message::Ping(bytes),
196 Frame::Pong(bytes) => Message::Pong(bytes),
197 Frame::Close(reason) => Message::Close(reason),
198 Frame::Continuation(item) => Message::Continuation(item),
199 };
200
201 this.messages.push_back(message);
202 }
203
204 if let Some(msg) = this.messages.pop_front() {
206 return Poll::Ready(Some(Ok(msg)));
207 }
208
209 if this.closing {
211 return Poll::Ready(None);
212 }
213
214 Poll::Pending
215 }
216}