use crate::error::{DecodeError, EncodeError};
use std::convert::TryFrom;
use std::io::Write;
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum Header {
Null,
True,
False,
F32,
F64,
Bin(usize),
Pos(u64),
Neg(u64),
Bag(usize),
Str(usize),
Sym(usize),
Key(usize),
Ref(usize),
}
impl Header {
pub fn name(&self) -> &'static str {
match *self {
Header::Null => "Null",
Header::True => "True",
Header::False => "False",
Header::F32 => "F32",
Header::F64 => "F64",
Header::Pos(_) => "Pos",
Header::Neg(_) => "Neg",
Header::Bin(_) => "Bin",
Header::Bag(_) => "Bag",
Header::Str(_) => "Str",
Header::Sym(_) => "Sym",
Header::Key(_) => "Key",
Header::Ref(_) => "Ref",
}
}
pub fn encode<W: Write>(&self, w: &mut W) -> Result<usize, EncodeError> {
match *self {
Header::Null => { w.write_all(&[self.code() << 5 | 0])?; Ok(1) },
Header::True => { w.write_all(&[self.code() << 5 | 1])?; Ok(1) },
Header::False => { w.write_all(&[self.code() << 5 | 2])?; Ok(1) },
Header::F32 => { w.write_all(&[self.code() << 5 | 3])?; Ok(1) },
Header::F64 => { w.write_all(&[self.code() << 5 | 4])?; Ok(1) },
Header::Pos(i) => self.encode_long_header(i, w),
Header::Neg(0) => { w.write_all(&[1 << 5 | 0])?; Ok(1) },
Header::Neg(i) => self.encode_long_header(i - 1, w),
Header::Bin(i)
| Header::Bag(i)
| Header::Str(i)
| Header::Sym(i)
| Header::Key(i)
| Header::Ref(i) => self.encode_long_header(Self::to_u64(i)?, w)
}
}
pub fn decode<B: ?Sized + AsRef<[u8]>>(buf: &B) -> Result<(Self, usize), DecodeError> {
let buf = buf.as_ref();
if buf.len() < 1 {
return Err(DecodeError::Eof);
}
let code = buf[0] >> 5;
let sz = buf[0] & 0x1f;
match code {
0 => {
match sz {
0 => Ok((Header::Null, 1)),
1 => Ok((Header::True, 1)),
2 => Ok((Header::False, 1)),
3 => Ok((Header::F32, 1)),
4 => Ok((Header::F64, 1)),
x if x < 24 => Ok((Header::Bin(x as usize - 5), 1)),
x => Self::decode_u64(&buf[1..], x).and_then(|(i, c)| Ok((Header::Bin(Self::to_usize(i)?), c + 1))),
}
},
1 => Self::decode_u64(&buf[1..], sz).map(|(i, c)| (Header::Pos(i), c + 1)),
2 => Self::decode_u64(&buf[1..], sz).map(|(i, c)| (Header::Neg(i.saturating_add(1)), c + 1)),
3 => Self::decode_u64(&buf[1..], sz).and_then(|(i, c)| Ok((Header::Bag(Self::to_usize(i)?), c + 1))),
4 => Self::decode_u64(&buf[1..], sz).and_then(|(i, c)| Ok((Header::Str(Self::to_usize(i)?), c + 1))),
5 => Self::decode_u64(&buf[1..], sz).and_then(|(i, c)| Ok((Header::Sym(Self::to_usize(i)?), c + 1))),
6 => Self::decode_u64(&buf[1..], sz).and_then(|(i, c)| Ok((Header::Key(Self::to_usize(i)?), c + 1))),
7 => Self::decode_u64(&buf[1..], sz).and_then(|(i, c)| Ok((Header::Ref(Self::to_usize(i)?), c + 1))),
_ => unreachable!(),
}
}
#[inline]
fn encode_long_header<W: Write>(&self, i: u64, w: &mut W) -> Result<usize, EncodeError> {
let limit = self.sz_limit();
if i < limit as u64 {
w.write_all(&[self.code() << 5 | i as u8 + (24 - limit)])?;
Ok(1)
} else {
let sz = Self::size(i);
let buf = i.to_be_bytes();
w.write_all(&[self.code() << 5 | (sz + 23)])?;
w.write_all(&buf[buf.len() - sz as usize ..])?;
Ok(1 + sz as usize)
}
}
#[inline]
fn decode_u64(buf: &[u8], sz: u8) -> Result<(u64, usize), DecodeError> {
if sz < 24 {
Ok((sz as u64, 0))
} else {
let bytes = sz as usize - 23;
if buf.len() < bytes {
Err(DecodeError::Eof)
} else {
let mut tmp = [0u8; 8];
tmp[8 - bytes..].copy_from_slice(&buf[..bytes]);
Ok((<u64>::from_be_bytes(tmp), bytes))
}
}
}
#[inline]
fn code(&self) -> u8 {
match *self {
Header::Null | Header::True | Header::False | Header::F32 | Header::F64 | Header::Bin(_) => 0,
Header::Pos(_) => 1,
Header::Neg(_) => 2,
Header::Bag(_) => 3,
Header::Str(_) => 4,
Header::Sym(_) => 5,
Header::Key(_) => 6,
Header::Ref(_) => 7,
}
}
#[inline]
fn sz_limit(&self) -> u8 {
match *self {
Header::Bin(_) => 19,
_ => 24,
}
}
#[inline]
fn size(value: u64) -> u8 {
if value < 1 << 8 {
1
} else if value < 1 << 16 {
2
} else if value < 1 << 24 {
3
} else if value < 1 << 32 {
4
} else if value < 1 << 40 {
5
} else if value < 1 << 48 {
6
} else if value < 1 << 56 {
7
} else {
8
}
}
#[inline]
fn to_usize(value: u64) -> Result<usize, DecodeError> {
usize::try_from(value).map_err(|_| DecodeError::Length(value))
}
#[inline]
fn to_u64(value: usize) -> Result<u64, EncodeError> {
u64::try_from(value).map_err(|_| EncodeError::Length(value))
}
}
#[cfg(test)]
mod tests {
use super::Header;
#[test]
fn lead_bytes() {
let mut src = [0u8; 9];
let mut dst = Vec::with_capacity(9);
for l in 0..u8::MAX {
dst.clear();
src[0] = l;
let decoded = Header::decode(&src).unwrap().0;
let _ = decoded.encode(&mut dst).unwrap();
}
}
#[test]
fn negative_zero() {
let mut buf = Vec::new();
let _ = Header::Neg(0).encode(&mut buf);
assert_eq!(Header::Pos(0), Header::decode(&buf).unwrap().0);
}
#[test]
fn negative_max() {
let buf = [0x5f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff];
assert_eq!(Header::Neg(u64::MAX), Header::decode(&buf).unwrap().0);
let buf = [0x5f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe];
assert_eq!(Header::Neg(u64::MAX), Header::decode(&buf).unwrap().0);
}
#[test]
fn roundtrip_compact() {
let mut buf = Vec::new();
assert_roundtrip(Header::Null, &mut buf);
assert_roundtrip(Header::True, &mut buf);
assert_roundtrip(Header::False, &mut buf);
assert_roundtrip(Header::F32, &mut buf);
assert_roundtrip(Header::F64, &mut buf);
for i in 0..24 {
if i < 19 {
assert_roundtrip(Header::Bin(i), &mut buf);
}
assert_roundtrip(Header::Pos(i as u64), &mut buf);
assert_roundtrip(Header::Neg(if i == 0 { 1 } else { i } as u64), &mut buf);
assert_roundtrip(Header::Bag(i), &mut buf);
assert_roundtrip(Header::Str(i), &mut buf);
assert_roundtrip(Header::Sym(i), &mut buf);
assert_roundtrip(Header::Key(i), &mut buf);
assert_roundtrip(Header::Ref(i), &mut buf);
}
}
#[test]
fn roundtrip_long() {
let mut buf = Vec::new();
for i in (0..u64::MAX).step_by(3_203_431_780_337) {
assert_roundtrip(Header::Bin(i as usize), &mut buf);
assert_roundtrip(Header::Pos(i), &mut buf);
assert_roundtrip(Header::Neg(if i == 0 { 1 } else { i } as u64), &mut buf);
assert_roundtrip(Header::Bag(i as usize), &mut buf);
assert_roundtrip(Header::Str(i as usize), &mut buf);
assert_roundtrip(Header::Sym(i as usize), &mut buf);
assert_roundtrip(Header::Key(i as usize), &mut buf);
assert_roundtrip(Header::Ref(i as usize), &mut buf);
}
}
#[test]
fn inefficient_encoding() {
let buf = [0x7f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02];
assert_eq!(Header::Bag(2), Header::decode(&buf).unwrap().0);
}
fn assert_roundtrip(value: Header, buf: &mut Vec<u8>) {
let _ = value.encode(buf);
assert_eq!(value, Header::decode(buf).unwrap().0);
buf.clear();
}
}