watermelon_net/connection/
websocket.rs

1use std::{
2    future, io,
3    pin::Pin,
4    task::{Context, Poll, Waker},
5};
6
7use bytes::Bytes;
8use futures_core::Stream as _;
9use futures_sink::Sink;
10use http::Uri;
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio_websockets::{ClientBuilder, Message, WebSocketStream};
13use watermelon_proto::proto::{
14    ClientOp, FramedEncoder, ServerOp, decode_frame, error::FrameDecoderError,
15};
16
17#[derive(Debug)]
18pub struct WebsocketConnection<S> {
19    socket: WebSocketStream<S>,
20    encoder: FramedEncoder,
21    residual_frame: Bytes,
22    should_flush: bool,
23}
24
25impl<S> WebsocketConnection<S>
26where
27    S: AsyncRead + AsyncWrite + Unpin,
28{
29    /// Construct a websocket stream to a pre-established connection `socket`.
30    ///
31    /// # Errors
32    ///
33    /// Returns an error if the websocket handshake fails.
34    pub async fn new(uri: Uri, socket: S) -> io::Result<Self> {
35        let (socket, _resp) = ClientBuilder::from_uri(uri)
36            .connect_on(socket)
37            .await
38            .map_err(websockets_error_to_io)?;
39        Ok(Self {
40            socket,
41            encoder: FramedEncoder::new(),
42            residual_frame: Bytes::new(),
43            should_flush: false,
44        })
45    }
46
47    pub fn poll_read_next(
48        &mut self,
49        cx: &mut Context<'_>,
50    ) -> Poll<Result<ServerOp, WebsocketReadError>> {
51        loop {
52            if !self.residual_frame.is_empty() {
53                return Poll::Ready(
54                    decode_frame(&mut self.residual_frame).map_err(WebsocketReadError::Decoder),
55                );
56            }
57
58            match Pin::new(&mut self.socket).poll_next(cx) {
59                Poll::Pending => return Poll::Pending,
60                Poll::Ready(Some(Ok(message))) if message.is_binary() => {
61                    self.residual_frame = message.into_payload().into();
62                }
63                Poll::Ready(Some(Ok(_message))) => {}
64                Poll::Ready(Some(Err(err))) => {
65                    return Poll::Ready(Err(WebsocketReadError::Io(websockets_error_to_io(err))));
66                }
67                Poll::Ready(None) => return Poll::Ready(Err(WebsocketReadError::Closed)),
68            }
69        }
70    }
71
72    /// Reads the next [`ServerOp`].
73    ///
74    /// # Errors
75    ///
76    /// It returns an error if the content cannot be decoded or if an I/O error occurs.
77    pub async fn read_next(&mut self) -> Result<ServerOp, WebsocketReadError> {
78        future::poll_fn(|cx| self.poll_read_next(cx)).await
79    }
80
81    pub fn should_flush(&self) -> bool {
82        self.should_flush
83    }
84
85    pub fn may_enqueue_more_ops(&mut self) -> bool {
86        let mut cx = Context::from_waker(Waker::noop());
87        Pin::new(&mut self.socket).poll_ready(&mut cx).is_ready()
88    }
89
90    /// Enqueue `item` to be written.
91    #[expect(clippy::missing_panics_doc)]
92    pub fn enqueue_write_op(&mut self, item: &ClientOp) {
93        let payload = self.encoder.encode(item);
94        Pin::new(&mut self.socket)
95            .start_send(Message::binary(payload))
96            .unwrap();
97        self.should_flush = true;
98    }
99
100    pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
101        Pin::new(&mut self.socket)
102            .poll_flush(cx)
103            .map_err(websockets_error_to_io)
104    }
105
106    /// Flush any buffered writes to the connection
107    ///
108    /// # Errors
109    ///
110    /// Returns an error if flushing fails
111    pub async fn flush(&mut self) -> io::Result<()> {
112        future::poll_fn(|cx| self.poll_flush(cx)).await
113    }
114
115    /// Shutdown the connection
116    ///
117    /// # Errors
118    ///
119    /// Returns an error if shutting down the connection fails.
120    /// Implementations usually ignore this error.
121    pub async fn shutdown(&mut self) -> io::Result<()> {
122        future::poll_fn(|cx| Pin::new(&mut self.socket).poll_close(cx))
123            .await
124            .map_err(websockets_error_to_io)
125    }
126}
127
128#[derive(Debug, thiserror::Error)]
129pub enum WebsocketReadError {
130    #[error("decoder")]
131    Decoder(#[source] FrameDecoderError),
132    #[error("io")]
133    Io(#[source] io::Error),
134    #[error("closed")]
135    Closed,
136}
137
138fn websockets_error_to_io(err: tokio_websockets::Error) -> io::Error {
139    match err {
140        tokio_websockets::Error::Io(err) => err,
141        err => io::Error::other(err),
142    }
143}