zero-trust-rps 0.0.4

Online Multiplayer Rock Paper Scissors
Documentation
use std::{num::NonZeroU64, ops::Div, time::Duration};

use futures::future::join3;
use tokio::{
    sync::mpsc::{channel, unbounded_channel},
    time::{sleep, sleep_until, Instant},
};

use crate::{
    common::{
        client::{
            channel::{AsyncChannelReceiver, AsyncChannelSender, SendError},
            do_moves::{do_move, MoveError},
            simple_move::SimpleUserMove,
            state::{ClientState, ClientStateView},
            update::handle_new_room_state,
        },
        connection::{Reader, WriteMessageError, Writer},
        message::ClientMessage,
    },
    log_result,
};

use super::{
    read::{read_server_messages, ReadMessagesError, SimplifiedServerMessage},
    write::send_messages,
};

const INTERNAL_CHANNEL_BUF_SIZE: usize = 1;

#[derive(thiserror::Error, Debug)]
#[allow(clippy::enum_variant_names)]
pub enum RunClientError {
    #[error("{}", .0)]
    WriteError(#[from] WriteMessageError),
    #[error("{}", .0)]
    ReadError(#[from] ReadMessagesError),
    #[error("{}", .0)]
    ManageStateError(#[from] ManageStateError),
    #[error("{}", .0)]
    TokioError(#[from] tokio::task::JoinError),
}

#[inline]
pub async fn run_client(
    timout: Option<NonZeroU64>,
    writer: impl Writer + 'static,
    reader: impl Reader + 'static,
    repeat: impl AsyncChannelSender<SimplifiedServerMessage> + 'static + Clone,
    states: impl AsyncChannelSender<ClientStateView> + 'static,
    umoves: impl AsyncChannelReceiver<SimpleUserMove> + 'static,
) -> Result<(), RunClientError> {
    let (smsg_send, smsg_recv) = channel::<SimplifiedServerMessage>(INTERNAL_CHANNEL_BUF_SIZE);
    let (cmsg_send, cmsg_recv) = unbounded_channel::<ClientMessage>();

    let (a, b, c) = join3(
        tokio::spawn(log_result!(manage_state(
            timout, smsg_recv, states, repeat, umoves, cmsg_send
        ))),
        tokio::spawn(log_result!(send_messages(writer, cmsg_recv))),
        tokio::spawn(log_result!(read_server_messages(reader, smsg_send))),
    )
    .await;

    let _: () = a??;
    let _: () = b??;
    let _: () = c??;

    Ok(())
}

#[derive(thiserror::Error, Debug)]
#[allow(clippy::enum_variant_names)]
pub enum ManageStateError {
    #[error("Failed to send to channel: {}", .0)]
    SendError(#[from] SendError),
    #[error("{}", .0)]
    MoveError(#[from] MoveError),
    #[error("{}", .0)]
    UpdateError(String),
}

impl From<String> for ManageStateError {
    fn from(value: String) -> Self {
        ManageStateError::UpdateError(value)
    }
}

async fn manage_state(
    timout: Option<NonZeroU64>,
    smesgs: impl AsyncChannelReceiver<SimplifiedServerMessage> + 'static,
    output: impl AsyncChannelSender<ClientStateView> + 'static,
    repeat: impl AsyncChannelSender<SimplifiedServerMessage> + 'static + Clone,
    umoves: impl AsyncChannelReceiver<SimpleUserMove> + 'static,
    sender: impl AsyncChannelSender<ClientMessage> + 'static + Clone,
) -> Result<(), ManageStateError> {
    let mut smesgs = smesgs;
    let mut umoves = umoves;

    let timeout = timout.map(NonZeroU64::get).map(Duration::from_secs);

    let mut expected_pong: Option<u8> = None;
    let mut state: ClientState = Default::default();

    output.send(Box::new(state.clone().into())).await?; // send state asap!

    loop {
        let (ping_in_fut, timeout_fut) = if let Some(timeout) = timeout {
            if state.timed_out {
                (sleep(Duration::MAX), sleep(Duration::MAX))
            } else {
                let timing_out_at = state
                    .last_server_message
                    .checked_add(timeout)
                    .ok_or_else(|| format!("{timeout:?} is too long"))?;
                if expected_pong.is_some() {
                    (sleep(Duration::MAX), sleep_until(timing_out_at))
                } else {
                    let send_ping_at = state
                        .last_server_message
                        .checked_add(timeout.div(2))
                        .ok_or_else(|| format!("{timeout:?} is too long"))?;
                    (sleep_until(send_ping_at), sleep_until(timing_out_at))
                }
            }
        } else {
            (sleep(Duration::MAX), sleep(Duration::MAX))
        };
        let _: () = tokio::select! {
            _ = ping_in_fut => {
                let c: u8 = rand::random();
                sender.send(ClientMessage::Ping { c }).await?;
                expected_pong = Some(c);
            },
            _ = timeout_fut => {
                log::warn!("timed out");
                state.timed_out = true;
                output.send(Box::new(state.clone().into())).await?;
            },
            Some(mesg) = smesgs.receive() => {
                log::trace!("manage_state got {mesg:?}");
                update_state(&mut state, &mut expected_pong, &mesg)?;
                log::trace!("updated state");

                output.send(Box::new(state.clone().into())).await?;
                log::trace!("send state");
                repeat.send(mesg).await?;
                log::trace!("relayed server msg")
            },
            Some(umove) = umoves.receive() => {
                log::trace!("User wants to move: {umove:?}");

                do_move(&mut state, umove, sender.clone(), repeat.clone()).await?;
            },
            else => return Ok(()),
        };
    }
}

fn update_state(
    state: &mut ClientState,
    expected_pong: &mut Option<u8>,
    server_msg: &SimplifiedServerMessage,
) -> Result<(), String> {
    state.last_server_message = Instant::now();
    state.timed_out = false;
    match server_msg {
        SimplifiedServerMessage::NewRoomState(room_state) => {
            handle_new_room_state(state, room_state)
        }
        SimplifiedServerMessage::Pong(pong) => match expected_pong.take() {
            Some(p) if p == *pong => Ok(()),
            Some(expected) => Err(format!("Got invalid pong {pong}, but expected {expected}")),
            None => Err("Didn't expect pong".into()),
        },
        SimplifiedServerMessage::Error(error) => {
            log::error!("got error: {error:?}");
            Ok(())
        }
    }
}