xuko-net 0.6.1

xuko's networking utilities
Documentation
//! Generic packet format

use serde::{Deserialize, Serialize};
use std::{
    io::{Read, Write},
    marker::PhantomData,
};
use xuko_core::array::{Array, ArrayCreationError};

#[cfg(feature = "async")]
use smol::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

const COMPRESSION_LEVEL: i32 = 3;

// TODO: add UDP support?
// UDP doesn't use the I/O `Write` and `Read` traits, hmmm

/// Packet size
pub type PacketSize = u64;

/// Size of [`PacketSize`]
pub const PACKET_SIZE: usize = std::mem::size_of::<PacketSize>();

/// Max packet size, may be configured using [`set_max_packet_size`]
static mut MAX_PACKET_SIZE: PacketSize = 2u64.pow(31) - 1;

static mut PACKET_HAS_BEEN_CREATED: bool = false;

/// Set the max packet size
///
/// This can only be before any [`Packet`] is created
pub const fn set_max_packet_size(size: PacketSize) {
    unsafe {
        if !PACKET_HAS_BEEN_CREATED {
            MAX_PACKET_SIZE = size;
        }
    }
}

/// Get the max packet size, it may be configured using [`set_max_packet_size`]
pub const fn get_max_packet_size() -> PacketSize {
    unsafe { MAX_PACKET_SIZE }
}

/// An error that occurs when using a [`Packet`]s functions
#[derive(Debug, thiserror::Error)]
pub enum PacketError {
    /// Occurs when a packet is too large
    #[error("packet is too large ({0} bytes)")]
    TooLarge(usize),

    /// Occurs when the amount of bytes read isn't what the reader expects
    #[error("{context}; expected {expected} got {real}")]
    LengthFailure {
        /// Provides a little context on where this is happening
        context: String,
        /// The amount of bytes read expected
        expected: usize,
        /// The amount of bytes we actually read
        real: usize,
    },

