1use std::collections::HashMap;
6use std::io;
7use std::net::SocketAddr;
8use std::sync::Arc;
9
10use tokio::sync::mpsc;
11use tokio::sync::mpsc::error::TrySendError;
12use tokio::sync::oneshot;
13use tokio::task::JoinHandle;
14
15use crate::connection::{
16 Connection, ConnectionCloseReason, ConnectionCommand, ConnectionInbound, ConnectionSharedState,
17 RemoteDisconnectReason,
18};
19use crate::error::server::ServerError;
20use crate::server::{PeerId, RaknetServer, RaknetServerEvent};
21use crate::transport::{ShardedRuntimeConfig, TransportConfig};
22
23const DEFAULT_ACCEPT_QUEUE_CAPACITY: usize = 512;
24const DEFAULT_INBOUND_QUEUE_CAPACITY: usize = 256;
25const DEFAULT_COMMAND_QUEUE_CAPACITY: usize = 2048;
26
27struct ListenerRuntime {
28 command_tx: mpsc::Sender<ConnectionCommand>,
29 accept_rx: mpsc::Receiver<Connection>,
30 worker: JoinHandle<()>,
31}
32
33struct PeerRuntime {
34 addr: SocketAddr,
35 inbound_tx: mpsc::Sender<ConnectionInbound>,
36 shared: Arc<ConnectionSharedState>,
37}
38
39pub struct Incoming<'a> {
41 accept_rx: &'a mut mpsc::Receiver<Connection>,
42}
43
44pub struct Listener {
46 bind_addr: SocketAddr,
47 transport_config: TransportConfig,
48 runtime_config: ShardedRuntimeConfig,
49 accept_queue_capacity: usize,
50 inbound_queue_capacity: usize,
51 command_queue_capacity: usize,
52 runtime: Option<ListenerRuntime>,
53}
54
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub struct ListenerMetadata {
58 bind_addr: SocketAddr,
59 started: bool,
60 shard_count: usize,
61 advertisement: String,
62}
63
64impl ListenerMetadata {
65 pub const fn bind_addr(&self) -> SocketAddr {
67 self.bind_addr
68 }
69
70 pub const fn started(&self) -> bool {
72 self.started
73 }
74
75 pub const fn shard_count(&self) -> usize {
77 self.shard_count
78 }
79
80 pub fn advertisement(&self) -> &str {
82 &self.advertisement
83 }
84}
85
86impl Listener {
87 pub async fn bind(bind_addr: SocketAddr) -> Result<Self, ServerError> {
89 let transport_config = TransportConfig {
90 bind_addr,
91 ..TransportConfig::default()
92 };
93
94 Ok(Self {
95 bind_addr,
96 transport_config,
97 runtime_config: ShardedRuntimeConfig::default(),
98 accept_queue_capacity: DEFAULT_ACCEPT_QUEUE_CAPACITY,
99 inbound_queue_capacity: DEFAULT_INBOUND_QUEUE_CAPACITY,
100 command_queue_capacity: DEFAULT_COMMAND_QUEUE_CAPACITY,
101 runtime: None,
102 })
103 }
104
105 pub fn set_pong_data(&mut self, data: impl Into<String>) {
107 self.transport_config.advertisement = data.into();
108 }
109
110 pub fn pong_data(&self) -> &str {
112 &self.transport_config.advertisement
113 }
114
115 pub fn set_accept_queue_capacity(&mut self, capacity: usize) {
117 self.accept_queue_capacity = capacity.max(1);
118 }
119
120 pub fn set_inbound_queue_capacity(&mut self, capacity: usize) {
122 self.inbound_queue_capacity = capacity.max(1);
123 }
124
125 pub fn set_command_queue_capacity(&mut self, capacity: usize) {
127 self.command_queue_capacity = capacity.max(1);
128 }
129
130 pub fn set_shard_count(&mut self, shard_count: usize) {
132 self.runtime_config.shard_count = shard_count.max(1);
133 }
134
135 pub fn bind_addr(&self) -> SocketAddr {
137 self.bind_addr
138 }
139
140 pub fn metadata(&self) -> ListenerMetadata {
142 ListenerMetadata {
143 bind_addr: self.bind_addr,
144 started: self.runtime.is_some(),
145 shard_count: self.runtime_config.shard_count.max(1),
146 advertisement: self.transport_config.advertisement.clone(),
147 }
148 }
149
150 pub fn is_started(&self) -> bool {
152 self.runtime.is_some()
153 }
154
155 pub async fn start(&mut self) -> Result<(), ServerError> {
157 if self.runtime.is_some() {
158 return Err(ServerError::AlreadyStarted);
159 }
160
161 let mut transport_config = self.transport_config.clone();
162 transport_config.bind_addr = self.bind_addr;
163
164 transport_config.validate()?;
165 self.runtime_config.validate()?;
166
167 let server =
168 RaknetServer::start_with_configs(transport_config, self.runtime_config.clone())
169 .await
170 .map_err(ServerError::from)?;
171
172 let (accept_tx, accept_rx) = mpsc::channel(self.accept_queue_capacity.max(1));
173 let (command_tx, command_rx) = mpsc::channel(self.command_queue_capacity.max(1));
174 let worker_command_tx = command_tx.clone();
175 let inbound_queue_capacity = self.inbound_queue_capacity.max(1);
176
177 let worker = tokio::spawn(async move {
178 run_listener_worker(
179 server,
180 command_rx,
181 worker_command_tx,
182 accept_tx,
183 inbound_queue_capacity,
184 )
185 .await;
186 });
187
188 self.runtime = Some(ListenerRuntime {
189 command_tx,
190 accept_rx,
191 worker,
192 });
193
194 Ok(())
195 }
196
197 pub async fn stop(&mut self) -> Result<(), ServerError> {
199 let Some(runtime) = self.runtime.take() else {
200 return Ok(());
201 };
202
203 let (response_tx, response_rx) = oneshot::channel();
204 if runtime
205 .command_tx
206 .send(ConnectionCommand::Shutdown {
207 response: response_tx,
208 })
209 .await
210 .is_err()
211 {
212 let _ = runtime.worker.await;
213 return Err(ServerError::CommandChannelClosed);
214 }
215
216 let response = response_rx.await.map_err(|_| ServerError::WorkerStopped)?;
217 let _ = runtime.worker.await;
218 response.map_err(ServerError::from)
219 }
220
221 pub async fn accept(&mut self) -> Result<Connection, ServerError> {
223 self.accept_receiver()?
224 .recv()
225 .await
226 .ok_or(ServerError::AcceptChannelClosed)
227 }
228
229 pub fn incoming(&mut self) -> Result<Incoming<'_>, ServerError> {
231 let accept_rx = self.accept_receiver()?;
232 Ok(Incoming { accept_rx })
233 }
234
235 fn accept_receiver(&mut self) -> Result<&mut mpsc::Receiver<Connection>, ServerError> {
236 let runtime = self.runtime.as_mut().ok_or(ServerError::NotStarted)?;
237 Ok(&mut runtime.accept_rx)
238 }
239}
240
241impl Incoming<'_> {
242 pub async fn next(&mut self) -> Option<Connection> {
244 self.accept_rx.recv().await
245 }
246}
247
248impl Drop for Listener {
249 fn drop(&mut self) {
250 if let Some(runtime) = self.runtime.take() {
251 runtime.worker.abort();
252 }
253 }
254}
255
256async fn run_listener_worker(
257 mut server: RaknetServer,
258 mut command_rx: mpsc::Receiver<ConnectionCommand>,
259 command_tx: mpsc::Sender<ConnectionCommand>,
260 accept_tx: mpsc::Sender<Connection>,
261 inbound_queue_capacity: usize,
262) {
263 let mut peers: HashMap<PeerId, PeerRuntime> = HashMap::new();
264 let mut peer_ids_by_addr: HashMap<SocketAddr, PeerId> = HashMap::new();
265
266 loop {
267 tokio::select! {
268 command = command_rx.recv() => {
269 match command {
270 Some(ConnectionCommand::Send { peer_id, payload, options, response }) => {
271 let result = if peers.contains_key(&peer_id) {
272 server.send_with_options(peer_id, payload, options).await
273 } else {
274 Err(io::Error::new(io::ErrorKind::NotFound, "peer not found"))
275 };
276 let _ = response.send(result);
277 }
278 Some(ConnectionCommand::Disconnect { peer_id, response }) => {
279 let result = disconnect_peer(
280 &mut server,
281 &mut peers,
282 &mut peer_ids_by_addr,
283 peer_id,
284 ConnectionCloseReason::RequestedByLocal,
285 )
286 .await;
287 let _ = response.send(result);
288 }
289 Some(ConnectionCommand::DisconnectNoWait { peer_id }) => {
290 let _ = disconnect_peer(
291 &mut server,
292 &mut peers,
293 &mut peer_ids_by_addr,
294 peer_id,
295 ConnectionCloseReason::RequestedByLocal,
296 )
297 .await;
298 }
299 Some(ConnectionCommand::Shutdown { response }) => {
300 for peer_id in peers.keys().copied().collect::<Vec<_>>() {
301 let _ = server.disconnect(peer_id).await;
302 }
303
304 close_all_peers(&mut peers, &mut peer_ids_by_addr, ConnectionCloseReason::ListenerStopped);
305 let result = server.shutdown().await;
306 let _ = response.send(result);
307 break;
308 }
309 None => {
310 close_all_peers(&mut peers, &mut peer_ids_by_addr, ConnectionCloseReason::ListenerStopped);
311 let _ = server.shutdown().await;
312 break;
313 }
314 }
315 }
316 server_event = server.next_event() => {
317 let Some(server_event) = server_event else {
318 close_all_peers(&mut peers, &mut peer_ids_by_addr, ConnectionCloseReason::ListenerStopped);
319 break;
320 };
321
322 match server_event {
323 RaknetServerEvent::PeerConnected { peer_id, addr, .. } => {
324 if let Some(existing) = peers.remove(&peer_id) {
325 peer_ids_by_addr.remove(&existing.addr);
326 close_peer_entry(existing, ConnectionCloseReason::RequestedByLocal);
327 }
328
329 let shared = Arc::new(ConnectionSharedState::new());
330 let (inbound_tx, inbound_rx) = mpsc::channel(inbound_queue_capacity.max(1));
331 let connection = Connection::new(
332 peer_id,
333 addr,
334 command_tx.clone(),
335 inbound_rx,
336 Arc::clone(&shared),
337 );
338
339 peers.insert(
340 peer_id,
341 PeerRuntime {
342 addr,
343 inbound_tx,
344 shared,
345 },
346 );
347 peer_ids_by_addr.insert(addr, peer_id);
348
349 if let Err(err) = accept_tx.try_send(connection) {
350 match err {
351 TrySendError::Full(conn) => {
352 let _ = disconnect_peer(
353 &mut server,
354 &mut peers,
355 &mut peer_ids_by_addr,
356 conn.peer_id(),
357 ConnectionCloseReason::InboundBackpressure,
358 )
359 .await;
360 }
361 TrySendError::Closed(conn) => {
362 let _ = disconnect_peer(
363 &mut server,
364 &mut peers,
365 &mut peer_ids_by_addr,
366 conn.peer_id(),
367 ConnectionCloseReason::ListenerStopped,
368 )
369 .await;
370 close_all_peers(
371 &mut peers,
372 &mut peer_ids_by_addr,
373 ConnectionCloseReason::ListenerStopped,
374 );
375 let _ = server.shutdown().await;
376 break;
377 }
378 }
379 }
380 }
381 RaknetServerEvent::PeerDisconnected { peer_id, reason, .. } => {
382 if let Some(entry) = remove_peer(&mut peers, &mut peer_ids_by_addr, peer_id) {
383 close_peer_entry(
384 entry,
385 ConnectionCloseReason::PeerDisconnected(
386 RemoteDisconnectReason::from(reason),
387 ),
388 );
389 }
390 }
391 RaknetServerEvent::Packet { peer_id, payload, .. } => {
392 if let Some(entry) = peers.get(&peer_id) {
393 match entry.inbound_tx.try_send(ConnectionInbound::Packet(payload)) {
394 Ok(()) => {}
395 Err(TrySendError::Full(_)) => {
396 let _ = disconnect_peer(
397 &mut server,
398 &mut peers,
399 &mut peer_ids_by_addr,
400 peer_id,
401 ConnectionCloseReason::InboundBackpressure,
402 )
403 .await;
404 }
405 Err(TrySendError::Closed(_)) => {
406 let _ = disconnect_peer(
407 &mut server,
408 &mut peers,
409 &mut peer_ids_by_addr,
410 peer_id,
411 ConnectionCloseReason::ListenerStopped,
412 )
413 .await;
414 }
415 }
416 }
417 }
418 RaknetServerEvent::DecodeError { addr, error } => {
419 if let Some(peer_id) = peer_ids_by_addr.get(&addr).copied()
420 && let Some(entry) = peers.get(&peer_id)
421 {
422 let _ = entry
423 .inbound_tx
424 .try_send(ConnectionInbound::DecodeError(error));
425 }
426 }
427 RaknetServerEvent::PeerRateLimited { .. }
428 | RaknetServerEvent::SessionLimitReached { .. }
429 | RaknetServerEvent::ProxyDropped { .. }
430 | RaknetServerEvent::OfflinePacket { .. }
431 | RaknetServerEvent::ReceiptAcked { .. }
432 | RaknetServerEvent::WorkerError { .. }
433 | RaknetServerEvent::WorkerStopped { .. }
434 | RaknetServerEvent::Metrics { .. } => {}
435 }
436 }
437 }
438 }
439
440 drop(accept_tx);
441}
442
443fn remove_peer(
444 peers: &mut HashMap<PeerId, PeerRuntime>,
445 peer_ids_by_addr: &mut HashMap<SocketAddr, PeerId>,
446 peer_id: PeerId,
447) -> Option<PeerRuntime> {
448 let entry = peers.remove(&peer_id)?;
449 peer_ids_by_addr.remove(&entry.addr);
450 Some(entry)
451}
452
453async fn disconnect_peer(
454 server: &mut RaknetServer,
455 peers: &mut HashMap<PeerId, PeerRuntime>,
456 peer_ids_by_addr: &mut HashMap<SocketAddr, PeerId>,
457 peer_id: PeerId,
458 reason: ConnectionCloseReason,
459) -> io::Result<()> {
460 let result = server.disconnect(peer_id).await;
461 if let Some(entry) = remove_peer(peers, peer_ids_by_addr, peer_id) {
462 close_peer_entry(entry, reason);
463 }
464 result
465}
466
467fn close_all_peers(
468 peers: &mut HashMap<PeerId, PeerRuntime>,
469 peer_ids_by_addr: &mut HashMap<SocketAddr, PeerId>,
470 reason: ConnectionCloseReason,
471) {
472 peer_ids_by_addr.clear();
473 for (_, entry) in peers.drain() {
474 close_peer_entry(entry, reason.clone());
475 }
476}
477
478fn close_peer_entry(entry: PeerRuntime, reason: ConnectionCloseReason) {
479 entry.shared.mark_closed(reason.clone());
480 let _ = entry.inbound_tx.try_send(ConnectionInbound::Closed(reason));
481}