harlequinn 0.1.1

A real-time networking library primarily aimed at games.
Documentation
use std::net::SocketAddr;

use {
    bytes::{Bytes, BytesMut},
    crossbeam_channel::Sender,
    futures::StreamExt,
    quinn::{
        Connecting, Connection, ConnectionDriver, Datagrams, IncomingUniStreams, NewConnection,
        RecvStream, SendStream,
    },
    tokio::sync::{
        mpsc::{channel, Receiver as TokioReceiver, Sender as TokioSender},
        oneshot::channel as oneshot_channel,
    },
};

use crate::{
    worker::{
        handshake::{verify_handshake_client, verify_handshake_server},
        WorkerEvent,
    },
    MessageOrder, PeerId,
};

#[derive(Debug)]
pub enum ConnectionCommand {
    SendDatagram { bytes: Bytes },
    SendMessage { bytes: Bytes, order: MessageOrder },
}

/// If we're connecting as a client to a server, the peer id is previously set
pub async fn handle_connecting(
    connecting: Connecting,
    sender: Sender<WorkerEvent>,
    server_peer_id: Option<PeerId>,
    protocol_checksum: u32,
) {
    let new_connection = match connecting.await {
        Ok(value) => value,
        Err(_error) => {
            // We only really care if a connection we initiated failed
            if let Some(peer_id) = server_peer_id {
                sender
                    .send(WorkerEvent::ConnectionFailed { peer_id })
                    .unwrap();
            }
            return;
        }
    };

    let NewConnection {
        driver,
        connection,
        datagrams,
        uni_streams,
        mut bi_streams,
        ..
    } = { new_connection };

    // The capacity of this channel decides how many messages can back up before the peer gets
    // dropped
    let (connection_sender, connection_receiver) = channel(256);

    // Notify the synchronous thread of the new connection
    let peer_id = initialize_connection_data(
        &sender,
        connection.clone(),
        connection_sender,
        server_peer_id,
    )
    .await;

    // Start the driver so we can actually send and receive
    tokio::spawn(drive_connection(driver, sender.clone(), peer_id));

    // Complete a handshake to make sure we're connecting to the right QUIC endpoint
    let result = if server_peer_id.is_none() {
        verify_handshake_server(&connection, protocol_checksum).await
    } else {
        verify_handshake_client(&mut bi_streams, protocol_checksum).await
    };

    let (send, recv) = match result {
        Some((send, recv)) => (send, recv),
        None => {
            connection.close(0u32.into(), b"Handshake Failed");
            return;
        }
    };

    // Verify with the caller that the connection is valid
    let confirmed = verify_connection(&sender, connection.remote_address(), peer_id).await;

    if !confirmed {
        // The connection has been denied
        connection.close(0u32.into(), b"Denied by Application");
        return;
    };

    // Spawn listening workers
    tokio::spawn(listen_datagrams(datagrams, sender.clone(), peer_id));
    tokio::spawn(listen_ordered_stream(recv, sender.clone(), peer_id));
    tokio::spawn(listen_oneshot_streams(uni_streams, sender, peer_id));

    let _ = connection_worker(connection, send, connection_receiver).await;
}

async fn initialize_connection_data(
    sender: &Sender<WorkerEvent>,
    connection: Connection,
    connection_sender: TokioSender<ConnectionCommand>,
    server_peer_id: Option<PeerId>,
) -> PeerId {
    let (oneshot_sender, oneshot_receiver) = oneshot_channel();

    sender
        .send(WorkerEvent::ConnectionStarted {
            connection,
            connection_sender,
            server_peer_id,
            peer_id_sender: oneshot_sender,
        })
        .unwrap();

    oneshot_receiver.await.unwrap()
}

async fn verify_connection(
    sender: &Sender<WorkerEvent>,
    socket_addr: SocketAddr,
    peer_id: PeerId,
) -> bool {
    let (oneshot_sender, oneshot_receiver) = oneshot_channel();

    // Notify the caller that we want our current connection confirmed or denied
    sender
        .send(WorkerEvent::ConnectionRequested {
            peer_id,
            socket_addr,
            confirm_sender: oneshot_sender,
        })
        .unwrap();

    oneshot_receiver.await.unwrap_or(false)
}

async fn connection_worker(
    connection: Connection,
    mut ordered_send: SendStream,
    mut receiver: TokioReceiver<ConnectionCommand>,
) -> Result<(), Box<dyn std::error::Error>> {
    while let Some(command) = receiver.recv().await {
        match command {
            ConnectionCommand::SendDatagram { bytes } => {
                connection.send_datagram(bytes).await?;
            }
            ConnectionCommand::SendMessage { bytes, order } => {
                let length = bytes.len() as u16;
                let length_bytes = length.to_be_bytes();

                match order {
                    MessageOrder::Unordered => {
                        // Send it through a disposable stream
                        let mut oneshot_send = connection.open_uni().await?;
                        oneshot_send.write_all(&length_bytes).await?;
                        oneshot_send.write_all(&bytes).await?;
                    }
                    MessageOrder::Ordered => {
                        // Write the next message length and then the message itself
                        ordered_send.write_all(&length_bytes).await?;
                        ordered_send.write_all(&bytes).await?;
                    }
                }
            }
        }
    }

    Ok(())
}

async fn drive_connection(driver: ConnectionDriver, sender: Sender<WorkerEvent>, peer_id: PeerId) {
    let reason = match driver.await {
        Ok(()) => None,
        Err(error) => Some(error.to_string()),
    };

    sender
        .send(WorkerEvent::Disconnected { peer_id, reason })
        .unwrap();
}

async fn listen_datagrams(mut datagrams: Datagrams, sender: Sender<WorkerEvent>, peer_id: PeerId) {
    while let Some(result) = datagrams.next().await {
        match result {
            Ok(bytes) => {
                sender
                    .send(WorkerEvent::ReceivedDatagram { peer_id, bytes })
                    .unwrap();
            }
            Err(_) => {
                // Connection errors are handled on the driver await
                return;
            }
        }
    }
}

async fn listen_ordered_stream(mut recv: RecvStream, sender: Sender<WorkerEvent>, peer_id: PeerId) {
    while let Some(bytes) = read_stream_message(&mut recv).await {
        // Send the message to the synchronous thread
        sender
            .send(WorkerEvent::ReceivedMessage { peer_id, bytes })
            .unwrap();
    }
}

async fn listen_oneshot_streams(
    mut uni_streams: IncomingUniStreams,
    sender: Sender<WorkerEvent>,
    peer_id: PeerId,
) {
    while let Some(Ok(mut recv)) = uni_streams.next().await {
        let bytes = match read_stream_message(&mut recv).await {
            Some(bytes) => bytes,
            None => return,
        };

        // Send the message to the synchronous thread
        sender
            .send(WorkerEvent::ReceivedMessage { peer_id, bytes })
            .unwrap();
    }
}

async fn read_stream_message(recv: &mut RecvStream) -> Option<Bytes> {
    // Read the stream message length
    let mut length_bytes = [0u8; 2];
    recv.read_exact(&mut length_bytes).await.ok()?;

    let length = u16::from_be_bytes(length_bytes);

    // Read the message
    let mut bytes = BytesMut::new();
    bytes.resize(length as usize, 0);
    recv.read_exact(&mut bytes).await.ok()?;

    Some(bytes.freeze())
}