use serde::{Deserialize, Serialize};
use std::{
io::{Read, Write},
marker::PhantomData,
};
use xuko_core::array::{Array, ArrayCreationError};
#[cfg(feature = "async")]
use smol::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
const COMPRESSION_LEVEL: i32 = 3;
pub type PacketSize = u64;
pub const PACKET_SIZE: usize = std::mem::size_of::<PacketSize>();
static mut MAX_PACKET_SIZE: PacketSize = 2u64.pow(31) - 1;
static mut PACKET_HAS_BEEN_CREATED: bool = false;
pub const fn set_max_packet_size(size: PacketSize) {
unsafe {
if !PACKET_HAS_BEEN_CREATED {
MAX_PACKET_SIZE = size;
}
}
}
pub const fn get_max_packet_size() -> PacketSize {
unsafe { MAX_PACKET_SIZE }
}
#[derive(Debug, thiserror::Error)]
pub enum PacketError {
#[error("packet is too large ({0} bytes)")]
TooLarge(usize),
#[error("{context}; expected {expected} got {real}")]
LengthFailure {
context: String,
expected: usize,
real: usize,
},
#[error("(de)serialization error: {0}")]
RonError(#[from] ron::Error),
#[error("IO error: {0}")]
IOError(#[from] std::io::Error),
#[error("invalid utf-8: {0}")]
UTF8Error(#[from] std::string::FromUtf8Error),
#[error("array creation error: {0}")]
ArrayCreationError(#[from] ArrayCreationError),
}
pub struct Packet<'de, T>
where
T: Serialize + Deserialize<'de>,
{
payload: T,
_marker: PhantomData<&'de T>,
}
impl<'de, T: Serialize + Deserialize<'de>> Packet<'de, T> {
pub const fn new(payload: T) -> Self {
unsafe {
PACKET_HAS_BEEN_CREATED = true;
}
Self {
payload,
_marker: PhantomData,
}
}
pub fn unwrap(self) -> T {
self.payload
}
pub fn serialize(&self) -> Result<String, ron::Error> {
ron::to_string(&self.payload)
}
pub fn deserialize(serialized: &'de str) -> Result<Packet<'de, T>, ron::Error> {
Ok(Self::new(ron::from_str(serialized)?))
}
pub fn send<W: Write>(&self, stream: &mut W) -> Result<(), PacketError> {
let packet = self.serialize()?;
let bytes = &zstd::encode_all(packet.as_bytes(), COMPRESSION_LEVEL)?;
if bytes.len() > get_max_packet_size() as usize {
return Err(PacketError::TooLarge(bytes.len()));
}
let packet_len = (bytes.len() as PacketSize).to_be_bytes();
stream.write_all(&packet_len)?;
stream.write_all(bytes)?;
stream.flush()?;
Ok(())
}
#[cfg(feature = "async")]
pub async fn send_async<W: AsyncWrite + Unpin>(
&self,
stream: &mut W,
) -> Result<(), PacketError> {
let packet = self.serialize()?;
let bytes = &zstd::encode_all(packet.as_bytes(), COMPRESSION_LEVEL)?;
if bytes.len() > unsafe { MAX_PACKET_SIZE } as usize {
return Err(PacketError::TooLarge(bytes.len()));
}
let packet_len = (bytes.len() as PacketSize).to_be_bytes();
stream.write_all(&packet_len).await?;
stream.write_all(bytes).await?;
stream.flush().await?;
Ok(())
}
}
pub fn chomp_packet<R: Read>(stream: &mut R) -> Result<String, PacketError> {
let mut buf = [0u8; PACKET_SIZE];
let read = stream.read(&mut buf)?;
if read != PACKET_SIZE {
return Err(packet_size_error(PACKET_SIZE, read));
}
let len = PacketSize::from_be_bytes(buf) as usize;
let mut buf = Array::new(len)?;
let mut read = 0;
while let Ok(n) = stream.read(&mut buf) {
if n == 0 {
break;
}
read += n;
}
if read != len {
return Err(packet_bytes_read_error(len, read));
}
let b: &[u8] = &buf;
let serialized_packet = String::from_utf8(zstd::decode_all(b)?)?;
Ok(serialized_packet)
}
#[cfg(feature = "async")]
pub async fn chomp_packet_async<R: AsyncRead + Unpin>(
stream: &mut R,
) -> Result<String, PacketError> {
let mut buf = [0u8; PACKET_SIZE];
let read = stream.read(&mut buf).await?;
if read != PACKET_SIZE {
return Err(packet_size_error(PACKET_SIZE, read));
}
let len = PacketSize::from_be_bytes(buf) as usize;
let mut buf = Array::new(len)?;
let mut read = 0;
while let Ok(n) = stream.read(&mut buf).await {
if n == 0 {
break;
}
read += n;
}
if read != len {
return Err(packet_bytes_read_error(len, read));
}
let b: &[u8] = &buf;
let serialized_packet = String::from_utf8(zstd::decode_all(b)?)?;
Ok(serialized_packet)
}
fn packet_size_error(expected: usize, real: usize) -> PacketError {
PacketError::LengthFailure {
context: "invalid packet length".to_string(),
expected,
real,
}
}
fn packet_bytes_read_error(expected: usize, real: usize) -> PacketError {
PacketError::LengthFailure {
context: "didnt read expected amount of bytes from packet".to_string(),
expected,
real,
}
}