use std::{
io::{Error, ErrorKind, Result},
marker::{Send, Unpin}
};
use async_trait::async_trait;
pub use byteorder::{LittleEndian as LE, ReadBytesExt, WriteBytesExt};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[inline]
pub fn make_packet<F>(n: usize, f: F) -> Result<Vec<u8>>
where F: Fn(&mut Vec<u8>) -> Result<()>
{
let mut packet = Vec::with_capacity(n);
f(&mut packet)?;
Ok(packet)
}
pub trait WriteExtraExt : WriteBytesExt {
fn write_str(&mut self, s: &str) -> Result<()> {
self.write_b8(s.as_bytes())
}
fn write_b8(&mut self, bytes: &[u8]) -> Result<()>;
fn write_b16(&mut self, bytes: &[u8]) -> Result<()>;
fn write_b32(&mut self, bytes: &[u8], max: usize) -> Result<()>;
}
impl <W: WriteBytesExt> WriteExtraExt for W {
fn write_b8(&mut self, bytes: &[u8]) -> Result<()> {
if bytes.len() > u8::MAX.into() {
return Err(Error::new(ErrorKind::InvalidInput, "Data too long to read back"));
}
self.write_u8(bytes.len() as u8)?;
self.write_all(bytes)
}
fn write_b16(&mut self, bytes: &[u8]) -> Result<()> {
if bytes.len() > u16::MAX as usize {
return Err(Error::new(ErrorKind::InvalidInput, "Data too long to read back"));
}
self.write_u16::<LE>(bytes.len() as u16)?;
self.write_all(bytes)
}
fn write_b32(&mut self, bytes: &[u8], max: usize) -> Result<()> {
if bytes.len() > max {
return Err(Error::new(ErrorKind::InvalidInput, "Data too long to read back"));
}
self.write_u32::<LE>(bytes.len() as u32)?;
self.write_all(bytes)
}
}
#[async_trait]
pub trait AsyncWriteExtraExt : AsyncWriteExt {
async fn write_str(&mut self, s: &str) -> Result<()> {
self.write_b8(s.as_bytes()).await
}
async fn write_b8(&mut self, bytes: &[u8]) -> Result<()>;
async fn write_b16(&mut self, bytes: &[u8]) -> Result<()>;
async fn write_b32(&mut self, bytes: &[u8], max: usize) -> Result<()>;
}
#[async_trait]
impl <W: AsyncWriteExt + Send + Unpin> AsyncWriteExtraExt for W {
async fn write_b8(&mut self, bytes: &[u8]) -> Result<()> {
if bytes.len() > u8::MAX.into() {
return Err(Error::new(ErrorKind::InvalidInput, "Data too long to read back"));
}
self.write_u8(bytes.len() as u8).await?;
self.write_all(bytes).await
}
async fn write_b16(&mut self, bytes: &[u8]) -> Result<()> {
if bytes.len() > u16::MAX as usize {
return Err(Error::new(ErrorKind::InvalidInput, "Data too long to read back"));
}
self.write_u16_le(bytes.len() as u16).await?;
self.write_all(bytes).await
}
async fn write_b32(&mut self, bytes: &[u8], max: usize) -> Result<()> {
if bytes.len() > max {
return Err(Error::new(ErrorKind::InvalidInput, "Data too long to read back"));
}
self.write_u32_le(bytes.len() as u32).await?;
self.write_all(bytes).await
}
}
pub trait ReadExtraExt : ReadBytesExt {
fn read_str(&mut self) -> Result<String> {
let bytes = self.read_b8()?;
String::from_utf8(bytes)
.map_err(|_| Error::new(ErrorKind::InvalidData, "String contains non-utf8 bytes"))
}
fn read_b8(&mut self) -> Result<Vec<u8>>;
fn read_b16(&mut self) -> Result<Vec<u8>>;
fn read_b32(&mut self, max: usize) -> Result<Vec<u8>>;
}
impl <R: ReadBytesExt> ReadExtraExt for R {
fn read_b8(&mut self) -> Result<Vec<u8>> {
let len = self.read_u8()? as usize;
let mut bytes = vec![0; len];
self.read_exact(&mut bytes)?;
Ok(bytes)
}
fn read_b16(&mut self) -> Result<Vec<u8>> {
let len = self.read_u16::<LE>()? as usize;
let mut bytes = vec![0; len];
self.read_exact(&mut bytes)?;
Ok(bytes)
}
fn read_b32(&mut self, max: usize) -> Result<Vec<u8>> {
let len = self.read_u32::<LE>()? as usize;
if len > max {
return Err(Error::new(ErrorKind::InvalidData, "Data too long"));
}
let mut bytes = vec![0; len];
self.read_exact(&mut bytes)?;
Ok(bytes)
}
}
#[async_trait]
pub trait AsyncReadExtraExt : AsyncReadExt {
async fn read_str(&mut self) -> Result<String> {
let bytes = self.read_b8().await?;
String::from_utf8(bytes)
.map_err(|_| Error::new(ErrorKind::InvalidData, "String contains non-utf8 bytes"))
}
async fn read_b8(&mut self) -> Result<Vec<u8>>;
async fn read_b16(&mut self) -> Result<Vec<u8>>;
async fn read_b32(&mut self, max: usize) -> Result<Vec<u8>>;
}
#[async_trait]
impl <R: AsyncReadExt + Send + Unpin> AsyncReadExtraExt for R {
async fn read_b8(&mut self) -> Result<Vec<u8>> {
let len = self.read_u8().await? as usize;
let mut bytes = vec![0; len];
self.read_exact(&mut bytes).await?;
Ok(bytes)
}
async fn read_b16(&mut self) -> Result<Vec<u8>> {
let len = self.read_u16_le().await? as usize;
let mut bytes = vec![0; len];
self.read_exact(&mut bytes).await?;
Ok(bytes)
}
async fn read_b32(&mut self, max: usize) -> Result<Vec<u8>> {
let len = self.read_u32_le().await? as usize;
if len > max {
return Err(Error::new(ErrorKind::InvalidData, "Data too long"));
}
let mut bytes = vec![0; len];
self.read_exact(&mut bytes).await?;
Ok(bytes)
}
}