1use std::{
2 collections::VecDeque,
3 future::poll_fn,
4 io,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use actix_codec::{Decoder, Encoder};
10use actix_http::{
11 ws::{Codec, Frame, Message, ProtocolError},
12 Payload,
13};
14use actix_web::{
15 web::{Bytes, BytesMut},
16 Error,
17};
18use futures_core::stream::Stream;
19use tokio::sync::mpsc::Receiver;
20
21pub struct StreamingBody {
23 session_rx: Receiver<Message>,
24
25 messages: VecDeque<Message>,
26 buf: BytesMut,
27 codec: Codec,
28 closing: bool,
29}
30
31pub struct MessageStream {
35 payload: Payload,
36
37 messages: VecDeque<Message>,
38 buf: BytesMut,
39 codec: Codec,
40 closing: bool,
41}
42
43impl StreamingBody {
44 pub(super) fn new(session_rx: Receiver<Message>) -> Self {
45 StreamingBody {
46 session_rx,
47 messages: VecDeque::new(),
48 buf: BytesMut::new(),
49 codec: Codec::new(),
50 closing: false,
51 }
52 }
53}
54
55impl MessageStream {
56 pub(super) fn new(payload: Payload) -> Self {
57 MessageStream {
58 payload,
59 messages: VecDeque::new(),
60 buf: BytesMut::new(),
61 codec: Codec::new(),
62 closing: false,
63 }
64 }
65
66 pub async fn recv(&mut self) -> Option<Result<Message, ProtocolError>> {
74 poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
75 }
76}
77
78impl Stream for StreamingBody {
79 type Item = Result<Bytes, Error>;
80
81 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
82 let this = self.get_mut();
83
84 if this.closing {
85 return Poll::Ready(None);
86 }
87
88 loop {
89 match Pin::new(&mut this.session_rx).poll_recv(cx) {
90 Poll::Ready(Some(msg)) => {
91 this.messages.push_back(msg);
92 }
93 Poll::Ready(None) => {
94 this.closing = true;
95 break;
96 }
97 Poll::Pending => break,
98 }
99 }
100
101 while let Some(msg) = this.messages.pop_front() {
102 if let Err(e) = this.codec.encode(msg, &mut this.buf) {
103 return Poll::Ready(Some(Err(e.into())));
104 }
105 }
106
107 if !this.buf.is_empty() {
108 return Poll::Ready(Some(Ok(this.buf.split().freeze())));
109 }
110
111 Poll::Pending
112 }
113}
114
115impl Stream for MessageStream {
116 type Item = Result<Message, ProtocolError>;
117
118 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119 let this = self.get_mut();
120
121 if let Some(msg) = this.messages.pop_front() {
125 return Poll::Ready(Some(Ok(msg)));
126 }
127
128 if !this.closing {
129 loop {
131 match Pin::new(&mut this.payload).poll_next(cx) {
132 Poll::Ready(Some(Ok(bytes))) => {
133 this.buf.extend_from_slice(&bytes);
134 }
135 Poll::Ready(Some(Err(e))) => {
136 return Poll::Ready(Some(Err(ProtocolError::Io(io::Error::new(
137 io::ErrorKind::Other,
138 e.to_string(),
139 )))));
140 }
141 Poll::Ready(None) => {
142 this.closing = true;
143 break;
144 }
145 Poll::Pending => break,
146 }
147 }
148 }
149
150 while let Some(frame) = this.codec.decode(&mut this.buf)? {
152 let message = match frame {
153 Frame::Text(bytes) => {
154 let s = std::str::from_utf8(&bytes)
155 .map_err(|e| {
156 ProtocolError::Io(io::Error::new(io::ErrorKind::Other, e.to_string()))
157 })?
158 .to_string();
159 Message::Text(s.into())
160 }
161 Frame::Binary(bytes) => Message::Binary(bytes),
162 Frame::Ping(bytes) => Message::Ping(bytes),
163 Frame::Pong(bytes) => Message::Pong(bytes),
164 Frame::Close(reason) => Message::Close(reason),
165 Frame::Continuation(item) => Message::Continuation(item),
166 };
167
168 this.messages.push_back(message);
169 }
170
171 if let Some(msg) = this.messages.pop_front() {
173 return Poll::Ready(Some(Ok(msg)));
174 }
175
176 if this.closing {
178 return Poll::Ready(None);
179 }
180
181 Poll::Pending
182 }
183}