watermelon_net/connection/
websocket.rs1use std::{
2 future, io,
3 pin::Pin,
4 task::{Context, Poll, Waker},
5};
6
7use bytes::Bytes;
8use futures_core::Stream as _;
9use futures_sink::Sink;
10use http::Uri;
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio_websockets::{ClientBuilder, Message, WebSocketStream};
13use watermelon_proto::proto::{
14 ClientOp, FramedEncoder, ServerOp, decode_frame, error::FrameDecoderError,
15};
16
17#[derive(Debug)]
18pub struct WebsocketConnection<S> {
19 socket: WebSocketStream<S>,
20 encoder: FramedEncoder,
21 residual_frame: Bytes,
22 should_flush: bool,
23}
24
25impl<S> WebsocketConnection<S>
26where
27 S: AsyncRead + AsyncWrite + Unpin,
28{
29 pub async fn new(uri: Uri, socket: S) -> io::Result<Self> {
35 let (socket, _resp) = ClientBuilder::from_uri(uri)
36 .connect_on(socket)
37 .await
38 .map_err(websockets_error_to_io)?;
39 Ok(Self {
40 socket,
41 encoder: FramedEncoder::new(),
42 residual_frame: Bytes::new(),
43 should_flush: false,
44 })
45 }
46
47 pub fn poll_read_next(
48 &mut self,
49 cx: &mut Context<'_>,
50 ) -> Poll<Result<ServerOp, WebsocketReadError>> {
51 loop {
52 if !self.residual_frame.is_empty() {
53 return Poll::Ready(
54 decode_frame(&mut self.residual_frame).map_err(WebsocketReadError::Decoder),
55 );
56 }
57
58 match Pin::new(&mut self.socket).poll_next(cx) {
59 Poll::Pending => return Poll::Pending,
60 Poll::Ready(Some(Ok(message))) if message.is_binary() => {
61 self.residual_frame = message.into_payload().into();
62 }
63 Poll::Ready(Some(Ok(_message))) => {}
64 Poll::Ready(Some(Err(err))) => {
65 return Poll::Ready(Err(WebsocketReadError::Io(websockets_error_to_io(err))));
66 }
67 Poll::Ready(None) => return Poll::Ready(Err(WebsocketReadError::Closed)),
68 }
69 }
70 }
71
72 pub async fn read_next(&mut self) -> Result<ServerOp, WebsocketReadError> {
78 future::poll_fn(|cx| self.poll_read_next(cx)).await
79 }
80
81 pub fn should_flush(&self) -> bool {
82 self.should_flush
83 }
84
85 pub fn may_enqueue_more_ops(&mut self) -> bool {
86 let mut cx = Context::from_waker(Waker::noop());
87 Pin::new(&mut self.socket).poll_ready(&mut cx).is_ready()
88 }
89
90 #[expect(clippy::missing_panics_doc)]
92 pub fn enqueue_write_op(&mut self, item: &ClientOp) {
93 let payload = self.encoder.encode(item);
94 Pin::new(&mut self.socket)
95 .start_send(Message::binary(payload))
96 .unwrap();
97 self.should_flush = true;
98 }
99
100 pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
101 Pin::new(&mut self.socket)
102 .poll_flush(cx)
103 .map_err(websockets_error_to_io)
104 }
105
106 pub async fn flush(&mut self) -> io::Result<()> {
112 future::poll_fn(|cx| self.poll_flush(cx)).await
113 }
114
115 pub async fn shutdown(&mut self) -> io::Result<()> {
122 future::poll_fn(|cx| Pin::new(&mut self.socket).poll_close(cx))
123 .await
124 .map_err(websockets_error_to_io)
125 }
126}
127
128#[derive(Debug, thiserror::Error)]
129pub enum WebsocketReadError {
130 #[error("decoder")]
131 Decoder(#[source] FrameDecoderError),
132 #[error("io")]
133 Io(#[source] io::Error),
134 #[error("closed")]
135 Closed,
136}
137
138fn websockets_error_to_io(err: tokio_websockets::Error) -> io::Error {
139 match err {
140 tokio_websockets::Error::Io(err) => err,
141 err => io::Error::other(err),
142 }
143}