1#![warn(clippy::all, clippy::pedantic)]
9#![allow(clippy::too_many_lines)]
10#![allow(clippy::duration_suboptimal_units)]
13#![cfg(not(target_arch = "wasm32"))]
14
15use std::net::SocketAddr;
16use std::sync::Mutex;
17use std::time::{Duration, Instant};
18
19use async_trait::async_trait;
20use renet::{ChannelConfig, ConnectionConfig, RenetServer, SendType, ServerEvent};
21use renet_netcode::{NetcodeServerTransport, ServerConfig};
22use socket2::{Domain, Socket, Type};
23
24use aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE;
25use aetheris_protocol::error::TransportError;
26use aetheris_protocol::events::NetworkEvent;
27use aetheris_protocol::traits::PlatformTransport;
28use aetheris_protocol::types::ClientId;
29
30pub struct RenetTransport {
32 server: Mutex<RenetServer>,
33 transport: Mutex<NetcodeServerTransport>,
34 last_update: Mutex<Instant>,
35 local_addr: SocketAddr,
36 rate_limiter: Mutex<IpRateLimiter>,
37 max_payload_size: usize,
38 last_prune: Mutex<Instant>,
39 suppressed_disconnects: Mutex<std::collections::HashSet<u64>>,
40}
41
42pub struct RenetServerConfig {
44 pub protocol_id: u64,
46 pub max_clients: usize,
48 pub authentication: renet_netcode::ServerAuthentication,
50 pub max_new_connections_per_second: u32,
52 pub max_payload_size: usize,
54 pub max_unreliable_channel_memory_bytes: usize,
56}
57
58impl Default for RenetServerConfig {
59 fn default() -> Self {
60 Self {
61 protocol_id: 0,
62 max_clients: 1000,
63 authentication: renet_netcode::ServerAuthentication::Unsecure,
64 max_new_connections_per_second: 5,
65 max_payload_size: MAX_SAFE_PAYLOAD_SIZE,
66 max_unreliable_channel_memory_bytes: 1024 * 1024,
67 }
68 }
69}
70
71struct IpRateLimiter {
73 limits: std::collections::HashMap<std::net::IpAddr, TokenBucket>,
74 max_rate: f64,
75}
76
77struct TokenBucket {
78 tokens: f64,
79 last_refill: Instant,
80}
81
82impl IpRateLimiter {
83 fn new(max_rate: f64) -> Self {
84 Self {
85 limits: std::collections::HashMap::new(),
86 max_rate,
87 }
88 }
89
90 fn check(&mut self, ip: std::net::IpAddr) -> bool {
91 let now = Instant::now();
92 let bucket = self.limits.entry(ip).or_insert_with(|| TokenBucket {
93 tokens: self.max_rate,
94 last_refill: now,
95 });
96
97 let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
98 bucket.tokens = (bucket.tokens + elapsed * self.max_rate).min(self.max_rate);
99 bucket.last_refill = now;
100
101 if bucket.tokens >= 1.0 {
102 bucket.tokens -= 1.0;
103 true
104 } else {
105 false
106 }
107 }
108
109 fn prune(&mut self) {
111 let now = Instant::now();
112 self.limits.retain(|_ip, bucket| {
114 let full = bucket.tokens >= self.max_rate - 0.1;
115 let idle = now.duration_since(bucket.last_refill) > Duration::from_secs(10 * 60);
116 !(full && idle)
117 });
118 }
119}
120
121pub const CHANNEL_UNRELIABLE: u8 = 0;
123pub const CHANNEL_RELIABLE: u8 = 1;
125
126impl RenetTransport {
127 pub fn new_server(
134 addr: SocketAddr,
135 config: Option<RenetServerConfig>,
136 ) -> Result<Self, TransportError> {
137 let config = config.unwrap_or_default();
138 let connection_config = ConnectionConfig {
139 server_channels_config: vec![
140 ChannelConfig {
141 channel_id: CHANNEL_UNRELIABLE,
142 max_memory_usage_bytes: config.max_unreliable_channel_memory_bytes,
143 send_type: SendType::Unreliable,
144 },
145 ChannelConfig {
146 channel_id: CHANNEL_RELIABLE,
147 max_memory_usage_bytes: 1024 * 1024,
148 send_type: SendType::ReliableOrdered {
149 resend_time: Duration::from_millis(300),
150 },
151 },
152 ],
153 ..Default::default()
154 };
155
156 let server = RenetServer::new(connection_config);
157
158 let server_config = ServerConfig {
159 current_time: Duration::ZERO,
160 max_clients: config.max_clients,
161 protocol_id: config.protocol_id,
162 public_addresses: vec![addr],
163 authentication: config.authentication,
164 };
165
166 let raw = Socket::new(
172 if addr.is_ipv6() {
173 Domain::IPV6
174 } else {
175 Domain::IPV4
176 },
177 Type::DGRAM,
178 None,
179 )
180 .map_err(TransportError::Io)?;
181 raw.set_reuse_address(true).map_err(TransportError::Io)?;
182 raw.set_recv_buffer_size(8 * 1024 * 1024)
183 .map_err(TransportError::Io)?;
184 raw.set_send_buffer_size(8 * 1024 * 1024)
185 .map_err(TransportError::Io)?;
186 raw.set_nonblocking(true).map_err(TransportError::Io)?;
187 raw.bind(&addr.into()).map_err(TransportError::Io)?;
188 let socket: std::net::UdpSocket = raw.into();
189 let local_addr = socket.local_addr().map_err(TransportError::Io)?;
190
191 let transport = NetcodeServerTransport::new(server_config, socket)
192 .map_err(|e| TransportError::Io(std::io::Error::other(e)))?;
193
194 Ok(Self {
195 server: Mutex::new(server),
196 transport: Mutex::new(transport),
197 last_update: Mutex::new(Instant::now()),
198 local_addr,
199 rate_limiter: Mutex::new(IpRateLimiter::new(f64::from(
200 config.max_new_connections_per_second,
201 ))),
202 max_payload_size: config.max_payload_size,
203 last_prune: Mutex::new(Instant::now()),
204 suppressed_disconnects: Mutex::new(std::collections::HashSet::new()),
205 })
206 }
207
208 #[must_use]
210 pub fn addr(&self) -> SocketAddr {
211 self.local_addr
212 }
213}
214
215#[async_trait]
216impl PlatformTransport for RenetTransport {
217 #[tracing::instrument(skip(self, data), fields(client_id = %client_id.0, size = data.len()))]
218 async fn send_unreliable(
219 &self,
220 client_id: ClientId,
221 data: &[u8],
222 ) -> Result<(), TransportError> {
223 if data.len() > MAX_SAFE_PAYLOAD_SIZE {
224 metrics::counter!("aetheris_transport_errors_total", "transport" => "renet", "type" => "payload_too_large").increment(1);
225 return Err(TransportError::PayloadTooLarge {
226 size: data.len(),
227 max: MAX_SAFE_PAYLOAD_SIZE,
228 });
229 }
230
231 let mut server = self
232 .server
233 .lock()
234 .map_err(|e| TransportError::Io(std::io::Error::other(e.to_string())))?;
235
236 if !server.is_connected(client_id.0) {
237 metrics::counter!("aetheris_transport_errors_total", "transport" => "renet", "type" => "client_not_connected").increment(1);
238 return Err(TransportError::ClientNotConnected(client_id));
239 }
240
241 server.send_message(client_id.0, CHANNEL_UNRELIABLE, data.to_vec());
242 metrics::counter!("aetheris_transport_packets_total", "transport" => "renet", "direction" => "outbound", "channel" => "unreliable").increment(1);
243 metrics::counter!("aetheris_transport_bytes_total", "transport" => "renet", "direction" => "outbound", "channel" => "unreliable").increment(data.len() as u64);
244 Ok(())
245 }
246
247 #[tracing::instrument(skip(self, data), fields(client_id = %client_id.0, size = data.len()))]
248 async fn send_reliable(&self, client_id: ClientId, data: &[u8]) -> Result<(), TransportError> {
249 if data.len() > MAX_SAFE_PAYLOAD_SIZE {
250 metrics::counter!("aetheris_transport_errors_total", "transport" => "renet", "type" => "payload_too_large").increment(1);
251 return Err(TransportError::PayloadTooLarge {
252 size: data.len(),
253 max: MAX_SAFE_PAYLOAD_SIZE,
254 });
255 }
256
257 let mut server = self
258 .server
259 .lock()
260 .map_err(|e| TransportError::Io(std::io::Error::other(e.to_string())))?;
261
262 if !server.is_connected(client_id.0) {
263 metrics::counter!("aetheris_transport_errors_total", "transport" => "renet", "type" => "client_not_connected").increment(1);
264 return Err(TransportError::ClientNotConnected(client_id));
265 }
266
267 server.send_message(client_id.0, CHANNEL_RELIABLE, data.to_vec());
268 metrics::counter!("aetheris_transport_packets_total", "transport" => "renet", "direction" => "outbound", "channel" => "reliable").increment(1);
269 metrics::counter!("aetheris_transport_bytes_total", "transport" => "renet", "direction" => "outbound", "channel" => "reliable").increment(data.len() as u64);
270 Ok(())
271 }
272
273 #[tracing::instrument(skip(self, data), fields(size = data.len()))]
274 async fn broadcast_unreliable(&self, data: &[u8]) -> Result<(), TransportError> {
275 if data.len() > MAX_SAFE_PAYLOAD_SIZE {
276 metrics::counter!("aetheris_transport_errors_total", "transport" => "renet", "type" => "payload_too_large").increment(1);
277 return Err(TransportError::PayloadTooLarge {
278 size: data.len(),
279 max: MAX_SAFE_PAYLOAD_SIZE,
280 });
281 }
282
283 let mut server = self
284 .server
285 .lock()
286 .map_err(|e| TransportError::Io(std::io::Error::other(e.to_string())))?;
287 server.broadcast_message(CHANNEL_UNRELIABLE, data.to_vec());
288 metrics::counter!("aetheris_transport_packets_total", "transport" => "renet", "direction" => "outbound", "channel" => "broadcast_unreliable").increment(1);
289 metrics::counter!("aetheris_transport_bytes_total", "transport" => "renet", "direction" => "outbound", "channel" => "broadcast_unreliable").increment(data.len() as u64);
290 Ok(())
291 }
292
293 #[tracing::instrument(skip(self))]
294 async fn poll_events(&mut self) -> Result<Vec<NetworkEvent>, TransportError> {
295 let now = Instant::now();
296
297 let duration = {
298 let mut last_update = self
299 .last_update
300 .lock()
301 .map_err(|e| TransportError::Io(std::io::Error::other(e.to_string())))?;
302 let d = now.duration_since(*last_update);
303 *last_update = now;
304 d
305 };
306
307 {
308 let mut last_prune = self
309 .last_prune
310 .lock()
311 .map_err(|e| TransportError::Io(std::io::Error::other(e.to_string())))?;
312 if now.duration_since(*last_prune) > Duration::from_secs(60) {
313 let mut rate_limiter = self
314 .rate_limiter
315 .lock()
316 .map_err(|e| TransportError::Io(std::io::Error::other(e.to_string())))?;
317 rate_limiter.prune();
318 *last_prune = now;
319 }
320 }
321
322 let mut events = Vec::new();
323 let mut server = self
324 .server
325 .lock()
326 .map_err(|e| TransportError::Io(std::io::Error::other(e.to_string())))?;
327 let mut transport = self
328 .transport
329 .lock()
330 .map_err(|e| TransportError::Io(std::io::Error::other(e.to_string())))?;
331
332 if let Err(e) = transport.update(duration, &mut server) {
333 tracing::error!(error = ?e, "Netcode transport update failure");
334 }
335 server.update(duration);
336 transport.send_packets(&mut server);
337
338 while let Some(event) = server.get_event() {
339 match event {
340 ServerEvent::ClientConnected { client_id } => {
341 let addr = transport.client_addr(client_id);
342 let allowed = if let Some(addr) = addr {
343 let mut rate_limiter = self.rate_limiter.lock().map_err(|e| {
344 TransportError::Io(std::io::Error::other(e.to_string()))
345 })?;
346 rate_limiter.check(addr.ip())
347 } else {
348 true
349 };
350
351 if allowed {
352 events.push(NetworkEvent::ClientConnected(ClientId(client_id)));
353 } else {
354 tracing::warn!(
355 client_id,
356 ?addr,
357 "Connection rate limit exceeded, disconnecting"
358 );
359 let mut suppressed = self.suppressed_disconnects.lock().map_err(|e| {
362 TransportError::Io(std::io::Error::other(e.to_string()))
363 })?;
364 suppressed.insert(client_id);
365 server.disconnect(client_id);
366 }
367 }
368 ServerEvent::ClientDisconnected { client_id, reason } => {
369 let mut suppressed = self
370 .suppressed_disconnects
371 .lock()
372 .map_err(|e| TransportError::Io(std::io::Error::other(e.to_string())))?;
373
374 if suppressed.remove(&client_id) {
375 tracing::debug!(client_id, "Suppressed rate-limited disconnect event");
376 } else {
377 tracing::debug!(client_id, ?reason, "Client disconnected");
378 events.push(NetworkEvent::ClientDisconnected(ClientId(client_id)));
379 }
380 }
381 }
382 }
383
384 drop(transport);
386
387 let max_payload = self.max_payload_size;
388 let client_ids: Vec<u64> = server.clients_id();
389 for client_id in &client_ids {
390 while let Some(message) = server.receive_message(*client_id, CHANNEL_UNRELIABLE) {
391 if message.len() > max_payload {
392 tracing::warn!(
393 client_id,
394 size = message.len(),
395 limit = max_payload,
396 "Discarding oversized unreliable message"
397 );
398 metrics::counter!("aetheris_transport_errors_total", "transport" => "renet", "type" => "oversized_unreliable_msg").increment(1);
399 continue;
400 }
401 events.push(NetworkEvent::UnreliableMessage {
402 client_id: ClientId(*client_id),
403 data: message.to_vec(),
404 });
405 metrics::counter!("aetheris_transport_packets_total", "transport" => "renet", "direction" => "inbound", "channel" => "unreliable").increment(1);
406 metrics::counter!("aetheris_transport_bytes_total", "transport" => "renet", "direction" => "inbound", "channel" => "unreliable").increment(message.len() as u64);
407 }
408 while let Some(message) = server.receive_message(*client_id, CHANNEL_RELIABLE) {
409 if message.len() > max_payload {
410 tracing::warn!(
411 client_id,
412 size = message.len(),
413 limit = max_payload,
414 "Discarding oversized reliable message"
415 );
416 metrics::counter!("aetheris_transport_errors_total", "transport" => "renet", "type" => "oversized_reliable_msg").increment(1);
417 continue;
418 }
419 events.push(NetworkEvent::ReliableMessage {
420 client_id: ClientId(*client_id),
421 data: message.to_vec(),
422 });
423 metrics::counter!("aetheris_transport_packets_total", "transport" => "renet", "direction" => "inbound", "channel" => "reliable").increment(1);
424 metrics::counter!("aetheris_transport_bytes_total", "transport" => "renet", "direction" => "inbound", "channel" => "reliable").increment(message.len() as u64);
425 }
426 }
427
428 let mut total_loss = 0.0;
430 let mut connected_count = 0;
431 for client_id in &client_ids {
432 if let Ok(info) = server.network_info(*client_id) {
433 total_loss += info.packet_loss;
434 connected_count += 1;
435 }
436 }
437 if connected_count > 0 {
438 metrics::gauge!("aetheris_datagram_drop_rate")
439 .set(total_loss / f64::from(connected_count));
440 } else {
441 metrics::gauge!("aetheris_datagram_drop_rate").set(0.0);
447 }
448
449 Ok(events)
450 }
451
452 async fn connected_client_count(&self) -> usize {
453 let Ok(server) = self.server.lock() else {
454 return 0; };
456 server.connected_clients()
457 }
458
459 async fn disconnect(&self, client_id: ClientId) -> Result<(), TransportError> {
460 let mut server = self
461 .server
462 .lock()
463 .map_err(|e| TransportError::Io(std::io::Error::other(e.to_string())))?;
464 server.disconnect(client_id.0);
465 Ok(())
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472 use renet::RenetClient;
473 use renet_netcode::NetcodeClientTransport;
474
475 #[tokio::test]
476 #[allow(clippy::too_many_lines)]
477 async fn test_renet_loopback_connectivity() {
478 let addr = "127.0.0.1:0".parse().unwrap();
479 let mut server_transport = RenetTransport::new_server(addr, None).unwrap();
480 let server_addr = server_transport.addr();
481
482 let connection_config = ConnectionConfig::default();
484 let mut client = RenetClient::new(connection_config);
485
486 let client_id = 42;
487 let auth = renet_netcode::ClientAuthentication::Unsecure {
488 protocol_id: 0,
489 client_id,
490 server_addr,
491 user_data: None,
492 };
493
494 let socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
495 let mut client_transport =
496 NetcodeClientTransport::new(Duration::ZERO, auth, socket).unwrap();
497
498 let mut connected = false;
500 let duration = Duration::from_millis(10);
501 for _ in 0..100 {
502 let _ = client_transport.update(duration, &mut client);
503 client.update(duration);
504 client_transport.send_packets(&mut client).unwrap();
505
506 let events = server_transport.poll_events().await.unwrap();
507 for event in events {
508 if let NetworkEvent::ClientConnected(id) = event
509 && id.0 == client_id
510 {
511 connected = true;
512 }
513 }
514
515 if connected {
516 break;
517 }
518 tokio::time::sleep(duration).await;
519 }
520
521 assert!(connected, "Client failed to connect to server");
522
523 let msg = b"hello aetheris";
525 client.send_message(CHANNEL_UNRELIABLE, msg.to_vec());
526
527 let mut received = false;
529 for _ in 0..100 {
530 let _ = client_transport.update(duration, &mut client);
531 client.update(duration);
532 client_transport.send_packets(&mut client).unwrap();
533
534 let events = server_transport.poll_events().await.unwrap();
535 for event in events {
536 if let NetworkEvent::UnreliableMessage {
537 client_id: id,
538 data,
539 } = event
540 && id.0 == client_id
541 && data == msg
542 {
543 received = true;
544 }
545 }
546 if received {
547 break;
548 }
549 tokio::time::sleep(duration).await;
550 }
551
552 assert!(received, "Server failed to receive message from client");
553
554 let server_msg = b"welcome to aetheris";
556 server_transport
557 .send_reliable(ClientId(client_id), server_msg)
558 .await
559 .unwrap();
560
561 let mut client_received = false;
563 for _ in 0..100 {
564 let _ = client_transport.update(duration, &mut client);
565 client.update(duration);
566 client_transport.send_packets(&mut client).unwrap();
567
568 while let Some(data) = client.receive_message(CHANNEL_RELIABLE) {
569 if &data[..] == server_msg {
570 client_received = true;
571 }
572 }
573
574 server_transport.poll_events().await.unwrap(); if client_received {
577 break;
578 }
579 tokio::time::sleep(duration).await;
580 }
581
582 assert!(
583 client_received,
584 "Client failed to receive message from server"
585 );
586
587 let broadcast_msg = b"broadcast message";
589 server_transport
590 .broadcast_unreliable(broadcast_msg)
591 .await
592 .unwrap();
593
594 let mut broadcast_received = false;
595 for _ in 0..100 {
596 let _ = client_transport.update(duration, &mut client);
597 client.update(duration);
598 client_transport.send_packets(&mut client).unwrap();
599
600 while let Some(data) = client.receive_message(CHANNEL_UNRELIABLE) {
601 if &data[..] == broadcast_msg {
602 broadcast_received = true;
603 }
604 }
605 server_transport.poll_events().await.unwrap();
606 if broadcast_received {
607 break;
608 }
609 tokio::time::sleep(duration).await;
610 }
611 assert!(
612 broadcast_received,
613 "Client failed to receive broadcast message from server"
614 );
615
616 assert_eq!(server_transport.connected_client_count().await, 1);
618
619 client_transport.disconnect();
621 for _ in 0..10 {
622 let _ = client_transport.update(duration, &mut client);
623 client.update(duration);
624 let _ = client_transport.send_packets(&mut client);
625 tokio::time::sleep(duration).await;
626 }
627
628 let mut disconnected = false;
630 for _ in 0..100 {
631 let events = server_transport.poll_events().await.unwrap();
632 for event in events {
633 if let NetworkEvent::ClientDisconnected(id) = event
634 && id.0 == client_id
635 {
636 disconnected = true;
637 }
638 }
639 if disconnected {
640 break;
641 }
642 tokio::time::sleep(duration).await;
643 }
644 assert!(
645 disconnected,
646 "Server failed to observe client disconnection"
647 );
648 assert_eq!(server_transport.connected_client_count().await, 0);
649 }
650
651 #[tokio::test]
652 async fn test_inbound_payload_size_limit() {
653 let addr = "127.0.0.1:0".parse().unwrap();
654 let mut server_transport = RenetTransport::new_server(
655 addr,
656 Some(RenetServerConfig {
657 max_payload_size: 10,
658 ..Default::default()
659 }),
660 )
661 .unwrap();
662 let server_addr = server_transport.addr();
663
664 let connection_config = ConnectionConfig::default();
666 let mut client = RenetClient::new(connection_config);
667
668 let client_id = 99;
669 let auth = renet_netcode::ClientAuthentication::Unsecure {
670 protocol_id: 0,
671 client_id,
672 server_addr,
673 user_data: None,
674 };
675
676 let socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
677 let mut client_transport =
678 NetcodeClientTransport::new(Duration::ZERO, auth, socket).unwrap();
679
680 let duration = Duration::from_millis(10);
681 for _ in 0..50 {
683 let _ = client_transport.update(duration, &mut client);
684 client.update(duration);
685 let _ = client_transport.send_packets(&mut client);
686 server_transport.poll_events().await.unwrap();
687 tokio::time::sleep(duration).await;
688 }
689
690 let too_large_msg = vec![0u8; 11];
692 client.send_message(CHANNEL_UNRELIABLE, too_large_msg);
693
694 let mut received = false;
696 for _ in 0..50 {
697 let _ = client_transport.update(duration, &mut client);
698 client.update(duration);
699 let _ = client_transport.send_packets(&mut client);
700 let events = server_transport.poll_events().await.unwrap();
701 for event in events {
702 if let NetworkEvent::UnreliableMessage { .. } = event {
703 received = true;
704 }
705 }
706 if received {
707 break;
708 }
709 tokio::time::sleep(duration).await;
710 }
711
712 assert!(
713 !received,
714 "Server should have discarded the oversized message"
715 );
716 }
717
718 #[tokio::test]
719 async fn test_connection_rate_limit() {
720 let addr = "127.0.0.1:0".parse().unwrap();
721 let mut server_transport = RenetTransport::new_server(
722 addr,
723 Some(RenetServerConfig {
724 max_new_connections_per_second: 1,
725 ..Default::default()
726 }),
727 )
728 .unwrap();
729 let server_addr = server_transport.addr();
730
731 let duration = Duration::from_millis(10);
732
733 macro_rules! attempt_connect {
734 ($id:expr) => {{
735 let mut connected = false;
736 let config = ConnectionConfig::default();
737 let mut client = RenetClient::new(config);
738 let auth = renet_netcode::ClientAuthentication::Unsecure {
739 protocol_id: 0,
740 client_id: $id,
741 server_addr,
742 user_data: None,
743 };
744 let socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
745 let mut transport =
746 NetcodeClientTransport::new(Duration::ZERO, auth, socket).unwrap();
747
748 for _ in 0..20 {
749 let _ = transport.update(duration, &mut client);
750 client.update(duration);
751 let _ = transport.send_packets(&mut client);
752 let events = server_transport.poll_events().await.unwrap();
753 for event in events {
754 if let NetworkEvent::ClientConnected(cid) = event
755 && cid.0 == $id
756 {
757 connected = true;
758 }
759 }
760 if connected {
761 break;
762 }
763 tokio::time::sleep(duration).await;
764 }
765 connected
766 }};
767 }
768
769 let connected1 = attempt_connect!(1);
771 assert!(connected1, "First connection should succeed");
772
773 let connected2 = attempt_connect!(2);
775 assert!(!connected2, "Second connection should be rate-limited");
776
777 tokio::time::sleep(Duration::from_millis(1100)).await;
779 let connected3 = attempt_connect!(3);
780 assert!(connected3, "Third connection should succeed after timeout");
781 }
782}