Skip to main content

raknet_rust/
listener.rs

1//! Listener-oriented API built on top of [`crate::server::RaknetServer`].
2//!
3//! [`Listener`] accepts inbound peers and exposes them as [`Connection`] objects.
4
5use std::collections::HashMap;
6use std::io;
7use std::net::SocketAddr;
8use std::sync::Arc;
9
10use tokio::sync::mpsc;
11use tokio::sync::mpsc::error::TrySendError;
12use tokio::sync::oneshot;
13use tokio::task::JoinHandle;
14
15use crate::connection::{
16    Connection, ConnectionCloseReason, ConnectionCommand, ConnectionInbound, ConnectionSharedState,
17    RemoteDisconnectReason,
18};
19use crate::error::server::ServerError;
20use crate::server::{PeerId, RaknetServer, RaknetServerEvent};
21use crate::transport::{ShardedRuntimeConfig, TransportConfig};
22
23const DEFAULT_ACCEPT_QUEUE_CAPACITY: usize = 512;
24const DEFAULT_INBOUND_QUEUE_CAPACITY: usize = 256;
25const DEFAULT_COMMAND_QUEUE_CAPACITY: usize = 2048;
26
27struct ListenerRuntime {
28    command_tx: mpsc::Sender<ConnectionCommand>,
29    accept_rx: mpsc::Receiver<Connection>,
30    worker: JoinHandle<()>,
31}
32
33struct PeerRuntime {
34    addr: SocketAddr,
35    inbound_tx: mpsc::Sender<ConnectionInbound>,
36    shared: Arc<ConnectionSharedState>,
37}
38
39/// Stream-like helper for sequentially accepting [`Connection`] values.
40pub struct Incoming<'a> {
41    accept_rx: &'a mut mpsc::Receiver<Connection>,
42}
43
44/// High-level listener that accepts inbound RakNet peers as [`Connection`] objects.
45pub struct Listener {
46    bind_addr: SocketAddr,
47    transport_config: TransportConfig,
48    runtime_config: ShardedRuntimeConfig,
49    accept_queue_capacity: usize,
50    inbound_queue_capacity: usize,
51    command_queue_capacity: usize,
52    runtime: Option<ListenerRuntime>,
53}
54
55#[derive(Debug, Clone, PartialEq, Eq)]
56/// Runtime metadata snapshot for [`Listener`].
57pub struct ListenerMetadata {
58    bind_addr: SocketAddr,
59    started: bool,
60    shard_count: usize,
61    advertisement: String,
62}
63
64impl ListenerMetadata {
65    /// Returns bound address.
66    pub const fn bind_addr(&self) -> SocketAddr {
67        self.bind_addr
68    }
69
70    /// Returns whether listener runtime is currently started.
71    pub const fn started(&self) -> bool {
72        self.started
73    }
74
75    /// Returns configured shard count.
76    pub const fn shard_count(&self) -> usize {
77        self.shard_count
78    }
79
80    /// Returns current pong advertisement string.
81    pub fn advertisement(&self) -> &str {
82        &self.advertisement
83    }
84}
85
86impl Listener {
87    /// Creates a listener bound to `bind_addr` with default configs.
88    pub async fn bind(bind_addr: SocketAddr) -> Result<Self, ServerError> {
89        let transport_config = TransportConfig {
90            bind_addr,
91            ..TransportConfig::default()
92        };
93
94        Ok(Self {
95            bind_addr,
96            transport_config,
97            runtime_config: ShardedRuntimeConfig::default(),
98            accept_queue_capacity: DEFAULT_ACCEPT_QUEUE_CAPACITY,
99            inbound_queue_capacity: DEFAULT_INBOUND_QUEUE_CAPACITY,
100            command_queue_capacity: DEFAULT_COMMAND_QUEUE_CAPACITY,
101            runtime: None,
102        })
103    }
104
105    /// Sets pong/advertisement payload returned during offline ping.
106    pub fn set_pong_data(&mut self, data: impl Into<String>) {
107        self.transport_config.advertisement = data.into();
108    }
109
110    /// Returns pong/advertisement payload.
111    pub fn pong_data(&self) -> &str {
112        &self.transport_config.advertisement
113    }
114
115    /// Sets incoming connection queue capacity.
116    pub fn set_accept_queue_capacity(&mut self, capacity: usize) {
117        self.accept_queue_capacity = capacity.max(1);
118    }
119
120    /// Sets per-connection inbound packet queue capacity.
121    pub fn set_inbound_queue_capacity(&mut self, capacity: usize) {
122        self.inbound_queue_capacity = capacity.max(1);
123    }
124
125    /// Sets command channel capacity used by accepted [`Connection`]s.
126    pub fn set_command_queue_capacity(&mut self, capacity: usize) {
127        self.command_queue_capacity = capacity.max(1);
128    }
129
130    /// Sets shard count (minimum `1`).
131    pub fn set_shard_count(&mut self, shard_count: usize) {
132        self.runtime_config.shard_count = shard_count.max(1);
133    }
134
135    /// Returns configured bind address.
136    pub fn bind_addr(&self) -> SocketAddr {
137        self.bind_addr
138    }
139
140    /// Returns listener metadata snapshot.
141    pub fn metadata(&self) -> ListenerMetadata {
142        ListenerMetadata {
143            bind_addr: self.bind_addr,
144            started: self.runtime.is_some(),
145            shard_count: self.runtime_config.shard_count.max(1),
146            advertisement: self.transport_config.advertisement.clone(),
147        }
148    }
149
150    /// Returns `true` if runtime is started.
151    pub fn is_started(&self) -> bool {
152        self.runtime.is_some()
153    }
154
155    /// Starts listener runtime.
156    pub async fn start(&mut self) -> Result<(), ServerError> {
157        if self.runtime.is_some() {
158            return Err(ServerError::AlreadyStarted);
159        }
160
161        let mut transport_config = self.transport_config.clone();
162        transport_config.bind_addr = self.bind_addr;
163
164        transport_config.validate()?;
165        self.runtime_config.validate()?;
166
167        let server =
168            RaknetServer::start_with_configs(transport_config, self.runtime_config.clone())
169                .await
170                .map_err(ServerError::from)?;
171
172        let (accept_tx, accept_rx) = mpsc::channel(self.accept_queue_capacity.max(1));
173        let (command_tx, command_rx) = mpsc::channel(self.command_queue_capacity.max(1));
174        let worker_command_tx = command_tx.clone();
175        let inbound_queue_capacity = self.inbound_queue_capacity.max(1);
176
177        let worker = tokio::spawn(async move {
178            run_listener_worker(
179                server,
180                command_rx,
181                worker_command_tx,
182                accept_tx,
183                inbound_queue_capacity,
184            )
185            .await;
186        });
187
188        self.runtime = Some(ListenerRuntime {
189            command_tx,
190            accept_rx,
191            worker,
192        });
193
194        Ok(())
195    }
196
197    /// Stops listener runtime and disconnects active peers.
198    pub async fn stop(&mut self) -> Result<(), ServerError> {
199        let Some(runtime) = self.runtime.take() else {
200            return Ok(());
201        };
202
203        let (response_tx, response_rx) = oneshot::channel();
204        if runtime
205            .command_tx
206            .send(ConnectionCommand::Shutdown {
207                response: response_tx,
208            })
209            .await
210            .is_err()
211        {
212            let _ = runtime.worker.await;
213            return Err(ServerError::CommandChannelClosed);
214        }
215
216        let response = response_rx.await.map_err(|_| ServerError::WorkerStopped)?;
217        let _ = runtime.worker.await;
218        response.map_err(ServerError::from)
219    }
220
221    /// Accepts next inbound connection.
222    pub async fn accept(&mut self) -> Result<Connection, ServerError> {
223        self.accept_receiver()?
224            .recv()
225            .await
226            .ok_or(ServerError::AcceptChannelClosed)
227    }
228
229    /// Returns `Incoming` helper for stream-style accept loop.
230    pub fn incoming(&mut self) -> Result<Incoming<'_>, ServerError> {
231        let accept_rx = self.accept_receiver()?;
232        Ok(Incoming { accept_rx })
233    }
234
235    fn accept_receiver(&mut self) -> Result<&mut mpsc::Receiver<Connection>, ServerError> {
236        let runtime = self.runtime.as_mut().ok_or(ServerError::NotStarted)?;
237        Ok(&mut runtime.accept_rx)
238    }
239}
240
241impl Incoming<'_> {
242    /// Waits for the next accepted connection.
243    pub async fn next(&mut self) -> Option<Connection> {
244        self.accept_rx.recv().await
245    }
246}
247
248impl Drop for Listener {
249    fn drop(&mut self) {
250        if let Some(runtime) = self.runtime.take() {
251            runtime.worker.abort();
252        }
253    }
254}
255
256async fn run_listener_worker(
257    mut server: RaknetServer,
258    mut command_rx: mpsc::Receiver<ConnectionCommand>,
259    command_tx: mpsc::Sender<ConnectionCommand>,
260    accept_tx: mpsc::Sender<Connection>,
261    inbound_queue_capacity: usize,
262) {
263    let mut peers: HashMap<PeerId, PeerRuntime> = HashMap::new();
264    let mut peer_ids_by_addr: HashMap<SocketAddr, PeerId> = HashMap::new();
265
266    loop {
267        tokio::select! {
268            command = command_rx.recv() => {
269                match command {
270                    Some(ConnectionCommand::Send { peer_id, payload, options, response }) => {
271                        let result = if peers.contains_key(&peer_id) {
272                            server.send_with_options(peer_id, payload, options).await
273                        } else {
274                            Err(io::Error::new(io::ErrorKind::NotFound, "peer not found"))
275                        };
276                        let _ = response.send(result);
277                    }
278                    Some(ConnectionCommand::Disconnect { peer_id, response }) => {
279                        let result = disconnect_peer(
280                            &mut server,
281                            &mut peers,
282                            &mut peer_ids_by_addr,
283                            peer_id,
284                            ConnectionCloseReason::RequestedByLocal,
285                        )
286                        .await;
287                        let _ = response.send(result);
288                    }
289                    Some(ConnectionCommand::DisconnectNoWait { peer_id }) => {
290                        let _ = disconnect_peer(
291                            &mut server,
292                            &mut peers,
293                            &mut peer_ids_by_addr,
294                            peer_id,
295                            ConnectionCloseReason::RequestedByLocal,
296                        )
297                        .await;
298                    }
299                    Some(ConnectionCommand::Shutdown { response }) => {
300                        for peer_id in peers.keys().copied().collect::<Vec<_>>() {
301                            let _ = server.disconnect(peer_id).await;
302                        }
303
304                        close_all_peers(&mut peers, &mut peer_ids_by_addr, ConnectionCloseReason::ListenerStopped);
305                        let result = server.shutdown().await;
306                        let _ = response.send(result);
307                        break;
308                    }
309                    None => {
310                        close_all_peers(&mut peers, &mut peer_ids_by_addr, ConnectionCloseReason::ListenerStopped);
311                        let _ = server.shutdown().await;
312                        break;
313                    }
314                }
315            }
316            server_event = server.next_event() => {
317                let Some(server_event) = server_event else {
318                    close_all_peers(&mut peers, &mut peer_ids_by_addr, ConnectionCloseReason::ListenerStopped);
319                    break;
320                };
321
322                match server_event {
323                    RaknetServerEvent::PeerConnected { peer_id, addr, .. } => {
324                        if let Some(existing) = peers.remove(&peer_id) {
325                            peer_ids_by_addr.remove(&existing.addr);
326                            close_peer_entry(existing, ConnectionCloseReason::RequestedByLocal);
327                        }
328
329                        let shared = Arc::new(ConnectionSharedState::new());
330                        let (inbound_tx, inbound_rx) = mpsc::channel(inbound_queue_capacity.max(1));
331                        let connection = Connection::new(
332                            peer_id,
333                            addr,
334                            command_tx.clone(),
335                            inbound_rx,
336                            Arc::clone(&shared),
337                        );
338
339                        peers.insert(
340                            peer_id,
341                            PeerRuntime {
342                                addr,
343                                inbound_tx,
344                                shared,
345                            },
346                        );
347                        peer_ids_by_addr.insert(addr, peer_id);
348
349                        if let Err(err) = accept_tx.try_send(connection) {
350                            match err {
351                                TrySendError::Full(conn) => {
352                                    let _ = disconnect_peer(
353                                        &mut server,
354                                        &mut peers,
355                                        &mut peer_ids_by_addr,
356                                        conn.peer_id(),
357                                        ConnectionCloseReason::InboundBackpressure,
358                                    )
359                                    .await;
360                                }
361                                TrySendError::Closed(conn) => {
362                                    let _ = disconnect_peer(
363                                        &mut server,
364                                        &mut peers,
365                                        &mut peer_ids_by_addr,
366                                        conn.peer_id(),
367                                        ConnectionCloseReason::ListenerStopped,
368                                    )
369                                    .await;
370                                    close_all_peers(
371                                        &mut peers,
372                                        &mut peer_ids_by_addr,
373                                        ConnectionCloseReason::ListenerStopped,
374                                    );
375                                    let _ = server.shutdown().await;
376                                    break;
377                                }
378                            }
379                        }
380                    }
381                    RaknetServerEvent::PeerDisconnected { peer_id, reason, .. } => {
382                        if let Some(entry) = remove_peer(&mut peers, &mut peer_ids_by_addr, peer_id) {
383                            close_peer_entry(
384                                entry,
385                                ConnectionCloseReason::PeerDisconnected(
386                                    RemoteDisconnectReason::from(reason),
387                                ),
388                            );
389                        }
390                    }
391                    RaknetServerEvent::Packet { peer_id, payload, .. } => {
392                        if let Some(entry) = peers.get(&peer_id) {
393                            match entry.inbound_tx.try_send(ConnectionInbound::Packet(payload)) {
394                                Ok(()) => {}
395                                Err(TrySendError::Full(_)) => {
396                                    let _ = disconnect_peer(
397                                        &mut server,
398                                        &mut peers,
399                                        &mut peer_ids_by_addr,
400                                        peer_id,
401                                        ConnectionCloseReason::InboundBackpressure,
402                                    )
403                                    .await;
404                                }
405                                Err(TrySendError::Closed(_)) => {
406                                    let _ = disconnect_peer(
407                                        &mut server,
408                                        &mut peers,
409                                        &mut peer_ids_by_addr,
410                                        peer_id,
411                                        ConnectionCloseReason::ListenerStopped,
412                                    )
413                                    .await;
414                                }
415                            }
416                        }
417                    }
418                    RaknetServerEvent::DecodeError { addr, error } => {
419                        if let Some(peer_id) = peer_ids_by_addr.get(&addr).copied()
420                            && let Some(entry) = peers.get(&peer_id)
421                        {
422                            let _ = entry
423                                .inbound_tx
424                                .try_send(ConnectionInbound::DecodeError(error));
425                        }
426                    }
427                    RaknetServerEvent::PeerRateLimited { .. }
428                    | RaknetServerEvent::SessionLimitReached { .. }
429                    | RaknetServerEvent::ProxyDropped { .. }
430                    | RaknetServerEvent::OfflinePacket { .. }
431                    | RaknetServerEvent::ReceiptAcked { .. }
432                    | RaknetServerEvent::WorkerError { .. }
433                    | RaknetServerEvent::WorkerStopped { .. }
434                    | RaknetServerEvent::Metrics { .. } => {}
435                }
436            }
437        }
438    }
439
440    drop(accept_tx);
441}
442
443fn remove_peer(
444    peers: &mut HashMap<PeerId, PeerRuntime>,
445    peer_ids_by_addr: &mut HashMap<SocketAddr, PeerId>,
446    peer_id: PeerId,
447) -> Option<PeerRuntime> {
448    let entry = peers.remove(&peer_id)?;
449    peer_ids_by_addr.remove(&entry.addr);
450    Some(entry)
451}
452
453async fn disconnect_peer(
454    server: &mut RaknetServer,
455    peers: &mut HashMap<PeerId, PeerRuntime>,
456    peer_ids_by_addr: &mut HashMap<SocketAddr, PeerId>,
457    peer_id: PeerId,
458    reason: ConnectionCloseReason,
459) -> io::Result<()> {
460    let result = server.disconnect(peer_id).await;
461    if let Some(entry) = remove_peer(peers, peer_ids_by_addr, peer_id) {
462        close_peer_entry(entry, reason);
463    }
464    result
465}
466
467fn close_all_peers(
468    peers: &mut HashMap<PeerId, PeerRuntime>,
469    peer_ids_by_addr: &mut HashMap<SocketAddr, PeerId>,
470    reason: ConnectionCloseReason,
471) {
472    peer_ids_by_addr.clear();
473    for (_, entry) in peers.drain() {
474        close_peer_entry(entry, reason.clone());
475    }
476}
477
478fn close_peer_entry(entry: PeerRuntime, reason: ConnectionCloseReason) {
479    entry.shared.mark_closed(reason.clone());
480    let _ = entry.inbound_tx.try_send(ConnectionInbound::Closed(reason));
481}