pocket_relay_client_shared/servers/
tunnel.rs

1//! Tunneling server
2//!
3//! Provides a local tunnel that connects clients by tunneling through the Pocket Relay
4//! server. This allows clients with more strict NATs to host games without common issues
5//! faced when trying to connect
6//!
7//! Details can be found on the GitHub issue: https://github.com/PocketRelay/Server/issues/64
8
9use self::codec::{TunnelCodec, TunnelMessage};
10use crate::{
11    api::create_server_tunnel,
12    ctx::ClientContext,
13    servers::{spawn_server_task, GAME_HOST_PORT, RANDOM_PORT, TUNNEL_HOST_PORT},
14};
15use bytes::Bytes;
16use futures::{Sink, Stream};
17use log::{debug, error};
18use reqwest::Upgraded;
19use std::{
20    future::Future,
21    io::ErrorKind,
22    net::{Ipv4Addr, SocketAddr, SocketAddrV4},
23    pin::Pin,
24    sync::Arc,
25    task::{ready, Context, Poll},
26    time::Duration,
27};
28use tokio::{io::ReadBuf, net::UdpSocket, sync::mpsc, try_join};
29use tokio_util::codec::Framed;
30
31/// The fixed size of socket pool to use
32const SOCKET_POOL_SIZE: usize = 4;
33/// Max tunnel creation attempts that can be an error before cancelling
34const MAX_ERROR_ATTEMPTS: u8 = 5;
35
36// Local address the client uses to send packets
37static LOCAL_SEND_TARGET: SocketAddr =
38    SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, GAME_HOST_PORT));
39
40/// Starts the tunnel socket pool and creates the tunnel
41/// connection to the server
42///
43/// ## Arguments
44/// * `ctx` - The client context
45pub async fn start_tunnel_server(ctx: Arc<ClientContext>) -> std::io::Result<()> {
46    let association = match Option::as_ref(&ctx.association) {
47        Some(value) => value,
48        // Don't try and tunnel without a token
49        None => return Ok(()),
50    };
51
52    // Last encountered error
53    let mut last_error: Option<std::io::Error> = None;
54    // Number of attempts that errored
55    let mut attempt_errors: u8 = 0;
56
57    // Looping to attempt reconnecting if lost
58    while attempt_errors < MAX_ERROR_ATTEMPTS {
59        // Create the tunnel (Future will end if tunnel stopped)
60        let reconnect_time = if let Err(err) = create_tunnel(ctx.clone(), association).await {
61            error!("Failed to create tunnel: {}", err);
62
63            // Set last error
64            last_error = Some(err);
65
66            // Increase error attempts
67            attempt_errors += 1;
68
69            // Error should be delayed by the number of errors already hit
70            Duration::from_millis(1000 * attempt_errors as u64)
71        } else {
72            // Reset error attempts
73            attempt_errors = 0;
74
75            // Non errored reconnect can be quick
76            Duration::from_millis(1000)
77        };
78
79        debug!(
80            "Next tunnel create attempt in: {}s",
81            reconnect_time.as_secs()
82        );
83
84        // Wait before attempting to re-create the tunnel
85        tokio::time::sleep(reconnect_time).await;
86    }
87
88    Err(last_error.unwrap_or(std::io::Error::new(
89        ErrorKind::Other,
90        "Reached error connect limit",
91    )))
92}
93
94/// Creates a new tunnel
95///
96/// ## Arguments
97/// * `ctx`         - The client context
98/// * `association` - The client association token
99async fn create_tunnel(ctx: Arc<ClientContext>, association: &str) -> std::io::Result<()> {
100    // Create the tunnel with the server
101    let io = create_server_tunnel(&ctx.http_client, &ctx.base_url, association)
102        .await
103        // Wrap the tunnel with the [`TunnelCodec`] framing
104        .map(|io| Framed::new(io, TunnelCodec::default()))
105        // Wrap the error into an [`std::io::Error`]
106        .map_err(|err| std::io::Error::new(ErrorKind::Other, err))?;
107    debug!("Created server tunnel");
108
109    // Allocate the socket pool for the tunnel
110    let (tx, rx) = mpsc::unbounded_channel();
111    let pool = Socket::allocate_pool(tx).await?;
112    debug!("Allocated tunnel pool");
113
114    // Start the tunnel
115    Tunnel {
116        io,
117        rx,
118        pool,
119        write_state: Default::default(),
120    }
121    .await;
122
123    Ok(())
124}
125
126/// Represents a tunnel and its pool of connections that it can
127/// send data to and receive data from
128struct Tunnel {
129    /// Tunnel connection to the Pocket Relay server for sending [`TunnelMessage`]s
130    /// through the server to reach a specific peer
131    io: Framed<Upgraded, TunnelCodec>,
132    /// Receiver for receiving messages from [`Socket`]s within the [`Tunnel::pool`]
133    /// that need to be sent through [`Tunnel::io`]
134    rx: mpsc::UnboundedReceiver<TunnelMessage>,
135    /// Pool of [`Socket`]s that this tunnel can use for sending out messages
136    pool: [SocketHandle; SOCKET_POOL_SIZE],
137    /// Current state of writing [`TunnelMessage`]s to the [`Tunnel::io`]
138    write_state: TunnelWriteState,
139}
140
141/// Holds the state for the current writing progress for a [`Tunnel`]
142#[derive(Default)]
143enum TunnelWriteState {
144    /// Waiting for a message to come through the [`Tunnel::rx`]
145    #[default]
146    Recv,
147    /// Waiting for the [`Tunnel::io`] to be writable, then writing the
148    /// contained [`TunnelMessage`]
149    Write(Option<TunnelMessage>),
150    /// Poll flushing the bytes written to [`Tunnel::io`]
151    Flush,
152    /// The tunnel has stopped and should not continue
153    Stop,
154}
155
156/// Holds the state for the current reading progress for a [`Tunnel`]
157enum TunnelReadState {
158    /// Continue reading
159    Continue,
160    /// The tunnel has stopped and should not continue
161    Stop,
162}
163
164impl Tunnel {
165    /// Polls accepting messages from [`Tunnel::rx`] then writing them to [`Tunnel::io`] and
166    /// flushing the underlying stream. Provides the next [`TunnelWriteState`]
167    /// when [`Poll::Ready`] is returned
168    ///
169    /// Should be repeatedly called until it no-longer returns [`Poll::Ready`]
170    fn poll_write_state(&mut self, cx: &mut Context<'_>) -> Poll<TunnelWriteState> {
171        Poll::Ready(match &mut self.write_state {
172            TunnelWriteState::Recv => {
173                // Try receive a packet from the write channel
174                let result = ready!(Pin::new(&mut self.rx).poll_recv(cx));
175
176                if let Some(message) = result {
177                    TunnelWriteState::Write(Some(message))
178                } else {
179                    // All writers have closed, tunnel must be closed (Future end)
180                    TunnelWriteState::Stop
181                }
182            }
183            TunnelWriteState::Write(message) => {
184                // Wait until the `io` is ready
185                if ready!(Pin::new(&mut self.io).poll_ready(cx)).is_ok() {
186                    let message = message
187                        .take()
188                        .expect("Unexpected write state without message");
189
190                    // Write the packet to the buffer
191                    Pin::new(&mut self.io)
192                        .start_send(message)
193                        // Packet encoder impl shouldn't produce errors
194                        .expect("Message encoder errored");
195
196                    TunnelWriteState::Flush
197                } else {
198                    // Failed to ready, tunnel must be closed
199                    TunnelWriteState::Stop
200                }
201            }
202            TunnelWriteState::Flush => {
203                // Poll flushing `io`
204                if ready!(Pin::new(&mut self.io).poll_flush(cx)).is_ok() {
205                    TunnelWriteState::Recv
206                } else {
207                    // Failed to flush, tunnel must be closed
208                    TunnelWriteState::Stop
209                }
210            }
211
212            // Tunnel should *NOT* be polled if its already stopped
213            TunnelWriteState::Stop => panic!("Tunnel polled after already stopped"),
214        })
215    }
216
217    /// Polls reading messages from [`Tunnel::io`] and sending them to the correct
218    /// handle within the [`Tunnel::pool`]. Provides the next [`TunnelReadState`]
219    /// when [`Poll::Ready`] is returned
220    ///
221    /// Should be repeatedly called until it no-longer returns [`Poll::Ready`]
222    fn poll_read_state(&mut self, cx: &mut Context<'_>) -> Poll<TunnelReadState> {
223        // Try receive a message from the `io`
224        let Some(Ok(message)) = ready!(Pin::new(&mut self.io).poll_next(cx)) else {
225            // Cannot read next message stop the tunnel
226            return Poll::Ready(TunnelReadState::Stop);
227        };
228
229        if message.index == 255 {
230            // Write a ping response if we aren't already writing another message
231            if let TunnelWriteState::Recv = self.write_state {
232                // Move to a writing state
233                self.write_state = TunnelWriteState::Write(Some(TunnelMessage {
234                    index: 255,
235                    message: Bytes::new(),
236                }));
237
238                // Poll the write state
239                if let Poll::Ready(next_state) = self.poll_write_state(cx) {
240                    self.write_state = next_state;
241
242                    // Tunnel has stopped
243                    if let TunnelWriteState::Stop = self.write_state {
244                        return Poll::Ready(TunnelReadState::Stop);
245                    }
246                }
247            }
248
249            return Poll::Ready(TunnelReadState::Continue);
250        }
251
252        // Get the handle to use within the connection pool
253        let handle = self.pool.get(message.index as usize);
254
255        // Send the message to the handle if its valid
256        if let Some(handle) = handle {
257            _ = handle.0.send(message);
258        }
259
260        Poll::Ready(TunnelReadState::Continue)
261    }
262}
263
264impl Future for Tunnel {
265    type Output = ();
266
267    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
268        let this = self.get_mut();
269
270        // Poll the write half
271        while let Poll::Ready(next_state) = this.poll_write_state(cx) {
272            this.write_state = next_state;
273
274            // Tunnel has stopped
275            if let TunnelWriteState::Stop = this.write_state {
276                return Poll::Ready(());
277            }
278        }
279
280        // Poll the read half
281        while let Poll::Ready(next_state) = this.poll_read_state(cx) {
282            // Tunnel has stopped
283            if let TunnelReadState::Stop = next_state {
284                return Poll::Ready(());
285            }
286        }
287
288        Poll::Pending
289    }
290}
291
292/// Handle to a [`Socket`] for sending [`TunnelMessage`]s that the
293/// socket should send to the [`LOCAL_SEND_TARGET`]
294#[derive(Clone)]
295struct SocketHandle(mpsc::UnboundedSender<TunnelMessage>);
296
297/// Size of the socket read buffer 2^16 bytes
298///
299/// Can likely be reduced to 2^15 bytes or 2^13 bytes (or lower) since
300/// highest observed message length was 1254 bytes but testing is required
301/// before that can take place
302const READ_BUFFER_LENGTH: usize = 2usize.pow(16);
303
304/// Socket used by a [`Tunnel`] for sending and receiving messages in
305/// order to simulate another player on the local network
306struct Socket {
307    // Index of the socket
308    index: u8,
309    // The underlying socket for sending and receiving
310    socket: UdpSocket,
311    /// Receiver for messages coming from the the [`Tunnel`] that need to be
312    /// send through the socket
313    rx: mpsc::UnboundedReceiver<TunnelMessage>,
314    /// Sender for sending [`TunnelMessage`]s through the associated [`Tunnel`]
315    /// in order for them to be sent to the correct peer on the other side
316    tun_tx: mpsc::UnboundedSender<TunnelMessage>,
317    /// Buffer for reading bytes from the `socket`
318    read_buffer: [u8; READ_BUFFER_LENGTH],
319    /// Current state of writing [`TunnelMessage`]s to the `socket`
320    write_state: SocketWriteState,
321}
322
323/// Holds the state for the current writing progress for a [`Socket`]
324#[derive(Default)]
325enum SocketWriteState {
326    /// Waiting for a message to come through the [`Socket::rx`]
327    #[default]
328    Recv,
329    /// Waiting for the [`Socket::socket`] to write the bytes
330    Write(Bytes),
331    /// The tunnel has stopped and should not continue
332    Stop,
333}
334
335/// Holds the state for the current reading progress for a [`Socket`]
336enum SocketReadState {
337    /// Continue reading
338    Continue,
339    /// The tunnel has stopped and should not continue
340    Stop,
341}
342
343impl Socket {
344    /// Allocates a pool of [`Socket`]s for a [`Tunnel`] to use
345    ///
346    /// ## Arguments
347    /// * `tun_tx` - The tunnel sender for sending [`TunnelMessage`]s through the tunnel
348    async fn allocate_pool(
349        tun_tx: mpsc::UnboundedSender<TunnelMessage>,
350    ) -> std::io::Result<[SocketHandle; SOCKET_POOL_SIZE]> {
351        let sockets = try_join!(
352            // Host socket index *must* use a fixed port since its used on the server side
353            Socket::start(0, TUNNEL_HOST_PORT, tun_tx.clone()),
354            // Other sockets can used OS auto assigned port
355            Socket::start(1, RANDOM_PORT, tun_tx.clone()),
356            Socket::start(2, RANDOM_PORT, tun_tx.clone()),
357            Socket::start(3, RANDOM_PORT, tun_tx),
358        )?;
359        Ok(sockets.into())
360    }
361
362    /// Starts a new tunnel socket returning a [`SocketHandle`] that can be used
363    /// to send [`TunnelMessage`]s to the socket
364    ///
365    /// ## Arguments
366    /// * `index`  - The index of the socket
367    /// * `port`   - The port to bind the socket on
368    /// * `tun_tx` - The tunnel sender for sending [`TunnelMessage`]s through the tunnel
369    async fn start(
370        index: u8,
371        port: u16,
372        tun_tx: mpsc::UnboundedSender<TunnelMessage>,
373    ) -> std::io::Result<SocketHandle> {
374        // Bind the socket
375        let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, port)).await?;
376        // Set the socket send target
377        socket.connect(LOCAL_SEND_TARGET).await?;
378
379        // Create the message channel
380        let (tx, rx) = mpsc::unbounded_channel();
381
382        // Spawn the socket task
383        spawn_server_task(Socket {
384            index,
385            socket,
386            rx,
387            tun_tx,
388            read_buffer: [0; READ_BUFFER_LENGTH],
389            write_state: Default::default(),
390        });
391
392        Ok(SocketHandle(tx))
393    }
394
395    /// Polls accepting messages from [`Socket::rx`] then writing them to the [`Socket::socket`].
396    /// Provides the next [`SocketWriteState`] when [`Poll::Ready`] is returned
397    ///
398    /// Should be repeatedly called until it no-longer returns [`Poll::Ready`]
399    fn poll_write_state(&mut self, cx: &mut Context<'_>) -> Poll<SocketWriteState> {
400        Poll::Ready(match &mut self.write_state {
401            SocketWriteState::Recv => {
402                // Try receive a packet from the write channel
403                let result = ready!(Pin::new(&mut self.rx).poll_recv(cx));
404
405                if let Some(message) = result {
406                    SocketWriteState::Write(message.message)
407                } else {
408                    // All writers have closed, tunnel must be closed (Future end)
409                    SocketWriteState::Stop
410                }
411            }
412            SocketWriteState::Write(message) => {
413                // Try send the message to the local target
414                let Ok(count) = ready!(self.socket.poll_send(cx, message)) else {
415                    return Poll::Ready(SocketWriteState::Stop);
416                };
417
418                // Didn't write the entire message
419                if count != message.len() {
420                    // Continue with a writing state at the remaining message
421                    let message = message.slice(count..);
422                    SocketWriteState::Write(message)
423                } else {
424                    SocketWriteState::Recv
425                }
426            }
427
428            // Tunnel socket should *NOT* be polled if its already stopped
429            SocketWriteState::Stop => panic!("Tunnel socket polled after already stopped"),
430        })
431    }
432
433    /// Polls reading messages from `socket` and sending them to the [`Tunnel`]
434    /// in order for them to be sent out to the peer. Provides the next
435    /// [`SocketReadState`] when [`Poll::Ready`] is returned
436    ///
437    /// Should be repeatedly called until it no-longer returns [`Poll::Ready`]
438    fn poll_read_state(&mut self, cx: &mut Context<'_>) -> Poll<SocketReadState> {
439        let mut read_buf = ReadBuf::new(&mut self.read_buffer);
440
441        // Try receive a message from the socket
442        if ready!(self.socket.poll_recv(cx, &mut read_buf)).is_err() {
443            return Poll::Ready(SocketReadState::Stop);
444        }
445
446        // Get the received message
447        let bytes = read_buf.filled();
448        let message = Bytes::copy_from_slice(bytes);
449        let message = TunnelMessage {
450            index: self.index,
451            message,
452        };
453
454        // Send the message through the tunnel
455        Poll::Ready(if self.tun_tx.send(message).is_ok() {
456            SocketReadState::Continue
457        } else {
458            SocketReadState::Stop
459        })
460    }
461}
462
463impl Future for Socket {
464    type Output = ();
465
466    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
467        let this = self.get_mut();
468
469        // Poll the write half
470        while let Poll::Ready(next_state) = this.poll_write_state(cx) {
471            this.write_state = next_state;
472
473            // Tunnel has stopped
474            if let SocketWriteState::Stop = this.write_state {
475                return Poll::Ready(());
476            }
477        }
478
479        // Poll the read half
480        while let Poll::Ready(next_state) = this.poll_read_state(cx) {
481            // Tunnel has stopped
482            if let SocketReadState::Stop = next_state {
483                return Poll::Ready(());
484            }
485        }
486
487        Poll::Pending
488    }
489}
490
491mod codec {
492    //! This modules contains the codec and message structures for [TunnelMessage]s
493    //!
494    //! # Tunnel Messages
495    //!
496    //! Tunnel message frames are as follows:
497    //!
498    //! ```text
499    //!  0                   1                   2                      
500    //!  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3
501    //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
502    //! |     Index     |            Length             |
503    //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
504    //! |                                               :
505    //! :                    Payload                    :
506    //! :                                               |
507    //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
508    //! ```
509    //!
510    //! Tunnel message frames contain the following fields:
511    //!
512    //! Index: 8-bits. Determines the destination of the message within the current pool.
513    //!
514    //! Length: 16-bits. Determines the size in bytes of the payload that follows
515    //!
516    //! Payload: Variable length. The message bytes payload of `Length`
517    //!
518    //!
519    //! ## Keep alive
520    //!
521    //! The server will send keep-alive messages, these are in the same
522    //! format as the packet above. However, the index will always be 255
523    //! and the payload will be empty.
524
525    use bytes::{Buf, BufMut, Bytes};
526    use tokio_util::codec::{Decoder, Encoder};
527
528    /// Header portion of a [TunnelMessage] that contains the
529    /// index of the message and the length of the expected payload
530    struct TunnelMessageHeader {
531        /// Socket index to use
532        index: u8,
533        /// The length of the tunnel message bytes
534        length: u16,
535    }
536
537    /// Message sent through the tunnel
538    pub struct TunnelMessage {
539        /// Socket index to use
540        pub index: u8,
541        /// The message contents
542        pub message: Bytes,
543    }
544
545    /// Codec for encoding and decoding tunnel messages
546    #[derive(Default)]
547    pub struct TunnelCodec {
548        /// Stores the current message header while its waiting
549        /// for the full payload to become available
550        partial: Option<TunnelMessageHeader>,
551    }
552
553    impl Decoder for TunnelCodec {
554        type Item = TunnelMessage;
555        type Error = std::io::Error;
556
557        fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
558            let partial = match self.partial.as_mut() {
559                Some(value) => value,
560                None => {
561                    // Not enough room for a partial frame
562                    if src.len() < 5 {
563                        return Ok(None);
564                    }
565                    let index = src.get_u8();
566                    let length = src.get_u16();
567
568                    self.partial.insert(TunnelMessageHeader { index, length })
569                }
570            };
571            // Not enough data for the partial frame
572            if src.len() < partial.length as usize {
573                return Ok(None);
574            }
575
576            let partial = self.partial.take().expect("Partial frame missing");
577            let bytes = src.split_to(partial.length as usize);
578
579            Ok(Some(TunnelMessage {
580                index: partial.index,
581                message: bytes.freeze(),
582            }))
583        }
584    }
585
586    impl Encoder<TunnelMessage> for TunnelCodec {
587        type Error = std::io::Error;
588
589        fn encode(
590            &mut self,
591            item: TunnelMessage,
592            dst: &mut bytes::BytesMut,
593        ) -> Result<(), Self::Error> {
594            dst.put_u8(item.index);
595            dst.put_u16(item.message.len() as u16);
596            dst.extend_from_slice(&item.message);
597            Ok(())
598        }
599    }
600}