watermelon_net/connection/
mod.rs

1#[cfg(not(feature = "websocket"))]
2use std::{convert::Infallible, marker::PhantomData};
3use std::{
4    io,
5    task::{Context, Poll},
6};
7
8use tokio::io::{AsyncRead, AsyncWrite};
9#[cfg(feature = "websocket")]
10use watermelon_proto::proto::error::FrameDecoderError;
11use watermelon_proto::{
12    error::ServerError,
13    proto::{error::DecoderError, ClientOp, ServerOp},
14    Connect,
15};
16
17pub use self::streaming::{StreamingConnection, StreamingReadError};
18#[cfg(feature = "websocket")]
19pub use self::websocket::{WebsocketConnection, WebsocketReadError};
20
21mod streaming;
22#[cfg(feature = "websocket")]
23mod websocket;
24
25#[derive(Debug)]
26pub enum Connection<S1, S2> {
27    Streaming(StreamingConnection<S1>),
28    Websocket(WebsocketConnection<S2>),
29}
30
31#[derive(Debug)]
32#[cfg(not(feature = "websocket"))]
33#[doc(hidden)]
34pub struct WebsocketConnection<S> {
35    _socket: PhantomData<S>,
36    _impossible: Infallible,
37}
38
39#[derive(Debug, thiserror::Error)]
40pub enum ConnectionReadError {
41    #[error("streaming connection error")]
42    Streaming(#[source] StreamingReadError),
43    #[cfg(feature = "websocket")]
44    #[error("websocket connection error")]
45    Websocket(#[source] WebsocketReadError),
46}
47
48impl<S1, S2> Connection<S1, S2>
49where
50    S1: AsyncRead + AsyncWrite + Unpin,
51    S2: AsyncRead + AsyncWrite + Unpin,
52{
53    pub fn poll_read_next(
54        &mut self,
55        cx: &mut Context<'_>,
56    ) -> Poll<Result<ServerOp, ConnectionReadError>> {
57        match self {
58            Self::Streaming(streaming) => streaming
59                .poll_read_next(cx)
60                .map_err(ConnectionReadError::Streaming),
61            #[cfg(feature = "websocket")]
62            Self::Websocket(websocket) => websocket
63                .poll_read_next(cx)
64                .map_err(ConnectionReadError::Websocket),
65            #[cfg(not(feature = "websocket"))]
66            Self::Websocket(_) => unreachable!(),
67        }
68    }
69
70    /// Read the next incoming server operation.
71    ///
72    /// # Errors
73    ///
74    /// Returns an error if reading or decoding the message fails.
75    pub async fn read_next(&mut self) -> Result<ServerOp, ConnectionReadError> {
76        match self {
77            Self::Streaming(streaming) => streaming
78                .read_next()
79                .await
80                .map_err(ConnectionReadError::Streaming),
81            #[cfg(feature = "websocket")]
82            Self::Websocket(websocket) => websocket
83                .read_next()
84                .await
85                .map_err(ConnectionReadError::Websocket),
86            #[cfg(not(feature = "websocket"))]
87            Self::Websocket(_) => unreachable!(),
88        }
89    }
90
91    pub fn flushes_automatically_when_full(&self) -> bool {
92        match self {
93            Self::Streaming(_streaming) => true,
94            #[cfg(feature = "websocket")]
95            Self::Websocket(_websocket) => false,
96            #[cfg(not(feature = "websocket"))]
97            Self::Websocket(_) => unreachable!(),
98        }
99    }
100
101    pub fn should_flush(&self) -> bool {
102        match self {
103            Self::Streaming(streaming) => streaming.may_flush(),
104            #[cfg(feature = "websocket")]
105            Self::Websocket(websocket) => websocket.should_flush(),
106            #[cfg(not(feature = "websocket"))]
107            Self::Websocket(_) => unreachable!(),
108        }
109    }
110
111    pub fn may_enqueue_more_ops(&mut self) -> bool {
112        match self {
113            Self::Streaming(streaming) => streaming.may_enqueue_more_ops(),
114            #[cfg(feature = "websocket")]
115            Self::Websocket(websocket) => websocket.may_enqueue_more_ops(),
116            #[cfg(not(feature = "websocket"))]
117            Self::Websocket(_) => unreachable!(),
118        }
119    }
120
121    pub fn enqueue_write_op(&mut self, item: &ClientOp) {
122        match self {
123            Self::Streaming(streaming) => streaming.enqueue_write_op(item),
124            #[cfg(feature = "websocket")]
125            Self::Websocket(websocket) => websocket.enqueue_write_op(item),
126            #[cfg(not(feature = "websocket"))]
127            Self::Websocket(_) => unreachable!(),
128        }
129    }
130
131    /// Convenience function for writing enqueued messages and flushing.
132    ///
133    /// # Errors
134    ///
135    /// Returns an error if writing or flushing fails.
136    pub async fn write_and_flush(&mut self) -> io::Result<()> {
137        if let Self::Streaming(streaming) = self {
138            while streaming.may_write() {
139                streaming.write_next().await?;
140            }
141        }
142
143        self.flush().await
144    }
145
146    pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
147        match self {
148            Self::Streaming(streaming) => streaming.poll_flush(cx),
149            #[cfg(feature = "websocket")]
150            Self::Websocket(websocket) => websocket.poll_flush(cx),
151            #[cfg(not(feature = "websocket"))]
152            Self::Websocket(_) => unreachable!(),
153        }
154    }
155
156    /// Flush any buffered writes to the connection
157    ///
158    /// # Errors
159    ///
160    /// Returns an error if flushing fails
161    pub async fn flush(&mut self) -> io::Result<()> {
162        match self {
163            Self::Streaming(streaming) => streaming.flush().await,
164            #[cfg(feature = "websocket")]
165            Self::Websocket(websocket) => websocket.flush().await,
166            #[cfg(not(feature = "websocket"))]
167            Self::Websocket(_) => unreachable!(),
168        }
169    }
170
171    /// Shutdown the connection
172    ///
173    /// # Errors
174    ///
175    /// Returns an error if shutting down the connection fails.
176    /// Implementations usually ignore this error.
177    pub async fn shutdown(&mut self) -> io::Result<()> {
178        match self {
179            Self::Streaming(streaming) => streaming.shutdown().await,
180            #[cfg(feature = "websocket")]
181            Self::Websocket(websocket) => websocket.shutdown().await,
182            #[cfg(not(feature = "websocket"))]
183            Self::Websocket(_) => unreachable!(),
184        }
185    }
186}
187
188#[derive(Debug, thiserror::Error)]
189pub enum ConnectError {
190    #[error("proto")]
191    Proto(#[source] DecoderError),
192    #[error("server")]
193    ServerError(#[source] ServerError),
194    #[error("io")]
195    Io(#[source] io::Error),
196    #[error("unexpected ServerOp")]
197    UnexpectedOp,
198}
199
200/// Send the `CONNECT` command to a pre-establised connection `conn`.
201///
202/// # Errors
203///
204/// Returns an error if connecting fails
205pub async fn connect<S1, S2, F>(
206    conn: &mut Connection<S1, S2>,
207    connect: Connect,
208    after_connect: F,
209) -> Result<(), ConnectError>
210where
211    S1: AsyncRead + AsyncWrite + Unpin,
212    S2: AsyncRead + AsyncWrite + Unpin,
213    F: FnOnce(&mut Connection<S1, S2>),
214{
215    conn.enqueue_write_op(&ClientOp::Connect {
216        connect: Box::new(connect),
217    });
218    conn.write_and_flush().await.map_err(ConnectError::Io)?;
219
220    after_connect(conn);
221    conn.enqueue_write_op(&ClientOp::Ping);
222    conn.write_and_flush().await.map_err(ConnectError::Io)?;
223
224    loop {
225        match conn.read_next().await {
226            Ok(ServerOp::Success) => {
227                // Success. Repeat to receive the PONG
228            }
229            Ok(ServerOp::Pong) => {
230                // Success. We've received the PONG,
231                // possibly after having received OK.
232                return Ok(());
233            }
234            Ok(ServerOp::Ping) => {
235                // I guess this could somehow happen. Handle it and repeat
236                conn.enqueue_write_op(&ClientOp::Pong);
237            }
238            Ok(ServerOp::Error { error }) => return Err(ConnectError::ServerError(error)),
239            Ok(ServerOp::Info { .. } | ServerOp::Message { .. }) => {
240                return Err(ConnectError::UnexpectedOp);
241            }
242            Err(ConnectionReadError::Streaming(StreamingReadError::Decoder(err))) => {
243                return Err(ConnectError::Proto(err))
244            }
245            Err(ConnectionReadError::Streaming(StreamingReadError::Io(err))) => {
246                return Err(ConnectError::Io(err))
247            }
248            #[cfg(feature = "websocket")]
249            Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
250                FrameDecoderError::Decoder(err),
251            ))) => return Err(ConnectError::Proto(err)),
252            #[cfg(feature = "websocket")]
253            Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
254                FrameDecoderError::IncompleteFrame,
255            ))) => todo!(),
256            #[cfg(feature = "websocket")]
257            Err(ConnectionReadError::Websocket(WebsocketReadError::Io(err))) => {
258                return Err(ConnectError::Io(err))
259            }
260            #[cfg(feature = "websocket")]
261            Err(ConnectionReadError::Websocket(WebsocketReadError::Closed)) => todo!(),
262        }
263    }
264}