watermelon_net/connection/
mod.rs1#[cfg(not(feature = "websocket"))]
2use std::{convert::Infallible, marker::PhantomData};
3use std::{
4 io,
5 task::{Context, Poll},
6};
7
8use tokio::io::{AsyncRead, AsyncWrite};
9#[cfg(feature = "websocket")]
10use watermelon_proto::proto::error::FrameDecoderError;
11use watermelon_proto::{
12 error::ServerError,
13 proto::{error::DecoderError, ClientOp, ServerOp},
14 Connect,
15};
16
17pub use self::streaming::{StreamingConnection, StreamingReadError};
18#[cfg(feature = "websocket")]
19pub use self::websocket::{WebsocketConnection, WebsocketReadError};
20
21mod streaming;
22#[cfg(feature = "websocket")]
23mod websocket;
24
25#[derive(Debug)]
26pub enum Connection<S1, S2> {
27 Streaming(StreamingConnection<S1>),
28 Websocket(WebsocketConnection<S2>),
29}
30
31#[derive(Debug)]
32#[cfg(not(feature = "websocket"))]
33#[doc(hidden)]
34pub struct WebsocketConnection<S> {
35 _socket: PhantomData<S>,
36 _impossible: Infallible,
37}
38
39#[derive(Debug, thiserror::Error)]
40pub enum ConnectionReadError {
41 #[error("streaming connection error")]
42 Streaming(#[source] StreamingReadError),
43 #[cfg(feature = "websocket")]
44 #[error("websocket connection error")]
45 Websocket(#[source] WebsocketReadError),
46}
47
48impl<S1, S2> Connection<S1, S2>
49where
50 S1: AsyncRead + AsyncWrite + Unpin,
51 S2: AsyncRead + AsyncWrite + Unpin,
52{
53 pub fn poll_read_next(
54 &mut self,
55 cx: &mut Context<'_>,
56 ) -> Poll<Result<ServerOp, ConnectionReadError>> {
57 match self {
58 Self::Streaming(streaming) => streaming
59 .poll_read_next(cx)
60 .map_err(ConnectionReadError::Streaming),
61 #[cfg(feature = "websocket")]
62 Self::Websocket(websocket) => websocket
63 .poll_read_next(cx)
64 .map_err(ConnectionReadError::Websocket),
65 #[cfg(not(feature = "websocket"))]
66 Self::Websocket(_) => unreachable!(),
67 }
68 }
69
70 pub async fn read_next(&mut self) -> Result<ServerOp, ConnectionReadError> {
76 match self {
77 Self::Streaming(streaming) => streaming
78 .read_next()
79 .await
80 .map_err(ConnectionReadError::Streaming),
81 #[cfg(feature = "websocket")]
82 Self::Websocket(websocket) => websocket
83 .read_next()
84 .await
85 .map_err(ConnectionReadError::Websocket),
86 #[cfg(not(feature = "websocket"))]
87 Self::Websocket(_) => unreachable!(),
88 }
89 }
90
91 pub fn flushes_automatically_when_full(&self) -> bool {
92 match self {
93 Self::Streaming(_streaming) => true,
94 #[cfg(feature = "websocket")]
95 Self::Websocket(_websocket) => false,
96 #[cfg(not(feature = "websocket"))]
97 Self::Websocket(_) => unreachable!(),
98 }
99 }
100
101 pub fn should_flush(&self) -> bool {
102 match self {
103 Self::Streaming(streaming) => streaming.may_flush(),
104 #[cfg(feature = "websocket")]
105 Self::Websocket(websocket) => websocket.should_flush(),
106 #[cfg(not(feature = "websocket"))]
107 Self::Websocket(_) => unreachable!(),
108 }
109 }
110
111 pub fn may_enqueue_more_ops(&mut self) -> bool {
112 match self {
113 Self::Streaming(streaming) => streaming.may_enqueue_more_ops(),
114 #[cfg(feature = "websocket")]
115 Self::Websocket(websocket) => websocket.may_enqueue_more_ops(),
116 #[cfg(not(feature = "websocket"))]
117 Self::Websocket(_) => unreachable!(),
118 }
119 }
120
121 pub fn enqueue_write_op(&mut self, item: &ClientOp) {
122 match self {
123 Self::Streaming(streaming) => streaming.enqueue_write_op(item),
124 #[cfg(feature = "websocket")]
125 Self::Websocket(websocket) => websocket.enqueue_write_op(item),
126 #[cfg(not(feature = "websocket"))]
127 Self::Websocket(_) => unreachable!(),
128 }
129 }
130
131 pub async fn write_and_flush(&mut self) -> io::Result<()> {
137 if let Self::Streaming(streaming) = self {
138 while streaming.may_write() {
139 streaming.write_next().await?;
140 }
141 }
142
143 self.flush().await
144 }
145
146 pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
147 match self {
148 Self::Streaming(streaming) => streaming.poll_flush(cx),
149 #[cfg(feature = "websocket")]
150 Self::Websocket(websocket) => websocket.poll_flush(cx),
151 #[cfg(not(feature = "websocket"))]
152 Self::Websocket(_) => unreachable!(),
153 }
154 }
155
156 pub async fn flush(&mut self) -> io::Result<()> {
162 match self {
163 Self::Streaming(streaming) => streaming.flush().await,
164 #[cfg(feature = "websocket")]
165 Self::Websocket(websocket) => websocket.flush().await,
166 #[cfg(not(feature = "websocket"))]
167 Self::Websocket(_) => unreachable!(),
168 }
169 }
170
171 pub async fn shutdown(&mut self) -> io::Result<()> {
178 match self {
179 Self::Streaming(streaming) => streaming.shutdown().await,
180 #[cfg(feature = "websocket")]
181 Self::Websocket(websocket) => websocket.shutdown().await,
182 #[cfg(not(feature = "websocket"))]
183 Self::Websocket(_) => unreachable!(),
184 }
185 }
186}
187
188#[derive(Debug, thiserror::Error)]
189pub enum ConnectError {
190 #[error("proto")]
191 Proto(#[source] DecoderError),
192 #[error("server")]
193 ServerError(#[source] ServerError),
194 #[error("io")]
195 Io(#[source] io::Error),
196 #[error("unexpected ServerOp")]
197 UnexpectedOp,
198}
199
200pub async fn connect<S1, S2, F>(
206 conn: &mut Connection<S1, S2>,
207 connect: Connect,
208 after_connect: F,
209) -> Result<(), ConnectError>
210where
211 S1: AsyncRead + AsyncWrite + Unpin,
212 S2: AsyncRead + AsyncWrite + Unpin,
213 F: FnOnce(&mut Connection<S1, S2>),
214{
215 conn.enqueue_write_op(&ClientOp::Connect {
216 connect: Box::new(connect),
217 });
218 conn.write_and_flush().await.map_err(ConnectError::Io)?;
219
220 after_connect(conn);
221 conn.enqueue_write_op(&ClientOp::Ping);
222 conn.write_and_flush().await.map_err(ConnectError::Io)?;
223
224 loop {
225 match conn.read_next().await {
226 Ok(ServerOp::Success) => {
227 }
229 Ok(ServerOp::Pong) => {
230 return Ok(());
233 }
234 Ok(ServerOp::Ping) => {
235 conn.enqueue_write_op(&ClientOp::Pong);
237 }
238 Ok(ServerOp::Error { error }) => return Err(ConnectError::ServerError(error)),
239 Ok(ServerOp::Info { .. } | ServerOp::Message { .. }) => {
240 return Err(ConnectError::UnexpectedOp);
241 }
242 Err(ConnectionReadError::Streaming(StreamingReadError::Decoder(err))) => {
243 return Err(ConnectError::Proto(err))
244 }
245 Err(ConnectionReadError::Streaming(StreamingReadError::Io(err))) => {
246 return Err(ConnectError::Io(err))
247 }
248 #[cfg(feature = "websocket")]
249 Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
250 FrameDecoderError::Decoder(err),
251 ))) => return Err(ConnectError::Proto(err)),
252 #[cfg(feature = "websocket")]
253 Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder(
254 FrameDecoderError::IncompleteFrame,
255 ))) => todo!(),
256 #[cfg(feature = "websocket")]
257 Err(ConnectionReadError::Websocket(WebsocketReadError::Io(err))) => {
258 return Err(ConnectError::Io(err))
259 }
260 #[cfg(feature = "websocket")]
261 Err(ConnectionReadError::Websocket(WebsocketReadError::Closed)) => todo!(),
262 }
263 }
264}