use base64::DecodeError;
use hex::FromHexError;
use std::{
convert::TryInto,
fmt::Debug,
io::{Cursor, Error as IOError, Read, Write},
};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SerError {
#[error("Attempted to deserialize non-minmal VarInt. Someone is doing something fishy.")]
NonMinimalVarInt,
#[error(transparent)]
IoError(#[from] IOError),
#[error(transparent)]
FromHexError(#[from] FromHexError),
#[error(transparent)]
DecodeError(#[from] DecodeError),
#[error("Error in component (de)serialization: {0}")]
ComponentError(String),
#[error("Expected a sequence of exaclty {expected} items. Got only {got} items")]
InsufficientSeqItems {
expected: usize,
got: usize,
},
}
pub enum ReadSeqMode {
Exactly(usize),
AtMost(usize),
UntilEnd,
}
pub type SerResult<T> = Result<T, SerError>;
pub fn prefix_byte_len(number: u64) -> u8 {
match number {
0..=0xfc => 1,
0xfd..=0xffff => 3,
0x10000..=0xffff_ffff => 5,
_ => 9,
}
}
pub fn first_byte_from_len(number: u8) -> Option<u8> {
match number {
3 => Some(0xfd),
5 => Some(0xfe),
9 => Some(0xff),
_ => None,
}
}
pub fn prefix_len_from_first_byte(number: u8) -> u8 {
match number {
0..=0xfc => 1,
0xfd => 3,
0xfe => 5,
0xff => 9,
}
}
pub fn write_compact_int<W>(writer: &mut W, number: u64) -> SerResult<usize>
where
W: Write,
{
let prefix_len = prefix_byte_len(number);
let written: usize = match first_byte_from_len(prefix_len) {
None => writer.write(&[number as u8])?,
Some(prefix) => {
let mut written = writer.write(&[prefix])?;
let body = (number as u64).to_le_bytes();
written += writer.write(&body[..prefix_len as usize - 1])?;
written
}
};
Ok(written)
}
pub fn read_compact_int<R>(reader: &mut R) -> SerResult<u64>
where
R: Read,
{
let mut prefix = [0u8; 1];
reader.read_exact(&mut prefix)?;
let prefix_len = prefix_len_from_first_byte(prefix[0]);
let number = if prefix_len > 1 {
let mut buf = [0u8; 8];
let mut body = reader.take(prefix_len as u64 - 1);
let _ = body.read(&mut buf)?;
u64::from_le_bytes(buf)
} else {
prefix[0] as u64
};
let minimal_length = prefix_byte_len(number);
if minimal_length < prefix_len {
Err(SerError::NonMinimalVarInt)
} else {
Ok(number)
}
}
pub fn read_u32_le<R>(reader: &mut R) -> SerResult<u32>
where
R: Read,
{
let mut buf = [0u8; 4];
reader.read_exact(&mut buf)?;
Ok(u32::from_le_bytes(buf))
}
pub fn write_u32_le<W>(writer: &mut W, number: u32) -> SerResult<usize>
where
W: Write,
{
Ok(writer.write(&number.to_le_bytes())?)
}
pub fn read_u64_le<R>(reader: &mut R) -> SerResult<u64>
where
R: Read,
{
let mut buf = [0u8; 8];
reader.read_exact(&mut buf)?;
Ok(u64::from_le_bytes(buf))
}
pub fn write_u64_le<W>(writer: &mut W, number: u64) -> SerResult<usize>
where
W: Write,
{
Ok(writer.write(&number.to_le_bytes())?)
}
pub fn read_prefix_vec<R, E, I>(reader: &mut R) -> Result<Vec<I>, E>
where
R: Read,
E: From<SerError> + From<IOError> + std::error::Error,
I: ByteFormat<Error = E>,
{
let items = read_compact_int(reader)?;
I::read_seq_from(reader, ReadSeqMode::Exactly(items.try_into().unwrap())).map_err(Into::into)
}
pub fn write_prefix_vec<W, E, I>(writer: &mut W, vector: &[I]) -> Result<usize, E>
where
W: Write,
E: From<SerError> + From<IOError> + std::error::Error,
I: ByteFormat<Error = E>,
{
let mut written = write_compact_int(writer, vector.len() as u64)?;
written += I::write_seq_to(writer, vector.iter())?;
Ok(written)
}
pub trait ByteFormat {
type Error: From<SerError> + From<IOError> + std::error::Error;
fn serialized_length(&self) -> usize;
fn read_from<R>(reader: &mut R) -> Result<Self, Self::Error>
where
R: Read,
Self: std::marker::Sized;
fn write_to<W>(&self, writer: &mut W) -> Result<usize, <Self as ByteFormat>::Error>
where
W: Write;
fn read_seq_from<R>(reader: &mut R, mode: ReadSeqMode) -> Result<Vec<Self>, Self::Error>
where
R: Read,
Self: std::marker::Sized,
{
let mut v = vec![];
match mode {
ReadSeqMode::Exactly(number) => {
for _ in 0..number {
v.push(Self::read_from(reader)?);
}
if v.len() != number {
return Err(SerError::InsufficientSeqItems {
got: v.len(),
expected: number,
}
.into());
}
}
ReadSeqMode::AtMost(limit) => {
for _ in 0..limit {
v.push(Self::read_from(reader)?);
}
}
ReadSeqMode::UntilEnd => {
while let Ok(obj) = Self::read_from(reader) {
v.push(obj);
}
}
}
Ok(v)
}
fn write_seq_to<'a, W, E, Iter, Item>(
writer: &mut W,
iter: Iter,
) -> Result<usize, <Self as ByteFormat>::Error>
where
W: Write,
E: Into<Self::Error> + From<SerError> + From<IOError> + std::error::Error,
Item: 'a + ByteFormat<Error = E>,
Iter: IntoIterator<Item = &'a Item>,
{
let mut written = 0;
for item in iter {
written += item.write_to(writer).map_err(Into::into)?;
}
Ok(written)
}
fn deserialize_hex(s: &str) -> Result<Self, Self::Error>
where
Self: std::marker::Sized,
{
let v: Vec<u8> = hex::decode(s).map_err(SerError::from)?;
let mut cursor = Cursor::new(v);
Self::read_from(&mut cursor)
}
fn deserialize_base64(s: &str) -> Result<Self, Self::Error>
where
Self: std::marker::Sized,
{
let v: Vec<u8> = base64::decode(s).map_err(SerError::from)?;
let mut cursor = Cursor::new(v);
Self::read_from(&mut cursor)
}
fn serialize_hex(&self) -> String {
let mut v: Vec<u8> = vec![];
self.write_to(&mut v).expect("No error on heap write");
hex::encode(v)
}
fn serialize_base64(&self) -> String {
let mut v: Vec<u8> = vec![];
self.write_to(&mut v).expect("No error on heap write");
base64::encode(v)
}
}
impl ByteFormat for u8 {
type Error = SerError;
fn serialized_length(&self) -> usize {
1
}
fn read_seq_from<R>(reader: &mut R, mode: ReadSeqMode) -> SerResult<Vec<u8>>
where
R: Read,
Self: std::marker::Sized,
{
match mode {
ReadSeqMode::Exactly(number) => {
let mut v = vec![0u8; number];
reader.read_exact(v.as_mut_slice())?;
Ok(v)
}
ReadSeqMode::AtMost(limit) => {
let mut v = vec![0u8; limit];
let n = reader.read(v.as_mut_slice())?;
v.truncate(n);
Ok(v)
}
ReadSeqMode::UntilEnd => Ok(reader.bytes().collect::<Result<Vec<u8>, _>>()?),
}
}
fn read_from<R>(reader: &mut R) -> SerResult<Self>
where
R: Read,
Self: std::marker::Sized,
{
let mut buf = [0u8; 1];
reader.read_exact(&mut buf)?;
Ok(u8::from_le_bytes(buf))
}
fn write_to<W>(&self, writer: &mut W) -> SerResult<usize>
where
W: Write,
{
Ok(writer.write(&self.to_le_bytes())?)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn it_matches_byte_len_and_prefix() {
let cases = [
(1, 1, None),
(0xff, 3, Some(0xfd)),
(0xffff_ffff, 5, Some(0xfe)),
(0xffff_ffff_ffff_ffff, 9, Some(0xff)),
];
for case in cases.iter() {
assert_eq!(prefix_byte_len(case.0), case.1);
assert_eq!(first_byte_from_len(case.1), case.2);
}
}
#[test]
fn it_implements_byteformat_for_u8() {
for i in 0..u8::MAX {
let size = i.serialized_length();
assert_eq!(size, 1);
let mut v = vec![];
i.write_to(&mut v).unwrap();
let mut slice = v.as_slice();
let expected = u8::read_from(&mut slice).unwrap();
assert_eq!(i, expected);
}
}
#[test]
fn it_implements_seq_ops_for_u8() {
let input = vec![0, 1, 2, 3, 4];
let mut buf = vec![];
u8::write_seq_to(&mut buf, input.iter()).unwrap();
assert_eq!(buf.len(), input.len());
assert_eq!(buf, input);
let exact_len =
u8::read_seq_from(&mut buf.clone().as_slice(), ReadSeqMode::Exactly(buf.len()))
.unwrap();
assert_eq!(exact_len.len(), buf.len());
assert_eq!(input, exact_len);
let exact_too_long = u8::read_seq_from(
&mut buf.clone().as_slice(),
ReadSeqMode::Exactly(buf.len() + 1),
);
assert_eq!(exact_too_long.is_err(), true);
let exact_first =
u8::read_seq_from(&mut buf.clone().as_slice(), ReadSeqMode::Exactly(1)).unwrap();
assert_eq!(exact_first, vec![0]);
let exact_none =
u8::read_seq_from(&mut buf.clone().as_slice(), ReadSeqMode::Exactly(0)).unwrap();
assert_eq!(exact_none, Vec::<u8>::new());
let at_most_all =
u8::read_seq_from(&mut buf.clone().as_slice(), ReadSeqMode::AtMost(buf.len())).unwrap();
assert_eq!(at_most_all, buf.clone());
let at_most_more = u8::read_seq_from(
&mut buf.clone().as_slice(),
ReadSeqMode::AtMost(buf.len() + 10),
)
.unwrap();
assert_eq!(at_most_more, buf.clone());
let at_most_less = u8::read_seq_from(
&mut buf.clone().as_slice(),
ReadSeqMode::AtMost(buf.len() - 1),
)
.unwrap();
let mut resized = buf.clone();
resized.resize(buf.len() - 1, 0);
assert_eq!(at_most_less, resized);
let until_end =
u8::read_seq_from(&mut buf.clone().as_slice(), ReadSeqMode::UntilEnd).unwrap();
assert_eq!(until_end, buf.clone());
}
}