harlequinn 0.1.1

A real-time networking library primarily aimed at games.
Documentation
use std::net::SocketAddr;

use {
    bytes::Bytes,
    crc32fast::Hasher,
    crossbeam_channel::{bounded, Receiver},
    quinn::{
        Certificate, CertificateChain, Connection, Endpoint, EndpointBuilder, PrivateKey,
        ServerConfigBuilder,
    },
    slotmap::{SecondaryMap, SlotMap, SparseSecondaryMap},
    tokio::{
        runtime::Runtime,
        sync::{
            mpsc::{channel, Sender as TokioSender},
            oneshot::Sender as OneshotSender,
        },
    },
};

use crate::{
    worker::{driver_worker, network_worker, ConnectionCommand, WorkerCommand, WorkerEvent},
    MessageOrder, PeerId,
};

/// A harlequinn client or server endpoint.
pub struct HqEndpoint {
    runtime: Runtime,

    sender: TokioSender<WorkerCommand>,
    receiver: Receiver<WorkerEvent>,

    peer_ids: SlotMap<PeerId, ()>,
    peers: SecondaryMap<PeerId, ActivePeer>,
    pending_confirms: SparseSecondaryMap<PeerId, OneshotSender<bool>>,
}

struct ActivePeer {
    connection: Connection,
    sender: TokioSender<ConnectionCommand>,
    is_client: bool,
}

impl HqEndpoint {
    /// Creates a new client endpoint.
    ///
    /// Doesn't listen for incoming connections, use `connect` to start a connection to a server.
    pub fn new_client(protocol: &str) -> Self {
        let socket_addr: SocketAddr = "0.0.0.0:0".parse().unwrap();

        // No default client config is provided, certificates are given on a per-connection basis

        // Create the client endpoint
        let endpoint_builder = Endpoint::builder();

        Self::new_priv(endpoint_builder, socket_addr, protocol, false)
    }

    /// Creates a new server endpoint.
    ///
    /// Listens for incoming connections.
    pub fn new_server(
        protocol: &str,
        socket_addr: SocketAddr,
        certificate: Certificate,
        private_key: PrivateKey,
    ) -> Self {
        // Create the server config
        let mut config_builder = ServerConfigBuilder::default();
        config_builder
            .certificate(CertificateChain::from_certs(vec![certificate]), private_key)
            .unwrap();

        let server_config = config_builder.build();

        // Create the server endpoint
        let mut endpoint_builder = Endpoint::builder();
        endpoint_builder.listen(server_config);

        Self::new_priv(endpoint_builder, socket_addr, protocol, true)
    }

    fn new_priv(
        endpoint_builder: EndpointBuilder,
        socket_addr: SocketAddr,
        protocol: &str,
        is_listening: bool,
    ) -> Self {
        // Calculate the checksum for the protocol
        let mut hasher = Hasher::new();
        hasher.update(protocol.as_bytes());
        let protocol_checksum = hasher.finalize();

        // This is only used for initiating connections, doesn't need many slots
        let (sender, worker_receiver) = channel(64);

        // This handles incoming messages and events, needs a lot of slots
        let (worker_sender, receiver) = bounded(512);

        // Spawn the network worker on a tokio runtime
        let mut runtime = Runtime::new().unwrap();
        runtime.block_on(async move {
            let (driver, endpoint, incoming) = endpoint_builder.bind(&socket_addr).unwrap();

            tokio::spawn(driver_worker(driver, worker_sender.clone()));
            tokio::spawn(network_worker(
                endpoint,
                incoming,
                worker_sender,
                worker_receiver,
                protocol_checksum,
                is_listening,
            ));
        });

        Self {
            runtime,

            sender,
            receiver,

            peer_ids: SlotMap::with_key(),
            peers: SecondaryMap::new(),
            pending_confirms: SparseSecondaryMap::new(),
        }
    }

    /// Initiate a connection to a server endpoint.
    ///
    /// Raises `ConnectionFailed` on failure, and `ConnectionRequested` on success.
    pub fn connect<S: ToString>(
        &mut self,
        server_addr: SocketAddr,
        server_name: S,
        certificate: Certificate,
    ) -> PeerId {
        // When the caller initiates a connection, they want to keep track of it, that's why we
        // create a PeerId in advance.
        let peer_id = self.peer_ids.insert(());

        let future = self.sender.send(WorkerCommand::Connect {
            server_addr,
            server_name: server_name.to_string(),
            certificate,
            assigned_peer_id: peer_id,
        });
        self.runtime.block_on(future).unwrap();

        peer_id
    }

    /// Disconnects a peer, ending the connection.
    ///
    /// `reason` is not preserved intact, and will be mangled on the other side, but can be used to
    /// provide human-readable reasons.
    pub fn disconnect(&mut self, peer_id: PeerId, reason: String) {
        if let Some(peer) = self.peers.get_mut(peer_id) {
            // We only need to close at the handle here, the worker later notifies us when the
            // connection is disconnected.
            peer.connection.close(0u32.into(), reason.as_bytes());
        }
    }

    /// Accepts a pending connection.
    pub fn accept(&mut self, peer_id: PeerId) {
        if let Some(sender) = self.pending_confirms.remove(peer_id) {
            // Send an acception back
            sender.send(true).unwrap();
        }
    }

    /// Rejects a pending connection.
    pub fn reject(&mut self, peer_id: PeerId) {
        if let Some(sender) = self.pending_confirms.remove(peer_id) {
            // Send a rejection back
            sender.send(false).unwrap();
        }
    }

