use hex::encode;
use num_enum::TryFromPrimitive;
use std::fmt;
use std::io::{Error, ErrorKind, Result};
pub const MAX_HEADER_BYTES: usize = 1024;
#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive)]
#[repr(u8)]
enum Type {
U8 = 0,
U16 = 1,
U32 = 2,
U64 = 3,
Flag = 8,
String = 9,
Bytes = 10,
}
#[derive(Clone, Eq, PartialEq)]
pub struct Header {
pub data: Vec<u8>,
len: usize,
}
impl Header {
pub fn new() -> Header {
Header { data: Vec::new(), len: 0 }
}
pub fn iter(&self) -> HeaderIterator {
HeaderIterator { header: self, index: 0 }
}
pub fn fields(&self) -> Vec<Field> {
self.iter().collect()
}
fn find_field(&self, field_type: Type, id: u8) -> Option<Field> {
for f in self.iter() {
if f.field_type == field_type && f.id == id {
return Some(f);
}
}
None
}
pub fn get_flag(&self, id: u8) -> bool {
self.find_field(Type::Flag, id).is_some()
}
pub fn get_int(&self, id: u8) -> Option<u64> {
for f in self.iter() {
if f.id == id {
match f.field_type {
Type::U8 => return f.get_u8().map(|d| d as u64),
Type::U16 => return f.get_u16().map(|d| d as u64),
Type::U32 => return f.get_u32().map(|d| d as u64),
Type::U64 => return f.get_u64(),
_ => (),
}
}
}
None
}
pub fn get_string(&self, id: u8) -> Option<&str> {
self.find_field(Type::String, id).and_then(|f| f.get_string())
}
pub fn get_bytes(&self, id: u8) -> Option<&[u8]> {
self.find_field(Type::Bytes, id).and_then(|f| f.get_bytes())
}
fn add_header(&mut self, t: Type, id: u8, n: usize) -> Result<()> {
if self.len + n < MAX_HEADER_BYTES && id <= 15 && n <= 255 {
self.data.resize(self.len + n + 1, 0);
self.data[self.len] = ((t as u8) << 4) | id;
self.len += 1;
Ok(())
} else if n > 255 {
Err(Error::new(ErrorKind::InvalidInput, "Header field length must be less than 256"))
} else if id > 15 {
Err(Error::new(ErrorKind::InvalidInput, "Header ID must be between 0 and 15, inclusive"))
} else {
Err(Error::new(ErrorKind::OutOfMemory, "Header can only be 1KB"))
}
}
pub fn add_flag(&mut self, id: u8) -> Result<&mut Self> {
self.add_header(Type::Flag, id, 0)?;
Ok(self)
}
pub fn add_int(&mut self, id: u8, value: u64) -> Result<&mut Self> {
if value < 0x100 {
self.add_header(Type::U8, id, 1)?;
self.data[self.len] = value as u8;
self.len += 1;
} else if value < 0x10000 {
self.add_header(Type::U16, id, 2)?;
self.data[self.len .. self.len + 2].copy_from_slice(&(value as u16).to_le_bytes());
self.len += 2;
} else if value < 0x1_0000_0000 {
self.add_header(Type::U32, id, 4)?;
self.data[self.len .. self.len + 4].copy_from_slice(&(value as u32).to_le_bytes());
self.len += 4;
} else {
self.add_header(Type::U64, id, 8)?;
self.data[self.len .. self.len + 8].copy_from_slice(&(value as u64).to_le_bytes());
self.len += 8;
}
Ok(self)
}
pub fn add_string(&mut self, id: u8, value: &str) -> Result<&mut Self> {
let encoded = value.as_bytes();
self.add_header(Type::String, id, encoded.len() + 1)?;
self.data[self.len] = encoded.len() as u8;
self.len += 1;
self.data[self.len .. self.len + encoded.len()].copy_from_slice(encoded);
self.len += encoded.len();
Ok(self)
}
pub fn add_bytes(&mut self, id: u8, value: &[u8]) -> Result<&mut Self> {
self.add_header(Type::Bytes, id, value.len() + 1)?;
self.data[self.len] = value.len() as u8;
self.len += 1;
self.data[self.len .. self.len + value.len()].copy_from_slice(value);
self.len += value.len();
Ok(self)
}
pub fn pack(&self) -> &[u8] {
&self.data[0 .. self.len]
}
pub fn dump(&self) -> Vec<String> {
self.iter().map(|f| format!("{:?}", f)).collect()
}
}
impl Default for Header {
fn default() -> Self {
Header::new()
}
}
impl fmt::Debug for Header {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Header({})", self.dump().join(", "))?;
Ok(())
}
}
impl From<&[u8]> for Header {
fn from(data: &[u8]) -> Header {
Header { data: Vec::from(data), len: data.len() }
}
}
impl From<Vec<u8>> for Header {
fn from(data: Vec<u8>) -> Header {
let len = data.len();
Header { data, len }
}
}
pub struct Field<'a> {
field_type: Type,
id: u8,
data: &'a [u8],
}
impl<'a> Field<'a> {
pub fn get_u8(&self) -> Option<u8> {
if self.field_type != Type::U8 { return None }
self.data.try_into().ok().map(u8::from_le_bytes)
}
pub fn get_u16(&self) -> Option<u16> {
if self.field_type != Type::U16 { return None }
self.data.try_into().ok().map(u16::from_le_bytes)
}
pub fn get_u32(&self) -> Option<u32> {
if self.field_type != Type::U32 { return None }
self.data.try_into().ok().map(u32::from_le_bytes)
}
pub fn get_u64(&self) -> Option<u64> {
if self.field_type != Type::U64 { return None }
self.data.try_into().ok().map(u64::from_le_bytes)
}
pub fn get_string(&self) -> Option<&'a str> {
if self.field_type != Type::String { return None }
std::str::from_utf8(self.data).ok()
}
pub fn get_bytes(&self) -> Option<&'a [u8]> {
if self.field_type != Type::Bytes { return None }
Some(self.data)
}
}
impl<'a> fmt::Debug for Field<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.field_type {
Type::Flag => write!(f, "F({})", self.id)?,
Type::U8 => write!(f, "U8({})={}", self.id, self.get_u8().unwrap())?,
Type::U16 => write!(f, "U16({})={}", self.id, self.get_u16().unwrap())?,
Type::U32 => write!(f, "U32({})={}", self.id, self.get_u32().unwrap())?,
Type::U64 => write!(f, "U64({})={}", self.id, self.get_u64().unwrap())?,
Type::String => write!(f, "S({})=\"{}\"", self.id, self.get_string().unwrap())?,
Type::Bytes => write!(f, "B({})={}", self.id, encode(self.get_bytes().unwrap()))?,
};
Ok(())
}
}
pub struct HeaderIterator<'a> {
header: &'a Header,
index: usize,
}
impl<'a> Iterator for HeaderIterator<'a> {
type Item = Field<'a>;
fn next(&mut self) -> Option<Field<'a>> {
if self.index >= self.header.data.len() {
return None
}
let b = self.header.data[self.index];
self.index += 1;
let field_type: Type = match ((b >> 4) & 15).try_into() {
Ok(t) => t,
Err(_) => return None,
};
let id = b & 15;
let len = match field_type {
Type::Flag => 0,
Type::U8 => 1,
Type::U16 => 2,
Type::U32 => 4,
Type::U64 => 8,
Type::String | Type::Bytes => {
if self.index >= self.header.data.len() {
return None;
}
let len = self.header.data[self.index];
self.index += 1;
len
},
} as usize;
if self.index + len > self.header.data.len() {
return None
}
let data = &self.header.data[self.index .. self.index + len];
self.index += len;
Some(Field { field_type, id, data })
}
}
#[cfg(test)]
mod test {
use hex::{decode, encode};
use super::Header;
#[test]
fn builder() {
let mut h = Header::new();
h.add_flag(1).unwrap();
assert_eq!(encode(h.pack()), "81");
h.add_int(10, 10).unwrap();
assert_eq!(encode(h.pack()), "810a0a");
h.add_int(11, 1000).unwrap();
assert_eq!(encode(h.pack()), "810a0a1be803");
h.add_int(12, 65538).unwrap();
assert_eq!(encode(h.pack()), "810a0a1be8032c02000100");
h.add_int(13, 1189998819991197253).unwrap();
assert_eq!(encode(h.pack()), "810a0a1be8032c020001003d4596b5349fb98310");
h = Header::new();
h.add_string(3, "iron").unwrap();
assert_eq!(encode(h.pack()), "930469726f6e");
h.add_bytes(2, &[ 4, 5, 6 ]).unwrap();
assert_eq!(encode(h.pack()), "930469726f6ea203040506");
}
#[test]
fn unpack() {
let data1 = decode("81").unwrap();
assert_eq!(format!("{:?}", Header::from(data1)), "Header(F(1))");
let data2 = decode("810a0a").unwrap();
assert_eq!(format!("{:?}", Header::from(data2)), "Header(F(1), U8(10)=10)");
let data3 = decode("810a0a1be803").unwrap();
assert_eq!(format!("{:?}", Header::from(data3)), "Header(F(1), U8(10)=10, U16(11)=1000)");
let data4 = decode("810a0a1be8032c02000100").unwrap();
assert_eq!(format!("{:?}", Header::from(data4)), "Header(F(1), U8(10)=10, U16(11)=1000, U32(12)=65538)");
let data5 = decode("3d4596b5349fb98310").unwrap();
assert_eq!(format!("{:?}", Header::from(data5)), "Header(U64(13)=1189998819991197253)");
let data6 = decode("930469726f6e").unwrap();
assert_eq!(format!("{:?}", Header::from(data6)), "Header(S(3)=\"iron\")");
let data7 = decode("930469726f6ea203040506").unwrap();
assert_eq!(format!("{:?}", Header::from(data7)), "Header(S(3)=\"iron\", B(2)=040506)");
}
#[test]
fn accessors() {
let data1 = decode("810a0a1be8032c02000100").unwrap();
let h1 = Header::from(data1);
assert_eq!(h1.get_int(10), Some(10));
assert_eq!(h1.get_int(11), Some(1000));
assert_eq!(h1.get_int(12), Some(65538));
assert_eq!(h1.get_int(1), None);
let data2 = decode("3d4596b5349fb98310").unwrap();
let h2 = Header::from(data2);
assert_eq!(h2.get_int(13), Some(1189998819991197253));
let data3 = decode("930469726f6ea203040506").unwrap();
let h3 = Header::from(data3);
assert_eq!(h3.get_string(3), Some("iron"));
assert_eq!(h3.get_bytes(2).map(|b| encode(b)), Some(String::from("040506")));
}
}