routerify_websocket/
websocket.rs

1use crate::{CloseCode, Message, WebSocketConfig};
2use futures::{ready, FutureExt, Sink, Stream};
3use std::borrow::Cow;
4use std::fmt;
5use std::net::SocketAddr;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tokio_tungstenite::{
9    tungstenite::protocol::{CloseFrame, Role},
10    WebSocketStream,
11};
12
13/// The WebSocket input-output stream.
14///
15/// It implements the [`Stream`](https://docs.rs/futures/0.3.5/futures/stream/trait.Stream.html) and [`Sink`](https://docs.rs/futures/0.3.5/futures/sink/trait.Sink.html)
16/// traits, so the socket is just a stream of messages coming in and going out.
17pub struct WebSocket {
18    inner: WebSocketStream<hyper::upgrade::Upgraded>,
19    remote_addr: SocketAddr,
20}
21
22impl WebSocket {
23    pub(crate) async fn from_raw_socket(
24        upgraded: hyper::upgrade::Upgraded,
25        remote_addr: SocketAddr,
26        config: WebSocketConfig,
27    ) -> Self {
28        WebSocketStream::from_raw_socket(upgraded, Role::Server, Some(config))
29            .map(|inner| WebSocket { inner, remote_addr })
30            .await
31    }
32
33    /// Get the peer's remote address.
34    pub fn remote_addr(&self) -> SocketAddr {
35        self.remote_addr
36    }
37
38    /// Consumes the websocket connection and gracefully closes it.
39    pub async fn close(self) -> crate::Result<()> {
40        let mut this = self;
41        this.inner
42            .close(None)
43            .await
44            .map_err(|err| crate::WebsocketError::WebSocketClose(err.into()))
45    }
46
47    /// Consumes the websocket connection and gracefully closes it with a code and reason.
48    pub async fn close_with<R: Into<Cow<'static, str>>>(self, code: CloseCode, reason: R) -> crate::Result<()> {
49        let mut this = self;
50        this.inner
51            .close(Some(CloseFrame {
52                code,
53                reason: reason.into(),
54            }))
55            .await
56            .map_err(|err| crate::WebsocketError::WebSocketClose(err.into()))
57    }
58}
59
60impl Stream for WebSocket {
61    type Item = Result<Message, crate::WebsocketError>;
62
63    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
64        match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
65            Some(Ok(item)) => Poll::Ready(Some(Ok(Message { inner: item }))),
66            Some(Err(err)) => Poll::Ready(Some(Err(crate::WebsocketError::MessageReceive(err.into())))),
67            None => Poll::Ready(None),
68        }
69    }
70}
71
72impl Sink<Message> for WebSocket {
73    type Error = crate::WebsocketError;
74
75    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
76        match ready!(Pin::new(&mut self.inner).poll_ready(cx)) {
77            Ok(()) => Poll::Ready(Ok(())),
78            Err(err) => Poll::Ready(Err(crate::WebsocketError::ReadyStatus(err.into()))),
79        }
80    }
81
82    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
83        match Pin::new(&mut self.inner).start_send(item.inner) {
84            Ok(()) => Ok(()),
85            Err(err) => Err(crate::WebsocketError::MessageSend(err.into())),
86        }
87    }
88
89    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
90        match ready!(Pin::new(&mut self.inner).poll_flush(cx)) {
91            Ok(()) => Poll::Ready(Ok(())),
92            Err(err) => Poll::Ready(Err(crate::WebsocketError::MessageFlush(err.into()))),
93        }
94    }
95
96    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
97        match ready!(Pin::new(&mut self.inner).poll_close(cx)) {
98            Ok(()) => Poll::Ready(Ok(())),
99            Err(err) => Poll::Ready(Err(crate::WebsocketError::WebSocketClose(err.into()))),
100        }
101    }
102}
103
104impl fmt::Debug for WebSocket {
105    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
106        f.debug_struct("WebSocket").finish()
107    }
108}