    /// Send an unreliable datagram to a peer.
    ///
    /// If the peer is taking too long to respond to messages, this may disconnect the peer.
    pub fn send_datagram(&mut self, peer_id: PeerId, bytes: Bytes) {
        self.attempt_connection_command(peer_id, ConnectionCommand::SendDatagram { bytes });
    }

    /// Send a reliable message over the main stream to a peer.
    ///
    /// Messages with `ordered` set to true will arrive in order at the peer.
    ///
    /// If the peer is taking too long to respond to messages, this may disconnect the peer.
    pub fn send_message(&mut self, peer_id: PeerId, bytes: Bytes, order: MessageOrder) {
        self.attempt_connection_command(peer_id, ConnectionCommand::SendMessage { bytes, order });
    }

    fn attempt_connection_command(&mut self, peer_id: PeerId, command: ConnectionCommand) {
        let mut should_disconnect = false;

        if let Some(peer) = self.peers.get_mut(peer_id) {
            let result = peer.sender.try_send(command);

            // If the channel is full, the peer is taking too long to respond and should be dropped
            should_disconnect = result.is_err();
        }

        if should_disconnect {
            self.disconnect(
                peer_id,
                "Peer is too slow, command buffer exceeded".to_string(),
            );
        }
    }

    /// Polls for events and stores them in the given events buffer.
    ///
    /// This *must* be called frequently, or the endpoint will run out of space in the events
    /// channel and stall.
    pub fn poll_events(&mut self, events: &mut Vec<EndpointEvent>) {
        while let Ok(event) = self.receiver.try_recv() {
            match event {
                WorkerEvent::ConnectionFailed { peer_id } => {
                    self.peer_ids.remove(peer_id);

                    events.push(EndpointEvent::ConnectionFailed { peer_id });
                }
                WorkerEvent::ConnectionStarted {
                    connection,
                    connection_sender,
                    server_peer_id,
                    peer_id_sender,
                } => {
                    // If we already got assigned an ID previously, use that
                    let peer_id = if let Some(peer_id) = server_peer_id {
                        peer_id
                    } else {
                        self.peer_ids.insert(())
                    };

                    // Respond to the worker with the peer ID decided on
                    peer_id_sender.send(peer_id).unwrap();

                    self.peers.insert(
                        peer_id,
                        ActivePeer {
                            connection,
                            sender: connection_sender,
                            is_client: server_peer_id.is_none(),
                        },
                    );
                }
                WorkerEvent::ConnectionRequested {
                    peer_id,
                    socket_addr,
                    confirm_sender,
                } => {
                    // Don't do anything it the peer already got removed
                    if let Some(peer) = self.peers.get(peer_id) {
                        // Store this as pending confirmation
                        self.pending_confirms.insert(peer_id, confirm_sender);

                        events.push(EndpointEvent::ConnectionRequested {
                            peer_id,
                            socket_addr,
                            is_client: peer.is_client,
                        });
                    }
                }
                WorkerEvent::Disconnected { peer_id, reason } => {
                    self.peer_ids.remove(peer_id);
                    self.peers.remove(peer_id);

                    // Just in case the game hasn't gotten to the request yet
                    self.pending_confirms.remove(peer_id);

                    events.push(EndpointEvent::Disconnected { peer_id, reason });
                }
                WorkerEvent::ReceivedDatagram { peer_id, bytes } => {
                    events.push(EndpointEvent::ReceivedDatagram { peer_id, bytes });
                }
                WorkerEvent::ReceivedMessage { peer_id, bytes } => {
                    events.push(EndpointEvent::ReceivedMessage { peer_id, bytes });
                }
                WorkerEvent::Stopped => {
                    // TODO: Handle gracefully
                    panic!("Unexpected stopped event");
                }
            }
        }
    }
}

impl Drop for HqEndpoint {
    fn drop(&mut self) {
        // Send a shutdown message
        let future = self.sender.send(WorkerCommand::Stop);
        self.runtime.block_on(future).unwrap();

        // Wait until we receive the stopped event
        while let Ok(event) = self.receiver.recv() {
            if let WorkerEvent::Stopped = event {
                return;
            }
        }
    }
}

/// An event raised on the endpoint.
pub enum EndpointEvent {
    /// A connection initiated by this endpoint has failed.
    ConnectionFailed {
        /// Associated peer id.
        peer_id: PeerId,
    },
    /// A new connection has been requested by a peer.
    ///
    /// You can use this to filter connections, limit player counts, etc.
    /// Either accept or reject a connection using `RivetsEndpoint::accept` and
    /// `RivetsEndpoint::reject`.
    ConnectionRequested {
        /// Associated peer id.
        peer_id: PeerId,
        /// The socket address of the connecting peer.
        socket_addr: SocketAddr,
        /// If true, the endpoint that has been connected to is a client.
        is_client: bool,
    },
    /// An accepted connection got disconnected.
    ///
    /// This will always be raised regardless of if the peer's request was accepted or not, because
    /// this could happen before the game loop gets to processing the request.
    Disconnected {
        /// Associated peer id.
        peer_id: PeerId,
        /// An optional reason for the disconnection, None if the connection was closed by the local
        /// endpoint.
        reason: Option<String>,
    },
    /// A datagram was received from a peer.
    ReceivedDatagram {
        /// Associated peer id.
        peer_id: PeerId,
        /// Datagram data as bytes.
        bytes: Bytes,
    },
    /// A reliable message was received from a peer.
    ReceivedMessage {
        /// Associated peer id.
        peer_id: PeerId,
        /// Message data as bytes.
        bytes: Bytes,
    },
}