watermelon_net/connection/
websocket.rs1use std::{
2 future, io,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use bytes::Bytes;
8use futures_core::Stream as _;
9use futures_sink::Sink;
10use futures_util::task::noop_waker_ref;
11use http::Uri;
12use tokio::io::{AsyncRead, AsyncWrite};
13use tokio_websockets::{ClientBuilder, Message, WebSocketStream};
14use watermelon_proto::proto::{
15 decode_frame, error::FrameDecoderError, ClientOp, FramedEncoder, ServerOp,
16};
17
18#[derive(Debug)]
19pub struct WebsocketConnection<S> {
20 socket: WebSocketStream<S>,
21 encoder: FramedEncoder,
22 residual_frame: Bytes,
23 should_flush: bool,
24}
25
26impl<S> WebsocketConnection<S>
27where
28 S: AsyncRead + AsyncWrite + Unpin,
29{
30 pub async fn new(uri: Uri, socket: S) -> io::Result<Self> {
36 let (socket, _resp) = ClientBuilder::from_uri(uri)
37 .connect_on(socket)
38 .await
39 .map_err(websockets_error_to_io)?;
40 Ok(Self {
41 socket,
42 encoder: FramedEncoder::new(),
43 residual_frame: Bytes::new(),
44 should_flush: false,
45 })
46 }
47
48 pub fn poll_read_next(
49 &mut self,
50 cx: &mut Context<'_>,
51 ) -> Poll<Result<ServerOp, WebsocketReadError>> {
52 loop {
53 if !self.residual_frame.is_empty() {
54 return Poll::Ready(
55 decode_frame(&mut self.residual_frame).map_err(WebsocketReadError::Decoder),
56 );
57 }
58
59 match Pin::new(&mut self.socket).poll_next(cx) {
60 Poll::Pending => return Poll::Pending,
61 Poll::Ready(Some(Ok(message))) if message.is_binary() => {
62 self.residual_frame = message.into_payload().into();
63 }
64 Poll::Ready(Some(Ok(_message))) => {}
65 Poll::Ready(Some(Err(err))) => {
66 return Poll::Ready(Err(WebsocketReadError::Io(websockets_error_to_io(err))))
67 }
68 Poll::Ready(None) => return Poll::Ready(Err(WebsocketReadError::Closed)),
69 }
70 }
71 }
72
73 pub async fn read_next(&mut self) -> Result<ServerOp, WebsocketReadError> {
79 future::poll_fn(|cx| self.poll_read_next(cx)).await
80 }
81
82 pub fn should_flush(&self) -> bool {
83 self.should_flush
84 }
85
86 pub fn may_enqueue_more_ops(&mut self) -> bool {
87 let mut cx = Context::from_waker(noop_waker_ref());
89 Pin::new(&mut self.socket).poll_ready(&mut cx).is_ready()
90 }
91
92 #[expect(clippy::missing_panics_doc)]
94 pub fn enqueue_write_op(&mut self, item: &ClientOp) {
95 let payload = self.encoder.encode(item);
96 Pin::new(&mut self.socket)
97 .start_send(Message::binary(payload))
98 .unwrap();
99 self.should_flush = true;
100 }
101
102 pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
103 Pin::new(&mut self.socket)
104 .poll_flush(cx)
105 .map_err(websockets_error_to_io)
106 }
107
108 pub async fn flush(&mut self) -> io::Result<()> {
114 future::poll_fn(|cx| self.poll_flush(cx)).await
115 }
116
117 pub async fn shutdown(&mut self) -> io::Result<()> {
124 future::poll_fn(|cx| Pin::new(&mut self.socket).poll_close(cx))
125 .await
126 .map_err(websockets_error_to_io)
127 }
128}
129
130#[derive(Debug, thiserror::Error)]
131pub enum WebsocketReadError {
132 #[error("decoder")]
133 Decoder(#[source] FrameDecoderError),
134 #[error("io")]
135 Io(#[source] io::Error),
136 #[error("closed")]
137 Closed,
138}
139
140fn websockets_error_to_io(err: tokio_websockets::Error) -> io::Error {
141 match err {
142 tokio_websockets::Error::Io(err) => err,
143 err => io::Error::other(err),
144 }
145}