collab-server 0.0.7

Nomad's collab server
Documentation
use std::net::{Ipv4Addr, SocketAddrV4};
use std::sync::Arc;

use async_compat::Compat;
use futures_rustls::TlsAcceptor;
use rustls::ServerConfig;
use tokio::net::TcpListener;
use tokio::sync::mpsc;

use crate::*;

/// TODO: docs
pub struct Server {
    /// TODO: docs
    active_sessions: Sessions,

    /// The socket the server is listening on.
    listening_socket: SocketAddrV4,

    /// TODO: docs
    tls_config: Option<ServerConfig>,

    #[cfg(any(test, feature = "__tests"))]
    receiver: Option<mpsc::Receiver<SessionMessage>>,

    #[cfg(any(test, feature = "__tests"))]
    sender: Option<mpsc::Sender<SessionMessage>>,
}

impl Default for Server {
    #[inline]
    fn default() -> Self {
        Self::builder().build()
    }
}

impl Server {
    /// TODO: docs
    #[inline]
    pub fn builder() -> ServerBuilder {
        ServerBuilder::default()
    }

    /// Handles a message received from a connection.
    #[inline]
    fn handle_message(&mut self, msg: SessionMessage) {
        use SessionMessage::*;

        match msg {
            JoinSession(peer_id, session_id, sender) => {
                let res = self.active_sessions.join(session_id, peer_id);
                let _ = sender.send(res);
            },

            PeerLeft(peer_id, session_id) => {
                self.active_sessions.remove_peer(session_id, peer_id);
            },

            StartSession(peer_id, sender) => {
                let infos = self.active_sessions.start(peer_id);
                let _ = sender.send(infos);
            },

            #[cfg(any(test, feature = "__tests"))]
            Ping(sender) => {
                let _ = sender.send(());
            },

            #[cfg(any(test, feature = "__tests"))]
            Sessions(sender) => {
                let _ = sender.send(self.active_sessions.clone());
            },
        }
    }

    /// TODO: docs
    #[inline]
    pub fn new() -> Self {
        Self::default()
    }

    /// TODO: docs
    #[inline]
    pub async fn run(mut self) -> Result<(), anyhow::Error> {
        let tcp_listener = TcpListener::bind(self.listening_socket).await?;

        let tls_acceptor =
            self.tls_config.take().map(Arc::new).map(TlsAcceptor::from);

        #[cfg(not(any(test, feature = "__tests")))]
        let (sender, mut receiver) = mpsc::channel(1024);

        #[cfg(any(test, feature = "__tests"))]
        let (sender, mut receiver) =
            (self.sender.take().unwrap(), self.receiver.take().unwrap());

        tokio::spawn(async move {
            while let Some(msg) = receiver.recv().await {
                self.handle_message(msg);
            }
        });

        loop {
            let (tcp_stream, client_addr) = tcp_listener.accept().await?;

            let tcp_stream = Compat::new(tcp_stream);

            let sender = sender.clone();

            if let Some(tls_acceptor) = &tls_acceptor {
                let tls_stream = tls_acceptor.accept(tcp_stream).await?;
                let run = Connection::run(tls_stream, client_addr, sender);
                tokio::spawn(run);
            } else {
                let run = Connection::run(tcp_stream, client_addr, sender);
                tokio::spawn(run);
            }
        }
    }

    /// TODO: docs
    #[cfg(any(test, feature = "__tests"))]
    pub fn test_sender(&self) -> TestSender {
        TestSender::new(self.sender.as_ref().unwrap().clone())
    }
}

/// TODO: docs
#[derive(Clone)]
pub struct ServerBuilder {
    /// The IP address to run the server on.
    ip: Ipv4Addr,

    /// The port the server will listen on.
    port: u16,

    /// TODO: docs
    tls_config: Option<ServerConfig>,
}

impl Default for ServerBuilder {
    #[inline]
    fn default() -> Self {
        Self {
            ip: Ipv4Addr::new(0, 0, 0, 0),
            port: common::SERVER_LISTENING_PORT,
            tls_config: None,
        }
    }
}

impl ServerBuilder {
    /// TODO: docs
    #[inline]
    pub fn build(self) -> Server {
        #[cfg(any(test, feature = "__tests"))]
        let (sender, receiver) = mpsc::channel(1024);

        Server {
            active_sessions: Sessions::new(),
            listening_socket: SocketAddrV4::new(self.ip, self.port),
            tls_config: self.tls_config,
            #[cfg(any(test, feature = "__tests"))]
            sender: Some(sender),
            #[cfg(any(test, feature = "__tests"))]
            receiver: Some(receiver),
        }
    }

    /// TODO: docs
    #[inline]
    pub fn with_ip(mut self, ip: Ipv4Addr) -> Self {
        self.ip = ip;
        self
    }

    /// TODO: docs
    #[inline]
    pub fn with_port(mut self, port: u16) -> Self {
        self.port = port;
        self
    }

    /// TODO: docs
    #[inline]
    pub fn with_tls_config(mut self, config: ServerConfig) -> Self {
        self.tls_config = Some(config);
        self
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn server_start_twice() -> Result<(), anyhow::Error> {
        let server = Server::new();

        let sender = server.test_sender();

        tokio::spawn(async move { server.run().await });

        sender.start().await?;

        let server = Server::new();

        assert!(
            server.run().await.is_err(),
            "running the server should fail if the socket is being used"
        );

        Ok(())
    }
}