Skip to main content

raknet_rust/transport/
runtime.rs

1use std::io;
2use std::net::SocketAddr;
3use std::time::Duration;
4
5use bytes::Bytes;
6use tokio::sync::mpsc::error::TrySendError;
7use tokio::sync::{broadcast, mpsc};
8use tokio::task::JoinHandle;
9use tokio::time::{self, MissedTickBehavior};
10
11use crate::error::ConfigValidationError;
12use crate::protocol::reliability::Reliability;
13use crate::session::RakPriority;
14
15use super::config::TransportConfig;
16use super::server::{TransportEvent, TransportMetricsSnapshot, TransportServer};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum EventOverflowPolicy {
20    BlockProducer,
21    ShedNonCritical,
22}
23
24#[derive(Debug, Clone)]
25pub struct ShardedRuntimeConfig {
26    pub shard_count: usize,
27    pub outbound_tick_interval: Duration,
28    pub metrics_emit_interval: Duration,
29    pub event_queue_capacity: usize,
30    pub command_queue_capacity: usize,
31    pub event_overflow_policy: EventOverflowPolicy,
32    pub max_new_datagrams_per_session: usize,
33    pub max_new_bytes_per_session: usize,
34    pub max_resend_datagrams_per_session: usize,
35    pub max_resend_bytes_per_session: usize,
36}
37
38impl Default for ShardedRuntimeConfig {
39    fn default() -> Self {
40        Self {
41            shard_count: std::thread::available_parallelism()
42                .map(|value| value.get())
43                .unwrap_or(1)
44                .max(1),
45            outbound_tick_interval: Duration::from_millis(10),
46            metrics_emit_interval: Duration::from_millis(1000),
47            event_queue_capacity: 4096,
48            command_queue_capacity: 4096,
49            event_overflow_policy: EventOverflowPolicy::ShedNonCritical,
50            max_new_datagrams_per_session: 8,
51            max_new_bytes_per_session: 64 * 1024,
52            max_resend_datagrams_per_session: 8,
53            max_resend_bytes_per_session: 64 * 1024,
54        }
55    }
56}
57
58impl ShardedRuntimeConfig {
59    pub fn validate(&self) -> Result<(), ConfigValidationError> {
60        if self.shard_count == 0 {
61            return Err(ConfigValidationError::new(
62                "ShardedRuntimeConfig",
63                "shard_count",
64                "must be >= 1",
65            ));
66        }
67        if self.outbound_tick_interval.is_zero() {
68            return Err(ConfigValidationError::new(
69                "ShardedRuntimeConfig",
70                "outbound_tick_interval",
71                "must be > 0",
72            ));
73        }
74        if self.metrics_emit_interval.is_zero() {
75            return Err(ConfigValidationError::new(
76                "ShardedRuntimeConfig",
77                "metrics_emit_interval",
78                "must be > 0",
79            ));
80        }
81        if self.event_queue_capacity == 0 {
82            return Err(ConfigValidationError::new(
83                "ShardedRuntimeConfig",
84                "event_queue_capacity",
85                "must be >= 1",
86            ));
87        }
88        if self.command_queue_capacity == 0 {
89            return Err(ConfigValidationError::new(
90                "ShardedRuntimeConfig",
91                "command_queue_capacity",
92                "must be >= 1",
93            ));
94        }
95        if self.max_new_datagrams_per_session == 0 {
96            return Err(ConfigValidationError::new(
97                "ShardedRuntimeConfig",
98                "max_new_datagrams_per_session",
99                "must be >= 1",
100            ));
101        }
102        if self.max_new_bytes_per_session < crate::protocol::constants::MINIMUM_MTU_SIZE as usize {
103            return Err(ConfigValidationError::new(
104                "ShardedRuntimeConfig",
105                "max_new_bytes_per_session",
106                format!(
107                    "must be >= {}, got {}",
108                    crate::protocol::constants::MINIMUM_MTU_SIZE,
109                    self.max_new_bytes_per_session
110                ),
111            ));
112        }
113        if self.max_resend_datagrams_per_session == 0 {
114            return Err(ConfigValidationError::new(
115                "ShardedRuntimeConfig",
116                "max_resend_datagrams_per_session",
117                "must be >= 1",
118            ));
119        }
120        if self.max_resend_bytes_per_session < crate::protocol::constants::MINIMUM_MTU_SIZE as usize
121        {
122            return Err(ConfigValidationError::new(
123                "ShardedRuntimeConfig",
124                "max_resend_bytes_per_session",
125                format!(
126                    "must be >= {}, got {}",
127                    crate::protocol::constants::MINIMUM_MTU_SIZE,
128                    self.max_resend_bytes_per_session
129                ),
130            ));
131        }
132
133        Ok(())
134    }
135}
136
137#[derive(Debug)]
138pub enum ShardedRuntimeEvent {
139    Transport {
140        shard_id: usize,
141        event: TransportEvent,
142    },
143    Metrics {
144        shard_id: usize,
145        snapshot: Box<TransportMetricsSnapshot>,
146        dropped_non_critical_events: u64,
147    },
148    WorkerError {
149        shard_id: usize,
150        message: String,
151    },
152    WorkerStopped {
153        shard_id: usize,
154    },
155}
156
157#[derive(Debug, Clone)]
158pub enum ShardedRuntimeCommand {
159    SendPayload {
160        addr: SocketAddr,
161        payload: Bytes,
162        reliability: Reliability,
163        channel: u8,
164        priority: RakPriority,
165        receipt_id: Option<u64>,
166    },
167    DisconnectPeer {
168        addr: SocketAddr,
169    },
170}
171
172#[derive(Debug, Clone)]
173pub struct ShardedSendPayload {
174    pub addr: SocketAddr,
175    pub payload: Bytes,
176    pub reliability: Reliability,
177    pub channel: u8,
178    pub priority: RakPriority,
179}
180
181pub struct ShardedRuntimeHandle {
182    pub event_rx: mpsc::Receiver<ShardedRuntimeEvent>,
183    shutdown_tx: broadcast::Sender<()>,
184    command_txs: Vec<mpsc::Sender<ShardedRuntimeCommand>>,
185    handles: Vec<JoinHandle<io::Result<()>>>,
186}
187
188impl ShardedRuntimeHandle {
189    pub fn shard_count(&self) -> usize {
190        self.command_txs.len()
191    }
192
193    pub async fn send_payload_to_shard(
194        &self,
195        shard_id: usize,
196        payload: ShardedSendPayload,
197    ) -> io::Result<()> {
198        self.send_command_to_shard(
199            shard_id,
200            ShardedRuntimeCommand::SendPayload {
201                addr: payload.addr,
202                payload: payload.payload,
203                reliability: payload.reliability,
204                channel: payload.channel,
205                priority: payload.priority,
206                receipt_id: None,
207            },
208        )
209        .await
210    }
211
212    pub async fn send_payload_to_shard_with_receipt(
213        &self,
214        shard_id: usize,
215        payload: ShardedSendPayload,
216        receipt_id: u64,
217    ) -> io::Result<()> {
218        self.send_command_to_shard(
219            shard_id,
220            ShardedRuntimeCommand::SendPayload {
221                addr: payload.addr,
222                payload: payload.payload,
223                reliability: payload.reliability,
224                channel: payload.channel,
225                priority: payload.priority,
226                receipt_id: Some(receipt_id),
227            },
228        )
229        .await
230    }
231
232    pub async fn send_payload_any_shard(
233        &self,
234        addr: SocketAddr,
235        payload: Bytes,
236        reliability: Reliability,
237        channel: u8,
238        priority: RakPriority,
239    ) -> io::Result<()> {
240        for tx in &self.command_txs {
241            tx.send(ShardedRuntimeCommand::SendPayload {
242                addr,
243                payload: payload.clone(),
244                reliability,
245                channel,
246                priority,
247                receipt_id: None,
248            })
249            .await
250            .map_err(|_| {
251                io::Error::new(io::ErrorKind::BrokenPipe, "runtime command channel closed")
252            })?;
253        }
254        Ok(())
255    }
256
257    pub async fn disconnect_peer_from_shard(
258        &self,
259        shard_id: usize,
260        addr: SocketAddr,
261    ) -> io::Result<()> {
262        self.send_command_to_shard(shard_id, ShardedRuntimeCommand::DisconnectPeer { addr })
263            .await
264    }
265
266    pub async fn disconnect_peer_any_shard(&self, addr: SocketAddr) -> io::Result<()> {
267        for tx in &self.command_txs {
268            tx.send(ShardedRuntimeCommand::DisconnectPeer { addr })
269                .await
270                .map_err(|_| {
271                    io::Error::new(io::ErrorKind::BrokenPipe, "runtime command channel closed")
272                })?;
273        }
274        Ok(())
275    }
276
277    async fn send_command_to_shard(
278        &self,
279        shard_id: usize,
280        command: ShardedRuntimeCommand,
281    ) -> io::Result<()> {
282        let tx = self.command_txs.get(shard_id).ok_or_else(|| {
283            io::Error::new(
284                io::ErrorKind::InvalidInput,
285                format!("invalid shard_id {shard_id}"),
286            )
287        })?;
288
289        tx.send(command).await.map_err(|_| {
290            io::Error::new(io::ErrorKind::BrokenPipe, "runtime command channel closed")
291        })
292    }
293
294    pub fn request_shutdown(&self) {
295        let _ = self.shutdown_tx.send(());
296    }
297
298    pub async fn shutdown(mut self) -> io::Result<()> {
299        self.request_shutdown();
300
301        while let Some(handle) = self.handles.pop() {
302            match handle.await {
303                Ok(Ok(())) => {}
304                Ok(Err(e)) => return Err(e),
305                Err(join_err) => {
306                    return Err(io::Error::other(format!("worker join error: {join_err}")));
307                }
308            }
309        }
310
311        Ok(())
312    }
313}
314
315pub async fn spawn_sharded_runtime(
316    transport_config: TransportConfig,
317    runtime_config: ShardedRuntimeConfig,
318) -> io::Result<ShardedRuntimeHandle> {
319    transport_config
320        .validate()
321        .map_err(invalid_config_io_error)?;
322    runtime_config.validate().map_err(invalid_config_io_error)?;
323
324    let shard_count = runtime_config.shard_count.max(1);
325    let workers = TransportServer::bind_shards(transport_config, shard_count).await?;
326
327    let (event_tx, event_rx) = mpsc::channel(runtime_config.event_queue_capacity.max(1));
328    let (shutdown_tx, _) = broadcast::channel(1);
329
330    let mut handles = Vec::with_capacity(workers.len());
331    let mut command_txs = Vec::with_capacity(workers.len());
332    for (shard_id, server) in workers.into_iter().enumerate() {
333        let tx = event_tx.clone();
334        let cfg = runtime_config.clone();
335        let shutdown_rx = shutdown_tx.subscribe();
336        let (command_tx, command_rx) = mpsc::channel(runtime_config.command_queue_capacity.max(1));
337        command_txs.push(command_tx);
338
339        handles.push(tokio::spawn(async move {
340            run_worker_loop(shard_id, server, cfg, tx, command_rx, shutdown_rx).await
341        }));
342    }
343
344    Ok(ShardedRuntimeHandle {
345        event_rx,
346        shutdown_tx,
347        command_txs,
348        handles,
349    })
350}
351
352async fn run_worker_loop(
353    shard_id: usize,
354    mut server: TransportServer,
355    cfg: ShardedRuntimeConfig,
356    event_tx: mpsc::Sender<ShardedRuntimeEvent>,
357    mut command_rx: mpsc::Receiver<ShardedRuntimeCommand>,
358    mut shutdown_rx: broadcast::Receiver<()>,
359) -> io::Result<()> {
360    let mut outbound_tick = time::interval(cfg.outbound_tick_interval);
361    outbound_tick.set_missed_tick_behavior(MissedTickBehavior::Skip);
362
363    let mut metrics_tick = time::interval(cfg.metrics_emit_interval);
364    metrics_tick.set_missed_tick_behavior(MissedTickBehavior::Skip);
365
366    let mut dropped_non_critical_events = 0u64;
367    let initial_dropped_snapshot = dropped_non_critical_events;
368    send_non_critical_event(
369        &event_tx,
370        cfg.event_overflow_policy,
371        &mut dropped_non_critical_events,
372        ShardedRuntimeEvent::Metrics {
373            shard_id,
374            snapshot: Box::new(server.metrics_snapshot()),
375            dropped_non_critical_events: initial_dropped_snapshot,
376        },
377    )
378    .await?;
379
380    loop {
381        tokio::select! {
382            biased;
383
384            _ = shutdown_rx.recv() => {
385                send_critical_event(&event_tx, ShardedRuntimeEvent::WorkerStopped { shard_id }).await?;
386                return Ok(());
387            }
388
389            command = command_rx.recv() => {
390                if let Some(command) = command {
391                    apply_command(&mut server, command);
392                }
393            }
394
395            _ = outbound_tick.tick() => {
396                if let Err(e) = server.tick_outbound(
397                    cfg.max_new_datagrams_per_session,
398                    cfg.max_new_bytes_per_session,
399                    cfg.max_resend_datagrams_per_session,
400                    cfg.max_resend_bytes_per_session,
401                ).await {
402                    let _ = send_critical_event(&event_tx, ShardedRuntimeEvent::WorkerError {
403                        shard_id,
404                        message: format!("outbound tick failed: {e}"),
405                    }).await;
406                    return Err(e);
407                }
408            }
409
410            _ = metrics_tick.tick() => {
411                let dropped_snapshot = dropped_non_critical_events;
412                send_non_critical_event(
413                    &event_tx,
414                    cfg.event_overflow_policy,
415                    &mut dropped_non_critical_events,
416                    ShardedRuntimeEvent::Metrics {
417                        shard_id,
418                        snapshot: Box::new(server.metrics_snapshot()),
419                        dropped_non_critical_events: dropped_snapshot,
420                    },
421                ).await?;
422            }
423
424            recv_result = server.recv_and_process() => {
425                match recv_result {
426                    Ok(event) => {
427                        send_non_critical_event(
428                            &event_tx,
429                            cfg.event_overflow_policy,
430                            &mut dropped_non_critical_events,
431                            ShardedRuntimeEvent::Transport {
432                                shard_id,
433                                event,
434                            },
435                        ).await?;
436                    }
437                    Err(e) => {
438                        let _ = send_critical_event(&event_tx, ShardedRuntimeEvent::WorkerError {
439                            shard_id,
440                            message: format!("recv loop failed: {e}"),
441                        }).await;
442                        return Err(e);
443                    }
444                }
445            }
446        }
447    }
448}
449
450fn apply_command(server: &mut TransportServer, command: ShardedRuntimeCommand) {
451    match command {
452        ShardedRuntimeCommand::SendPayload {
453            addr,
454            payload,
455            reliability,
456            channel,
457            priority,
458            receipt_id,
459        } => {
460            let _ = if let Some(receipt_id) = receipt_id {
461                server.queue_payload_with_receipt(
462                    addr,
463                    payload,
464                    reliability,
465                    channel,
466                    priority,
467                    receipt_id,
468                )
469            } else {
470                server.queue_payload(addr, payload, reliability, channel, priority)
471            };
472        }
473        ShardedRuntimeCommand::DisconnectPeer { addr } => {
474            let _ = server.disconnect_peer(addr);
475        }
476    }
477}
478
479async fn send_non_critical_event(
480    event_tx: &mpsc::Sender<ShardedRuntimeEvent>,
481    overflow_policy: EventOverflowPolicy,
482    dropped_non_critical_events: &mut u64,
483    event: ShardedRuntimeEvent,
484) -> io::Result<()> {
485    match overflow_policy {
486        EventOverflowPolicy::BlockProducer => event_tx
487            .send(event)
488            .await
489            .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "runtime event channel closed")),
490        EventOverflowPolicy::ShedNonCritical => match event_tx.try_send(event) {
491            Ok(()) => Ok(()),
492            Err(TrySendError::Full(_)) => {
493                *dropped_non_critical_events = dropped_non_critical_events.saturating_add(1);
494                Ok(())
495            }
496            Err(TrySendError::Closed(_)) => Err(io::Error::new(
497                io::ErrorKind::BrokenPipe,
498                "runtime event channel closed",
499            )),
500        },
501    }
502}
503
504async fn send_critical_event(
505    event_tx: &mpsc::Sender<ShardedRuntimeEvent>,
506    event: ShardedRuntimeEvent,
507) -> io::Result<()> {
508    event_tx
509        .send(event)
510        .await
511        .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "runtime event channel closed"))
512}
513
514fn invalid_config_io_error(error: ConfigValidationError) -> io::Error {
515    io::Error::new(io::ErrorKind::InvalidInput, error.to_string())
516}
517
518#[cfg(test)]
519mod tests {
520    use super::{
521        EventOverflowPolicy, ShardedRuntimeConfig, ShardedRuntimeEvent, send_non_critical_event,
522    };
523    use crate::transport::server::TransportMetricsSnapshot;
524    use std::time::Duration;
525    use tokio::sync::mpsc;
526
527    fn metrics_event(shard_id: usize) -> ShardedRuntimeEvent {
528        ShardedRuntimeEvent::Metrics {
529            shard_id,
530            snapshot: Box::new(TransportMetricsSnapshot::default()),
531            dropped_non_critical_events: 0,
532        }
533    }
534
535    #[tokio::test]
536    async fn shed_policy_drops_non_critical_when_channel_is_full() {
537        let (tx, mut rx) = mpsc::channel(1);
538        tx.send(metrics_event(1))
539            .await
540            .expect("initial send should succeed");
541
542        let mut dropped = 0u64;
543        send_non_critical_event(
544            &tx,
545            EventOverflowPolicy::ShedNonCritical,
546            &mut dropped,
547            metrics_event(2),
548        )
549        .await
550        .expect("shed policy should not fail on full queue");
551
552        assert_eq!(dropped, 1);
553        let first = rx.recv().await.expect("first event should be present");
554        assert!(matches!(
555            first,
556            ShardedRuntimeEvent::Metrics { shard_id: 1, .. }
557        ));
558        assert!(rx.try_recv().is_err(), "second event should be shed");
559    }
560
561    #[tokio::test]
562    async fn block_policy_enqueues_event_normally() {
563        let (tx, mut rx) = mpsc::channel(1);
564        let mut dropped = 0u64;
565        send_non_critical_event(
566            &tx,
567            EventOverflowPolicy::BlockProducer,
568            &mut dropped,
569            metrics_event(5),
570        )
571        .await
572        .expect("block policy send should succeed");
573
574        assert_eq!(dropped, 0);
575        let event = rx.recv().await.expect("event should be available");
576        assert!(matches!(
577            event,
578            ShardedRuntimeEvent::Metrics { shard_id: 5, .. }
579        ));
580    }
581
582    #[test]
583    fn runtime_config_validate_rejects_zero_shards() {
584        let cfg = ShardedRuntimeConfig {
585            shard_count: 0,
586            ..ShardedRuntimeConfig::default()
587        };
588        let err = cfg.validate().expect_err("shard_count=0 must be rejected");
589        assert_eq!(err.config, "ShardedRuntimeConfig");
590        assert_eq!(err.field, "shard_count");
591    }
592
593    #[test]
594    fn runtime_config_validate_rejects_zero_tick_interval() {
595        let cfg = ShardedRuntimeConfig {
596            outbound_tick_interval: Duration::ZERO,
597            ..ShardedRuntimeConfig::default()
598        };
599        let err = cfg
600            .validate()
601            .expect_err("outbound_tick_interval=0 must be rejected");
602        assert_eq!(err.config, "ShardedRuntimeConfig");
603        assert_eq!(err.field, "outbound_tick_interval");
604    }
605}