zero-trust-rps 0.0.5

Online Multiplayer Rock Paper Scissors
Documentation
use futures::future::{join, join_all};
use notify::Watcher as _;
use quinn::{Endpoint, SendStream, VarInt};
use std::collections::HashMap;
use std::fs::canonicalize;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::task::{self, JoinHandle};

use crate::cli_utils::configure_logging;
use crate::common::connection::quic::QuicReader;
use crate::common::constants::{LOCALHOST4, LOCALHOST6};
use crate::common::ip_display::IpDisplay;
use crate::common::message::UserState;
use crate::common::result::DynResult;

use super::bot::server_run_bots;
use super::options::ServerOptions;
use super::process::process;
use super::sockets::create_server_config;
use super::types::{Connections, Rooms};
use super::utils::send_room_update;

pub async fn main(module: &'static str, options: &ServerOptions) -> DynResult<()> {
    configure_logging(module)?;

    let streams = Arc::new(Mutex::new(HashMap::new()));
    let rooms = Arc::new(Mutex::new(HashMap::new()));

    let mut local_addrs: Vec<SocketAddr> = vec![];
    let mut endpoints: Vec<Arc<Endpoint>> = vec![];

    let futures = super::sockets::listen(options)
        .await?
        .into_iter()
        .map(|endpoint| {
            log::info!(
                "Accepting connections on {} ({})",
                endpoint
                    .local_addr()
                    .map(|addr| format!("{addr}"))
                    .unwrap_or_else(|err| format!("{err:?}")),
                options.domain
            );
            if let Ok(addr) = endpoint.local_addr() {
                local_addrs.push(addr);
            }
            let endpoint = Arc::new(endpoint);
            endpoints.push(endpoint.clone());
            handle_endpoint(endpoint, streams.clone(), rooms.clone())
        })
        .collect::<Vec<_>>();

    let _watcher = if let Some(path) = options.private_pem.as_ref() {
        endpoints.shrink_to_fit();
        let path = canonicalize(path)?;
        let path_clone = path.clone();
        let options = options.clone();
        // TODO: PollWatcher is bad, but it works even when removing the file
        let mut watcher = notify::PollWatcher::new(
            move |event: Result<notify::Event, _>| match event {
                Ok(event) => {
                    if !event.paths.contains(&path_clone) {
                        return;
                    }
                    if !(event.kind.is_modify() || event.kind.is_create()) {
                        return;
                    }
                    log::debug!("Got modify notify events: {event:?}");
                    match create_server_config(&options) {
                        Ok(server_config) => {
                            for endpoint in &endpoints {
                                endpoint.set_server_config(Some(server_config.clone()));
                                log::debug!(
                                    "Updated server config for endpoint: {:?}",
                                    endpoint.local_addr()
                                );
                            }
                        }
                        Err(err) => {
                            log::error!("Error while trying to create server config: {err:?}")
                        }
                    }
                }
                Err(err) => log::error!("notify error: {err:?}"),
            },
            notify::Config::default()
                .with_compare_contents(false)
                .with_poll_interval(Duration::from_secs(100)), // certs do not need to change often
        )?;
        watcher.watch(&path, notify::RecursiveMode::NonRecursive)?;

        log::debug!("watching {path:?}");

        Some(watcher) // keep reference alive!
    } else {
        drop(endpoints);
        None
    };

    let results = if let Some(addr) = local_addrs.into_iter().next() {
        let ip = match addr.ip() {
            IpAddr::V4(_) => LOCALHOST4,
            IpAddr::V6(_) => LOCALHOST6,
        };
        join(server_run_bots(ip, options), join_all(futures))
            .await
            .1
    } else {
        join_all(futures).await
    };

    for result in results {
        match result {
            Ok(()) => (),
            Err(err) => log::error!("Got error: {err:?}"),
        }
    }

    Ok(())
}

async fn handle_endpoint(
    endpoint: Arc<Endpoint>,
    streams: Connections<IpDisplay, SendStream>,
    rooms: Rooms<IpDisplay>,
) -> DynResult<()> {
    while let Some(stream) = endpoint.accept().await {
        let jh: JoinHandle<()> = match stream.accept() {
            Ok(stream) => {
                let peer: IpDisplay = stream.remote_address().into();
                match stream.await {
                    Ok(connection) => match connection.open_bi().await {
                        Ok((writer, reader)) => {
                            let streams = streams.clone();
                            let rooms = rooms.clone();
                            task::spawn(async move {
                                {
                                    let mut guard = streams.lock().await;
                                    (*guard).insert(peer, writer);
                                };
                                match process(
                                    &peer,
                                    QuicReader::from(reader),
                                    streams.clone(),
                                    rooms.clone(),
                                )
                                .await
                                {
                                    Ok(()) => (),
                                    Err(err) => {
                                        connection
                                            .close(VarInt::from_u32(0), format!("{err}").as_ref());
                                        log::warn!("Disconnected `{peer}`: {err}")
                                    }
                                };
                                // CLEAN UP
                                let mut updated_rooms = vec![];
                                {
                                    let mut guard = rooms.lock().await;
                                    let mut empty_keys = Vec::new();
                                    for (key, value) in guard.iter_mut() {
                                        let mut room = value.lock().await;
                                        let removed = room.users.remove(&peer);
                                        if room.users.is_empty() {
                                            if Arc::strong_count(value) > 1 {
                                                log::error!("({key}: {value:?}) value has to high strong count");
                                            }
                                            empty_keys.push(*key);
                                        } else if let Some(own) = removed.as_ref() {
                                            if let Some(round) = room.round.as_ref() {
                                                if round.users.contains(&own.id) {
                                                    // TODO: improve! Current behaviour is weird
                                                    // user of round left -> round broken -> set state of all to InRoom
                                                    room.round = None;
                                                    for (_, user) in room.users.iter_mut() {
                                                        user.state = UserState::InRoom;
                                                    }
                                                }
                                            }
                                            updated_rooms.push(*key);
                                        }
                                    }
                                    for key in empty_keys {
                                        guard.remove(&key);
                                    }
                                    log::debug!("ROOMS: {:?}", *guard);
                                };
                                {
                                    let mut guard = streams.lock().await;
                                    (*guard).remove(&peer);
                                    log::debug!("STREAMS: {:?}", guard.keys());
                                };
                                for room_id in updated_rooms {
                                    if let Some(room) = rooms.lock().await.get(&room_id) {
                                        if let Err(err) =
                                            send_room_update(room_id, room, &streams).await
                                        {
                                            log::warn!("Failed to send room update: {err:?}");
                                        }
                                    }
                                }
                            })
                        }
                        Err(err) => {
                            log::trace!("Failed to accept bi: {err:?}");
                            continue;
                        }
                    },
                    Err(err) => {
                        log::trace!("Failed to await connecting: {err:?}");
                        continue;
                    }
                }
            }
            Err(err) => {
                log::trace!("Failed to get stream: {err}");
                continue;
            }
        };
        drop(jh);
    }
    Ok(())
}