Skip to main content

raknet_rust/
connection.rs

1//! Per-peer connection API returned by [`crate::listener::Listener`].
2//!
3//! [`Connection`] offers message-oriented send/recv methods.
4//! [`ConnectionIo`] adapts it to Tokio [`AsyncRead`] / [`AsyncWrite`].
5
6use std::future::Future;
7use std::io;
8use std::net::SocketAddr;
9use std::pin::Pin;
10use std::sync::{
11    Arc,
12    atomic::{AtomicBool, Ordering},
13};
14use std::task::{Context, Poll};
15
16use bytes::Bytes;
17use thiserror::Error;
18use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
19use tokio::sync::{mpsc, oneshot};
20
21use crate::concurrency::FastMutex;
22use crate::server::{PeerDisconnectReason, PeerId, SendOptions};
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25/// Stable identifier for a [`Connection`].
26pub struct ConnectionId(u64);
27
28impl ConnectionId {
29    /// Creates an id from raw `u64`.
30    pub const fn from_u64(value: u64) -> Self {
31        Self(value)
32    }
33
34    /// Returns raw id value.
35    pub const fn as_u64(self) -> u64 {
36        self.0
37    }
38}
39
40impl From<PeerId> for ConnectionId {
41    fn from(value: PeerId) -> Self {
42        Self::from_u64(value.as_u64())
43    }
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47/// Immutable connection identity snapshot.
48pub struct ConnectionMetadata {
49    id: ConnectionId,
50    remote_addr: SocketAddr,
51}
52
53impl ConnectionMetadata {
54    /// Returns connection id.
55    pub const fn id(self) -> ConnectionId {
56        self.id
57    }
58
59    /// Returns remote peer socket address.
60    pub const fn remote_addr(self) -> SocketAddr {
61        self.remote_addr
62    }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
66/// Remote-side reason mapped from RakNet disconnect semantics.
67pub enum RemoteDisconnectReason {
68    Requested,
69    RemoteDisconnectionNotification { reason_code: Option<u8> },
70    RemoteDetectLostConnection,
71    WorkerStopped { shard_id: usize },
72}
73
74impl From<PeerDisconnectReason> for RemoteDisconnectReason {
75    fn from(value: PeerDisconnectReason) -> Self {
76        match value {
77            PeerDisconnectReason::Requested => Self::Requested,
78            PeerDisconnectReason::RemoteDisconnectionNotification { reason_code } => {
79                Self::RemoteDisconnectionNotification { reason_code }
80            }
81            PeerDisconnectReason::RemoteDetectLostConnection => Self::RemoteDetectLostConnection,
82            PeerDisconnectReason::WorkerStopped { shard_id } => Self::WorkerStopped { shard_id },
83        }
84    }
85}
86
87#[derive(Debug, Clone, PartialEq, Eq)]
88/// Local close reason recorded by [`Connection`].
89pub enum ConnectionCloseReason {
90    RequestedByLocal,
91    PeerDisconnected(RemoteDisconnectReason),
92    ListenerStopped,
93    InboundBackpressure,
94    TransportError(String),
95}
96
97#[derive(Debug, Error, Clone, PartialEq, Eq)]
98/// Receive-side errors from [`Connection::recv`] / [`Connection::recv_bytes`].
99pub enum RecvError {
100    #[error("connection closed: {reason:?}")]
101    ConnectionClosed { reason: ConnectionCloseReason },
102    #[error("decode error: {message}")]
103    DecodeError { message: String },
104    #[error("connection receive channel closed")]
105    ChannelClosed,
106}
107
108pub mod queue {
109    use thiserror::Error;
110
111    #[derive(Debug, Error, Clone, PartialEq, Eq)]
112    /// Errors produced by connection send queue operations.
113    pub enum SendQueueError {
114        #[error("connection command channel closed")]
115        CommandChannelClosed,
116        #[error("connection command response dropped")]
117        ResponseDropped,
118        #[error("transport send failed: {message}")]
119        Transport { message: String },
120    }
121}
122
123#[derive(Debug)]
124pub(crate) enum ConnectionInbound {
125    Packet(Bytes),
126    DecodeError(String),
127    Closed(ConnectionCloseReason),
128}
129
130#[derive(Debug)]
131pub(crate) enum ConnectionCommand {
132    Send {
133        peer_id: PeerId,
134        payload: Bytes,
135        options: SendOptions,
136        response: oneshot::Sender<io::Result<()>>,
137    },
138    Disconnect {
139        peer_id: PeerId,
140        response: oneshot::Sender<io::Result<()>>,
141    },
142    DisconnectNoWait {
143        peer_id: PeerId,
144    },
145    Shutdown {
146        response: oneshot::Sender<io::Result<()>>,
147    },
148}
149
150#[derive(Debug)]
151pub(crate) struct ConnectionSharedState {
152    closed: AtomicBool,
153    close_reason: FastMutex<Option<ConnectionCloseReason>>,
154}
155
156impl ConnectionSharedState {
157    pub(crate) fn new() -> Self {
158        Self {
159            closed: AtomicBool::new(false),
160            close_reason: FastMutex::new(None),
161        }
162    }
163
164    pub(crate) fn mark_closed(&self, reason: ConnectionCloseReason) {
165        self.closed.store(true, Ordering::Release);
166        *self.close_reason.lock() = Some(reason);
167    }
168
169    pub(crate) fn is_closed(&self) -> bool {
170        self.closed.load(Ordering::Acquire)
171    }
172
173    pub(crate) fn close_reason(&self) -> Option<ConnectionCloseReason> {
174        self.close_reason.lock().clone()
175    }
176}
177
178type BoxSendFuture = Pin<Box<dyn Future<Output = Result<(), queue::SendQueueError>> + Send>>;
179type BoxIoFuture = Pin<Box<dyn Future<Output = io::Result<()>> + Send>>;
180
181struct PendingWrite {
182    len: usize,
183    fut: BoxSendFuture,
184}
185
186fn is_eof_close_reason(reason: &ConnectionCloseReason) -> bool {
187    matches!(
188        reason,
189        ConnectionCloseReason::RequestedByLocal
190            | ConnectionCloseReason::PeerDisconnected(_)
191            | ConnectionCloseReason::ListenerStopped
192    )
193}
194
195fn close_reason_to_io_error(reason: ConnectionCloseReason) -> io::Error {
196    if is_eof_close_reason(&reason) {
197        io::Error::new(
198            io::ErrorKind::UnexpectedEof,
199            format!("connection closed: {reason:?}"),
200        )
201    } else {
202        io::Error::new(
203            io::ErrorKind::BrokenPipe,
204            format!("connection closed: {reason:?}"),
205        )
206    }
207}
208
209fn send_queue_error_to_io_error(error: queue::SendQueueError) -> io::Error {
210    match error {
211        queue::SendQueueError::CommandChannelClosed => io::Error::new(
212            io::ErrorKind::BrokenPipe,
213            "connection command channel closed",
214        ),
215        queue::SendQueueError::ResponseDropped => io::Error::new(
216            io::ErrorKind::BrokenPipe,
217            "connection command response dropped",
218        ),
219        queue::SendQueueError::Transport { message } => {
220            io::Error::new(io::ErrorKind::BrokenPipe, message)
221        }
222    }
223}
224
225fn send_command_future(
226    shared: Arc<ConnectionSharedState>,
227    command_tx: mpsc::Sender<ConnectionCommand>,
228    peer_id: PeerId,
229    payload: Bytes,
230    options: SendOptions,
231) -> BoxSendFuture {
232    Box::pin(async move {
233        if shared.is_closed() {
234            return Err(queue::SendQueueError::Transport {
235                message: "connection already closed".to_string(),
236            });
237        }
238
239        let (response_tx, response_rx) = oneshot::channel();
240        command_tx
241            .send(ConnectionCommand::Send {
242                peer_id,
243                payload,
244                options,
245                response: response_tx,
246            })
247            .await
248            .map_err(|_| queue::SendQueueError::CommandChannelClosed)?;
249
250        match response_rx.await {
251            Ok(Ok(())) => Ok(()),
252            Ok(Err(err)) => Err(queue::SendQueueError::Transport {
253                message: err.to_string(),
254            }),
255            Err(_) => Err(queue::SendQueueError::ResponseDropped),
256        }
257    })
258}
259
260fn disconnect_command_future(
261    shared: Arc<ConnectionSharedState>,
262    command_tx: mpsc::Sender<ConnectionCommand>,
263    peer_id: PeerId,
264) -> BoxIoFuture {
265    Box::pin(async move {
266        if shared.is_closed() {
267            return Ok(());
268        }
269
270        let (response_tx, response_rx) = oneshot::channel();
271        command_tx
272            .send(ConnectionCommand::Disconnect {
273                peer_id,
274                response: response_tx,
275            })
276            .await
277            .map_err(|_| {
278                io::Error::new(
279                    io::ErrorKind::BrokenPipe,
280                    "connection command channel closed",
281                )
282            })?;
283
284        match response_rx.await {
285            Ok(result) => result,
286            Err(_) => Err(io::Error::new(
287                io::ErrorKind::BrokenPipe,
288                "connection command response dropped",
289            )),
290        }
291    })
292}
293
294fn fill_read_buf_from_payload(read_buf: &mut ReadBuf<'_>, payload: &mut Bytes) {
295    let copy_len = payload.len().min(read_buf.remaining());
296    if copy_len == 0 {
297        return;
298    }
299
300    let copied = payload.split_to(copy_len);
301    read_buf.put_slice(&copied);
302}
303
304pub struct Connection {
305    remote_addr: SocketAddr,
306    id: ConnectionId,
307    peer_id: PeerId,
308    command_tx: mpsc::Sender<ConnectionCommand>,
309    inbound_rx: mpsc::Receiver<ConnectionInbound>,
310    shared: Arc<ConnectionSharedState>,
311}
312
313impl Connection {
314    pub(crate) fn new(
315        peer_id: PeerId,
316        address: SocketAddr,
317        command_tx: mpsc::Sender<ConnectionCommand>,
318        inbound_rx: mpsc::Receiver<ConnectionInbound>,
319        shared: Arc<ConnectionSharedState>,
320    ) -> Self {
321        Self {
322            remote_addr: address,
323            id: ConnectionId::from(peer_id),
324            peer_id,
325            command_tx,
326            inbound_rx,
327            shared,
328        }
329    }
330
331    /// Returns connection id.
332    pub fn id(&self) -> ConnectionId {
333        self.id
334    }
335
336    /// Returns remote peer address.
337    pub fn remote_addr(&self) -> SocketAddr {
338        self.remote_addr
339    }
340
341    /// Returns immutable metadata snapshot.
342    pub fn metadata(&self) -> ConnectionMetadata {
343        ConnectionMetadata {
344            id: self.id,
345            remote_addr: self.remote_addr,
346        }
347    }
348
349    pub(crate) fn peer_id(&self) -> PeerId {
350        self.peer_id
351    }
352
353    /// Returns close reason if connection is closed.
354    pub fn close_reason(&self) -> Option<ConnectionCloseReason> {
355        self.shared.close_reason()
356    }
357
358    pub(crate) async fn send_with_options(
359        &self,
360        payload: impl Into<Bytes>,
361        options: SendOptions,
362    ) -> Result<(), queue::SendQueueError> {
363        send_command_future(
364            self.shared.clone(),
365            self.command_tx.clone(),
366            self.peer_id,
367            payload.into(),
368            options,
369        )
370        .await
371    }
372
373    /// Sends bytes using default send options.
374    pub async fn send_bytes(&self, payload: impl Into<Bytes>) -> Result<(), queue::SendQueueError> {
375        self.send_with_options(payload, SendOptions::default())
376            .await
377    }
378
379    /// Sends borrowed bytes, copying into an owned payload buffer.
380    pub async fn send(&self, payload: impl AsRef<[u8]>) -> Result<(), queue::SendQueueError> {
381        self.send_bytes(Bytes::copy_from_slice(payload.as_ref()))
382            .await
383    }
384
385    /// Compatibility helper matching stream-like send signatures.
386    pub async fn send_compat(
387        &self,
388        stream: &[u8],
389        _immediate: bool,
390    ) -> Result<(), queue::SendQueueError> {
391        self.send(stream).await
392    }
393
394    /// Receives next payload as zero-copy [`Bytes`].
395    pub async fn recv_bytes(&mut self) -> Result<Bytes, RecvError> {
396        match self.inbound_rx.recv().await {
397            Some(ConnectionInbound::Packet(payload)) => Ok(payload),
398            Some(ConnectionInbound::DecodeError(message)) => {
399                Err(RecvError::DecodeError { message })
400            }
401            Some(ConnectionInbound::Closed(reason)) => {
402                self.shared.mark_closed(reason.clone());
403                Err(RecvError::ConnectionClosed { reason })
404            }
405            None => {
406                if let Some(reason) = self.shared.close_reason() {
407                    Err(RecvError::ConnectionClosed { reason })
408                } else {
409                    self.shared
410                        .mark_closed(ConnectionCloseReason::ListenerStopped);
411                    Err(RecvError::ChannelClosed)
412                }
413            }
414        }
415    }
416
417    /// Receives next payload as owned `Vec<u8>`.
418    pub async fn recv(&mut self) -> Result<Vec<u8>, RecvError> {
419        self.recv_bytes().await.map(|payload| payload.to_vec())
420    }
421
422    /// Gracefully closes this connection.
423    pub async fn close(&self) {
424        if self.shared.is_closed() {
425            return;
426        }
427
428        let (response_tx, response_rx) = oneshot::channel();
429        if self
430            .command_tx
431            .send(ConnectionCommand::Disconnect {
432                peer_id: self.peer_id,
433                response: response_tx,
434            })
435            .await
436            .is_err()
437        {
438            self.shared
439                .mark_closed(ConnectionCloseReason::ListenerStopped);
440            return;
441        }
442
443        if response_rx.await.is_ok() {
444            self.shared
445                .mark_closed(ConnectionCloseReason::RequestedByLocal);
446        }
447    }
448
449    /// Returns whether connection is currently closed.
450    pub async fn is_closed(&self) -> bool {
451        self.shared.is_closed()
452    }
453
454    /// Converts into Tokio AsyncRead/AsyncWrite adapter.
455    pub fn into_io(self) -> ConnectionIo {
456        ConnectionIo::new(self)
457    }
458}
459
460impl Drop for Connection {
461    fn drop(&mut self) {
462        if self.shared.is_closed() {
463            return;
464        }
465
466        let _ = self
467            .command_tx
468            .try_send(ConnectionCommand::DisconnectNoWait {
469                peer_id: self.peer_id,
470            });
471    }
472}
473
474/// Tokio AsyncRead/AsyncWrite adapter over [`Connection`].
475pub struct ConnectionIo {
476    connection: Connection,
477    read_remainder: Option<Bytes>,
478    write_in_flight: Option<PendingWrite>,
479    shutdown_in_flight: Option<BoxIoFuture>,
480}
481
482impl ConnectionIo {
483    fn new(connection: Connection) -> Self {
484        Self {
485            connection,
486            read_remainder: None,
487            write_in_flight: None,
488            shutdown_in_flight: None,
489        }
490    }
491
492    /// Returns immutable underlying connection reference.
493    pub fn connection(&self) -> &Connection {
494        &self.connection
495    }
496
497    /// Returns mutable underlying connection reference.
498    pub fn connection_mut(&mut self) -> &mut Connection {
499        &mut self.connection
500    }
501
502    /// Returns underlying connection and consumes adapter.
503    pub fn into_inner(self) -> Connection {
504        self.connection
505    }
506
507    fn poll_pending_write(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Option<usize>>> {
508        let Some(mut state) = self.write_in_flight.take() else {
509            return Poll::Ready(Ok(None));
510        };
511
512        match state.fut.as_mut().poll(cx) {
513            Poll::Ready(Ok(())) => Poll::Ready(Ok(Some(state.len))),
514            Poll::Ready(Err(error)) => Poll::Ready(Err(send_queue_error_to_io_error(error))),
515            Poll::Pending => {
516                self.write_in_flight = Some(state);
517                Poll::Pending
518            }
519        }
520    }
521}
522
523impl AsyncRead for ConnectionIo {
524    fn poll_read(
525        mut self: Pin<&mut Self>,
526        cx: &mut Context<'_>,
527        read_buf: &mut ReadBuf<'_>,
528    ) -> Poll<io::Result<()>> {
529        if read_buf.remaining() == 0 {
530            return Poll::Ready(Ok(()));
531        }
532
533        if let Some(mut remainder) = self.read_remainder.take() {
534            fill_read_buf_from_payload(read_buf, &mut remainder);
535            if !remainder.is_empty() {
536                self.read_remainder = Some(remainder);
537            }
538            return Poll::Ready(Ok(()));
539        }
540
541        match Pin::new(&mut self.connection.inbound_rx).poll_recv(cx) {
542            Poll::Ready(Some(ConnectionInbound::Packet(mut payload))) => {
543                fill_read_buf_from_payload(read_buf, &mut payload);
544                if !payload.is_empty() {
545                    self.read_remainder = Some(payload);
546                }
547                Poll::Ready(Ok(()))
548            }
549            Poll::Ready(Some(ConnectionInbound::DecodeError(message))) => {
550                Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, message)))
551            }
552            Poll::Ready(Some(ConnectionInbound::Closed(reason))) => {
553                self.connection.shared.mark_closed(reason.clone());
554                if is_eof_close_reason(&reason) {
555                    Poll::Ready(Ok(()))
556                } else {
557                    Poll::Ready(Err(close_reason_to_io_error(reason)))
558                }
559            }
560            Poll::Ready(None) => {
561                if let Some(reason) = self.connection.shared.close_reason() {
562                    if is_eof_close_reason(&reason) {
563                        Poll::Ready(Ok(()))
564                    } else {
565                        Poll::Ready(Err(close_reason_to_io_error(reason)))
566                    }
567                } else {
568                    self.connection
569                        .shared
570                        .mark_closed(ConnectionCloseReason::ListenerStopped);
571                    Poll::Ready(Ok(()))
572                }
573            }
574            Poll::Pending => Poll::Pending,
575        }
576    }
577}
578
579impl AsyncWrite for ConnectionIo {
580    fn poll_write(
581        mut self: Pin<&mut Self>,
582        cx: &mut Context<'_>,
583        buf: &[u8],
584    ) -> Poll<io::Result<usize>> {
585        if self.shutdown_in_flight.is_some() {
586            return Poll::Ready(Err(io::Error::new(
587                io::ErrorKind::BrokenPipe,
588                "connection shutdown already in progress",
589            )));
590        }
591
592        match self.as_mut().get_mut().poll_pending_write(cx) {
593            Poll::Ready(Ok(Some(written))) => return Poll::Ready(Ok(written)),
594            Poll::Ready(Err(error)) => return Poll::Ready(Err(error)),
595            Poll::Ready(Ok(None)) => {}
596            Poll::Pending => return Poll::Pending,
597        }
598
599        if buf.is_empty() {
600            return Poll::Ready(Ok(0));
601        }
602
603        if self.connection.shared.is_closed() {
604            return Poll::Ready(Err(io::Error::new(
605                io::ErrorKind::BrokenPipe,
606                "connection already closed",
607            )));
608        }
609
610        let payload = Bytes::copy_from_slice(buf);
611        self.write_in_flight = Some(PendingWrite {
612            len: buf.len(),
613            fut: send_command_future(
614                self.connection.shared.clone(),
615                self.connection.command_tx.clone(),
616                self.connection.peer_id,
617                payload,
618                SendOptions::default(),
619            ),
620        });
621
622        match self.as_mut().get_mut().poll_pending_write(cx) {
623            Poll::Ready(Ok(Some(written))) => Poll::Ready(Ok(written)),
624            Poll::Ready(Ok(None)) => Poll::Ready(Ok(0)),
625            Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
626            Poll::Pending => Poll::Pending,
627        }
628    }
629
630    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
631        match self.as_mut().get_mut().poll_pending_write(cx) {
632            Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
633            Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
634            Poll::Pending => Poll::Pending,
635        }
636    }
637
638    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
639        match self.as_mut().poll_flush(cx) {
640            Poll::Ready(Ok(())) => {}
641            Poll::Ready(Err(error)) => return Poll::Ready(Err(error)),
642            Poll::Pending => return Poll::Pending,
643        }
644
645        if self.connection.shared.is_closed() {
646            return Poll::Ready(Ok(()));
647        }
648
649        if self.shutdown_in_flight.is_none() {
650            self.shutdown_in_flight = Some(disconnect_command_future(
651                self.connection.shared.clone(),
652                self.connection.command_tx.clone(),
653                self.connection.peer_id,
654            ));
655        }
656
657        let Some(mut shutdown_future) = self.shutdown_in_flight.take() else {
658            return Poll::Ready(Ok(()));
659        };
660
661        match shutdown_future.as_mut().poll(cx) {
662            Poll::Ready(Ok(())) => {
663                self.connection
664                    .shared
665                    .mark_closed(ConnectionCloseReason::RequestedByLocal);
666                Poll::Ready(Ok(()))
667            }
668            Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
669            Poll::Pending => {
670                self.shutdown_in_flight = Some(shutdown_future);
671                Poll::Pending
672            }
673        }
674    }
675}