naia-server 0.25.0

A server that uses either UDP or WebRTC communication to send/receive messages to/from connected clients, and syncs registered Entities/Components to clients to whom they are in-scope.
use std::{net::SocketAddr, panic, time::Duration};

use naia_shared::{CompressionConfig, Decoder, Encoder, OutgoingPacket, OwnedBitReader};

use super::bandwidth_monitor::BandwidthMonitor;
use crate::{
    error::NaiaServerError,
    transport::{PacketReceiver, PacketSender},
};

#[derive(Clone)]
pub struct Io {
    packet_sender: Option<Box<dyn PacketSender>>,
    packet_receiver: Option<Box<dyn PacketReceiver>>,
    outgoing_bandwidth_monitor: Option<BandwidthMonitor>,
    incoming_bandwidth_monitor: Option<BandwidthMonitor>,
    outgoing_encoder: Option<Encoder>,
    incoming_decoder: Option<Decoder>,
    /// Bytes sent during the most recent `send_all_packets` tick.
    /// Reset at the start of each `send_all_packets` via
    /// `reset_outgoing_bytes_this_tick`, incremented in `send_packet`.
    /// Unconditionally tracked — no bandwidth monitor required.
    outgoing_bytes_this_tick: u64,
}

impl Io {
    pub fn new(
        bandwidth_measure_duration: &Option<Duration>,
        compression_config: &Option<CompressionConfig>,
    ) -> Self {
        let outgoing_bandwidth_monitor = bandwidth_measure_duration.map(BandwidthMonitor::new);
        let incoming_bandwidth_monitor = bandwidth_measure_duration.map(BandwidthMonitor::new);

        let outgoing_encoder = compression_config.as_ref().and_then(|config| {
            config
                .server_to_client
                .as_ref()
                .map(|mode| Encoder::new(mode.clone()))
        });
        let incoming_decoder = compression_config.as_ref().and_then(|config| {
            config
                .client_to_server
                .as_ref()
                .map(|mode| Decoder::new(mode.clone()))
        });

        Self {
            packet_sender: None,
            packet_receiver: None,
            outgoing_bandwidth_monitor,
            incoming_bandwidth_monitor,
            outgoing_encoder,
            incoming_decoder,
            outgoing_bytes_this_tick: 0,
        }
    }

    pub fn load(
        &mut self,
        packet_sender: Box<dyn PacketSender>,
        packet_receiver: Box<dyn PacketReceiver>,
    ) {
        if self.packet_sender.is_some() {
            panic!("Packet sender/receiver already loaded! Cannot do this twice!");
        }

        self.packet_sender = Some(packet_sender);
        self.packet_receiver = Some(packet_receiver);
    }

    pub fn is_loaded(&self) -> bool {
        self.packet_sender.is_some()
    }

    pub fn sender_cloned(&self) -> Box<dyn PacketSender> {
        if self.packet_sender.is_none() {
            panic!("Cannot call Server.sender_cloned() until you call Server.listen()!");
        }

        self.packet_sender.as_ref().unwrap().clone()
    }

    pub fn send_packet(
        &mut self,
        address: &SocketAddr,
        packet: OutgoingPacket,
    ) -> Result<(), NaiaServerError> {
        // get payload
        let mut payload = packet.slice();

        // Compression
        if let Some(encoder) = &mut self.outgoing_encoder {
            payload = encoder.encode(payload);
        }

        // Bandwidth monitoring
        if let Some(monitor) = &mut self.outgoing_bandwidth_monitor {
            monitor.record_packet(address, payload.len());
        }

        // Per-tick byte counter (always tracked; cheap)
        self.outgoing_bytes_this_tick =
            self.outgoing_bytes_this_tick.saturating_add(payload.len() as u64);

        self.packet_sender
            .as_ref()
            .expect("Cannot call Server.send_packet() until you call Server.listen()!")
            .send(address, payload)
            .map_err(|_| NaiaServerError::SendError(*address))
    }

