rama_ws/runtime/
stream.rs

1use std::{
2    io::{self, Read, Write},
3    pin::Pin,
4    task::{Context, Poll, ready},
5};
6
7use rama_core::stream::Stream;
8use rama_core::{
9    error::OpaqueError,
10    extensions::{Extensions, ExtensionsMut, ExtensionsRef},
11    futures::{self, SinkExt, StreamExt},
12    telemetry::tracing::{debug, trace},
13};
14use rama_http::io::upgrade;
15
16use crate::{
17    Message, ProtocolError,
18    protocol::{CloseFrame, Role, WebSocket, WebSocketConfig},
19    runtime::{
20        compat::{self, AllowStd, ContextWaker},
21        handshake::without_handshake,
22    },
23};
24
25/// A wrapper around an underlying raw stream which implements the WebSocket
26/// protocol.
27///
28/// A `AsyncWebSocket<S>` represents a handshake that has been completed
29/// successfully and both the server and the client are ready for receiving
30/// and sending data. Message from a `AsyncWebSocket<S>` are accessible
31/// through the respective `Stream` and `Sink`.
32#[derive(Debug)]
33pub struct AsyncWebSocket<S = upgrade::Upgraded> {
34    inner: WebSocket<AllowStd<S>>,
35    closing: bool,
36    ended: bool,
37    /// Tungstenite is probably ready to receive more data.
38    ///
39    /// `false` once start_send hits `WouldBlock` errors.
40    /// `true` initially and after `flush`ing.
41    ready: bool,
42}
43
44impl<S> AsyncWebSocket<S> {
45    /// Convert a raw socket into a AsyncWebSocket without performing a
46    /// handshake.
47    pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
48    where
49        S: Stream + Unpin + ExtensionsMut,
50    {
51        without_handshake(stream, move |allow_std| {
52            WebSocket::from_raw_socket(allow_std, role, config)
53        })
54        .await
55    }
56
57    /// Convert a raw socket into a AsyncWebSocket without performing a
58    /// handshake.
59    pub async fn from_partially_read(
60        stream: S,
61        part: Vec<u8>,
62        role: Role,
63        config: Option<WebSocketConfig>,
64    ) -> Self
65    where
66        S: Stream + Unpin + ExtensionsMut,
67    {
68        without_handshake(stream, move |allow_std| {
69            WebSocket::from_partially_read(allow_std, part, role, config)
70        })
71        .await
72    }
73
74    pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
75        Self {
76            inner: ws,
77            closing: false,
78            ended: false,
79            ready: true,
80        }
81    }
82
83    fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
84    where
85        S: Unpin,
86        F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
87        AllowStd<S>: Read + Write,
88    {
89        trace!("AsyncWebSocket.with_context");
90        if let Some((kind, ctx)) = ctx {
91            self.inner.get_mut().set_waker(kind, ctx.waker());
92        }
93        f(&mut self.inner)
94    }
95
96    /// Consumes the `WebSocketStream` and returns the underlying stream.
97    pub fn into_inner(self) -> S {
98        self.inner.into_inner().into_inner()
99    }
100
101    /// Returns a shared reference to the inner stream.
102    pub fn get_ref(&self) -> &S
103    where
104        S: Stream + Unpin,
105    {
106        self.inner.get_ref().get_ref()
107    }
108
109    /// Returns a mutable reference to the inner stream.
110    pub fn get_mut(&mut self) -> &mut S
111    where
112        S: Stream + Unpin,
113    {
114        self.inner.get_mut().get_mut()
115    }
116
117    /// Returns a reference to the configuration of the tungstenite stream.
118    pub fn get_config(&self) -> &WebSocketConfig {
119        self.inner.get_config()
120    }
121
122    /// Close the underlying web socket
123    pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), ProtocolError>
124    where
125        S: Stream + Unpin,
126    {
127        self.send(Message::Close(msg)).await
128    }
129}
130
131impl<S: ExtensionsRef> ExtensionsRef for AsyncWebSocket<S> {
132    fn extensions(&self) -> &Extensions {
133        self.inner.extensions()
134    }
135}
136
137impl<S: ExtensionsMut> ExtensionsMut for AsyncWebSocket<S> {
138    fn extensions_mut(&mut self) -> &mut Extensions {
139        self.inner.extensions_mut()
140    }
141}
142
143impl<S: Stream + Unpin> AsyncWebSocket<S> {
144    #[inline]
145    /// Writes and immediately flushes a message.
146    pub fn send_message(
147        &mut self,
148        msg: Message,
149    ) -> impl Future<Output = Result<(), ProtocolError>> + Send + '_ {
150        self.send(msg)
151    }
152
153    pub async fn recv_message(&mut self) -> Result<Message, ProtocolError> {
154        self.next().await.ok_or_else(|| {
155            ProtocolError::Io(io::Error::new(
156                io::ErrorKind::ConnectionAborted,
157                OpaqueError::from_display(
158                    "Connection closed: no messages to be received any longer",
159                ),
160            ))
161        })?
162    }
163}
164
165impl<T> futures::Stream for AsyncWebSocket<T>
166where
167    T: Stream + Unpin,
168{
169    type Item = Result<Message, ProtocolError>;
170
171    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
172        trace!("Stream.poll_next");
173
174        // The connection has been closed or a critical error has occurred.
175        // We have already returned the error to the user, the `Stream` is unusable,
176        // so we assume that the stream has been "fused".
177        if self.ended {
178            return Poll::Ready(None);
179        }
180
181        match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
182            trace!("Stream.with_context poll_next -> read()");
183            compat::cvt(s.read())
184        })) {
185            Ok(v) => Poll::Ready(Some(Ok(v))),
186            Err(e) => {
187                self.ended = true;
188                if e.is_connection_error() {
189                    Poll::Ready(None)
190                } else {
191                    Poll::Ready(Some(Err(e)))
192                }
193            }
194        }
195    }
196}
197
198impl<T> futures::stream::FusedStream for AsyncWebSocket<T>
199where
200    T: Stream + Unpin,
201{
202    fn is_terminated(&self) -> bool {
203        self.ended
204    }
205}
206
207impl<T> futures::Sink<Message> for AsyncWebSocket<T>
208where
209    T: Stream + Unpin,
210{
211    type Error = ProtocolError;
212
213    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
214        if self.ready {
215            Poll::Ready(Ok(()))
216        } else {
217            // Currently blocked so try to flush the blockage away
218            (*self)
219                .with_context(Some((ContextWaker::Write, cx)), |s| compat::cvt(s.flush()))
220                .map(|r| {
221                    self.ready = true;
222                    r
223                })
224        }
225    }
226
227    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
228        match (*self).with_context(None, |s| s.write(item)) {
229            Ok(()) => {
230                self.ready = true;
231                Ok(())
232            }
233            Err(ProtocolError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
234                // the message was accepted and queued so not an error
235                // but `poll_ready` will now start trying to flush the block
236                self.ready = false;
237                Ok(())
238            }
239            Err(e) => {
240                self.ready = true;
241                debug!("websocket start_send error: {e}");
242                Err(e)
243            }
244        }
245    }
246
247    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
248        (*self)
249            .with_context(Some((ContextWaker::Write, cx)), |s| compat::cvt(s.flush()))
250            .map(|r| {
251                self.ready = true;
252                match r {
253                    Err(err) if err.is_connection_error() => {
254                        // WebSocket connection has just been closed. Flushing completed, not an error.
255                        Ok(())
256                    }
257                    other => other,
258                }
259            })
260    }
261
262    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
263        self.ready = true;
264        let res = if self.closing {
265            // After queueing it, we call `flush` to drive the close handshake to completion.
266            (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
267        } else {
268            (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
269        };
270
271        match res {
272            Ok(()) => Poll::Ready(Ok(())),
273            Err(ProtocolError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
274                trace!("WouldBlock");
275                self.closing = true;
276                Poll::Pending
277            }
278            Err(err) => {
279                if err.is_connection_error() {
280                    Poll::Ready(Ok(()))
281                } else {
282                    debug!("websocket close error: {}", err);
283                    Poll::Ready(Err(err))
284                }
285            }
286        }
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use crate::runtime::{AsyncWebSocket, compat::AllowStd};
293    use std::io::{Read, Write};
294
295    fn is_read<T: Read>() {}
296    fn is_write<T: Write>() {}
297    fn is_unpin<T: Unpin>() {}
298
299    #[test]
300    fn web_socket_stream_has_traits() {
301        is_read::<AllowStd<tokio::net::TcpStream>>();
302        is_write::<AllowStd<tokio::net::TcpStream>>();
303        is_unpin::<AsyncWebSocket<tokio::net::TcpStream>>();
304    }
305}