watermelon_net/connection/
mod.rs

1#[cfg(not(feature = "websocket"))]
2use std::{convert::Infallible, marker::PhantomData};
3use std::{
4    io,
5    task::{Context, Poll},
6};
7
8use tokio::io::{AsyncRead, AsyncWrite};
9#[cfg(feature = "websocket")]
10use watermelon_proto::proto::error::FrameDecoderError;
11use watermelon_proto::{
12    Connect,
13    error::ServerError,
14    proto::{ClientOp, ServerOp, error::DecoderError},
15};
16
17pub use self::streaming::{StreamingConnection, StreamingReadError};
18#[cfg(feature = "websocket")]
19pub use self::websocket::{WebsocketConnection, WebsocketReadError};
20
21mod streaming;
22#[cfg(feature = "websocket")]
23mod websocket;
24
25#[derive(Debug)]
26#[expect(clippy::large_enum_variant)]
27pub enum Connection<S1, S2> {
28    Streaming(StreamingConnection<S1>),
29    Websocket(WebsocketConnection<S2>),
30}
31
32#[derive(Debug)]
33#[cfg(not(feature = "websocket"))]
34#[doc(hidden)]
35pub struct WebsocketConnection<S> {
36    _socket: PhantomData<S>,
37    _impossible: Infallible,
38}
39
40#[derive(Debug, thiserror::Error)]
41pub enum ConnectionReadError {
42    #[error("streaming connection error")]
43    Streaming(#[source] StreamingReadError),
44    #[cfg(feature = "websocket")]
45    #[error("websocket connection error")]
46    Websocket(#[source] WebsocketReadError),
47}
48
49impl<S1, S2> Connection<S1, S2>
50where
51    S1: AsyncRead + AsyncWrite + Unpin,
52    S2: AsyncRead + AsyncWrite + Unpin,
53{
54    pub fn poll_read_next(
55        &mut self,
56        cx: &mut Context<'_>,
57    ) -> Poll<Result<ServerOp, ConnectionReadError>> {
58        match self {
59            Self::Streaming(streaming) => streaming
60                .poll_read_next(cx)
61                .map_err(ConnectionReadError::Streaming),
62            #[cfg(feature = "websocket")]
63            Self::Websocket(websocket) => websocket
64                .poll_read_next(cx)
65                .map_err(ConnectionReadError::Websocket),
66            #[cfg(not(feature = "websocket"))]
67            Self::Websocket(_) => unreachable!(),
68        }
69    }
70
71    /// Read the next incoming server operation.
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if reading or decoding the message fails.
76    pub async fn read_next(&mut self) -> Result<ServerOp, ConnectionReadError> {
77        match self {
78            Self::Streaming(streaming) => streaming
79                .read_next()
80                .await
81                .map_err(ConnectionReadError::Streaming),
82            #[cfg(feature = "websocket")]
83            Self::Websocket(websocket) => websocket
84                .read_next()
85                .await
86                .map_err(ConnectionReadError::Websocket),
87            #[cfg(not(feature = "websocket"))]
88            Self::Websocket(_) => unreachable!(),
89        }
90    }
91
92    pub fn flushes_automatically_when_full(&self) -> bool {
93        match self {
94            Self::Streaming(_streaming) => true,
95            #[cfg(feature = "websocket")]
96            Self::Websocket(_websocket) => false,
97            #[cfg(not(feature = "websocket"))]
98            Self::Websocket(_) => unreachable!(),
99        }
100    }
101
102    pub fn should_flush(&self) -> bool {
103        match self {
104            Self::Streaming(streaming) => streaming.may_flush(),
105            #[cfg(feature = "websocket")]
106            Self::Websocket(websocket) => websocket.should_flush(),
107            #[cfg(not(feature = "websocket"))]
108            Self::Websocket(_) => unreachable!(),
109        }
110    }
111
112    pub fn may_enqueue_more_ops(&mut self) -> bool {
113        match self {
114            Self::Streaming(streaming) => streaming.may_enqueue_more_ops(),
115            #[cfg(feature = "websocket")]
116            Self::Websocket(websocket) => websocket.may_enqueue_more_ops(),
117            #[cfg(not(feature = "websocket"))]
118            Self::Websocket(_) => unreachable!(),
119        }
120    }
121
122    pub fn enqueue_write_op(&mut self, item: &ClientOp) {
123        match self {
124            Self::Streaming(streaming) => streaming.enqueue_write_op(item),
125            #[cfg(feature = "websocket")]
126            Self::Websocket(websocket) => websocket.enqueue_write_op(item),
127            #[cfg(not(feature = "websocket"))]
128            Self::Websocket(_) => unreachable!(),
129        }
130    }
131
132    /// Convenience function for writing enqueued messages and flushing.
133    ///
134    /// # Errors
135    ///
136    /// Returns an error if writing or flushing fails.
137    pub async fn write_and_flush(&mut self) -> io::Result<()> {
138        if let Self::Streaming(streaming) = self {
139            while streaming.may_write() {
140                streaming.write_next().await?;
141            }
142        }
143
144        self.flush().await
145    }
146
147    pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
148        match self {
149            Self::Streaming(streaming) => streaming.poll_flush(cx),
150            #[cfg(feature = "websocket")]
151            Self::Websocket(websocket) => websocket.poll_flush(cx),
152            #[cfg(not(feature = "websocket"))]
153            Self::Websocket(_) => unreachable!(),
154        }
155    }
156
157    /// Flush any buffered writes to the connection
158    ///
159    /// # Errors
160    ///
161    /// Returns an error if flushing fails
162    pub async fn flush(&mut self) -> io::Result<()> {
163        match self {
164            Self::Streaming(streaming) => streaming.flush().await,
165            #[cfg(feature = "websocket")]
166            Self::Websocket(websocket) => websocket.flush().await,
167            #[cfg(not(feature = "websocket"))]
168            Self::Websocket(_) => unreachable!(),
169        }
170    }
171
172    /// Shutdown the connection
173    ///
174    /// # Errors
175    ///
176    /// Returns an error if shutting down the connection fails.
177    /// Implementations usually ignore this error.
178    pub async fn shutdown(&mut self) -> io::Result<()> {
179        match self {
180            Self::Streaming(streaming) => streaming.shutdown().await,
181            #[cfg(feature = "websocket")]
182            Self::Websocket(websocket) => websocket.shutdown().await,
183            #[cfg(not(feature = "websocket"))]
184            Self::Websocket(_) => unreachable!(),
185        }
186    }
187}
188
189#[derive(Debug, thiserror::Error)]
190pub enum ConnectError {
191    #[error("proto")]
192    Proto(#[source] DecoderError),
193    #[error("server")]
194    ServerError(#[source] ServerError),
195    #[error("io")]
196    Io(#[source] io::Error),
197    #[error("unexpected ServerOp")]
198    UnexpectedOp,
199}
200
201/// Send the `CONNECT` command to a pre-establised connection `conn`.
202///
203/// # Errors
204///
205/// Returns an error if connecting fails
206pub async fn connect<S1, S2, F>(
207    conn: &mut Connection<S1, S2>,
208    connect: Connect,
209    after_connect: F,
210) -> Result<(), ConnectError>
211where
212    S1: AsyncRead + AsyncWrite + Unpin,
213    S2: AsyncRead + AsyncWrite + Unpin,
214    F: FnOnce(&mut Connection<S1, S2>),
215{
216    conn.enqueue_write_op(&ClientOp::Connect {
217        connect: Box::new(connect),
218    });
219    conn.write_and_flush().await.map_err(ConnectError::Io)?;
220
221    after_connect(conn);
222    conn.enqueue_write_op(&ClientOp::Ping);
223    conn.write_and_flush().await.map_err(ConnectError::Io)?;
224
225    loop {
226        match conn.read_next().await {
227            Ok(ServerOp::Success) => {
228                // Success. Repeat to receive the PONG
229            }
230            Ok(ServerOp::Pong) => {
231                // Success. We've received the PONG,
232                // possibly after having received OK.
233                return Ok(());
234            }
235            Ok(ServerOp::Ping) => {
236                // I guess this could somehow happen. Handle it and repeat
237                conn.enqueue_write_op(&ClientOp::Pong);
238            }
239            Ok(ServerOp::Error { error }) => return Err(ConnectError::ServerError(error)),
240            Ok(ServerOp::Info { .. } | ServerOp::Message { .. }) => {
241                return Err(ConnectError::UnexpectedOp);
242            }
243            Err(ConnectionReadError::Streaming(StreamingReadError::Decoder(err))) => {
244                return Err(ConnectError::Proto(err));
245            }
246            Err(ConnectionReadError::Streaming(StreamingReadError::Io(err))) => {
247                return Err(ConnectError::Io(err));
248            }
249            #[cfg(feature = "websocket")]
250            Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
251                FrameDecoderError::Decoder(err),
252            ))) => return Err(ConnectError::Proto(err)),
253            #[cfg(feature = "websocket")]
254            Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
255                FrameDecoderError::IncompleteFrame,
256            ))) => todo!(),
257            #[cfg(feature = "websocket")]
258            Err(ConnectionReadError::Websocket(WebsocketReadError::Io(err))) => {
259                return Err(ConnectError::Io(err));
260            }
261            #[cfg(feature = "websocket")]
262            Err(ConnectionReadError::Websocket(WebsocketReadError::Closed)) => todo!(),
263        }
264    }
265}