watermelon_net/connection/
mod.rs1#[cfg(not(feature = "websocket"))]
2use std::{convert::Infallible, marker::PhantomData};
3use std::{
4 io,
5 task::{Context, Poll},
6};
7
8use tokio::io::{AsyncRead, AsyncWrite};
9#[cfg(feature = "websocket")]
10use watermelon_proto::proto::error::FrameDecoderError;
11use watermelon_proto::{
12 Connect,
13 error::ServerError,
14 proto::{ClientOp, ServerOp, error::DecoderError},
15};
16
17pub use self::streaming::{StreamingConnection, StreamingReadError};
18#[cfg(feature = "websocket")]
19pub use self::websocket::{WebsocketConnection, WebsocketReadError};
20
21mod streaming;
22#[cfg(feature = "websocket")]
23mod websocket;
24
25#[derive(Debug)]
26#[expect(clippy::large_enum_variant)]
27pub enum Connection<S1, S2> {
28 Streaming(StreamingConnection<S1>),
29 Websocket(WebsocketConnection<S2>),
30}
31
32#[derive(Debug)]
33#[cfg(not(feature = "websocket"))]
34#[doc(hidden)]
35pub struct WebsocketConnection<S> {
36 _socket: PhantomData<S>,
37 _impossible: Infallible,
38}
39
40#[derive(Debug, thiserror::Error)]
41pub enum ConnectionReadError {
42 #[error("streaming connection error")]
43 Streaming(#[source] StreamingReadError),
44 #[cfg(feature = "websocket")]
45 #[error("websocket connection error")]
46 Websocket(#[source] WebsocketReadError),
47}
48
49impl<S1, S2> Connection<S1, S2>
50where
51 S1: AsyncRead + AsyncWrite + Unpin,
52 S2: AsyncRead + AsyncWrite + Unpin,
53{
54 pub fn poll_read_next(
55 &mut self,
56 cx: &mut Context<'_>,
57 ) -> Poll<Result<ServerOp, ConnectionReadError>> {
58 match self {
59 Self::Streaming(streaming) => streaming
60 .poll_read_next(cx)
61 .map_err(ConnectionReadError::Streaming),
62 #[cfg(feature = "websocket")]
63 Self::Websocket(websocket) => websocket
64 .poll_read_next(cx)
65 .map_err(ConnectionReadError::Websocket),
66 #[cfg(not(feature = "websocket"))]
67 Self::Websocket(_) => unreachable!(),
68 }
69 }
70
71 pub async fn read_next(&mut self) -> Result<ServerOp, ConnectionReadError> {
77 match self {
78 Self::Streaming(streaming) => streaming
79 .read_next()
80 .await
81 .map_err(ConnectionReadError::Streaming),
82 #[cfg(feature = "websocket")]
83 Self::Websocket(websocket) => websocket
84 .read_next()
85 .await
86 .map_err(ConnectionReadError::Websocket),
87 #[cfg(not(feature = "websocket"))]
88 Self::Websocket(_) => unreachable!(),
89 }
90 }
91
92 pub fn flushes_automatically_when_full(&self) -> bool {
93 match self {
94 Self::Streaming(_streaming) => true,
95 #[cfg(feature = "websocket")]
96 Self::Websocket(_websocket) => false,
97 #[cfg(not(feature = "websocket"))]
98 Self::Websocket(_) => unreachable!(),
99 }
100 }
101
102 pub fn should_flush(&self) -> bool {
103 match self {
104 Self::Streaming(streaming) => streaming.may_flush(),
105 #[cfg(feature = "websocket")]
106 Self::Websocket(websocket) => websocket.should_flush(),
107 #[cfg(not(feature = "websocket"))]
108 Self::Websocket(_) => unreachable!(),
109 }
110 }
111
112 pub fn may_enqueue_more_ops(&mut self) -> bool {
113 match self {
114 Self::Streaming(streaming) => streaming.may_enqueue_more_ops(),
115 #[cfg(feature = "websocket")]
116 Self::Websocket(websocket) => websocket.may_enqueue_more_ops(),
117 #[cfg(not(feature = "websocket"))]
118 Self::Websocket(_) => unreachable!(),
119 }
120 }
121
122 pub fn enqueue_write_op(&mut self, item: &ClientOp) {
123 match self {
124 Self::Streaming(streaming) => streaming.enqueue_write_op(item),
125 #[cfg(feature = "websocket")]
126 Self::Websocket(websocket) => websocket.enqueue_write_op(item),
127 #[cfg(not(feature = "websocket"))]
128 Self::Websocket(_) => unreachable!(),
129 }
130 }
131
132 pub async fn write_and_flush(&mut self) -> io::Result<()> {
138 if let Self::Streaming(streaming) = self {
139 while streaming.may_write() {
140 streaming.write_next().await?;
141 }
142 }
143
144 self.flush().await
145 }
146
147 pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
148 match self {
149 Self::Streaming(streaming) => streaming.poll_flush(cx),
150 #[cfg(feature = "websocket")]
151 Self::Websocket(websocket) => websocket.poll_flush(cx),
152 #[cfg(not(feature = "websocket"))]
153 Self::Websocket(_) => unreachable!(),
154 }
155 }
156
157 pub async fn flush(&mut self) -> io::Result<()> {
163 match self {
164 Self::Streaming(streaming) => streaming.flush().await,
165 #[cfg(feature = "websocket")]
166 Self::Websocket(websocket) => websocket.flush().await,
167 #[cfg(not(feature = "websocket"))]
168 Self::Websocket(_) => unreachable!(),
169 }
170 }
171
172 pub async fn shutdown(&mut self) -> io::Result<()> {
179 match self {
180 Self::Streaming(streaming) => streaming.shutdown().await,
181 #[cfg(feature = "websocket")]
182 Self::Websocket(websocket) => websocket.shutdown().await,
183 #[cfg(not(feature = "websocket"))]
184 Self::Websocket(_) => unreachable!(),
185 }
186 }
187}
188
189#[derive(Debug, thiserror::Error)]
190pub enum ConnectError {
191 #[error("proto")]
192 Proto(#[source] DecoderError),
193 #[error("server")]
194 ServerError(#[source] ServerError),
195 #[error("io")]
196 Io(#[source] io::Error),
197 #[error("unexpected ServerOp")]
198 UnexpectedOp,
199}
200
201pub async fn connect<S1, S2, F>(
207 conn: &mut Connection<S1, S2>,
208 connect: Connect,
209 after_connect: F,
210) -> Result<(), ConnectError>
211where
212 S1: AsyncRead + AsyncWrite + Unpin,
213 S2: AsyncRead + AsyncWrite + Unpin,
214 F: FnOnce(&mut Connection<S1, S2>),
215{
216 conn.enqueue_write_op(&ClientOp::Connect {
217 connect: Box::new(connect),
218 });
219 conn.write_and_flush().await.map_err(ConnectError::Io)?;
220
221 after_connect(conn);
222 conn.enqueue_write_op(&ClientOp::Ping);
223 conn.write_and_flush().await.map_err(ConnectError::Io)?;
224
225 loop {
226 match conn.read_next().await {
227 Ok(ServerOp::Success) => {
228 }
230 Ok(ServerOp::Pong) => {
231 return Ok(());
234 }
235 Ok(ServerOp::Ping) => {
236 conn.enqueue_write_op(&ClientOp::Pong);
238 }
239 Ok(ServerOp::Error { error }) => return Err(ConnectError::ServerError(error)),
240 Ok(ServerOp::Info { .. } | ServerOp::Message { .. }) => {
241 return Err(ConnectError::UnexpectedOp);
242 }
243 Err(ConnectionReadError::Streaming(StreamingReadError::Decoder(err))) => {
244 return Err(ConnectError::Proto(err));
245 }
246 Err(ConnectionReadError::Streaming(StreamingReadError::Io(err))) => {
247 return Err(ConnectError::Io(err));
248 }
249 #[cfg(feature = "websocket")]
250 Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
251 FrameDecoderError::Decoder(err),
252 ))) => return Err(ConnectError::Proto(err)),
253 #[cfg(feature = "websocket")]
254 Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
255 FrameDecoderError::IncompleteFrame,
256 ))) => todo!(),
257 #[cfg(feature = "websocket")]
258 Err(ConnectionReadError::Websocket(WebsocketReadError::Io(err))) => {
259 return Err(ConnectError::Io(err));
260 }
261 #[cfg(feature = "websocket")]
262 Err(ConnectionReadError::Websocket(WebsocketReadError::Closed)) => todo!(),
263 }
264 }
265}