use super::{
decode_helpers::{consume_u16, consume_u8},
EofDecodeError,
};
use std::vec::Vec;
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash, Ord, PartialOrd)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct EofHeader {
pub types_size: u16,
pub code_sizes: Vec<u16>,
pub container_sizes: Vec<u32>,
pub data_size: u16,
pub sum_code_sizes: usize,
pub sum_container_sizes: usize,
}
pub const KIND_TERMINAL: u8 = 0;
pub const KIND_CODE_INFO: u8 = 1;
pub const KIND_CODE: u8 = 2;
pub const KIND_CONTAINER: u8 = 3;
pub const KIND_DATA: u8 = 0xff;
pub const CODE_SECTION_SIZE: usize = 2;
pub const CONTAINER_SECTION_SIZE: usize = 4;
#[inline]
fn consume_header_code_section(input: &[u8]) -> Result<(&[u8], Vec<u16>, usize), EofDecodeError> {
let (input, num_sections) = consume_u16(input)?;
if num_sections == 0 {
return Err(EofDecodeError::NonSizes);
}
let num_sections = num_sections as usize;
let byte_size = num_sections * CODE_SECTION_SIZE;
if input.len() < byte_size {
return Err(EofDecodeError::ShortInputForSizes);
}
let mut sizes = Vec::with_capacity(num_sections);
let mut sum = 0;
for i in 0..num_sections {
let code_size = u16::from_be_bytes([
input[i * CODE_SECTION_SIZE],
input[i * CODE_SECTION_SIZE + 1],
]);
if code_size == 0 {
return Err(EofDecodeError::ZeroSize);
}
sum += code_size as usize;
sizes.push(code_size);
}
Ok((&input[byte_size..], sizes, sum))
}
#[inline]
fn consume_header_container_section(
input: &[u8],
) -> Result<(&[u8], Vec<u32>, usize), EofDecodeError> {
let (input, num_sections) = consume_u16(input)?;
if num_sections == 0 {
return Err(EofDecodeError::NonSizes);
}
let num_sections = num_sections as usize;
let byte_size = num_sections * CONTAINER_SECTION_SIZE;
if input.len() < byte_size {
return Err(EofDecodeError::ShortInputForSizes);
}
let mut sizes = Vec::with_capacity(num_sections);
let mut sum = 0;
for i in 0..num_sections {
let container_size = u32::from_be_bytes(
input[i * CONTAINER_SECTION_SIZE..(i + 1) * CONTAINER_SECTION_SIZE]
.try_into()
.unwrap(),
);
if container_size == 0 {
return Err(EofDecodeError::ZeroSize);
}
sum += container_size as usize;
sizes.push(container_size);
}
Ok((&input[byte_size..], sizes, sum))
}
impl EofHeader {
pub fn size(&self) -> usize {
2 + 1 + 3 + 3 + CODE_SECTION_SIZE * self.code_sizes.len() + if self.container_sizes.is_empty() { 0 } else { 3 + CONTAINER_SECTION_SIZE * self.container_sizes.len() } + 3 + 1 }
pub fn data_size_raw_i(&self) -> usize {
self.size() - 3
}
pub fn types_count(&self) -> usize {
self.types_size as usize / 4
}
pub fn body_size(&self) -> usize {
self.types_size as usize
+ self.sum_code_sizes
+ self.sum_container_sizes
+ self.data_size as usize
}
pub fn eof_size(&self) -> usize {
self.size() + self.body_size()
}
pub fn encode(&self, buffer: &mut Vec<u8>) {
buffer.extend_from_slice(&0xEF00u16.to_be_bytes());
buffer.push(0x01);
buffer.push(KIND_CODE_INFO);
buffer.extend_from_slice(&self.types_size.to_be_bytes());
buffer.push(KIND_CODE);
buffer.extend_from_slice(&(self.code_sizes.len() as u16).to_be_bytes());
for size in &self.code_sizes {
buffer.extend_from_slice(&size.to_be_bytes());
}
if !self.container_sizes.is_empty() {
buffer.push(KIND_CONTAINER);
buffer.extend_from_slice(&(self.container_sizes.len() as u16).to_be_bytes());
for size in &self.container_sizes {
buffer.extend_from_slice(&size.to_be_bytes());
}
}
buffer.push(KIND_DATA);
buffer.extend_from_slice(&self.data_size.to_be_bytes());
buffer.push(KIND_TERMINAL);
}
pub fn decode(input: &[u8]) -> Result<(Self, &[u8]), EofDecodeError> {
let mut header = EofHeader::default();
let (input, kind) = consume_u16(input)?;
if kind != 0xEF00 {
return Err(EofDecodeError::InvalidEOFMagicNumber);
}
let (input, version) = consume_u8(input)?;
if version != 0x01 {
return Err(EofDecodeError::InvalidEOFVersion);
}
let (input, kind_code_info) = consume_u8(input)?;
if kind_code_info != KIND_CODE_INFO {
return Err(EofDecodeError::InvalidTypesKind);
}
let (input, types_size) = consume_u16(input)?;
header.types_size = types_size;
if header.types_size % CODE_SECTION_SIZE as u16 != 0 {
return Err(EofDecodeError::InvalidCodeInfo);
}
let (input, kind_code) = consume_u8(input)?;
if kind_code != KIND_CODE {
return Err(EofDecodeError::InvalidCodeKind);
}
let (input, sizes, sum) = consume_header_code_section(input)?;
if sizes.len() > 0x0400 {
return Err(EofDecodeError::TooManyCodeSections);
}
if sizes.is_empty() {
return Err(EofDecodeError::ZeroCodeSections);
}
if sizes.len() != (types_size / 4) as usize {
return Err(EofDecodeError::MismatchCodeAndInfoSize);
}
header.code_sizes = sizes;
header.sum_code_sizes = sum;
let (input, kind_container_or_data) = consume_u8(input)?;
let input = match kind_container_or_data {
KIND_CONTAINER => {
let (input, sizes, sum) = consume_header_container_section(input)?;
if sizes.len() > 0x0100 {
return Err(EofDecodeError::TooManyContainerSections);
}
header.container_sizes = sizes;
header.sum_container_sizes = sum;
let (input, kind_data) = consume_u8(input)?;
if kind_data != KIND_DATA {
return Err(EofDecodeError::InvalidDataKind);
}
input
}
KIND_DATA => input,
invalid_kind => return Err(EofDecodeError::InvalidKindAfterCode { invalid_kind }),
};
let (input, data_size) = consume_u16(input)?;
header.data_size = data_size;
let (input, terminator) = consume_u8(input)?;
if terminator != KIND_TERMINAL {
return Err(EofDecodeError::InvalidTerminalByte);
}
Ok((header, input))
}
}
#[cfg(test)]
mod tests {
use super::*;
use primitives::hex;
use std::vec;
#[test]
fn sanity_header_decode() {
let input = hex!("ef00010100040200010001ff00000000800000fe");
let (header, _) = EofHeader::decode(&input).unwrap();
assert_eq!(header.types_size, 4);
assert_eq!(header.code_sizes, vec![1]);
assert_eq!(header.container_sizes, Vec::new());
assert_eq!(header.data_size, 0);
}
#[test]
fn decode_header_not_terminated() {
let input = hex!("ef0001010004");
assert_eq!(EofHeader::decode(&input), Err(EofDecodeError::MissingInput));
}
#[test]
fn failing_test() {
let input = hex!("ef0001010004020001000603000100000014ff000200008000016000e0000000ef000101000402000100010400000000800000fe");
let _ = EofHeader::decode(&input).unwrap();
}
#[test]
fn cut_header() {
let input = hex!("ef0001010000028000");
assert_eq!(
EofHeader::decode(&input),
Err(EofDecodeError::ShortInputForSizes)
);
}
#[test]
fn short_input() {
let input = hex!("ef0001010000028000");
assert_eq!(
EofHeader::decode(&input),
Err(EofDecodeError::ShortInputForSizes)
);
}
}