zero-trust-rps 0.0.5

Online Multiplayer Rock Paper Scissors
Documentation
use std::io::Read as _;

use bytes::Buf as _;
use quinn::RecvStream;
use serde::de::DeserializeOwned;

use crate::common::{
    connection::{GetNextMessageError, Reader},
    constants::MAX_CHUNK_SIZE,
};

const BUFFER_SIZE: usize = 1024;
const MAX_MESSAGE_LEN: usize = 10 * MAX_CHUNK_SIZE;

pub struct QuicReader {
    stream: RecvStream,
    buffer: Vec<u8>,
}

impl From<RecvStream> for QuicReader {
    fn from(value: RecvStream) -> Self {
        QuicReader {
            stream: value,
            buffer: Vec::with_capacity(BUFFER_SIZE),
        }
    }
}

impl Reader for QuicReader {
    async fn get_next_message<T: DeserializeOwned>(
        &mut self,
    ) -> Result<Option<T>, GetNextMessageError> {
        loop {
            if let Some(result) = get_next_message(self).await {
                break result;
            }
            if self.buffer.len() > MAX_MESSAGE_LEN {
                log::debug!("Clearing buffer, was too large");
                self.buffer.clear();
                return Err(GetNextMessageError::MessageTooLarge);
            }
            log::trace!("call next_message again to read more data");
        }
    }
}

async fn get_next_message<T: DeserializeOwned>(
    quic_reader: &mut QuicReader,
) -> Option<Result<Option<T>, GetNextMessageError>> {
    if !quic_reader.buffer.is_empty() {
        log::trace!("buffer is not empty");
        match postcard::take_from_bytes(&quic_reader.buffer) {
            Ok((value, rest)) => {
                if rest.is_empty() {
                    quic_reader.buffer.clear();
                } else {
                    let mut new_buffer = Vec::with_capacity(BUFFER_SIZE.max(rest.len()));
                    new_buffer.extend_from_slice(rest);
                    quic_reader.buffer = new_buffer;
                }
                return Some(Ok(Some(value)));
            }
            Err(postcard::Error::DeserializeUnexpectedEnd) => {
                log::trace!("rest didn't contain full message")
            }
            Err(err) => {
                log::debug!("postcard error: {err:?}");
                return Some(Err(err.into()));
            }
        }
    }
    let chunk = match quic_reader.stream.read_chunk(MAX_CHUNK_SIZE, true).await {
        Ok(chunk) => chunk,
        Err(err) => return Some(Err(err.into())),
    };
    Some(if let Some(chunk) = chunk {
        log::trace!("got chunk of length {}", chunk.bytes.len());
        match postcard::from_io((
            chunk.bytes.clone().reader(),
            if quic_reader.buffer.is_empty() {
                &mut []
            } else {
                &mut quic_reader.buffer
            },
        )) {
            Ok((value, (mut reader, rest))) => {
                log::trace!("Read chunk, got rest of len {}", rest.len());
                let mut new_buffer = Vec::with_capacity(BUFFER_SIZE.max(rest.len()));
                reader
                    .read_to_end(&mut new_buffer)
                    .expect("reading from Bytes to Vec should not fail");
                new_buffer.extend_from_slice(rest);
                quic_reader.buffer = new_buffer;

                Ok(Some(value))
            }
            Err(postcard::Error::DeserializeUnexpectedEnd) => {
                log::trace!("chunk didn't contain full message");
                let mut new_buffer = Vec::with_capacity(
                    BUFFER_SIZE.max(chunk.bytes.len() + quic_reader.buffer.len()),
                );
                new_buffer.extend_from_slice(&quic_reader.buffer);
                new_buffer.extend_from_slice(chunk.bytes.as_ref());
                quic_reader.buffer = new_buffer;
                return None; // call this again (Recursion without recursion)
            }
            Err(err) => {
                log::debug!("postcard error: {err:?}");
                Err(err.into())
            }
        }
    } else if quic_reader.buffer.is_empty() {
        Ok(None) // stream finished and no partial data at the end
    } else {
        // stream finished with partial data at the end
        Err(postcard::Error::DeserializeUnexpectedEnd.into())
    })
}