seraphic 0.1.55

Lightweight JSON RPC 2.0
Documentation
use crate::MainResult;
use serde::{Deserialize, Serialize};
use std::{
    io::{BufRead, ErrorKind, Write},
    marker::PhantomData,
};

#[derive(Clone, Debug)]
pub struct TcpPacket<T> {
    pub(crate) buffer: Vec<u8>,
    marker: PhantomData<T>,
}

impl<T> PartialEq for TcpPacket<T> {
    fn eq(&self, other: &Self) -> bool {
        self.buffer.eq(&other.buffer)
    }
}

type HeaderSize = u32;
pub(crate) const fn header_size() -> usize {
    std::mem::size_of::<HeaderSize>() / std::mem::size_of::<u8>()
}

impl<T> TcpPacket<T> {
    pub fn buffer(&self) -> &[u8] {
        &self.buffer
    }
}

impl<T> TcpPacket<T>
where
    T: Serialize + std::fmt::Debug + for<'de> Deserialize<'de>,
{
    pub fn try_into_inner(self) -> MainResult<T> {
        let buf = &self.buffer[header_size()..];
        let str = String::from_utf8_lossy(buf);
        serde_json::from_slice::<T>(buf).map_err(|err| {
            std::io::Error::other(format!(
                "error getting tcp packet inner from slice: {err:#?}\nbuffer: {str}"
            ))
            .into()
        })
    }
}

impl<T> From<&T> for TcpPacket<T>
where
    T: Serialize + std::fmt::Debug + for<'de> Deserialize<'de>,
{
    fn from(r: &T) -> Self {
        let vec = serde_json::to_vec(r).expect("T will not work");

        assert!(
            vec.len() <= HeaderSize::MAX as usize,
            "consider making the header size larger"
        );

        let size: u32 = vec.len() as u32;

        let mut buffer = Vec::with_capacity(header_size() + vec.len());
        buffer.extend_from_slice(&size.to_le_bytes());
        buffer.extend_from_slice(&vec);
        Self {
            marker: PhantomData,
            buffer,
        }
    }
}

impl<'de, T> serde::Deserialize<'de> for TcpPacket<T> {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        let buffer = <Vec<u8> as Deserialize>::deserialize(deserializer)?;
        Ok(Self {
            buffer,
            marker: PhantomData,
        })
    }
}

impl<T> Serialize for TcpPacket<T> {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        self.buffer.serialize(serializer)
    }
}

#[derive(Debug, PartialEq)]
pub enum PacketRead<T> {
    Message(T),
    Disconnected,
    Empty,
}

impl<T> TcpPacket<T>
where
    T: Serialize + std::fmt::Debug + for<'de> Deserialize<'de>,
{
    pub fn read(inp: &mut dyn BufRead) -> std::io::Result<PacketRead<T>> {
        let mut header = [0u8; header_size()];
        let mut buffer = [0u8; 1024].to_vec();
        let mut size = None;
        while size.is_none() {
            match inp.read_exact(&mut header) {
                Ok(_) => {
                    if header.is_empty() {
                        break;
                    }
                    let payload_size = u32::from_le_bytes(header) as usize;
                    size = Some(payload_size);
                }
                Err(err)
                    if err.kind() == ErrorKind::UnexpectedEof && header == [0u8; header_size()] =>
                {
                    return Ok(PacketRead::Disconnected);
                }
                Err(err) if err.kind() == ErrorKind::WouldBlock => {
                    return Ok(PacketRead::Empty);
                }
                Err(err) => {
                    return Err(std::io::Error::other(format!(
                        "unexepect error when reading header: {err:#?}\nbuffer: {}",
                        String::from_utf8_lossy(&buffer)
                    )));
                }
            }
        }
        let size: usize = size.ok_or(std::io::Error::other("no content length"))?;
        tracing::debug!("got payload size from header: {size}");
        buffer.resize(size, 0);
        match inp.read_exact(&mut buffer) {
            Ok(_) => {
                let typ = serde_json::from_slice::<T>(&buffer).map_err(|err| {
                    std::io::Error::other(format!(
                        "malformed payload: {}\nErr: {err:#?}",
                        String::from_utf8_lossy(&buffer),
                    ))
                })?;
                Ok(PacketRead::Message(typ))
            }
            Err(err) if err.kind() == ErrorKind::WouldBlock => {
                return Ok(PacketRead::Empty);
            }
            Err(err) => {
                return Err(std::io::Error::other(format!(
                    "unexepect error when reading payload: {err:#?}\nbuffer: {}",
                    String::from_utf8_lossy(&buffer)
                )));
            }
        }
    }

    pub fn write(out: &mut dyn Write, typ: &T) -> std::io::Result<()> {
        let packet = Self::from(typ);
        out.write_all(&packet.buffer)?;
        out.flush()?;
        Ok(())
    }
}