zero-trust-rps 0.1.1

Online Multiplayer Rock Paper Scissors
Documentation
use futures::future::{join, join_all};
use quinn::{Endpoint, SendStream, VarInt};
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::task::{self, JoinHandle};

use crate::cli_utils::configure_logging;
use crate::common::connection::quic::QuicReader;
use crate::common::constants::{LOCALHOST4, LOCALHOST6};
use crate::common::ip_display::IpDisplay;
use crate::common::message::UserState;
use crate::common::result::DynResult;

use super::bot::server_run_bots;
use super::options::ServerOptions;
use super::process::process;
use super::signals::{run_and_configure_signals, SignalHandler};
use super::sockets::create_server_config;
use super::types::{Connections, Rooms};
use super::utils::send_room_update;

pub async fn main(module: &'static str, options: ServerOptions) -> DynResult<()> {
    configure_logging(module)?;

    let streams = Arc::new(Mutex::new(HashMap::new()));
    let rooms = Arc::new(Mutex::new(HashMap::new()));

    let mut local_addrs: Vec<SocketAddr> = vec![];
    let mut endpoints: Vec<Arc<Endpoint>> = vec![];

    let futures = super::sockets::listen(&options)
        .await?
        .into_iter()
        .map(|endpoint| {
            log::info!(
                "Accepting connections on {} ({})",
                endpoint
                    .local_addr()
                    .map(|addr| format!("{addr}"))
                    .unwrap_or_else(|err| format!("{err:?}")),
                options.domain
            );
            if let Ok(addr) = endpoint.local_addr() {
                local_addrs.push(addr);
            }
            let endpoint = Arc::new(endpoint);
            endpoints.push(endpoint.clone());
            handle_endpoint(endpoint, streams.clone(), rooms.clone())
        })
        .collect::<Vec<_>>();

    let shutdown = run_and_configure_signals(SignalHandlerImpl {
        endpoints,
        options: options.clone(),
    })?;

    let results = if let Some(addr) = local_addrs.into_iter().next() {
        let ip = match addr.ip() {
            IpAddr::V4(_) => LOCALHOST4,
            IpAddr::V6(_) => LOCALHOST6,
        };
        join(server_run_bots(ip, &options), join_all(futures))
            .await
            .1
    } else {
        join_all(futures).await
    };

    for result in results {
        match result {
            Ok(()) => (),
            Err(err) => log::error!("Got error: {err:?}"),
        }
    }

    shutdown.await?;

    Ok(())
}

struct SignalHandlerImpl {
    endpoints: Vec<Arc<Endpoint>>,
    options: ServerOptions,
}

impl SignalHandler for SignalHandlerImpl {
    // async fn shutdown(&self) {
    //     log::info!("Shutting down!");
    //     self.shutdown_event.store(true, atomic::Ordering::Relaxed);
    // }

    async fn reload_config(&self) {
        match create_server_config(&self.options) {
            Ok(server_config) => {
                for endpoint in &self.endpoints {
                    endpoint.set_server_config(Some(server_config.clone()));
                    log::debug!(
                        "Updated server config for endpoint: {:?}",
                        endpoint.local_addr()
                    );
                }
            }
            Err(err) => {
                log::error!("Error while trying to create server config: {err:?}")
            }
        }
    }
}

async fn handle_endpoint(
    endpoint: Arc<Endpoint>,
    streams: Connections<IpDisplay, SendStream>,
    rooms: Rooms<IpDisplay>,
) -> DynResult<()> {
    while let Some(stream) = endpoint.accept().await {
        let jh: JoinHandle<()> = match stream.accept() {
            Ok(stream) => {
                let peer: IpDisplay = stream.remote_address().into();
                match stream.await {
                    Ok(connection) => match connection.open_bi().await {
                        Ok((writer, reader)) => {
                            let streams = streams.clone();
                            let rooms = rooms.clone();
                            task::spawn(async move {
                                {
                                    let mut guard = streams.lock().await;
                                    (*guard).insert(peer, writer);
                                };
                                match process(
                                    &peer,
                                    QuicReader::from(reader),
                                    streams.clone(),
                                    rooms.clone(),
                                )
                                .await
                                {
                                    Ok(()) => (),
                                    Err(err) => {
                                        connection
                                            .close(VarInt::from_u32(0), format!("{err}").as_ref());
                                        log::warn!("Disconnected `{peer}`: {err}")
                                    }
                                };
                                // CLEAN UP
                                let mut updated_rooms = vec![];
                                {
                                    let mut guard = rooms.lock().await;
                                    let mut empty_keys = Vec::new();
                                    for (key, value) in guard.iter_mut() {
                                        let mut room = value.lock().await;
                                        let removed = room.users.remove(&peer);
                                        if room.users.is_empty() {
                                            if Arc::strong_count(value) > 1 {
                                                log::error!("({key}: {value:?}) value has to high strong count");
                                            }
                                            empty_keys.push(*key);
                                        } else if let Some(own) = removed.as_ref() {
                                            if let Some(round) = room.round.as_ref() {
                                                if round.users.contains(&own.id) {
                                                    // TODO: improve! Current behaviour is weird
                                                    // user of round left -> round broken -> set state of all to InRoom
                                                    room.round = None;
                                                    for (_, user) in room.users.iter_mut() {
                                                        user.state = UserState::InRoom;
                                                    }
                                                }
                                            }
                                            updated_rooms.push(*key);
                                        }
                                    }
                                    for key in empty_keys {
                                        guard.remove(&key);
                                    }
                                    log::debug!("ROOMS: {:?}", *guard);
                                };
                                {
                                    let mut guard = streams.lock().await;
                                    (*guard).remove(&peer);
                                    log::debug!("STREAMS: {:?}", guard.keys());
                                };
                                for room_id in updated_rooms {
                                    if let Some(room) = rooms.lock().await.get(&room_id) {
                                        if let Err(err) =
                                            send_room_update(room_id, room, &streams).await
                                        {
                                            log::warn!("Failed to send room update: {err:?}");
                                        }
                                    }
                                }
                            })
                        }
                        Err(err) => {
                            log::trace!("Failed to accept bi: {err:?}");
                            continue;
                        }
                    },
                    Err(err) => {
                        log::trace!("Failed to await connecting: {err:?}");
                        continue;
                    }
                }
            }
            Err(err) => {
                log::trace!("Failed to get stream: {err}");
                continue;
            }
        };
        drop(jh);
    }
    Ok(())
}