1use std::io;
2use std::net::SocketAddr;
3use std::time::Duration;
4
5use bytes::Bytes;
6use tokio::sync::mpsc::error::TrySendError;
7use tokio::sync::{broadcast, mpsc};
8use tokio::task::JoinHandle;
9use tokio::time::{self, MissedTickBehavior};
10
11use crate::error::ConfigValidationError;
12use crate::protocol::reliability::Reliability;
13use crate::session::RakPriority;
14
15use super::config::TransportConfig;
16use super::server::{TransportEvent, TransportMetricsSnapshot, TransportServer};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum EventOverflowPolicy {
20 BlockProducer,
21 ShedNonCritical,
22}
23
24#[derive(Debug, Clone)]
25pub struct ShardedRuntimeConfig {
26 pub shard_count: usize,
27 pub outbound_tick_interval: Duration,
28 pub metrics_emit_interval: Duration,
29 pub event_queue_capacity: usize,
30 pub command_queue_capacity: usize,
31 pub event_overflow_policy: EventOverflowPolicy,
32 pub max_new_datagrams_per_session: usize,
33 pub max_new_bytes_per_session: usize,
34 pub max_resend_datagrams_per_session: usize,
35 pub max_resend_bytes_per_session: usize,
36}
37
38impl Default for ShardedRuntimeConfig {
39 fn default() -> Self {
40 Self {
41 shard_count: std::thread::available_parallelism()
42 .map(|value| value.get())
43 .unwrap_or(1)
44 .max(1),
45 outbound_tick_interval: Duration::from_millis(10),
46 metrics_emit_interval: Duration::from_millis(1000),
47 event_queue_capacity: 4096,
48 command_queue_capacity: 4096,
49 event_overflow_policy: EventOverflowPolicy::ShedNonCritical,
50 max_new_datagrams_per_session: 8,
51 max_new_bytes_per_session: 64 * 1024,
52 max_resend_datagrams_per_session: 8,
53 max_resend_bytes_per_session: 64 * 1024,
54 }
55 }
56}
57
58impl ShardedRuntimeConfig {
59 pub fn validate(&self) -> Result<(), ConfigValidationError> {
60 if self.shard_count == 0 {
61 return Err(ConfigValidationError::new(
62 "ShardedRuntimeConfig",
63 "shard_count",
64 "must be >= 1",
65 ));
66 }
67 if self.outbound_tick_interval.is_zero() {
68 return Err(ConfigValidationError::new(
69 "ShardedRuntimeConfig",
70 "outbound_tick_interval",
71 "must be > 0",
72 ));
73 }
74 if self.metrics_emit_interval.is_zero() {
75 return Err(ConfigValidationError::new(
76 "ShardedRuntimeConfig",
77 "metrics_emit_interval",
78 "must be > 0",
79 ));
80 }
81 if self.event_queue_capacity == 0 {
82 return Err(ConfigValidationError::new(
83 "ShardedRuntimeConfig",
84 "event_queue_capacity",
85 "must be >= 1",
86 ));
87 }
88 if self.command_queue_capacity == 0 {
89 return Err(ConfigValidationError::new(
90 "ShardedRuntimeConfig",
91 "command_queue_capacity",
92 "must be >= 1",
93 ));
94 }
95 if self.max_new_datagrams_per_session == 0 {
96 return Err(ConfigValidationError::new(
97 "ShardedRuntimeConfig",
98 "max_new_datagrams_per_session",
99 "must be >= 1",
100 ));
101 }
102 if self.max_new_bytes_per_session < crate::protocol::constants::MINIMUM_MTU_SIZE as usize {
103 return Err(ConfigValidationError::new(
104 "ShardedRuntimeConfig",
105 "max_new_bytes_per_session",
106 format!(
107 "must be >= {}, got {}",
108 crate::protocol::constants::MINIMUM_MTU_SIZE,
109 self.max_new_bytes_per_session
110 ),
111 ));
112 }
113 if self.max_resend_datagrams_per_session == 0 {
114 return Err(ConfigValidationError::new(
115 "ShardedRuntimeConfig",
116 "max_resend_datagrams_per_session",
117 "must be >= 1",
118 ));
119 }
120 if self.max_resend_bytes_per_session < crate::protocol::constants::MINIMUM_MTU_SIZE as usize
121 {
122 return Err(ConfigValidationError::new(
123 "ShardedRuntimeConfig",
124 "max_resend_bytes_per_session",
125 format!(
126 "must be >= {}, got {}",
127 crate::protocol::constants::MINIMUM_MTU_SIZE,
128 self.max_resend_bytes_per_session
129 ),
130 ));
131 }
132
133 Ok(())
134 }
135}
136
137#[derive(Debug)]
138pub enum ShardedRuntimeEvent {
139 Transport {
140 shard_id: usize,
141 event: TransportEvent,
142 },
143 Metrics {
144 shard_id: usize,
145 snapshot: Box<TransportMetricsSnapshot>,
146 dropped_non_critical_events: u64,
147 },
148 WorkerError {
149 shard_id: usize,
150 message: String,
151 },
152 WorkerStopped {
153 shard_id: usize,
154 },
155}
156
157#[derive(Debug, Clone)]
158pub enum ShardedRuntimeCommand {
159 SendPayload {
160 addr: SocketAddr,
161 payload: Bytes,
162 reliability: Reliability,
163 channel: u8,
164 priority: RakPriority,
165 receipt_id: Option<u64>,
166 },
167 DisconnectPeer {
168 addr: SocketAddr,
169 },
170}
171
172#[derive(Debug, Clone)]
173pub struct ShardedSendPayload {
174 pub addr: SocketAddr,
175 pub payload: Bytes,
176 pub reliability: Reliability,
177 pub channel: u8,
178 pub priority: RakPriority,
179}
180
181pub struct ShardedRuntimeHandle {
182 pub event_rx: mpsc::Receiver<ShardedRuntimeEvent>,
183 shutdown_tx: broadcast::Sender<()>,
184 command_txs: Vec<mpsc::Sender<ShardedRuntimeCommand>>,
185 handles: Vec<JoinHandle<io::Result<()>>>,
186}
187
188impl ShardedRuntimeHandle {
189 pub fn shard_count(&self) -> usize {
190 self.command_txs.len()
191 }
192
193 pub async fn send_payload_to_shard(
194 &self,
195 shard_id: usize,
196 payload: ShardedSendPayload,
197 ) -> io::Result<()> {
198 self.send_command_to_shard(
199 shard_id,
200 ShardedRuntimeCommand::SendPayload {
201 addr: payload.addr,
202 payload: payload.payload,
203 reliability: payload.reliability,
204 channel: payload.channel,
205 priority: payload.priority,
206 receipt_id: None,
207 },
208 )
209 .await
210 }
211
212 pub async fn send_payload_to_shard_with_receipt(
213 &self,
214 shard_id: usize,
215 payload: ShardedSendPayload,
216 receipt_id: u64,
217 ) -> io::Result<()> {
218 self.send_command_to_shard(
219 shard_id,
220 ShardedRuntimeCommand::SendPayload {
221 addr: payload.addr,
222 payload: payload.payload,
223 reliability: payload.reliability,
224 channel: payload.channel,
225 priority: payload.priority,
226 receipt_id: Some(receipt_id),
227 },
228 )
229 .await
230 }
231
232 pub async fn send_payload_any_shard(
233 &self,
234 addr: SocketAddr,
235 payload: Bytes,
236 reliability: Reliability,
237 channel: u8,
238 priority: RakPriority,
239 ) -> io::Result<()> {
240 for tx in &self.command_txs {
241 tx.send(ShardedRuntimeCommand::SendPayload {
242 addr,
243 payload: payload.clone(),
244 reliability,
245 channel,
246 priority,
247 receipt_id: None,
248 })
249 .await
250 .map_err(|_| {
251 io::Error::new(io::ErrorKind::BrokenPipe, "runtime command channel closed")
252 })?;
253 }
254 Ok(())
255 }
256
257 pub async fn disconnect_peer_from_shard(
258 &self,
259 shard_id: usize,
260 addr: SocketAddr,
261 ) -> io::Result<()> {
262 self.send_command_to_shard(shard_id, ShardedRuntimeCommand::DisconnectPeer { addr })
263 .await
264 }
265
266 pub async fn disconnect_peer_any_shard(&self, addr: SocketAddr) -> io::Result<()> {
267 for tx in &self.command_txs {
268 tx.send(ShardedRuntimeCommand::DisconnectPeer { addr })
269 .await
270 .map_err(|_| {
271 io::Error::new(io::ErrorKind::BrokenPipe, "runtime command channel closed")
272 })?;
273 }
274 Ok(())
275 }
276
277 async fn send_command_to_shard(
278 &self,
279 shard_id: usize,
280 command: ShardedRuntimeCommand,
281 ) -> io::Result<()> {
282 let tx = self.command_txs.get(shard_id).ok_or_else(|| {
283 io::Error::new(
284 io::ErrorKind::InvalidInput,
285 format!("invalid shard_id {shard_id}"),
286 )
287 })?;
288
289 tx.send(command).await.map_err(|_| {
290 io::Error::new(io::ErrorKind::BrokenPipe, "runtime command channel closed")
291 })
292 }
293
294 pub fn request_shutdown(&self) {
295 let _ = self.shutdown_tx.send(());
296 }
297
298 pub async fn shutdown(mut self) -> io::Result<()> {
299 self.request_shutdown();
300
301 while let Some(handle) = self.handles.pop() {
302 match handle.await {
303 Ok(Ok(())) => {}
304 Ok(Err(e)) => return Err(e),
305 Err(join_err) => {
306 return Err(io::Error::other(format!("worker join error: {join_err}")));
307 }
308 }
309 }
310
311 Ok(())
312 }
313}
314
315pub async fn spawn_sharded_runtime(
316 transport_config: TransportConfig,
317 runtime_config: ShardedRuntimeConfig,
318) -> io::Result<ShardedRuntimeHandle> {
319 transport_config
320 .validate()
321 .map_err(invalid_config_io_error)?;
322 runtime_config.validate().map_err(invalid_config_io_error)?;
323
324 let shard_count = runtime_config.shard_count.max(1);
325 let workers = TransportServer::bind_shards(transport_config, shard_count).await?;
326
327 let (event_tx, event_rx) = mpsc::channel(runtime_config.event_queue_capacity.max(1));
328 let (shutdown_tx, _) = broadcast::channel(1);
329
330 let mut handles = Vec::with_capacity(workers.len());
331 let mut command_txs = Vec::with_capacity(workers.len());
332 for (shard_id, server) in workers.into_iter().enumerate() {
333 let tx = event_tx.clone();
334 let cfg = runtime_config.clone();
335 let shutdown_rx = shutdown_tx.subscribe();
336 let (command_tx, command_rx) = mpsc::channel(runtime_config.command_queue_capacity.max(1));
337 command_txs.push(command_tx);
338
339 handles.push(tokio::spawn(async move {
340 run_worker_loop(shard_id, server, cfg, tx, command_rx, shutdown_rx).await
341 }));
342 }
343
344 Ok(ShardedRuntimeHandle {
345 event_rx,
346 shutdown_tx,
347 command_txs,
348 handles,
349 })
350}
351
352async fn run_worker_loop(
353 shard_id: usize,
354 mut server: TransportServer,
355 cfg: ShardedRuntimeConfig,
356 event_tx: mpsc::Sender<ShardedRuntimeEvent>,
357 mut command_rx: mpsc::Receiver<ShardedRuntimeCommand>,
358 mut shutdown_rx: broadcast::Receiver<()>,
359) -> io::Result<()> {
360 let mut outbound_tick = time::interval(cfg.outbound_tick_interval);
361 outbound_tick.set_missed_tick_behavior(MissedTickBehavior::Skip);
362
363 let mut metrics_tick = time::interval(cfg.metrics_emit_interval);
364 metrics_tick.set_missed_tick_behavior(MissedTickBehavior::Skip);
365
366 let mut dropped_non_critical_events = 0u64;
367 let initial_dropped_snapshot = dropped_non_critical_events;
368 send_non_critical_event(
369 &event_tx,
370 cfg.event_overflow_policy,
371 &mut dropped_non_critical_events,
372 ShardedRuntimeEvent::Metrics {
373 shard_id,
374 snapshot: Box::new(server.metrics_snapshot()),
375 dropped_non_critical_events: initial_dropped_snapshot,
376 },
377 )
378 .await?;
379
380 loop {
381 tokio::select! {
382 biased;
383
384 _ = shutdown_rx.recv() => {
385 send_critical_event(&event_tx, ShardedRuntimeEvent::WorkerStopped { shard_id }).await?;
386 return Ok(());
387 }
388
389 command = command_rx.recv() => {
390 if let Some(command) = command {
391 apply_command(&mut server, command);
392 }
393 }
394
395 _ = outbound_tick.tick() => {
396 if let Err(e) = server.tick_outbound(
397 cfg.max_new_datagrams_per_session,
398 cfg.max_new_bytes_per_session,
399 cfg.max_resend_datagrams_per_session,
400 cfg.max_resend_bytes_per_session,
401 ).await {
402 let _ = send_critical_event(&event_tx, ShardedRuntimeEvent::WorkerError {
403 shard_id,
404 message: format!("outbound tick failed: {e}"),
405 }).await;
406 return Err(e);
407 }
408 }
409
410 _ = metrics_tick.tick() => {
411 let dropped_snapshot = dropped_non_critical_events;
412 send_non_critical_event(
413 &event_tx,
414 cfg.event_overflow_policy,
415 &mut dropped_non_critical_events,
416 ShardedRuntimeEvent::Metrics {
417 shard_id,
418 snapshot: Box::new(server.metrics_snapshot()),
419 dropped_non_critical_events: dropped_snapshot,
420 },
421 ).await?;
422 }
423
424 recv_result = server.recv_and_process() => {
425 match recv_result {
426 Ok(event) => {
427 send_non_critical_event(
428 &event_tx,
429 cfg.event_overflow_policy,
430 &mut dropped_non_critical_events,
431 ShardedRuntimeEvent::Transport {
432 shard_id,
433 event,
434 },
435 ).await?;
436 }
437 Err(e) => {
438 let _ = send_critical_event(&event_tx, ShardedRuntimeEvent::WorkerError {
439 shard_id,
440 message: format!("recv loop failed: {e}"),
441 }).await;
442 return Err(e);
443 }
444 }
445 }
446 }
447 }
448}
449
450fn apply_command(server: &mut TransportServer, command: ShardedRuntimeCommand) {
451 match command {
452 ShardedRuntimeCommand::SendPayload {
453 addr,
454 payload,
455 reliability,
456 channel,
457 priority,
458 receipt_id,
459 } => {
460 let _ = if let Some(receipt_id) = receipt_id {
461 server.queue_payload_with_receipt(
462 addr,
463 payload,
464 reliability,
465 channel,
466 priority,
467 receipt_id,
468 )
469 } else {
470 server.queue_payload(addr, payload, reliability, channel, priority)
471 };
472 }
473 ShardedRuntimeCommand::DisconnectPeer { addr } => {
474 let _ = server.disconnect_peer(addr);
475 }
476 }
477}
478
479async fn send_non_critical_event(
480 event_tx: &mpsc::Sender<ShardedRuntimeEvent>,
481 overflow_policy: EventOverflowPolicy,
482 dropped_non_critical_events: &mut u64,
483 event: ShardedRuntimeEvent,
484) -> io::Result<()> {
485 match overflow_policy {
486 EventOverflowPolicy::BlockProducer => event_tx
487 .send(event)
488 .await
489 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "runtime event channel closed")),
490 EventOverflowPolicy::ShedNonCritical => match event_tx.try_send(event) {
491 Ok(()) => Ok(()),
492 Err(TrySendError::Full(_)) => {
493 *dropped_non_critical_events = dropped_non_critical_events.saturating_add(1);
494 Ok(())
495 }
496 Err(TrySendError::Closed(_)) => Err(io::Error::new(
497 io::ErrorKind::BrokenPipe,
498 "runtime event channel closed",
499 )),
500 },
501 }
502}
503
504async fn send_critical_event(
505 event_tx: &mpsc::Sender<ShardedRuntimeEvent>,
506 event: ShardedRuntimeEvent,
507) -> io::Result<()> {
508 event_tx
509 .send(event)
510 .await
511 .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "runtime event channel closed"))
512}
513
514fn invalid_config_io_error(error: ConfigValidationError) -> io::Error {
515 io::Error::new(io::ErrorKind::InvalidInput, error.to_string())
516}
517
518#[cfg(test)]
519mod tests {
520 use super::{
521 EventOverflowPolicy, ShardedRuntimeConfig, ShardedRuntimeEvent, send_non_critical_event,
522 };
523 use crate::transport::server::TransportMetricsSnapshot;
524 use std::time::Duration;
525 use tokio::sync::mpsc;
526
527 fn metrics_event(shard_id: usize) -> ShardedRuntimeEvent {
528 ShardedRuntimeEvent::Metrics {
529 shard_id,
530 snapshot: Box::new(TransportMetricsSnapshot::default()),
531 dropped_non_critical_events: 0,
532 }
533 }
534
535 #[tokio::test]
536 async fn shed_policy_drops_non_critical_when_channel_is_full() {
537 let (tx, mut rx) = mpsc::channel(1);
538 tx.send(metrics_event(1))
539 .await
540 .expect("initial send should succeed");
541
542 let mut dropped = 0u64;
543 send_non_critical_event(
544 &tx,
545 EventOverflowPolicy::ShedNonCritical,
546 &mut dropped,
547 metrics_event(2),
548 )
549 .await
550 .expect("shed policy should not fail on full queue");
551
552 assert_eq!(dropped, 1);
553 let first = rx.recv().await.expect("first event should be present");
554 assert!(matches!(
555 first,
556 ShardedRuntimeEvent::Metrics { shard_id: 1, .. }
557 ));
558 assert!(rx.try_recv().is_err(), "second event should be shed");
559 }
560
561 #[tokio::test]
562 async fn block_policy_enqueues_event_normally() {
563 let (tx, mut rx) = mpsc::channel(1);
564 let mut dropped = 0u64;
565 send_non_critical_event(
566 &tx,
567 EventOverflowPolicy::BlockProducer,
568 &mut dropped,
569 metrics_event(5),
570 )
571 .await
572 .expect("block policy send should succeed");
573
574 assert_eq!(dropped, 0);
575 let event = rx.recv().await.expect("event should be available");
576 assert!(matches!(
577 event,
578 ShardedRuntimeEvent::Metrics { shard_id: 5, .. }
579 ));
580 }
581
582 #[test]
583 fn runtime_config_validate_rejects_zero_shards() {
584 let cfg = ShardedRuntimeConfig {
585 shard_count: 0,
586 ..ShardedRuntimeConfig::default()
587 };
588 let err = cfg.validate().expect_err("shard_count=0 must be rejected");
589 assert_eq!(err.config, "ShardedRuntimeConfig");
590 assert_eq!(err.field, "shard_count");
591 }
592
593 #[test]
594 fn runtime_config_validate_rejects_zero_tick_interval() {
595 let cfg = ShardedRuntimeConfig {
596 outbound_tick_interval: Duration::ZERO,
597 ..ShardedRuntimeConfig::default()
598 };
599 let err = cfg
600 .validate()
601 .expect_err("outbound_tick_interval=0 must be rejected");
602 assert_eq!(err.config, "ShardedRuntimeConfig");
603 assert_eq!(err.field, "outbound_tick_interval");
604 }
605}