actix_ws_ng/
fut.rs

1use std::{
2    collections::VecDeque,
3    future::poll_fn,
4    io,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use actix_codec::{Decoder, Encoder};
10use actix_http::{
11    ws::{Codec, Frame, Message, ProtocolError},
12    Payload,
13};
14use actix_web::{
15    web::{Bytes, BytesMut},
16    Error,
17};
18use futures_core::stream::Stream;
19use tokio::sync::mpsc::Receiver;
20
21/// A response body for Websocket HTTP Requests
22pub struct StreamingBody {
23    session_rx: Receiver<Message>,
24
25    messages: VecDeque<Message>,
26    buf: BytesMut,
27    codec: Codec,
28    closing: bool,
29}
30
31/// A stream of Messages from a websocket client
32///
33/// Messages can be accessed via the stream's `.next()` method
34pub struct MessageStream {
35    payload: Payload,
36
37    messages: VecDeque<Message>,
38    buf: BytesMut,
39    codec: Codec,
40    closing: bool,
41}
42
43impl StreamingBody {
44    pub(super) fn new(session_rx: Receiver<Message>) -> Self {
45        StreamingBody {
46            session_rx,
47            messages: VecDeque::new(),
48            buf: BytesMut::new(),
49            codec: Codec::new(),
50            closing: false,
51        }
52    }
53}
54
55impl MessageStream {
56    pub(super) fn new(payload: Payload) -> Self {
57        MessageStream {
58            payload,
59            messages: VecDeque::new(),
60            buf: BytesMut::new(),
61            codec: Codec::new(),
62            closing: false,
63        }
64    }
65
66    /// Wait for the next item from the message stream
67    ///
68    /// ```rust,ignore
69    /// while let Some(Ok(msg)) = stream.recv().await {
70    ///     // handle message
71    /// }
72    /// ```
73    pub async fn recv(&mut self) -> Option<Result<Message, ProtocolError>> {
74        poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
75    }
76}
77
78impl Stream for StreamingBody {
79    type Item = Result<Bytes, Error>;
80
81    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
82        let this = self.get_mut();
83
84        if this.closing {
85            return Poll::Ready(None);
86        }
87
88        loop {
89            match Pin::new(&mut this.session_rx).poll_recv(cx) {
90                Poll::Ready(Some(msg)) => {
91                    this.messages.push_back(msg);
92                }
93                Poll::Ready(None) => {
94                    this.closing = true;
95                    break;
96                }
97                Poll::Pending => break,
98            }
99        }
100
101        while let Some(msg) = this.messages.pop_front() {
102            if let Err(e) = this.codec.encode(msg, &mut this.buf) {
103                return Poll::Ready(Some(Err(e.into())));
104            }
105        }
106
107        if !this.buf.is_empty() {
108            return Poll::Ready(Some(Ok(this.buf.split().freeze())));
109        }
110
111        Poll::Pending
112    }
113}
114
115impl Stream for MessageStream {
116    type Item = Result<Message, ProtocolError>;
117
118    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119        let this = self.get_mut();
120
121        // Return the first message in the queue if one exists
122        //
123        // This is faster than polling and parsing
124        if let Some(msg) = this.messages.pop_front() {
125            return Poll::Ready(Some(Ok(msg)));
126        }
127
128        if !this.closing {
129            // Read in bytes until there's nothing left to read
130            loop {
131                match Pin::new(&mut this.payload).poll_next(cx) {
132                    Poll::Ready(Some(Ok(bytes))) => {
133                        this.buf.extend_from_slice(&bytes);
134                    }
135                    Poll::Ready(Some(Err(e))) => {
136                        return Poll::Ready(Some(Err(ProtocolError::Io(io::Error::new(
137                            io::ErrorKind::Other,
138                            e.to_string(),
139                        )))));
140                    }
141                    Poll::Ready(None) => {
142                        this.closing = true;
143                        break;
144                    }
145                    Poll::Pending => break,
146                }
147            }
148        }
149
150        // Create messages until there's no more bytes left
151        while let Some(frame) = this.codec.decode(&mut this.buf)? {
152            let message = match frame {
153                Frame::Text(bytes) => {
154                    let s = std::str::from_utf8(&bytes)
155                        .map_err(|e| {
156                            ProtocolError::Io(io::Error::new(io::ErrorKind::Other, e.to_string()))
157                        })?
158                        .to_string();
159                    Message::Text(s.into())
160                }
161                Frame::Binary(bytes) => Message::Binary(bytes),
162                Frame::Ping(bytes) => Message::Ping(bytes),
163                Frame::Pong(bytes) => Message::Pong(bytes),
164                Frame::Close(reason) => Message::Close(reason),
165                Frame::Continuation(item) => Message::Continuation(item),
166            };
167
168            this.messages.push_back(message);
169        }
170
171        // Return the first message in the queue
172        if let Some(msg) = this.messages.pop_front() {
173            return Poll::Ready(Some(Ok(msg)));
174        }
175
176        // If we've exhausted our message queue and we're closing, close the stream
177        if this.closing {
178            return Poll::Ready(None);
179        }
180
181        Poll::Pending
182    }
183}