1use std::collections::HashMap;
2use std::io;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use tokio::sync::mpsc;
7use tokio::sync::mpsc::error::TrySendError;
8use tokio::sync::oneshot;
9use tokio::task::JoinHandle;
10
11use crate::connection::{
12 Connection, ConnectionCloseReason, ConnectionCommand, ConnectionInbound, ConnectionSharedState,
13 RemoteDisconnectReason,
14};
15use crate::error::server::ServerError;
16use crate::server::{PeerId, RaknetServer, RaknetServerEvent};
17use crate::transport::{ShardedRuntimeConfig, TransportConfig};
18
19const DEFAULT_ACCEPT_QUEUE_CAPACITY: usize = 512;
20const DEFAULT_INBOUND_QUEUE_CAPACITY: usize = 256;
21const DEFAULT_COMMAND_QUEUE_CAPACITY: usize = 2048;
22
23struct ListenerRuntime {
24 command_tx: mpsc::Sender<ConnectionCommand>,
25 accept_rx: mpsc::Receiver<Connection>,
26 worker: JoinHandle<()>,
27}
28
29struct PeerRuntime {
30 addr: SocketAddr,
31 inbound_tx: mpsc::Sender<ConnectionInbound>,
32 shared: Arc<ConnectionSharedState>,
33}
34
35pub struct Incoming<'a> {
36 accept_rx: &'a mut mpsc::Receiver<Connection>,
37}
38
39pub struct Listener {
40 bind_addr: SocketAddr,
41 transport_config: TransportConfig,
42 runtime_config: ShardedRuntimeConfig,
43 accept_queue_capacity: usize,
44 inbound_queue_capacity: usize,
45 command_queue_capacity: usize,
46 runtime: Option<ListenerRuntime>,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct ListenerMetadata {
51 bind_addr: SocketAddr,
52 started: bool,
53 shard_count: usize,
54 advertisement: String,
55}
56
57impl ListenerMetadata {
58 pub const fn bind_addr(&self) -> SocketAddr {
59 self.bind_addr
60 }
61
62 pub const fn started(&self) -> bool {
63 self.started
64 }
65
66 pub const fn shard_count(&self) -> usize {
67 self.shard_count
68 }
69
70 pub fn advertisement(&self) -> &str {
71 &self.advertisement
72 }
73}
74
75impl Listener {
76 pub async fn bind(bind_addr: SocketAddr) -> Result<Self, ServerError> {
77 let transport_config = TransportConfig {
78 bind_addr,
79 ..TransportConfig::default()
80 };
81
82 Ok(Self {
83 bind_addr,
84 transport_config,
85 runtime_config: ShardedRuntimeConfig::default(),
86 accept_queue_capacity: DEFAULT_ACCEPT_QUEUE_CAPACITY,
87 inbound_queue_capacity: DEFAULT_INBOUND_QUEUE_CAPACITY,
88 command_queue_capacity: DEFAULT_COMMAND_QUEUE_CAPACITY,
89 runtime: None,
90 })
91 }
92
93 pub fn set_pong_data(&mut self, data: impl Into<String>) {
94 self.transport_config.advertisement = data.into();
95 }
96
97 pub fn pong_data(&self) -> &str {
98 &self.transport_config.advertisement
99 }
100
101 pub fn set_accept_queue_capacity(&mut self, capacity: usize) {
102 self.accept_queue_capacity = capacity.max(1);
103 }
104
105 pub fn set_inbound_queue_capacity(&mut self, capacity: usize) {
106 self.inbound_queue_capacity = capacity.max(1);
107 }
108
109 pub fn set_command_queue_capacity(&mut self, capacity: usize) {
110 self.command_queue_capacity = capacity.max(1);
111 }
112
113 pub fn set_shard_count(&mut self, shard_count: usize) {
114 self.runtime_config.shard_count = shard_count.max(1);
115 }
116
117 pub fn bind_addr(&self) -> SocketAddr {
118 self.bind_addr
119 }
120
121 pub fn metadata(&self) -> ListenerMetadata {
122 ListenerMetadata {
123 bind_addr: self.bind_addr,
124 started: self.runtime.is_some(),
125 shard_count: self.runtime_config.shard_count.max(1),
126 advertisement: self.transport_config.advertisement.clone(),
127 }
128 }
129
130 pub fn is_started(&self) -> bool {
131 self.runtime.is_some()
132 }
133
134 pub async fn start(&mut self) -> Result<(), ServerError> {
135 if self.runtime.is_some() {
136 return Err(ServerError::AlreadyStarted);
137 }
138
139 let mut transport_config = self.transport_config.clone();
140 transport_config.bind_addr = self.bind_addr;
141
142 transport_config.validate()?;
143 self.runtime_config.validate()?;
144
145 let server =
146 RaknetServer::start_with_configs(transport_config, self.runtime_config.clone())
147 .await
148 .map_err(ServerError::from)?;
149
150 let (accept_tx, accept_rx) = mpsc::channel(self.accept_queue_capacity.max(1));
151 let (command_tx, command_rx) = mpsc::channel(self.command_queue_capacity.max(1));
152 let worker_command_tx = command_tx.clone();
153 let inbound_queue_capacity = self.inbound_queue_capacity.max(1);
154
155 let worker = tokio::spawn(async move {
156 run_listener_worker(
157 server,
158 command_rx,
159 worker_command_tx,
160 accept_tx,
161 inbound_queue_capacity,
162 )
163 .await;
164 });
165
166 self.runtime = Some(ListenerRuntime {
167 command_tx,
168 accept_rx,
169 worker,
170 });
171
172 Ok(())
173 }
174
175 pub async fn stop(&mut self) -> Result<(), ServerError> {
176 let Some(runtime) = self.runtime.take() else {
177 return Ok(());
178 };
179
180 let (response_tx, response_rx) = oneshot::channel();
181 if runtime
182 .command_tx
183 .send(ConnectionCommand::Shutdown {
184 response: response_tx,
185 })
186 .await
187 .is_err()
188 {
189 let _ = runtime.worker.await;
190 return Err(ServerError::CommandChannelClosed);
191 }
192
193 let response = response_rx.await.map_err(|_| ServerError::WorkerStopped)?;
194 let _ = runtime.worker.await;
195 response.map_err(ServerError::from)
196 }
197
198 pub async fn accept(&mut self) -> Result<Connection, ServerError> {
199 self.accept_receiver()?
200 .recv()
201 .await
202 .ok_or(ServerError::AcceptChannelClosed)
203 }
204
205 pub fn incoming(&mut self) -> Result<Incoming<'_>, ServerError> {
206 let accept_rx = self.accept_receiver()?;
207 Ok(Incoming { accept_rx })
208 }
209
210 fn accept_receiver(&mut self) -> Result<&mut mpsc::Receiver<Connection>, ServerError> {
211 let runtime = self.runtime.as_mut().ok_or(ServerError::NotStarted)?;
212 Ok(&mut runtime.accept_rx)
213 }
214}
215
216impl Incoming<'_> {
217 pub async fn next(&mut self) -> Option<Connection> {
218 self.accept_rx.recv().await
219 }
220}
221
222impl Drop for Listener {
223 fn drop(&mut self) {
224 if let Some(runtime) = self.runtime.take() {
225 runtime.worker.abort();
226 }
227 }
228}
229
230async fn run_listener_worker(
231 mut server: RaknetServer,
232 mut command_rx: mpsc::Receiver<ConnectionCommand>,
233 command_tx: mpsc::Sender<ConnectionCommand>,
234 accept_tx: mpsc::Sender<Connection>,
235 inbound_queue_capacity: usize,
236) {
237 let mut peers: HashMap<PeerId, PeerRuntime> = HashMap::new();
238 let mut peer_ids_by_addr: HashMap<SocketAddr, PeerId> = HashMap::new();
239
240 loop {
241 tokio::select! {
242 command = command_rx.recv() => {
243 match command {
244 Some(ConnectionCommand::Send { peer_id, payload, options, response }) => {
245 let result = if peers.contains_key(&peer_id) {
246 server.send_with_options(peer_id, payload, options).await
247 } else {
248 Err(io::Error::new(io::ErrorKind::NotFound, "peer not found"))
249 };
250 let _ = response.send(result);
251 }
252 Some(ConnectionCommand::Disconnect { peer_id, response }) => {
253 let result = disconnect_peer(
254 &mut server,
255 &mut peers,
256 &mut peer_ids_by_addr,
257 peer_id,
258 ConnectionCloseReason::RequestedByLocal,
259 )
260 .await;
261 let _ = response.send(result);
262 }
263 Some(ConnectionCommand::DisconnectNoWait { peer_id }) => {
264 let _ = disconnect_peer(
265 &mut server,
266 &mut peers,
267 &mut peer_ids_by_addr,
268 peer_id,
269 ConnectionCloseReason::RequestedByLocal,
270 )
271 .await;
272 }
273 Some(ConnectionCommand::Shutdown { response }) => {
274 for peer_id in peers.keys().copied().collect::<Vec<_>>() {
275 let _ = server.disconnect(peer_id).await;
276 }
277
278 close_all_peers(&mut peers, &mut peer_ids_by_addr, ConnectionCloseReason::ListenerStopped);
279 let result = server.shutdown().await;
280 let _ = response.send(result);
281 break;
282 }
283 None => {
284 close_all_peers(&mut peers, &mut peer_ids_by_addr, ConnectionCloseReason::ListenerStopped);
285 let _ = server.shutdown().await;
286 break;
287 }
288 }
289 }
290 server_event = server.next_event() => {
291 let Some(server_event) = server_event else {
292 close_all_peers(&mut peers, &mut peer_ids_by_addr, ConnectionCloseReason::ListenerStopped);
293 break;
294 };
295
296 match server_event {
297 RaknetServerEvent::PeerConnected { peer_id, addr, .. } => {
298 if let Some(existing) = peers.remove(&peer_id) {
299 peer_ids_by_addr.remove(&existing.addr);
300 close_peer_entry(existing, ConnectionCloseReason::RequestedByLocal);
301 }
302
303 let shared = Arc::new(ConnectionSharedState::new());
304 let (inbound_tx, inbound_rx) = mpsc::channel(inbound_queue_capacity.max(1));
305 let connection = Connection::new(
306 peer_id,
307 addr,
308 command_tx.clone(),
309 inbound_rx,
310 Arc::clone(&shared),
311 );
312
313 peers.insert(
314 peer_id,
315 PeerRuntime {
316 addr,
317 inbound_tx,
318 shared,
319 },
320 );
321 peer_ids_by_addr.insert(addr, peer_id);
322
323 if let Err(err) = accept_tx.try_send(connection) {
324 match err {
325 TrySendError::Full(conn) => {
326 let _ = disconnect_peer(
327 &mut server,
328 &mut peers,
329 &mut peer_ids_by_addr,
330 conn.peer_id(),
331 ConnectionCloseReason::InboundBackpressure,
332 )
333 .await;
334 }
335 TrySendError::Closed(conn) => {
336 let _ = disconnect_peer(
337 &mut server,
338 &mut peers,
339 &mut peer_ids_by_addr,
340 conn.peer_id(),
341 ConnectionCloseReason::ListenerStopped,
342 )
343 .await;
344 close_all_peers(
345 &mut peers,
346 &mut peer_ids_by_addr,
347 ConnectionCloseReason::ListenerStopped,
348 );
349 let _ = server.shutdown().await;
350 break;
351 }
352 }
353 }
354 }
355 RaknetServerEvent::PeerDisconnected { peer_id, reason, .. } => {
356 if let Some(entry) = remove_peer(&mut peers, &mut peer_ids_by_addr, peer_id) {
357 close_peer_entry(
358 entry,
359 ConnectionCloseReason::PeerDisconnected(
360 RemoteDisconnectReason::from(reason),
361 ),
362 );
363 }
364 }
365 RaknetServerEvent::Packet { peer_id, payload, .. } => {
366 if let Some(entry) = peers.get(&peer_id) {
367 match entry.inbound_tx.try_send(ConnectionInbound::Packet(payload)) {
368 Ok(()) => {}
369 Err(TrySendError::Full(_)) => {
370 let _ = disconnect_peer(
371 &mut server,
372 &mut peers,
373 &mut peer_ids_by_addr,
374 peer_id,
375 ConnectionCloseReason::InboundBackpressure,
376 )
377 .await;
378 }
379 Err(TrySendError::Closed(_)) => {
380 let _ = disconnect_peer(
381 &mut server,
382 &mut peers,
383 &mut peer_ids_by_addr,
384 peer_id,
385 ConnectionCloseReason::ListenerStopped,
386 )
387 .await;
388 }
389 }
390 }
391 }
392 RaknetServerEvent::DecodeError { addr, error } => {
393 if let Some(peer_id) = peer_ids_by_addr.get(&addr).copied()
394 && let Some(entry) = peers.get(&peer_id)
395 {
396 let _ = entry
397 .inbound_tx
398 .try_send(ConnectionInbound::DecodeError(error));
399 }
400 }
401 RaknetServerEvent::PeerRateLimited { .. }
402 | RaknetServerEvent::SessionLimitReached { .. }
403 | RaknetServerEvent::ProxyDropped { .. }
404 | RaknetServerEvent::OfflinePacket { .. }
405 | RaknetServerEvent::ReceiptAcked { .. }
406 | RaknetServerEvent::WorkerError { .. }
407 | RaknetServerEvent::WorkerStopped { .. }
408 | RaknetServerEvent::Metrics { .. } => {}
409 }
410 }
411 }
412 }
413
414 drop(accept_tx);
415}
416
417fn remove_peer(
418 peers: &mut HashMap<PeerId, PeerRuntime>,
419 peer_ids_by_addr: &mut HashMap<SocketAddr, PeerId>,
420 peer_id: PeerId,
421) -> Option<PeerRuntime> {
422 let entry = peers.remove(&peer_id)?;
423 peer_ids_by_addr.remove(&entry.addr);
424 Some(entry)
425}
426
427async fn disconnect_peer(
428 server: &mut RaknetServer,
429 peers: &mut HashMap<PeerId, PeerRuntime>,
430 peer_ids_by_addr: &mut HashMap<SocketAddr, PeerId>,
431 peer_id: PeerId,
432 reason: ConnectionCloseReason,
433) -> io::Result<()> {
434 let result = server.disconnect(peer_id).await;
435 if let Some(entry) = remove_peer(peers, peer_ids_by_addr, peer_id) {
436 close_peer_entry(entry, reason);
437 }
438 result
439}
440
441fn close_all_peers(
442 peers: &mut HashMap<PeerId, PeerRuntime>,
443 peer_ids_by_addr: &mut HashMap<SocketAddr, PeerId>,
444 reason: ConnectionCloseReason,
445) {
446 peer_ids_by_addr.clear();
447 for (_, entry) in peers.drain() {
448 close_peer_entry(entry, reason.clone());
449 }
450}
451
452fn close_peer_entry(entry: PeerRuntime, reason: ConnectionCloseReason) {
453 entry.shared.mark_closed(reason.clone());
454 let _ = entry.inbound_tx.try_send(ConnectionInbound::Closed(reason));
455}