Skip to main content

raknet_rust/
listener.rs

1use std::collections::HashMap;
2use std::io;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use tokio::sync::mpsc;
7use tokio::sync::mpsc::error::TrySendError;
8use tokio::sync::oneshot;
9use tokio::task::JoinHandle;
10
11use crate::connection::{
12    Connection, ConnectionCloseReason, ConnectionCommand, ConnectionInbound, ConnectionSharedState,
13    RemoteDisconnectReason,
14};
15use crate::error::server::ServerError;
16use crate::server::{PeerId, RaknetServer, RaknetServerEvent};
17use crate::transport::{ShardedRuntimeConfig, TransportConfig};
18
19const DEFAULT_ACCEPT_QUEUE_CAPACITY: usize = 512;
20const DEFAULT_INBOUND_QUEUE_CAPACITY: usize = 256;
21const DEFAULT_COMMAND_QUEUE_CAPACITY: usize = 2048;
22
23struct ListenerRuntime {
24    command_tx: mpsc::Sender<ConnectionCommand>,
25    accept_rx: mpsc::Receiver<Connection>,
26    worker: JoinHandle<()>,
27}
28
29struct PeerRuntime {
30    addr: SocketAddr,
31    inbound_tx: mpsc::Sender<ConnectionInbound>,
32    shared: Arc<ConnectionSharedState>,
33}
34
35pub struct Incoming<'a> {
36    accept_rx: &'a mut mpsc::Receiver<Connection>,
37}
38
39pub struct Listener {
40    bind_addr: SocketAddr,
41    transport_config: TransportConfig,
42    runtime_config: ShardedRuntimeConfig,
43    accept_queue_capacity: usize,
44    inbound_queue_capacity: usize,
45    command_queue_capacity: usize,
46    runtime: Option<ListenerRuntime>,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct ListenerMetadata {
51    bind_addr: SocketAddr,
52    started: bool,
53    shard_count: usize,
54    advertisement: String,
55}
56
57impl ListenerMetadata {
58    pub const fn bind_addr(&self) -> SocketAddr {
59        self.bind_addr
60    }
61
62    pub const fn started(&self) -> bool {
63        self.started
64    }
65
66    pub const fn shard_count(&self) -> usize {
67        self.shard_count
68    }
69
70    pub fn advertisement(&self) -> &str {
71        &self.advertisement
72    }
73}
74
75impl Listener {
76    pub async fn bind(bind_addr: SocketAddr) -> Result<Self, ServerError> {
77        let transport_config = TransportConfig {
78            bind_addr,
79            ..TransportConfig::default()
80        };
81
82        Ok(Self {
83            bind_addr,
84            transport_config,
85            runtime_config: ShardedRuntimeConfig::default(),
86            accept_queue_capacity: DEFAULT_ACCEPT_QUEUE_CAPACITY,
87            inbound_queue_capacity: DEFAULT_INBOUND_QUEUE_CAPACITY,
88            command_queue_capacity: DEFAULT_COMMAND_QUEUE_CAPACITY,
89            runtime: None,
90        })
91    }
92
93    pub fn set_pong_data(&mut self, data: impl Into<String>) {
94        self.transport_config.advertisement = data.into();
95    }
96
97    pub fn pong_data(&self) -> &str {
98        &self.transport_config.advertisement
99    }
100
101    pub fn set_accept_queue_capacity(&mut self, capacity: usize) {
102        self.accept_queue_capacity = capacity.max(1);
103    }
104
105    pub fn set_inbound_queue_capacity(&mut self, capacity: usize) {
106        self.inbound_queue_capacity = capacity.max(1);
107    }
108
109    pub fn set_command_queue_capacity(&mut self, capacity: usize) {
110        self.command_queue_capacity = capacity.max(1);
111    }
112
113    pub fn set_shard_count(&mut self, shard_count: usize) {
114        self.runtime_config.shard_count = shard_count.max(1);
115    }
116
117    pub fn bind_addr(&self) -> SocketAddr {
118        self.bind_addr
119    }
120
121    pub fn metadata(&self) -> ListenerMetadata {
122        ListenerMetadata {
123            bind_addr: self.bind_addr,
124            started: self.runtime.is_some(),
125            shard_count: self.runtime_config.shard_count.max(1),
126            advertisement: self.transport_config.advertisement.clone(),
127        }
128    }
129
130    pub fn is_started(&self) -> bool {
131        self.runtime.is_some()
132    }
133
134    pub async fn start(&mut self) -> Result<(), ServerError> {
135        if self.runtime.is_some() {
136            return Err(ServerError::AlreadyStarted);
137        }
138
139        let mut transport_config = self.transport_config.clone();
140        transport_config.bind_addr = self.bind_addr;
141
142        transport_config.validate()?;
143        self.runtime_config.validate()?;
144
145        let server =
146            RaknetServer::start_with_configs(transport_config, self.runtime_config.clone())
147                .await
148                .map_err(ServerError::from)?;
149
150        let (accept_tx, accept_rx) = mpsc::channel(self.accept_queue_capacity.max(1));
151        let (command_tx, command_rx) = mpsc::channel(self.command_queue_capacity.max(1));
152        let worker_command_tx = command_tx.clone();
153        let inbound_queue_capacity = self.inbound_queue_capacity.max(1);
154
155        let worker = tokio::spawn(async move {
156            run_listener_worker(
157                server,
158                command_rx,
159                worker_command_tx,
160                accept_tx,
161                inbound_queue_capacity,
162            )
163            .await;
164        });
165
166        self.runtime = Some(ListenerRuntime {
167            command_tx,
168            accept_rx,
169            worker,
170        });
171
172        Ok(())
173    }
174
175    pub async fn stop(&mut self) -> Result<(), ServerError> {
176        let Some(runtime) = self.runtime.take() else {
177            return Ok(());
178        };
179
180        let (response_tx, response_rx) = oneshot::channel();
181        if runtime
182            .command_tx
183            .send(ConnectionCommand::Shutdown {
184                response: response_tx,
185            })
186            .await
187            .is_err()
188        {
189            let _ = runtime.worker.await;
190            return Err(ServerError::CommandChannelClosed);
191        }
192
193        let response = response_rx.await.map_err(|_| ServerError::WorkerStopped)?;
194        let _ = runtime.worker.await;
195        response.map_err(ServerError::from)
196    }
197
198    pub async fn accept(&mut self) -> Result<Connection, ServerError> {
199        self.accept_receiver()?
200            .recv()
201            .await
202            .ok_or(ServerError::AcceptChannelClosed)
203    }
204
205    pub fn incoming(&mut self) -> Result<Incoming<'_>, ServerError> {
206        let accept_rx = self.accept_receiver()?;
207        Ok(Incoming { accept_rx })
208    }
209
210    fn accept_receiver(&mut self) -> Result<&mut mpsc::Receiver<Connection>, ServerError> {
211        let runtime = self.runtime.as_mut().ok_or(ServerError::NotStarted)?;
212        Ok(&mut runtime.accept_rx)
213    }
214}
215
216impl Incoming<'_> {
217    pub async fn next(&mut self) -> Option<Connection> {
218        self.accept_rx.recv().await
219    }
220}
221
222impl Drop for Listener {
223    fn drop(&mut self) {
224        if let Some(runtime) = self.runtime.take() {
225            runtime.worker.abort();
226        }
227    }
228}
229
230async fn run_listener_worker(
231    mut server: RaknetServer,
232    mut command_rx: mpsc::Receiver<ConnectionCommand>,
233    command_tx: mpsc::Sender<ConnectionCommand>,
234    accept_tx: mpsc::Sender<Connection>,
235    inbound_queue_capacity: usize,
236) {
237    let mut peers: HashMap<PeerId, PeerRuntime> = HashMap::new();
238    let mut peer_ids_by_addr: HashMap<SocketAddr, PeerId> = HashMap::new();
239
240    loop {
241        tokio::select! {
242            command = command_rx.recv() => {
243                match command {
244                    Some(ConnectionCommand::Send { peer_id, payload, options, response }) => {
245                        let result = if peers.contains_key(&peer_id) {
246                            server.send_with_options(peer_id, payload, options).await
247                        } else {
248                            Err(io::Error::new(io::ErrorKind::NotFound, "peer not found"))
249                        };
250                        let _ = response.send(result);
251                    }
252                    Some(ConnectionCommand::Disconnect { peer_id, response }) => {
253                        let result = disconnect_peer(
254                            &mut server,
255                            &mut peers,
256                            &mut peer_ids_by_addr,
257                            peer_id,
258                            ConnectionCloseReason::RequestedByLocal,
259                        )
260                        .await;
261                        let _ = response.send(result);
262                    }
263                    Some(ConnectionCommand::DisconnectNoWait { peer_id }) => {
264                        let _ = disconnect_peer(
265                            &mut server,
266                            &mut peers,
267                            &mut peer_ids_by_addr,
268                            peer_id,
269                            ConnectionCloseReason::RequestedByLocal,
270                        )
271                        .await;
272                    }
273                    Some(ConnectionCommand::Shutdown { response }) => {
274                        for peer_id in peers.keys().copied().collect::<Vec<_>>() {
275                            let _ = server.disconnect(peer_id).await;
276                        }
277
278                        close_all_peers(&mut peers, &mut peer_ids_by_addr, ConnectionCloseReason::ListenerStopped);
279                        let result = server.shutdown().await;
280                        let _ = response.send(result);
281                        break;
282                    }
283                    None => {
284                        close_all_peers(&mut peers, &mut peer_ids_by_addr, ConnectionCloseReason::ListenerStopped);
285                        let _ = server.shutdown().await;
286                        break;
287                    }
288                }
289            }
290            server_event = server.next_event() => {
291                let Some(server_event) = server_event else {
292                    close_all_peers(&mut peers, &mut peer_ids_by_addr, ConnectionCloseReason::ListenerStopped);
293                    break;
294                };
295
296                match server_event {
297                    RaknetServerEvent::PeerConnected { peer_id, addr, .. } => {
298                        if let Some(existing) = peers.remove(&peer_id) {
299                            peer_ids_by_addr.remove(&existing.addr);
300                            close_peer_entry(existing, ConnectionCloseReason::RequestedByLocal);
301                        }
302
303                        let shared = Arc::new(ConnectionSharedState::new());
304                        let (inbound_tx, inbound_rx) = mpsc::channel(inbound_queue_capacity.max(1));
305                        let connection = Connection::new(
306                            peer_id,
307                            addr,
308                            command_tx.clone(),
309                            inbound_rx,
310                            Arc::clone(&shared),
311                        );
312
313                        peers.insert(
314                            peer_id,
315                            PeerRuntime {
316                                addr,
317                                inbound_tx,
318                                shared,
319                            },
320                        );
321                        peer_ids_by_addr.insert(addr, peer_id);
322
323                        if let Err(err) = accept_tx.try_send(connection) {
324                            match err {
325                                TrySendError::Full(conn) => {
326                                    let _ = disconnect_peer(
327                                        &mut server,
328                                        &mut peers,
329                                        &mut peer_ids_by_addr,
330                                        conn.peer_id(),
331                                        ConnectionCloseReason::InboundBackpressure,
332                                    )
333                                    .await;
334                                }
335                                TrySendError::Closed(conn) => {
336                                    let _ = disconnect_peer(
337                                        &mut server,
338                                        &mut peers,
339                                        &mut peer_ids_by_addr,
340                                        conn.peer_id(),
341                                        ConnectionCloseReason::ListenerStopped,
342                                    )
343                                    .await;
344                                    close_all_peers(
345                                        &mut peers,
346                                        &mut peer_ids_by_addr,
347                                        ConnectionCloseReason::ListenerStopped,
348                                    );
349                                    let _ = server.shutdown().await;
350                                    break;
351                                }
352                            }
353                        }
354                    }
355                    RaknetServerEvent::PeerDisconnected { peer_id, reason, .. } => {
356                        if let Some(entry) = remove_peer(&mut peers, &mut peer_ids_by_addr, peer_id) {
357                            close_peer_entry(
358                                entry,
359                                ConnectionCloseReason::PeerDisconnected(
360                                    RemoteDisconnectReason::from(reason),
361                                ),
362                            );
363                        }
364                    }
365                    RaknetServerEvent::Packet { peer_id, payload, .. } => {
366                        if let Some(entry) = peers.get(&peer_id) {
367                            match entry.inbound_tx.try_send(ConnectionInbound::Packet(payload)) {
368                                Ok(()) => {}
369                                Err(TrySendError::Full(_)) => {
370                                    let _ = disconnect_peer(
371                                        &mut server,
372                                        &mut peers,
373                                        &mut peer_ids_by_addr,
374                                        peer_id,
375                                        ConnectionCloseReason::InboundBackpressure,
376                                    )
377                                    .await;
378                                }
379                                Err(TrySendError::Closed(_)) => {
380                                    let _ = disconnect_peer(
381                                        &mut server,
382                                        &mut peers,
383                                        &mut peer_ids_by_addr,
384                                        peer_id,
385                                        ConnectionCloseReason::ListenerStopped,
386                                    )
387                                    .await;
388                                }
389                            }
390                        }
391                    }
392                    RaknetServerEvent::DecodeError { addr, error } => {
393                        if let Some(peer_id) = peer_ids_by_addr.get(&addr).copied()
394                            && let Some(entry) = peers.get(&peer_id)
395                        {
396                            let _ = entry
397                                .inbound_tx
398                                .try_send(ConnectionInbound::DecodeError(error));
399                        }
400                    }
401                    RaknetServerEvent::PeerRateLimited { .. }
402                    | RaknetServerEvent::SessionLimitReached { .. }
403                    | RaknetServerEvent::ProxyDropped { .. }
404                    | RaknetServerEvent::OfflinePacket { .. }
405                    | RaknetServerEvent::ReceiptAcked { .. }
406                    | RaknetServerEvent::WorkerError { .. }
407                    | RaknetServerEvent::WorkerStopped { .. }
408                    | RaknetServerEvent::Metrics { .. } => {}
409                }
410            }
411        }
412    }
413
414    drop(accept_tx);
415}
416
417fn remove_peer(
418    peers: &mut HashMap<PeerId, PeerRuntime>,
419    peer_ids_by_addr: &mut HashMap<SocketAddr, PeerId>,
420    peer_id: PeerId,
421) -> Option<PeerRuntime> {
422    let entry = peers.remove(&peer_id)?;
423    peer_ids_by_addr.remove(&entry.addr);
424    Some(entry)
425}
426
427async fn disconnect_peer(
428    server: &mut RaknetServer,
429    peers: &mut HashMap<PeerId, PeerRuntime>,
430    peer_ids_by_addr: &mut HashMap<SocketAddr, PeerId>,
431    peer_id: PeerId,
432    reason: ConnectionCloseReason,
433) -> io::Result<()> {
434    let result = server.disconnect(peer_id).await;
435    if let Some(entry) = remove_peer(peers, peer_ids_by_addr, peer_id) {
436        close_peer_entry(entry, reason);
437    }
438    result
439}
440
441fn close_all_peers(
442    peers: &mut HashMap<PeerId, PeerRuntime>,
443    peer_ids_by_addr: &mut HashMap<SocketAddr, PeerId>,
444    reason: ConnectionCloseReason,
445) {
446    peer_ids_by_addr.clear();
447    for (_, entry) in peers.drain() {
448        close_peer_entry(entry, reason.clone());
449    }
450}
451
452fn close_peer_entry(entry: PeerRuntime, reason: ConnectionCloseReason) {
453    entry.shared.mark_closed(reason.clone());
454    let _ = entry.inbound_tx.try_send(ConnectionInbound::Closed(reason));
455}