use std::{
io::{Read, Write},
net::TcpStream,
};
use serde::{de::DeserializeOwned, Serialize};
#[derive(Debug)]
pub enum Error {
Io(std::io::Error),
Parser(bincode::Error),
InvalidPkg,
NoPackage,
ConnClosed,
}
pub type Result<T> = std::result::Result<T, Error>;
pub fn get_pkg_head(s: &mut TcpStream) -> Result<u64> {
let mut buf = [0u8; 8];
match s.read(&mut buf) {
Ok(s) => match s {
0 => {
Err(Error::ConnClosed)
}
1..=7 => Err(Error::InvalidPkg),
8 => Ok(u64::from_le_bytes(buf)),
_ => unreachable!(),
},
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => Err(Error::NoPackage),
Err(e) => Err(Error::Io(e)),
}
}
pub fn read<T>(s: &mut TcpStream) -> Result<T>
where
T: DeserializeOwned,
{
let size = get_pkg_head(s)? as usize;
let mut buf = vec![0u8; size];
match s.read_exact(&mut buf) {
Err(e) => Err(Error::Io(e)),
Ok(_) => bincode::deserialize(&buf).map_err(Error::Parser),
}
}
pub fn send<T>(data: &T, s: &mut TcpStream) -> Result<()>
where
T: Serialize + for<'a> DeserializeOwned,
{
let ser = bincode::serialize(data).map_err(Error::Parser)?;
let len = ser.len() as u64;
let buf: Vec<u8> = len.to_le_bytes().into_iter().chain(ser).collect();
s.write_all(&buf).map_err(Error::Io)?;
Ok(())
}