    /// [`ron`] error
    #[error("(de)serialization error: {0}")]
    RonError(#[from] ron::Error),

    /// I/O Error
    #[error("IO error: {0}")]
    IOError(#[from] std::io::Error),

    /// UTF-8 Error
    #[error("invalid utf-8: {0}")]
    UTF8Error(#[from] std::string::FromUtf8Error),

    /// Array creation error
    #[error("array creation error: {0}")]
    ArrayCreationError(#[from] ArrayCreationError),
}

/// Generic packet format.
///
/// Useful for game servers and the like.
///
/// [`Packet`] supports writing to anything that implements [`Write`], and reading from
/// anything that implements [`Read`]. including [`std::net::TcpStream`]s.
///
/// Currently it will not work with [`std::net::UdpSocket`] since it does not implement [`Write`] or
/// [`Read`].
///
/// [`Packet`] supports a payload of any value that implements [`serde::Serialize`] and [`serde::Deserialize`].
///
/// The [`chomp_packet`] function allows you to read a packet from a [`Read`]er and will return a
/// serialized string of the packet, use [`Packet::deserialize`] to deserialize it.
///
/// [`Packet::unwrap`] will give you the original payload stored in the packet.
///
/// # Examples
///
/// ```
/// use serde::{Serialize, Deserialize};
/// use std::io::Cursor;
/// use xuko_net::packet::{Packet, chomp_packet};
///
/// #[derive(Serialize, Deserialize)]
/// struct Payload {
///     x: i32,
///     y: f32
/// }
///
/// let packet = Packet::new(Payload {
///     x: 6,
///     y: 3.14
/// });
///
/// let mut stream = Cursor::new(Vec::new());
///
/// packet.send(&mut stream); // Write to a stream
/// packet.unwrap(); // Unwrap the original value using `Packet::unwrap`
/// ```
pub struct Packet<'de, T>
where
    T: Serialize + Deserialize<'de>,
{
    payload: T,
    _marker: PhantomData<&'de T>,
}

impl<'de, T: Serialize + Deserialize<'de>> Packet<'de, T> {
    /// Create a new [Packet]
    pub const fn new(payload: T) -> Self {
        unsafe {
            PACKET_HAS_BEEN_CREATED = true;
        }

        Self {
            payload,
            _marker: PhantomData,
        }
    }

    /// Unwrap the [`Packet`] to its underlying type
    pub fn unwrap(self) -> T {
        self.payload
    }

    /// Serialize [`Packet`] into a [`String`]
    pub fn serialize(&self) -> Result<String, ron::Error> {
        ron::to_string(&self.payload)
    }

    /// Deserialize a string into a [`Packet`]
    pub fn deserialize(serialized: &'de str) -> Result<Packet<'de, T>, ron::Error> {
        Ok(Self::new(ron::from_str(serialized)?))
    }

    /// Write the [`Packet`] in its serialized form to a writer
    pub fn send<W: Write>(&self, stream: &mut W) -> Result<(), PacketError> {
        let packet = self.serialize()?;
        let bytes = &zstd::encode_all(packet.as_bytes(), COMPRESSION_LEVEL)?;

        if bytes.len() > get_max_packet_size() as usize {
            return Err(PacketError::TooLarge(bytes.len()));
        }

        let packet_len = (bytes.len() as PacketSize).to_be_bytes();

        stream.write_all(&packet_len)?;
        stream.write_all(bytes)?;
        stream.flush()?;

        Ok(())
    }

    /// Same as [`Self::send`] but supports [`smol`]'s async I/O
    #[cfg(feature = "async")]
    pub async fn send_async<W: AsyncWrite + Unpin>(
        &self,
        stream: &mut W,
    ) -> Result<(), PacketError> {
        let packet = self.serialize()?;
        let bytes = &zstd::encode_all(packet.as_bytes(), COMPRESSION_LEVEL)?;

        if bytes.len() > unsafe { MAX_PACKET_SIZE } as usize {
            return Err(PacketError::TooLarge(bytes.len()));
        }

        let packet_len = (bytes.len() as PacketSize).to_be_bytes();

        stream.write_all(&packet_len).await?;
        stream.write_all(bytes).await?;
        stream.flush().await?;

        Ok(())
    }
}

/// Read a serialized [`Packet`] from a reader
///
/// # Note
///
/// This function does not care what the packet type is as it doesn't deserialize it.
/// It should then be deserialized to a specific type using [`Packet::deserialize`].
///
/// The format of a packet looks like the following: (BBBBBBBB) *data*
///
/// The first 8 bytes tell us what the size of the is.
/// The rest of the data is a generic packet serialized as a [`ron`] value.
pub fn chomp_packet<R: Read>(stream: &mut R) -> Result<String, PacketError> {
    let mut buf = [0u8; PACKET_SIZE];
    let read = stream.read(&mut buf)?;

    if read != PACKET_SIZE {
        return Err(packet_size_error(PACKET_SIZE, read));
    }

    let len = PacketSize::from_be_bytes(buf) as usize;

    let mut buf = Array::new(len)?;
    let mut read = 0;

    while let Ok(n) = stream.read(&mut buf) {
        if n == 0 {
            break;
        }
        read += n;
    }

    if read != len {
        return Err(packet_bytes_read_error(len, read));
    }

    let b: &[u8] = &buf;
    let serialized_packet = String::from_utf8(zstd::decode_all(b)?)?;

    Ok(serialized_packet)
}

/// Same as [`chomp_packet`] but supports [`smol`]'s async I/O
#[cfg(feature = "async")]
pub async fn chomp_packet_async<R: AsyncRead + Unpin>(
    stream: &mut R,
) -> Result<String, PacketError> {
    let mut buf = [0u8; PACKET_SIZE];
    let read = stream.read(&mut buf).await?;

    if read != PACKET_SIZE {
        return Err(packet_size_error(PACKET_SIZE, read));
    }

    let len = PacketSize::from_be_bytes(buf) as usize;

    let mut buf = Array::new(len)?;
    let mut read = 0;

    while let Ok(n) = stream.read(&mut buf).await {
        if n == 0 {
            break;
        }
        read += n;
    }

    if read != len {
        return Err(packet_bytes_read_error(len, read));
    }

    let b: &[u8] = &buf;
    let serialized_packet = String::from_utf8(zstd::decode_all(b)?)?;

    Ok(serialized_packet)
}

fn packet_size_error(expected: usize, real: usize) -> PacketError {
    PacketError::LengthFailure {
        context: "invalid packet length".to_string(),
        expected,
        real,
    }
}

fn packet_bytes_read_error(expected: usize, real: usize) -> PacketError {
    PacketError::LengthFailure {
        context: "didnt read expected amount of bytes from packet".to_string(),
        expected,
        real,
    }
}