use std::mem;
use super::Predicate;
use error::DecodeResult;
pub use error::DecodeError;
pub use error::PredicateError;
mod error;
#[cfg(test)]
mod tests;
pub(super) const LEN_SIZE_BYTES: usize = core::mem::size_of::<u16>();
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct EncodedHeader {
pub fixed_size_header: EncodedFixedSizeHeader,
pub lens: Vec<u8>,
}
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct EncodedFixedSizeHeader(pub [u8; Self::SIZE]);
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct FixedSizeHeader {
pub num_state_reads: u8,
pub num_constraints: u8,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DecodedHeader<'a> {
pub state_reads: &'a [u8],
pub constraints: &'a [u8],
}
#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub(super) struct EncodedSize {
pub num_state_reads: usize,
pub num_constraints: usize,
pub state_read_lens_sum: usize,
pub constraint_lens_sum: usize,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub(super) struct PredicateBounds<S, C> {
pub num_state_reads: usize,
pub num_constraints: usize,
pub state_read_lens: S,
pub constraint_lens: C,
}
pub(super) fn encoded_size(sizes: &EncodedSize) -> usize {
EncodedFixedSizeHeader::SIZE
+ sizes.num_state_reads * LEN_SIZE_BYTES
+ sizes.num_constraints * LEN_SIZE_BYTES
+ sizes.state_read_lens_sum
+ sizes.constraint_lens_sum
}
pub(super) fn check_predicate_bounds<S, C>(
mut bounds: PredicateBounds<S, C>,
) -> Result<(), PredicateError>
where
S: Iterator<Item = usize>,
C: Iterator<Item = usize>,
{
if bounds.num_state_reads > Predicate::MAX_STATE_READS {
return Err(PredicateError::TooManyStateReads(bounds.num_state_reads));
}
if bounds.num_constraints > Predicate::MAX_CONSTRAINTS {
return Err(PredicateError::TooManyConstraints(bounds.num_constraints));
}
let mut state_read_lens_sum: usize = 0;
if let Some(err) = bounds.state_read_lens.find_map(|len| {
state_read_lens_sum = state_read_lens_sum.saturating_add(len);
(len > Predicate::MAX_STATE_READ_SIZE_BYTES)
.then_some(PredicateError::StateReadTooLarge(len))
}) {
return Err(err);
}
let mut constraint_lens_sum: usize = 0;
if let Some(err) = bounds.constraint_lens.find_map(|len| {
constraint_lens_sum = constraint_lens_sum.saturating_add(len);
(len > Predicate::MAX_CONSTRAINT_SIZE_BYTES)
.then_some(PredicateError::ConstraintTooLarge(len))
}) {
return Err(err);
}
let encoded_size = encoded_size(&EncodedSize {
num_state_reads: bounds.num_state_reads,
num_constraints: bounds.num_constraints,
state_read_lens_sum,
constraint_lens_sum,
});
if encoded_size > Predicate::MAX_BYTES {
return Err(PredicateError::PredicateTooLarge(encoded_size));
}
Ok(())
}
pub(super) fn encode_program_lengths(predicate: &Predicate) -> Vec<u8> {
let state_read_lens = predicate
.state_read
.iter()
.map(Vec::as_slice)
.flat_map(encode_bytes_length);
let constraint_lens = predicate
.constraints
.iter()
.map(Vec::as_slice)
.flat_map(encode_bytes_length);
let lengths_size = predicate
.state_read
.len()
.saturating_add(predicate.constraints.len())
.saturating_mul(2);
let mut buf = Vec::with_capacity(lengths_size);
buf.extend(state_read_lens);
buf.extend(constraint_lens);
buf
}
fn encode_bytes_length(bytes: &[u8]) -> [u8; 2] {
(bytes.len() as u16).to_be_bytes()
}
impl EncodedFixedSizeHeader {
pub const SIZE: usize = mem::size_of::<u8>() * 2;
pub const fn new(header: FixedSizeHeader) -> Self {
let buf = [header.num_state_reads, header.num_constraints];
Self(buf)
}
}
impl FixedSizeHeader {
const fn num_state_reads_ix() -> core::ops::Range<usize> {
0..core::mem::size_of::<u8>()
}
const fn num_constraints_ix() -> core::ops::Range<usize> {
let end = Self::num_state_reads_ix().end + core::mem::size_of::<u8>();
Self::num_state_reads_ix().end..end
}
fn get_num_state_reads(buf: &[u8]) -> u8 {
buf[Self::num_state_reads_ix().start]
}
fn get_num_constraints(buf: &[u8]) -> u8 {
buf[Self::num_constraints_ix().start]
}
fn decode(buf: &[u8]) -> Self {
let num_state_reads = Self::get_num_state_reads(buf);
let num_constraints = Self::get_num_constraints(buf);
Self {
num_state_reads,
num_constraints,
}
}
fn get_state_read_lens_bytes<'a>(&self, buf: &'a [u8]) -> &'a [u8] {
let start = Self::num_constraints_ix().end;
let end = start + (self.num_state_reads as usize).saturating_mul(2);
&buf[start..end]
}
fn get_constraint_lens_bytes<'a>(&self, buf: &'a [u8]) -> &'a [u8] {
let start =
Self::num_constraints_ix().end + (self.num_state_reads as usize).saturating_mul(2);
let end = start + (self.num_constraints as usize).saturating_mul(2);
&buf[start..end]
}
const fn check_len(len: usize) -> DecodeResult<()> {
if len < EncodedFixedSizeHeader::SIZE {
return Err(DecodeError::BufferTooSmall);
}
Ok(())
}
fn check_header_len_and_program_lens(&self, len: usize) -> DecodeResult<()> {
if len
< state_len_buffer_offset(self.num_state_reads as usize, self.num_constraints as usize)
{
return Err(DecodeError::BufferTooSmall);
}
Ok(())
}
}
impl<'a> DecodedHeader<'a> {
pub fn bytes_len(&self) -> usize {
let sr = self
.state_reads
.chunks_exact(LEN_SIZE_BYTES)
.fold(0usize, |acc, chunk| acc.saturating_add(decode_len(chunk)));
let c = self
.constraints
.chunks_exact(LEN_SIZE_BYTES)
.fold(0usize, |acc, chunk| acc.saturating_add(decode_len(chunk)));
let lens = self
.state_reads
.len()
.saturating_add(self.constraints.len());
EncodedFixedSizeHeader::SIZE
.saturating_add(sr)
.saturating_add(c)
.saturating_add(lens)
}
pub(super) fn num_state_reads(&self) -> usize {
self.state_reads.len() / LEN_SIZE_BYTES
}
pub(super) fn num_constraints(&self) -> usize {
self.constraints.len() / LEN_SIZE_BYTES
}
pub fn decode(buf: &'a [u8]) -> DecodeResult<Self> {
use FixedSizeHeader as Fixed;
Fixed::check_len(buf.len())?;
let fh = Fixed::decode(buf);
let num_state_reads = fh.num_state_reads as usize;
let num_constraints = fh.num_constraints as usize;
fh.check_header_len_and_program_lens(buf.len())?;
let header = Self {
state_reads: fh.get_state_read_lens_bytes(buf),
constraints: fh.get_constraint_lens_bytes(buf),
};
header.check_consistency(&fh)?;
let bounds = PredicateBounds {
num_state_reads,
num_constraints,
state_read_lens: header
.state_reads
.chunks_exact(LEN_SIZE_BYTES)
.map(decode_len),
constraint_lens: header
.constraints
.chunks_exact(LEN_SIZE_BYTES)
.map(decode_len),
};
check_predicate_bounds(bounds)?;
Ok(header)
}
fn check_consistency(&self, header: &FixedSizeHeader) -> DecodeResult<()> {
if self.state_reads.len() / LEN_SIZE_BYTES != header.num_state_reads as usize {
return Err(DecodeError::IncorrectBodyLength);
}
if self.constraints.len() / LEN_SIZE_BYTES != header.num_constraints as usize {
return Err(DecodeError::IncorrectBodyLength);
}
Ok(())
}
}
fn decode_len(chunk: &[u8]) -> usize {
u16::from_be_bytes([chunk[0], chunk[1]]) as usize
}
pub(super) fn state_len_buffer_offset(num_state_reads: usize, num_constraints: usize) -> usize {
EncodedFixedSizeHeader::SIZE
+ num_state_reads * LEN_SIZE_BYTES
+ num_constraints * LEN_SIZE_BYTES
}
impl From<FixedSizeHeader> for EncodedFixedSizeHeader {
fn from(header: FixedSizeHeader) -> Self {
Self::new(header)
}
}
impl IntoIterator for EncodedHeader {
type Item = u8;
type IntoIter = core::iter::Chain<
core::array::IntoIter<u8, { EncodedFixedSizeHeader::SIZE }>,
std::vec::IntoIter<u8>,
>;
fn into_iter(self) -> Self::IntoIter {
self.fixed_size_header.0.into_iter().chain(self.lens)
}
}