    pub fn recv_reader(&mut self) -> Result<Option<(SocketAddr, OwnedBitReader)>, NaiaServerError> {
        let receive_result = self
            .packet_receiver
            .as_mut()
            .expect("Cannot call Server.receive_packet() until you call Server.listen()!")
            .receive();

        match receive_result {
            Ok(Some((address, mut payload))) => {
                // Bandwidth monitoring
                if let Some(monitor) = &mut self.incoming_bandwidth_monitor {
                    monitor.record_packet(&address, payload.len());
                }

                // Decompression
                if let Some(decoder) = &mut self.incoming_decoder {
                    payload = decoder.decode(payload);
                }

                Ok(Some((address, OwnedBitReader::new(payload))))
            }
            Ok(None) => Ok(None),
            Err(_) => Err(NaiaServerError::RecvError),
        }
    }

    pub fn bandwidth_monitor_enabled(&self) -> bool {
        self.outgoing_bandwidth_monitor.is_some() && self.incoming_bandwidth_monitor.is_some()
    }

    pub fn register_client(&mut self, address: &SocketAddr) {
        self.outgoing_bandwidth_monitor
            .as_mut()
            .expect("Need to call `enable_bandwidth_monitor()` on Io before calling this")
            .create_client(address);
        self.incoming_bandwidth_monitor
            .as_mut()
            .expect("Need to call `enable_bandwidth_monitor()` on Io before calling this")
            .create_client(address);
    }

    pub fn deregister_client(&mut self, address: &SocketAddr) {
        self.outgoing_bandwidth_monitor
            .as_mut()
            .expect("Need to call `enable_bandwidth_monitor()` on Io before calling this")
            .delete_client(address);
        self.incoming_bandwidth_monitor
            .as_mut()
            .expect("Need to call `enable_bandwidth_monitor()` on Io before calling this")
            .delete_client(address);
    }

    /// Tick bandwidth monitors to clear expired packets.
    /// Call this during the update phase of the tick cycle.
    pub fn tick_bandwidth_monitors(&mut self) {
        if let Some(monitor) = &mut self.outgoing_bandwidth_monitor {
            monitor.tick();
        }
        if let Some(monitor) = &mut self.incoming_bandwidth_monitor {
            monitor.tick();
        }
    }

    /// Zero out the per-tick byte counter. Called by `send_all_packets` at
    /// the start of each server tick so that `outgoing_bytes_last_tick`
    /// reflects only that tick's work.
    pub fn reset_outgoing_bytes_this_tick(&mut self) {
        self.outgoing_bytes_this_tick = 0;
    }

    /// Total bytes sent (post-compression) during the most recent
    /// `send_all_packets` call. Call AFTER the tick completes.
    pub fn outgoing_bytes_last_tick(&self) -> u64 {
        self.outgoing_bytes_this_tick
    }

    pub fn outgoing_bandwidth_total(&self) -> f32 {
        self.outgoing_bandwidth_monitor
            .as_ref()
            .expect("Need to call `enable_bandwidth_monitor()` on Io before calling this")
            .total_bandwidth()
    }

    pub fn incoming_bandwidth_total(&self) -> f32 {
        self.incoming_bandwidth_monitor
            .as_ref()
            .expect("Need to call `enable_bandwidth_monitor()` on Io before calling this")
            .total_bandwidth()
    }

    pub fn outgoing_bandwidth_to_client(&self, address: &SocketAddr) -> f32 {
        self.outgoing_bandwidth_monitor
            .as_ref()
            .expect("Need to call `enable_bandwidth_monitor()` on Io before calling this")
            .client_bandwidth(address)
    }

    pub fn incoming_bandwidth_from_client(&self, address: &SocketAddr) -> f32 {
        self.incoming_bandwidth_monitor
            .as_ref()
            .expect("Need to call `enable_bandwidth_monitor()` on Io before calling this")
            .client_bandwidth(address)
    }
}