pocket_relay_client_shared/servers/tunnel.rs
1//! Tunneling server
2//!
3//! Provides a local tunnel that connects clients by tunneling through the Pocket Relay
4//! server. This allows clients with more strict NATs to host games without common issues
5//! faced when trying to connect
6//!
7//! Details can be found on the GitHub issue: https://github.com/PocketRelay/Server/issues/64
8
9use self::codec::{TunnelCodec, TunnelMessage};
10use crate::{
11 api::create_server_tunnel,
12 ctx::ClientContext,
13 servers::{spawn_server_task, GAME_HOST_PORT, RANDOM_PORT, TUNNEL_HOST_PORT},
14};
15use bytes::Bytes;
16use futures::{Sink, Stream};
17use log::{debug, error};
18use reqwest::Upgraded;
19use std::{
20 future::Future,
21 io::ErrorKind,
22 net::{Ipv4Addr, SocketAddr, SocketAddrV4},
23 pin::Pin,
24 sync::Arc,
25 task::{ready, Context, Poll},
26 time::Duration,
27};
28use tokio::{io::ReadBuf, net::UdpSocket, sync::mpsc, try_join};
29use tokio_util::codec::Framed;
30
31/// The fixed size of socket pool to use
32const SOCKET_POOL_SIZE: usize = 4;
33/// Max tunnel creation attempts that can be an error before cancelling
34const MAX_ERROR_ATTEMPTS: u8 = 5;
35
36// Local address the client uses to send packets
37static LOCAL_SEND_TARGET: SocketAddr =
38 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, GAME_HOST_PORT));
39
40/// Starts the tunnel socket pool and creates the tunnel
41/// connection to the server
42///
43/// ## Arguments
44/// * `ctx` - The client context
45pub async fn start_tunnel_server(ctx: Arc<ClientContext>) -> std::io::Result<()> {
46 let association = match Option::as_ref(&ctx.association) {
47 Some(value) => value,
48 // Don't try and tunnel without a token
49 None => return Ok(()),
50 };
51
52 // Last encountered error
53 let mut last_error: Option<std::io::Error> = None;
54 // Number of attempts that errored
55 let mut attempt_errors: u8 = 0;
56
57 // Looping to attempt reconnecting if lost
58 while attempt_errors < MAX_ERROR_ATTEMPTS {
59 // Create the tunnel (Future will end if tunnel stopped)
60 let reconnect_time = if let Err(err) = create_tunnel(ctx.clone(), association).await {
61 error!("Failed to create tunnel: {}", err);
62
63 // Set last error
64 last_error = Some(err);
65
66 // Increase error attempts
67 attempt_errors += 1;
68
69 // Error should be delayed by the number of errors already hit
70 Duration::from_millis(1000 * attempt_errors as u64)
71 } else {
72 // Reset error attempts
73 attempt_errors = 0;
74
75 // Non errored reconnect can be quick
76 Duration::from_millis(1000)
77 };
78
79 debug!(
80 "Next tunnel create attempt in: {}s",
81 reconnect_time.as_secs()
82 );
83
84 // Wait before attempting to re-create the tunnel
85 tokio::time::sleep(reconnect_time).await;
86 }
87
88 Err(last_error.unwrap_or(std::io::Error::new(
89 ErrorKind::Other,
90 "Reached error connect limit",
91 )))
92}
93
94/// Creates a new tunnel
95///
96/// ## Arguments
97/// * `ctx` - The client context
98/// * `association` - The client association token
99async fn create_tunnel(ctx: Arc<ClientContext>, association: &str) -> std::io::Result<()> {
100 // Create the tunnel with the server
101 let io = create_server_tunnel(&ctx.http_client, &ctx.base_url, association)
102 .await
103 // Wrap the tunnel with the [`TunnelCodec`] framing
104 .map(|io| Framed::new(io, TunnelCodec::default()))
105 // Wrap the error into an [`std::io::Error`]
106 .map_err(|err| std::io::Error::new(ErrorKind::Other, err))?;
107 debug!("Created server tunnel");
108
109 // Allocate the socket pool for the tunnel
110 let (tx, rx) = mpsc::unbounded_channel();
111 let pool = Socket::allocate_pool(tx).await?;
112 debug!("Allocated tunnel pool");
113
114 // Start the tunnel
115 Tunnel {
116 io,
117 rx,
118 pool,
119 write_state: Default::default(),
120 }
121 .await;
122
123 Ok(())
124}
125
126/// Represents a tunnel and its pool of connections that it can
127/// send data to and receive data from
128struct Tunnel {
129 /// Tunnel connection to the Pocket Relay server for sending [`TunnelMessage`]s
130 /// through the server to reach a specific peer
131 io: Framed<Upgraded, TunnelCodec>,
132 /// Receiver for receiving messages from [`Socket`]s within the [`Tunnel::pool`]
133 /// that need to be sent through [`Tunnel::io`]
134 rx: mpsc::UnboundedReceiver<TunnelMessage>,
135 /// Pool of [`Socket`]s that this tunnel can use for sending out messages
136 pool: [SocketHandle; SOCKET_POOL_SIZE],
137 /// Current state of writing [`TunnelMessage`]s to the [`Tunnel::io`]
138 write_state: TunnelWriteState,
139}
140
141/// Holds the state for the current writing progress for a [`Tunnel`]
142#[derive(Default)]
143enum TunnelWriteState {
144 /// Waiting for a message to come through the [`Tunnel::rx`]
145 #[default]
146 Recv,
147 /// Waiting for the [`Tunnel::io`] to be writable, then writing the
148 /// contained [`TunnelMessage`]
149 Write(Option<TunnelMessage>),
150 /// Poll flushing the bytes written to [`Tunnel::io`]
151 Flush,
152 /// The tunnel has stopped and should not continue
153 Stop,
154}
155
156/// Holds the state for the current reading progress for a [`Tunnel`]
157enum TunnelReadState {
158 /// Continue reading
159 Continue,
160 /// The tunnel has stopped and should not continue
161 Stop,
162}
163
164impl Tunnel {
165 /// Polls accepting messages from [`Tunnel::rx`] then writing them to [`Tunnel::io`] and
166 /// flushing the underlying stream. Provides the next [`TunnelWriteState`]
167 /// when [`Poll::Ready`] is returned
168 ///
169 /// Should be repeatedly called until it no-longer returns [`Poll::Ready`]
170 fn poll_write_state(&mut self, cx: &mut Context<'_>) -> Poll<TunnelWriteState> {
171 Poll::Ready(match &mut self.write_state {
172 TunnelWriteState::Recv => {
173 // Try receive a packet from the write channel
174 let result = ready!(Pin::new(&mut self.rx).poll_recv(cx));
175
176 if let Some(message) = result {
177 TunnelWriteState::Write(Some(message))
178 } else {
179 // All writers have closed, tunnel must be closed (Future end)
180 TunnelWriteState::Stop
181 }
182 }
183 TunnelWriteState::Write(message) => {
184 // Wait until the `io` is ready
185 if ready!(Pin::new(&mut self.io).poll_ready(cx)).is_ok() {
186 let message = message
187 .take()
188 .expect("Unexpected write state without message");
189
190 // Write the packet to the buffer
191 Pin::new(&mut self.io)
192 .start_send(message)
193 // Packet encoder impl shouldn't produce errors
194 .expect("Message encoder errored");
195
196 TunnelWriteState::Flush
197 } else {
198 // Failed to ready, tunnel must be closed
199 TunnelWriteState::Stop
200 }
201 }
202 TunnelWriteState::Flush => {
203 // Poll flushing `io`
204 if ready!(Pin::new(&mut self.io).poll_flush(cx)).is_ok() {
205 TunnelWriteState::Recv
206 } else {
207 // Failed to flush, tunnel must be closed
208 TunnelWriteState::Stop
209 }
210 }
211
212 // Tunnel should *NOT* be polled if its already stopped
213 TunnelWriteState::Stop => panic!("Tunnel polled after already stopped"),
214 })
215 }
216
217 /// Polls reading messages from [`Tunnel::io`] and sending them to the correct
218 /// handle within the [`Tunnel::pool`]. Provides the next [`TunnelReadState`]
219 /// when [`Poll::Ready`] is returned
220 ///
221 /// Should be repeatedly called until it no-longer returns [`Poll::Ready`]
222 fn poll_read_state(&mut self, cx: &mut Context<'_>) -> Poll<TunnelReadState> {
223 // Try receive a message from the `io`
224 let Some(Ok(message)) = ready!(Pin::new(&mut self.io).poll_next(cx)) else {
225 // Cannot read next message stop the tunnel
226 return Poll::Ready(TunnelReadState::Stop);
227 };
228
229 if message.index == 255 {
230 // Write a ping response if we aren't already writing another message
231 if let TunnelWriteState::Recv = self.write_state {
232 // Move to a writing state
233 self.write_state = TunnelWriteState::Write(Some(TunnelMessage {
234 index: 255,
235 message: Bytes::new(),
236 }));
237
238 // Poll the write state
239 if let Poll::Ready(next_state) = self.poll_write_state(cx) {
240 self.write_state = next_state;
241
242 // Tunnel has stopped
243 if let TunnelWriteState::Stop = self.write_state {
244 return Poll::Ready(TunnelReadState::Stop);
245 }
246 }
247 }
248
249 return Poll::Ready(TunnelReadState::Continue);
250 }
251
252 // Get the handle to use within the connection pool
253 let handle = self.pool.get(message.index as usize);
254
255 // Send the message to the handle if its valid
256 if let Some(handle) = handle {
257 _ = handle.0.send(message);
258 }
259
260 Poll::Ready(TunnelReadState::Continue)
261 }
262}
263
264impl Future for Tunnel {
265 type Output = ();
266
267 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
268 let this = self.get_mut();
269
270 // Poll the write half
271 while let Poll::Ready(next_state) = this.poll_write_state(cx) {
272 this.write_state = next_state;
273
274 // Tunnel has stopped
275 if let TunnelWriteState::Stop = this.write_state {
276 return Poll::Ready(());
277 }
278 }
279
280 // Poll the read half
281 while let Poll::Ready(next_state) = this.poll_read_state(cx) {
282 // Tunnel has stopped
283 if let TunnelReadState::Stop = next_state {
284 return Poll::Ready(());
285 }
286 }
287
288 Poll::Pending
289 }
290}
291
292/// Handle to a [`Socket`] for sending [`TunnelMessage`]s that the
293/// socket should send to the [`LOCAL_SEND_TARGET`]
294#[derive(Clone)]
295struct SocketHandle(mpsc::UnboundedSender<TunnelMessage>);
296
297/// Size of the socket read buffer 2^16 bytes
298///
299/// Can likely be reduced to 2^15 bytes or 2^13 bytes (or lower) since
300/// highest observed message length was 1254 bytes but testing is required
301/// before that can take place
302const READ_BUFFER_LENGTH: usize = 2usize.pow(16);
303
304/// Socket used by a [`Tunnel`] for sending and receiving messages in
305/// order to simulate another player on the local network
306struct Socket {
307 // Index of the socket
308 index: u8,
309 // The underlying socket for sending and receiving
310 socket: UdpSocket,
311 /// Receiver for messages coming from the the [`Tunnel`] that need to be
312 /// send through the socket
313 rx: mpsc::UnboundedReceiver<TunnelMessage>,
314 /// Sender for sending [`TunnelMessage`]s through the associated [`Tunnel`]
315 /// in order for them to be sent to the correct peer on the other side
316 tun_tx: mpsc::UnboundedSender<TunnelMessage>,
317 /// Buffer for reading bytes from the `socket`
318 read_buffer: [u8; READ_BUFFER_LENGTH],
319 /// Current state of writing [`TunnelMessage`]s to the `socket`
320 write_state: SocketWriteState,
321}
322
323/// Holds the state for the current writing progress for a [`Socket`]
324#[derive(Default)]
325enum SocketWriteState {
326 /// Waiting for a message to come through the [`Socket::rx`]
327 #[default]
328 Recv,
329 /// Waiting for the [`Socket::socket`] to write the bytes
330 Write(Bytes),
331 /// The tunnel has stopped and should not continue
332 Stop,
333}
334
335/// Holds the state for the current reading progress for a [`Socket`]
336enum SocketReadState {
337 /// Continue reading
338 Continue,
339 /// The tunnel has stopped and should not continue
340 Stop,
341}
342
343impl Socket {
344 /// Allocates a pool of [`Socket`]s for a [`Tunnel`] to use
345 ///
346 /// ## Arguments
347 /// * `tun_tx` - The tunnel sender for sending [`TunnelMessage`]s through the tunnel
348 async fn allocate_pool(
349 tun_tx: mpsc::UnboundedSender<TunnelMessage>,
350 ) -> std::io::Result<[SocketHandle; SOCKET_POOL_SIZE]> {
351 let sockets = try_join!(
352 // Host socket index *must* use a fixed port since its used on the server side
353 Socket::start(0, TUNNEL_HOST_PORT, tun_tx.clone()),
354 // Other sockets can used OS auto assigned port
355 Socket::start(1, RANDOM_PORT, tun_tx.clone()),
356 Socket::start(2, RANDOM_PORT, tun_tx.clone()),
357 Socket::start(3, RANDOM_PORT, tun_tx),
358 )?;
359 Ok(sockets.into())
360 }
361
362 /// Starts a new tunnel socket returning a [`SocketHandle`] that can be used
363 /// to send [`TunnelMessage`]s to the socket
364 ///
365 /// ## Arguments
366 /// * `index` - The index of the socket
367 /// * `port` - The port to bind the socket on
368 /// * `tun_tx` - The tunnel sender for sending [`TunnelMessage`]s through the tunnel
369 async fn start(
370 index: u8,
371 port: u16,
372 tun_tx: mpsc::UnboundedSender<TunnelMessage>,
373 ) -> std::io::Result<SocketHandle> {
374 // Bind the socket
375 let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, port)).await?;
376 // Set the socket send target
377 socket.connect(LOCAL_SEND_TARGET).await?;
378
379 // Create the message channel
380 let (tx, rx) = mpsc::unbounded_channel();
381
382 // Spawn the socket task
383 spawn_server_task(Socket {
384 index,
385 socket,
386 rx,
387 tun_tx,
388 read_buffer: [0; READ_BUFFER_LENGTH],
389 write_state: Default::default(),
390 });
391
392 Ok(SocketHandle(tx))
393 }
394
395 /// Polls accepting messages from [`Socket::rx`] then writing them to the [`Socket::socket`].
396 /// Provides the next [`SocketWriteState`] when [`Poll::Ready`] is returned
397 ///
398 /// Should be repeatedly called until it no-longer returns [`Poll::Ready`]
399 fn poll_write_state(&mut self, cx: &mut Context<'_>) -> Poll<SocketWriteState> {
400 Poll::Ready(match &mut self.write_state {
401 SocketWriteState::Recv => {
402 // Try receive a packet from the write channel
403 let result = ready!(Pin::new(&mut self.rx).poll_recv(cx));
404
405 if let Some(message) = result {
406 SocketWriteState::Write(message.message)
407 } else {
408 // All writers have closed, tunnel must be closed (Future end)
409 SocketWriteState::Stop
410 }
411 }
412 SocketWriteState::Write(message) => {
413 // Try send the message to the local target
414 let Ok(count) = ready!(self.socket.poll_send(cx, message)) else {
415 return Poll::Ready(SocketWriteState::Stop);
416 };
417
418 // Didn't write the entire message
419 if count != message.len() {
420 // Continue with a writing state at the remaining message
421 let message = message.slice(count..);
422 SocketWriteState::Write(message)
423 } else {
424 SocketWriteState::Recv
425 }
426 }
427
428 // Tunnel socket should *NOT* be polled if its already stopped
429 SocketWriteState::Stop => panic!("Tunnel socket polled after already stopped"),
430 })
431 }
432
433 /// Polls reading messages from `socket` and sending them to the [`Tunnel`]
434 /// in order for them to be sent out to the peer. Provides the next
435 /// [`SocketReadState`] when [`Poll::Ready`] is returned
436 ///
437 /// Should be repeatedly called until it no-longer returns [`Poll::Ready`]
438 fn poll_read_state(&mut self, cx: &mut Context<'_>) -> Poll<SocketReadState> {
439 let mut read_buf = ReadBuf::new(&mut self.read_buffer);
440
441 // Try receive a message from the socket
442 if ready!(self.socket.poll_recv(cx, &mut read_buf)).is_err() {
443 return Poll::Ready(SocketReadState::Stop);
444 }
445
446 // Get the received message
447 let bytes = read_buf.filled();
448 let message = Bytes::copy_from_slice(bytes);
449 let message = TunnelMessage {
450 index: self.index,
451 message,
452 };
453
454 // Send the message through the tunnel
455 Poll::Ready(if self.tun_tx.send(message).is_ok() {
456 SocketReadState::Continue
457 } else {
458 SocketReadState::Stop
459 })
460 }
461}
462
463impl Future for Socket {
464 type Output = ();
465
466 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
467 let this = self.get_mut();
468
469 // Poll the write half
470 while let Poll::Ready(next_state) = this.poll_write_state(cx) {
471 this.write_state = next_state;
472
473 // Tunnel has stopped
474 if let SocketWriteState::Stop = this.write_state {
475 return Poll::Ready(());
476 }
477 }
478
479 // Poll the read half
480 while let Poll::Ready(next_state) = this.poll_read_state(cx) {
481 // Tunnel has stopped
482 if let SocketReadState::Stop = next_state {
483 return Poll::Ready(());
484 }
485 }
486
487 Poll::Pending
488 }
489}
490
491mod codec {
492 //! This modules contains the codec and message structures for [TunnelMessage]s
493 //!
494 //! # Tunnel Messages
495 //!
496 //! Tunnel message frames are as follows:
497 //!
498 //! ```text
499 //! 0 1 2
500 //! 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3
501 //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
502 //! | Index | Length |
503 //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
504 //! | :
505 //! : Payload :
506 //! : |
507 //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
508 //! ```
509 //!
510 //! Tunnel message frames contain the following fields:
511 //!
512 //! Index: 8-bits. Determines the destination of the message within the current pool.
513 //!
514 //! Length: 16-bits. Determines the size in bytes of the payload that follows
515 //!
516 //! Payload: Variable length. The message bytes payload of `Length`
517 //!
518 //!
519 //! ## Keep alive
520 //!
521 //! The server will send keep-alive messages, these are in the same
522 //! format as the packet above. However, the index will always be 255
523 //! and the payload will be empty.
524
525 use bytes::{Buf, BufMut, Bytes};
526 use tokio_util::codec::{Decoder, Encoder};
527
528 /// Header portion of a [TunnelMessage] that contains the
529 /// index of the message and the length of the expected payload
530 struct TunnelMessageHeader {
531 /// Socket index to use
532 index: u8,
533 /// The length of the tunnel message bytes
534 length: u16,
535 }
536
537 /// Message sent through the tunnel
538 pub struct TunnelMessage {
539 /// Socket index to use
540 pub index: u8,
541 /// The message contents
542 pub message: Bytes,
543 }
544
545 /// Codec for encoding and decoding tunnel messages
546 #[derive(Default)]
547 pub struct TunnelCodec {
548 /// Stores the current message header while its waiting
549 /// for the full payload to become available
550 partial: Option<TunnelMessageHeader>,
551 }
552
553 impl Decoder for TunnelCodec {
554 type Item = TunnelMessage;
555 type Error = std::io::Error;
556
557 fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
558 let partial = match self.partial.as_mut() {
559 Some(value) => value,
560 None => {
561 // Not enough room for a partial frame
562 if src.len() < 5 {
563 return Ok(None);
564 }
565 let index = src.get_u8();
566 let length = src.get_u16();
567
568 self.partial.insert(TunnelMessageHeader { index, length })
569 }
570 };
571 // Not enough data for the partial frame
572 if src.len() < partial.length as usize {
573 return Ok(None);
574 }
575
576 let partial = self.partial.take().expect("Partial frame missing");
577 let bytes = src.split_to(partial.length as usize);
578
579 Ok(Some(TunnelMessage {
580 index: partial.index,
581 message: bytes.freeze(),
582 }))
583 }
584 }
585
586 impl Encoder<TunnelMessage> for TunnelCodec {
587 type Error = std::io::Error;
588
589 fn encode(
590 &mut self,
591 item: TunnelMessage,
592 dst: &mut bytes::BytesMut,
593 ) -> Result<(), Self::Error> {
594 dst.put_u8(item.index);
595 dst.put_u16(item.message.len() as u16);
596 dst.extend_from_slice(&item.message);
597 Ok(())
598 }
599 }
600}