Skip to main content

raknet_rust/
connection.rs

1use std::future::Future;
2use std::io;
3use std::net::SocketAddr;
4use std::pin::Pin;
5use std::sync::{
6    Arc,
7    atomic::{AtomicBool, Ordering},
8};
9use std::task::{Context, Poll};
10
11use bytes::Bytes;
12use thiserror::Error;
13use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14use tokio::sync::{mpsc, oneshot};
15
16use crate::concurrency::FastMutex;
17use crate::server::{PeerDisconnectReason, PeerId, SendOptions};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct ConnectionId(u64);
21
22impl ConnectionId {
23    pub const fn from_u64(value: u64) -> Self {
24        Self(value)
25    }
26
27    pub const fn as_u64(self) -> u64 {
28        self.0
29    }
30}
31
32impl From<PeerId> for ConnectionId {
33    fn from(value: PeerId) -> Self {
34        Self::from_u64(value.as_u64())
35    }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub struct ConnectionMetadata {
40    id: ConnectionId,
41    remote_addr: SocketAddr,
42}
43
44impl ConnectionMetadata {
45    pub const fn id(self) -> ConnectionId {
46        self.id
47    }
48
49    pub const fn remote_addr(self) -> SocketAddr {
50        self.remote_addr
51    }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub enum RemoteDisconnectReason {
56    Requested,
57    RemoteDisconnectionNotification { reason_code: Option<u8> },
58    RemoteDetectLostConnection,
59    WorkerStopped { shard_id: usize },
60}
61
62impl From<PeerDisconnectReason> for RemoteDisconnectReason {
63    fn from(value: PeerDisconnectReason) -> Self {
64        match value {
65            PeerDisconnectReason::Requested => Self::Requested,
66            PeerDisconnectReason::RemoteDisconnectionNotification { reason_code } => {
67                Self::RemoteDisconnectionNotification { reason_code }
68            }
69            PeerDisconnectReason::RemoteDetectLostConnection => Self::RemoteDetectLostConnection,
70            PeerDisconnectReason::WorkerStopped { shard_id } => Self::WorkerStopped { shard_id },
71        }
72    }
73}
74
75#[derive(Debug, Clone, PartialEq, Eq)]
76pub enum ConnectionCloseReason {
77    RequestedByLocal,
78    PeerDisconnected(RemoteDisconnectReason),
79    ListenerStopped,
80    InboundBackpressure,
81    TransportError(String),
82}
83
84#[derive(Debug, Error, Clone, PartialEq, Eq)]
85pub enum RecvError {
86    #[error("connection closed: {reason:?}")]
87    ConnectionClosed { reason: ConnectionCloseReason },
88    #[error("decode error: {message}")]
89    DecodeError { message: String },
90    #[error("connection receive channel closed")]
91    ChannelClosed,
92}
93
94pub mod queue {
95    use thiserror::Error;
96
97    #[derive(Debug, Error, Clone, PartialEq, Eq)]
98    pub enum SendQueueError {
99        #[error("connection command channel closed")]
100        CommandChannelClosed,
101        #[error("connection command response dropped")]
102        ResponseDropped,
103        #[error("transport send failed: {message}")]
104        Transport { message: String },
105    }
106}
107
108#[derive(Debug)]
109pub(crate) enum ConnectionInbound {
110    Packet(Bytes),
111    DecodeError(String),
112    Closed(ConnectionCloseReason),
113}
114
115#[derive(Debug)]
116pub(crate) enum ConnectionCommand {
117    Send {
118        peer_id: PeerId,
119        payload: Bytes,
120        options: SendOptions,
121        response: oneshot::Sender<io::Result<()>>,
122    },
123    Disconnect {
124        peer_id: PeerId,
125        response: oneshot::Sender<io::Result<()>>,
126    },
127    DisconnectNoWait {
128        peer_id: PeerId,
129    },
130    Shutdown {
131        response: oneshot::Sender<io::Result<()>>,
132    },
133}
134
135#[derive(Debug)]
136pub(crate) struct ConnectionSharedState {
137    closed: AtomicBool,
138    close_reason: FastMutex<Option<ConnectionCloseReason>>,
139}
140
141impl ConnectionSharedState {
142    pub(crate) fn new() -> Self {
143        Self {
144            closed: AtomicBool::new(false),
145            close_reason: FastMutex::new(None),
146        }
147    }
148
149    pub(crate) fn mark_closed(&self, reason: ConnectionCloseReason) {
150        self.closed.store(true, Ordering::Release);
151        *self.close_reason.lock() = Some(reason);
152    }
153
154    pub(crate) fn is_closed(&self) -> bool {
155        self.closed.load(Ordering::Acquire)
156    }
157
158    pub(crate) fn close_reason(&self) -> Option<ConnectionCloseReason> {
159        self.close_reason.lock().clone()
160    }
161}
162
163type BoxSendFuture = Pin<Box<dyn Future<Output = Result<(), queue::SendQueueError>> + Send>>;
164type BoxIoFuture = Pin<Box<dyn Future<Output = io::Result<()>> + Send>>;
165
166struct PendingWrite {
167    len: usize,
168    fut: BoxSendFuture,
169}
170
171fn is_eof_close_reason(reason: &ConnectionCloseReason) -> bool {
172    matches!(
173        reason,
174        ConnectionCloseReason::RequestedByLocal
175            | ConnectionCloseReason::PeerDisconnected(_)
176            | ConnectionCloseReason::ListenerStopped
177    )
178}
179
180fn close_reason_to_io_error(reason: ConnectionCloseReason) -> io::Error {
181    if is_eof_close_reason(&reason) {
182        io::Error::new(
183            io::ErrorKind::UnexpectedEof,
184            format!("connection closed: {reason:?}"),
185        )
186    } else {
187        io::Error::new(
188            io::ErrorKind::BrokenPipe,
189            format!("connection closed: {reason:?}"),
190        )
191    }
192}
193
194fn send_queue_error_to_io_error(error: queue::SendQueueError) -> io::Error {
195    match error {
196        queue::SendQueueError::CommandChannelClosed => io::Error::new(
197            io::ErrorKind::BrokenPipe,
198            "connection command channel closed",
199        ),
200        queue::SendQueueError::ResponseDropped => io::Error::new(
201            io::ErrorKind::BrokenPipe,
202            "connection command response dropped",
203        ),
204        queue::SendQueueError::Transport { message } => {
205            io::Error::new(io::ErrorKind::BrokenPipe, message)
206        }
207    }
208}
209
210fn send_command_future(
211    shared: Arc<ConnectionSharedState>,
212    command_tx: mpsc::Sender<ConnectionCommand>,
213    peer_id: PeerId,
214    payload: Bytes,
215    options: SendOptions,
216) -> BoxSendFuture {
217    Box::pin(async move {
218        if shared.is_closed() {
219            return Err(queue::SendQueueError::Transport {
220                message: "connection already closed".to_string(),
221            });
222        }
223
224        let (response_tx, response_rx) = oneshot::channel();
225        command_tx
226            .send(ConnectionCommand::Send {
227                peer_id,
228                payload,
229                options,
230                response: response_tx,
231            })
232            .await
233            .map_err(|_| queue::SendQueueError::CommandChannelClosed)?;
234
235        match response_rx.await {
236            Ok(Ok(())) => Ok(()),
237            Ok(Err(err)) => Err(queue::SendQueueError::Transport {
238                message: err.to_string(),
239            }),
240            Err(_) => Err(queue::SendQueueError::ResponseDropped),
241        }
242    })
243}
244
245fn disconnect_command_future(
246    shared: Arc<ConnectionSharedState>,
247    command_tx: mpsc::Sender<ConnectionCommand>,
248    peer_id: PeerId,
249) -> BoxIoFuture {
250    Box::pin(async move {
251        if shared.is_closed() {
252            return Ok(());
253        }
254
255        let (response_tx, response_rx) = oneshot::channel();
256        command_tx
257            .send(ConnectionCommand::Disconnect {
258                peer_id,
259                response: response_tx,
260            })
261            .await
262            .map_err(|_| {
263                io::Error::new(
264                    io::ErrorKind::BrokenPipe,
265                    "connection command channel closed",
266                )
267            })?;
268
269        match response_rx.await {
270            Ok(result) => result,
271            Err(_) => Err(io::Error::new(
272                io::ErrorKind::BrokenPipe,
273                "connection command response dropped",
274            )),
275        }
276    })
277}
278
279fn fill_read_buf_from_payload(read_buf: &mut ReadBuf<'_>, payload: &mut Bytes) {
280    let copy_len = payload.len().min(read_buf.remaining());
281    if copy_len == 0 {
282        return;
283    }
284
285    let copied = payload.split_to(copy_len);
286    read_buf.put_slice(&copied);
287}
288
289pub struct Connection {
290    remote_addr: SocketAddr,
291    id: ConnectionId,
292    peer_id: PeerId,
293    command_tx: mpsc::Sender<ConnectionCommand>,
294    inbound_rx: mpsc::Receiver<ConnectionInbound>,
295    shared: Arc<ConnectionSharedState>,
296}
297
298impl Connection {
299    pub(crate) fn new(
300        peer_id: PeerId,
301        address: SocketAddr,
302        command_tx: mpsc::Sender<ConnectionCommand>,
303        inbound_rx: mpsc::Receiver<ConnectionInbound>,
304        shared: Arc<ConnectionSharedState>,
305    ) -> Self {
306        Self {
307            remote_addr: address,
308            id: ConnectionId::from(peer_id),
309            peer_id,
310            command_tx,
311            inbound_rx,
312            shared,
313        }
314    }
315
316    pub fn id(&self) -> ConnectionId {
317        self.id
318    }
319
320    pub fn remote_addr(&self) -> SocketAddr {
321        self.remote_addr
322    }
323
324    pub fn metadata(&self) -> ConnectionMetadata {
325        ConnectionMetadata {
326            id: self.id,
327            remote_addr: self.remote_addr,
328        }
329    }
330
331    pub(crate) fn peer_id(&self) -> PeerId {
332        self.peer_id
333    }
334
335    pub fn close_reason(&self) -> Option<ConnectionCloseReason> {
336        self.shared.close_reason()
337    }
338
339    pub(crate) async fn send_with_options(
340        &self,
341        payload: impl Into<Bytes>,
342        options: SendOptions,
343    ) -> Result<(), queue::SendQueueError> {
344        send_command_future(
345            self.shared.clone(),
346            self.command_tx.clone(),
347            self.peer_id,
348            payload.into(),
349            options,
350        )
351        .await
352    }
353
354    pub async fn send_bytes(&self, payload: impl Into<Bytes>) -> Result<(), queue::SendQueueError> {
355        self.send_with_options(payload, SendOptions::default())
356            .await
357    }
358
359    pub async fn send(&self, payload: impl AsRef<[u8]>) -> Result<(), queue::SendQueueError> {
360        self.send_bytes(Bytes::copy_from_slice(payload.as_ref()))
361            .await
362    }
363
364    pub async fn send_compat(
365        &self,
366        stream: &[u8],
367        _immediate: bool,
368    ) -> Result<(), queue::SendQueueError> {
369        self.send(stream).await
370    }
371
372    pub async fn recv_bytes(&mut self) -> Result<Bytes, RecvError> {
373        match self.inbound_rx.recv().await {
374            Some(ConnectionInbound::Packet(payload)) => Ok(payload),
375            Some(ConnectionInbound::DecodeError(message)) => {
376                Err(RecvError::DecodeError { message })
377            }
378            Some(ConnectionInbound::Closed(reason)) => {
379                self.shared.mark_closed(reason.clone());
380                Err(RecvError::ConnectionClosed { reason })
381            }
382            None => {
383                if let Some(reason) = self.shared.close_reason() {
384                    Err(RecvError::ConnectionClosed { reason })
385                } else {
386                    self.shared
387                        .mark_closed(ConnectionCloseReason::ListenerStopped);
388                    Err(RecvError::ChannelClosed)
389                }
390            }
391        }
392    }
393
394    pub async fn recv(&mut self) -> Result<Vec<u8>, RecvError> {
395        self.recv_bytes().await.map(|payload| payload.to_vec())
396    }
397
398    pub async fn close(&self) {
399        if self.shared.is_closed() {
400            return;
401        }
402
403        let (response_tx, response_rx) = oneshot::channel();
404        if self
405            .command_tx
406            .send(ConnectionCommand::Disconnect {
407                peer_id: self.peer_id,
408                response: response_tx,
409            })
410            .await
411            .is_err()
412        {
413            self.shared
414                .mark_closed(ConnectionCloseReason::ListenerStopped);
415            return;
416        }
417
418        if response_rx.await.is_ok() {
419            self.shared
420                .mark_closed(ConnectionCloseReason::RequestedByLocal);
421        }
422    }
423
424    pub async fn is_closed(&self) -> bool {
425        self.shared.is_closed()
426    }
427
428    pub fn into_io(self) -> ConnectionIo {
429        ConnectionIo::new(self)
430    }
431}
432
433impl Drop for Connection {
434    fn drop(&mut self) {
435        if self.shared.is_closed() {
436            return;
437        }
438
439        let _ = self
440            .command_tx
441            .try_send(ConnectionCommand::DisconnectNoWait {
442                peer_id: self.peer_id,
443            });
444    }
445}
446
447pub struct ConnectionIo {
448    connection: Connection,
449    read_remainder: Option<Bytes>,
450    write_in_flight: Option<PendingWrite>,
451    shutdown_in_flight: Option<BoxIoFuture>,
452}
453
454impl ConnectionIo {
455    fn new(connection: Connection) -> Self {
456        Self {
457            connection,
458            read_remainder: None,
459            write_in_flight: None,
460            shutdown_in_flight: None,
461        }
462    }
463
464    pub fn connection(&self) -> &Connection {
465        &self.connection
466    }
467
468    pub fn connection_mut(&mut self) -> &mut Connection {
469        &mut self.connection
470    }
471
472    pub fn into_inner(self) -> Connection {
473        self.connection
474    }
475
476    fn poll_pending_write(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Option<usize>>> {
477        let Some(mut state) = self.write_in_flight.take() else {
478            return Poll::Ready(Ok(None));
479        };
480
481        match state.fut.as_mut().poll(cx) {
482            Poll::Ready(Ok(())) => Poll::Ready(Ok(Some(state.len))),
483            Poll::Ready(Err(error)) => Poll::Ready(Err(send_queue_error_to_io_error(error))),
484            Poll::Pending => {
485                self.write_in_flight = Some(state);
486                Poll::Pending
487            }
488        }
489    }
490}
491
492impl AsyncRead for ConnectionIo {
493    fn poll_read(
494        mut self: Pin<&mut Self>,
495        cx: &mut Context<'_>,
496        read_buf: &mut ReadBuf<'_>,
497    ) -> Poll<io::Result<()>> {
498        if read_buf.remaining() == 0 {
499            return Poll::Ready(Ok(()));
500        }
501
502        if let Some(mut remainder) = self.read_remainder.take() {
503            fill_read_buf_from_payload(read_buf, &mut remainder);
504            if !remainder.is_empty() {
505                self.read_remainder = Some(remainder);
506            }
507            return Poll::Ready(Ok(()));
508        }
509
510        match Pin::new(&mut self.connection.inbound_rx).poll_recv(cx) {
511            Poll::Ready(Some(ConnectionInbound::Packet(mut payload))) => {
512                fill_read_buf_from_payload(read_buf, &mut payload);
513                if !payload.is_empty() {
514                    self.read_remainder = Some(payload);
515                }
516                Poll::Ready(Ok(()))
517            }
518            Poll::Ready(Some(ConnectionInbound::DecodeError(message))) => {
519                Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, message)))
520            }
521            Poll::Ready(Some(ConnectionInbound::Closed(reason))) => {
522                self.connection.shared.mark_closed(reason.clone());
523                if is_eof_close_reason(&reason) {
524                    Poll::Ready(Ok(()))
525                } else {
526                    Poll::Ready(Err(close_reason_to_io_error(reason)))
527                }
528            }
529            Poll::Ready(None) => {
530                if let Some(reason) = self.connection.shared.close_reason() {
531                    if is_eof_close_reason(&reason) {
532                        Poll::Ready(Ok(()))
533                    } else {
534                        Poll::Ready(Err(close_reason_to_io_error(reason)))
535                    }
536                } else {
537                    self.connection
538                        .shared
539                        .mark_closed(ConnectionCloseReason::ListenerStopped);
540                    Poll::Ready(Ok(()))
541                }
542            }
543            Poll::Pending => Poll::Pending,
544        }
545    }
546}
547
548impl AsyncWrite for ConnectionIo {
549    fn poll_write(
550        mut self: Pin<&mut Self>,
551        cx: &mut Context<'_>,
552        buf: &[u8],
553    ) -> Poll<io::Result<usize>> {
554        if self.shutdown_in_flight.is_some() {
555            return Poll::Ready(Err(io::Error::new(
556                io::ErrorKind::BrokenPipe,
557                "connection shutdown already in progress",
558            )));
559        }
560
561        match self.as_mut().get_mut().poll_pending_write(cx) {
562            Poll::Ready(Ok(Some(written))) => return Poll::Ready(Ok(written)),
563            Poll::Ready(Err(error)) => return Poll::Ready(Err(error)),
564            Poll::Ready(Ok(None)) => {}
565            Poll::Pending => return Poll::Pending,
566        }
567
568        if buf.is_empty() {
569            return Poll::Ready(Ok(0));
570        }
571
572        if self.connection.shared.is_closed() {
573            return Poll::Ready(Err(io::Error::new(
574                io::ErrorKind::BrokenPipe,
575                "connection already closed",
576            )));
577        }
578
579        let payload = Bytes::copy_from_slice(buf);
580        self.write_in_flight = Some(PendingWrite {
581            len: buf.len(),
582            fut: send_command_future(
583                self.connection.shared.clone(),
584                self.connection.command_tx.clone(),
585                self.connection.peer_id,
586                payload,
587                SendOptions::default(),
588            ),
589        });
590
591        match self.as_mut().get_mut().poll_pending_write(cx) {
592            Poll::Ready(Ok(Some(written))) => Poll::Ready(Ok(written)),
593            Poll::Ready(Ok(None)) => Poll::Ready(Ok(0)),
594            Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
595            Poll::Pending => Poll::Pending,
596        }
597    }
598
599    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
600        match self.as_mut().get_mut().poll_pending_write(cx) {
601            Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
602            Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
603            Poll::Pending => Poll::Pending,
604        }
605    }
606
607    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
608        match self.as_mut().poll_flush(cx) {
609            Poll::Ready(Ok(())) => {}
610            Poll::Ready(Err(error)) => return Poll::Ready(Err(error)),
611            Poll::Pending => return Poll::Pending,
612        }
613
614        if self.connection.shared.is_closed() {
615            return Poll::Ready(Ok(()));
616        }
617
618        if self.shutdown_in_flight.is_none() {
619            self.shutdown_in_flight = Some(disconnect_command_future(
620                self.connection.shared.clone(),
621                self.connection.command_tx.clone(),
622                self.connection.peer_id,
623            ));
624        }
625
626        let Some(mut shutdown_future) = self.shutdown_in_flight.take() else {
627            return Poll::Ready(Ok(()));
628        };
629
630        match shutdown_future.as_mut().poll(cx) {
631            Poll::Ready(Ok(())) => {
632                self.connection
633                    .shared
634                    .mark_closed(ConnectionCloseReason::RequestedByLocal);
635                Poll::Ready(Ok(()))
636            }
637            Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
638            Poll::Pending => {
639                self.shutdown_in_flight = Some(shutdown_future);
640                Poll::Pending
641            }
642        }
643    }
644}