use super::{HashFunction, nibble};
use alloc::vec::Vec;
use core::{cmp, fmt, iter, slice};
pub fn encode<'a>(
decoded: Decoded<
'a,
impl ExactSizeIterator<Item = nibble::Nibble> + Clone,
impl AsRef<[u8]> + Clone + 'a,
>,
) -> Result<impl Iterator<Item = impl AsRef<[u8]> + Clone> + Clone, EncodeError> {
let mut before_storage_value: Vec<u8> = Vec::with_capacity(decoded.partial_key.len() / 2 + 32);
let has_children = decoded.children.iter().any(Option::is_some);
{
let (first_byte_msb, pk_len_first_byte_bits): (u8, _) =
match (has_children, decoded.storage_value) {
(false, StorageValue::Unhashed(_)) => (0b01, 6),
(true, StorageValue::None) => (0b10, 6),
(true, StorageValue::Unhashed(_)) => (0b11, 6),
(false, StorageValue::Hashed(_)) => (0b001, 5),
(true, StorageValue::Hashed(_)) => (0b0001, 4),
(false, StorageValue::None) => {
if decoded.partial_key.len() != 0 {
return Err(EncodeError::PartialKeyButNoChildrenNoStorageValue);
} else {
(0, 6)
}
}
};
let max_representable_in_first_byte = (1 << pk_len_first_byte_bits) - 1;
let first_byte = (first_byte_msb << pk_len_first_byte_bits)
| u8::try_from(cmp::min(
decoded.partial_key.len(),
max_representable_in_first_byte,
))
.unwrap();
before_storage_value.push(first_byte);
let mut remain_pk_len = decoded
.partial_key
.len()
.checked_sub(max_representable_in_first_byte);
while let Some(pk_len_inner) = remain_pk_len {
before_storage_value.push(u8::try_from(cmp::min(pk_len_inner, 255)).unwrap());
remain_pk_len = pk_len_inner.checked_sub(255);
}
}
before_storage_value.extend(nibble::nibbles_to_bytes_prefix_extend(
decoded.partial_key.clone(),
));
if has_children {
before_storage_value.extend_from_slice(&decoded.children_bitmap().to_le_bytes());
}
let storage_value = match decoded.storage_value {
StorageValue::Hashed(hash) => &hash[..],
StorageValue::None => &[][..],
StorageValue::Unhashed(storage_value) => {
before_storage_value.extend_from_slice(
crate::util::encode_scale_compact_usize(storage_value.len()).as_ref(),
);
storage_value
}
};
let children_nodes = decoded
.children
.into_iter()
.flatten()
.flat_map(|child_value| {
let size = crate::util::encode_scale_compact_usize(child_value.as_ref().len());
[either::Left(size), either::Right(child_value)].into_iter()
});
Ok(iter::once(either::Left(before_storage_value))
.chain(iter::once(either::Right(storage_value)))
.map(either::Left)
.chain(children_nodes.map(either::Right)))
}
#[derive(Debug, derive_more::Display, derive_more::Error, Clone)]
pub enum EncodeError {
PartialKeyButNoChildrenNoStorageValue,
}
pub fn encode_to_vec(
decoded: Decoded<
'_,
impl ExactSizeIterator<Item = nibble::Nibble> + Clone,
impl AsRef<[u8]> + Clone,
>,
) -> Result<Vec<u8>, EncodeError> {
let capacity = decoded.partial_key.len() / 2
+ match decoded.storage_value {
StorageValue::Hashed(_) => 32,
StorageValue::None => 0,
StorageValue::Unhashed(v) => v.len(),
}
+ 16 * 32
+ 32;
let result = encode(decoded)?.fold(Vec::with_capacity(capacity), |mut a, b| {
a.extend_from_slice(b.as_ref());
a
});
debug_assert_eq!(result.capacity(), capacity);
Ok(result)
}
pub fn calculate_merkle_value(
decoded: Decoded<
'_,
impl ExactSizeIterator<Item = nibble::Nibble> + Clone,
impl AsRef<[u8]> + Clone,
>,
hash_function: HashFunction,
is_root_node: bool,
) -> Result<MerkleValueOutput, EncodeError> {
enum HashOrInline {
Inline(arrayvec::ArrayVec<u8, 31>),
Blake2Hasher(blake2_rfc::blake2b::Blake2b),
Keccak256Hasher(sha3::Keccak256),
}
let mut merkle_value_sink = match (is_root_node, hash_function) {
(true, HashFunction::Blake2) => {
HashOrInline::Blake2Hasher(blake2_rfc::blake2b::Blake2b::new(32))
}
(true, HashFunction::Keccak256) => {
HashOrInline::Keccak256Hasher(<sha3::Keccak256 as sha3::Digest>::new())
}
(false, _) => HashOrInline::Inline(arrayvec::ArrayVec::new()),
};
for buffer in encode(decoded)? {
let buffer = buffer.as_ref();
match &mut merkle_value_sink {
HashOrInline::Inline(curr) => {
if curr.try_extend_from_slice(buffer).is_ok() {
continue;
}
match hash_function {
HashFunction::Blake2 => {
let mut hasher = blake2_rfc::blake2b::Blake2b::new(32);
hasher.update(curr);
hasher.update(buffer);
merkle_value_sink = HashOrInline::Blake2Hasher(hasher);
}
HashFunction::Keccak256 => {
let mut hasher = <sha3::Keccak256 as sha3::Digest>::new();
sha3::Digest::update(&mut hasher, curr);
sha3::Digest::update(&mut hasher, buffer);
merkle_value_sink = HashOrInline::Keccak256Hasher(hasher);
}
}
}
HashOrInline::Blake2Hasher(hasher) => {
hasher.update(buffer);
}
HashOrInline::Keccak256Hasher(hasher) => {
sha3::Digest::update(hasher, buffer);
}
}
}
Ok(MerkleValueOutput {
inner: match merkle_value_sink {
HashOrInline::Inline(b) => MerkleValueOutputInner::Inline(b),
HashOrInline::Blake2Hasher(h) => MerkleValueOutputInner::Blake2Hasher(h.finalize()),
HashOrInline::Keccak256Hasher(h) => {
MerkleValueOutputInner::Keccak256Hasher(sha3::Digest::finalize(h).into())
}
},
})
}
#[derive(Clone)]
pub struct MerkleValueOutput {
inner: MerkleValueOutputInner,
}
#[derive(Clone)]
enum MerkleValueOutputInner {
Inline(arrayvec::ArrayVec<u8, 31>),
Blake2Hasher(blake2_rfc::blake2b::Blake2bResult),
Keccak256Hasher([u8; 32]),
Bytes(arrayvec::ArrayVec<u8, 32>),
}
impl MerkleValueOutput {
pub fn from_bytes(bytes: &[u8]) -> MerkleValueOutput {
assert!(bytes.len() <= 32);
MerkleValueOutput {
inner: MerkleValueOutputInner::Bytes({
let mut v = arrayvec::ArrayVec::new();
v.try_extend_from_slice(bytes).unwrap();
v
}),
}
}
}
impl AsRef<[u8]> for MerkleValueOutput {
fn as_ref(&self) -> &[u8] {
match &self.inner {
MerkleValueOutputInner::Inline(a) => a.as_slice(),
MerkleValueOutputInner::Blake2Hasher(a) => a.as_bytes(),
MerkleValueOutputInner::Keccak256Hasher(a) => &a[..],
MerkleValueOutputInner::Bytes(a) => a.as_slice(),
}
}
}
impl TryFrom<MerkleValueOutput> for [u8; 32] {
type Error = ();
fn try_from(output: MerkleValueOutput) -> Result<Self, Self::Error> {
if output.as_ref().len() == 32 {
let mut out = [0; 32];
out.copy_from_slice(output.as_ref());
Ok(out)
} else {
Err(())
}
}
}
impl fmt::Debug for MerkleValueOutput {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(self.as_ref(), f)
}
}
pub fn decode(
mut node_value: &'_ [u8],
) -> Result<Decoded<'_, DecodedPartialKey<'_>, &'_ [u8]>, Error> {
if node_value.is_empty() {
return Err(Error::Empty);
}
let (has_children, storage_value_hashed, pk_len_first_byte_bits) = match node_value[0] >> 6 {
0b00 => {
if (node_value[0] >> 5) == 0b001 {
(false, Some(true), 5)
} else if (node_value[0] >> 4) == 0b0001 {
(true, Some(true), 4)
} else if node_value[0] == 0 {
(false, None, 6)
} else {
return Err(Error::InvalidHeaderBits);
}
}
0b10 => (true, None, 6),
0b01 => (false, Some(false), 6),
0b11 => (true, Some(false), 6),
_ => unreachable!(),
};
let pk_len = {
let mut accumulator = usize::from(node_value[0] & ((1 << pk_len_first_byte_bits) - 1));
node_value = &node_value[1..];
let mut continue_iter = accumulator == ((1 << pk_len_first_byte_bits) - 1);
while continue_iter {
if node_value.is_empty() {
return Err(Error::PartialKeyLenTooShort);
}
continue_iter = node_value[0] == 255;
accumulator = accumulator
.checked_add(usize::from(node_value[0]))
.ok_or(Error::PartialKeyLenOverflow)?;
node_value = &node_value[1..];
}
accumulator
};
if pk_len != 0 && !has_children && storage_value_hashed.is_none() {
return Err(Error::EmptyTrieWithPartialKey);
}
let partial_key = {
let pk_len_bytes = if pk_len == 0 {
0
} else {
1 + ((pk_len - 1) / 2)
};
if node_value.len() < pk_len_bytes {
return Err(Error::PartialKeyTooShort);
}
let pk = &node_value[..pk_len_bytes];
node_value = &node_value[pk_len_bytes..];
if (pk_len % 2) == 1 && (pk[0] & 0xf0) != 0 {
return Err(Error::InvalidPartialKeyPadding);
}
pk
};
let children_bitmap = if has_children {
if node_value.len() < 2 {
return Err(Error::ChildrenBitmapTooShort);
}
let val = u16::from_le_bytes(<[u8; 2]>::try_from(&node_value[..2]).unwrap());
if val == 0 {
return Err(Error::ZeroChildrenBitmap);
}
node_value = &node_value[2..];
val
} else {
0
};
let storage_value = match storage_value_hashed {
Some(false) => {
let (node_value_update, len) = crate::util::nom_scale_compact_usize(node_value)
.map_err(|_: nom::Err<nom::error::Error<&[u8]>>| Error::StorageValueLenDecode)?;
node_value = node_value_update;
if node_value.len() < len {
return Err(Error::StorageValueTooShort);
}
let storage_value = &node_value[..len];
node_value = &node_value[len..];
StorageValue::Unhashed(storage_value)
}
Some(true) => {
if node_value.len() < 32 {
return Err(Error::StorageValueTooShort);
}
let storage_value_hash = <&[u8; 32]>::try_from(&node_value[..32]).unwrap();
node_value = &node_value[32..];
StorageValue::Hashed(storage_value_hash)
}
None => StorageValue::None,
};
let mut children = [None; 16];
for (n, child) in children.iter_mut().enumerate() {
if children_bitmap & (1 << n) == 0 {
continue;
}
let (node_value_update, len) = crate::util::nom_scale_compact_usize(node_value)
.map_err(|_: nom::Err<nom::error::Error<&[u8]>>| Error::ChildLenDecode)?;
if len > 32 {
return Err(Error::ChildTooLarge);
}
node_value = node_value_update;
if node_value.len() < len {
return Err(Error::ChildrenTooShort);
}
*child = Some(&node_value[..len]);
node_value = &node_value[len..];
}
if !node_value.is_empty() {
return Err(Error::TooLong);
}
Ok(Decoded {
partial_key: if (pk_len % 2) == 1 {
DecodedPartialKey::from_bytes_skip_first(partial_key)
} else {
DecodedPartialKey::from_bytes(partial_key)
},
children,
storage_value,
})
}
#[derive(Debug, Clone)]
pub struct Decoded<'a, I, C> {
pub partial_key: I,
pub children: [Option<C>; 16],
pub storage_value: StorageValue<'a>,
}
impl<'a, I, C> Decoded<'a, I, C> {
pub fn children_bitmap(&self) -> u16 {
let mut out = 0u16;
for n in 0..16 {
if self.children[n].is_none() {
continue;
}
out |= 1 << n;
}
out
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum StorageValue<'a> {
Unhashed(&'a [u8]),
Hashed(&'a [u8; 32]),
None,
}
#[derive(Clone)]
pub struct DecodedPartialKey<'a> {
inner: nibble::BytesToNibbles<iter::Copied<slice::Iter<'a, u8>>>,
skip_first: bool,
}
impl<'a> DecodedPartialKey<'a> {
pub fn from_bytes(bytes: &'a [u8]) -> Self {
DecodedPartialKey {
inner: nibble::bytes_to_nibbles(bytes.iter().copied()),
skip_first: false,
}
}
pub fn from_bytes_skip_first(bytes: &'a [u8]) -> Self {
DecodedPartialKey {
inner: nibble::bytes_to_nibbles(bytes.iter().copied()),
skip_first: true,
}
}
}
impl<'a> Iterator for DecodedPartialKey<'a> {
type Item = nibble::Nibble;
fn next(&mut self) -> Option<nibble::Nibble> {
loop {
let nibble = self.inner.next()?;
if self.skip_first {
debug_assert_eq!(u8::from(nibble), 0);
self.skip_first = false;
continue;
}
break Some(nibble);
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let mut len = self.inner.len();
if self.skip_first {
len -= 1;
}
(len, Some(len))
}
}
impl<'a> ExactSizeIterator for DecodedPartialKey<'a> {}
impl<'a> fmt::Debug for DecodedPartialKey<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
const HEX_TABLE: &[u8] = b"0123456789abcdef";
write!(f, "0x")?;
for nibble in self.clone() {
let chr = HEX_TABLE[usize::from(u8::from(nibble))];
write!(f, "{}", char::from(chr))?;
}
Ok(())
}
}
#[derive(Debug, Clone, derive_more::Display, derive_more::Error)]
pub enum Error {
Empty,
InvalidHeaderBits,
PartialKeyLenTooShort,
PartialKeyLenOverflow,
PartialKeyTooShort,
InvalidPartialKeyPadding,
ChildrenBitmapTooShort,
ZeroChildrenBitmap,
ChildLenDecode,
ChildrenTooShort,
ChildTooLarge,
StorageValueLenDecode,
StorageValueTooShort,
TooLong,
EmptyTrieWithPartialKey,
}
#[cfg(test)]
mod tests {
use super::super::nibble;
#[test]
fn basic() {
let encoded_bytes = &[
194, 99, 192, 0, 0, 128, 129, 254, 111, 21, 39, 188, 215, 18, 139, 76, 128, 157, 108,
33, 139, 232, 34, 73, 0, 21, 202, 54, 18, 71, 145, 117, 47, 222, 189, 93, 119, 68, 128,
108, 211, 105, 98, 122, 206, 246, 73, 77, 237, 51, 77, 26, 166, 1, 52, 179, 173, 43,
89, 219, 104, 196, 190, 208, 128, 135, 177, 13, 185, 111, 175,
];
let decoded = super::decode(encoded_bytes).unwrap();
assert_eq!(
super::encode(decoded.clone())
.unwrap()
.fold(Vec::new(), |mut a, b| {
a.extend_from_slice(b.as_ref());
a
}),
encoded_bytes
);
assert_eq!(
decoded.partial_key.clone().collect::<Vec<_>>(),
vec![
nibble::Nibble::try_from(0x6).unwrap(),
nibble::Nibble::try_from(0x3).unwrap()
]
);
assert_eq!(
decoded.storage_value,
super::StorageValue::Unhashed(&[][..])
);
assert_eq!(decoded.children.iter().filter(|c| c.is_some()).count(), 2);
assert_eq!(
decoded.children[6],
Some(
&[
129, 254, 111, 21, 39, 188, 215, 18, 139, 76, 128, 157, 108, 33, 139, 232, 34,
73, 0, 21, 202, 54, 18, 71, 145, 117, 47, 222, 189, 93, 119, 68
][..]
)
);
assert_eq!(
decoded.children[7],
Some(
&[
108, 211, 105, 98, 122, 206, 246, 73, 77, 237, 51, 77, 26, 166, 1, 52, 179,
173, 43, 89, 219, 104, 196, 190, 208, 128, 135, 177, 13, 185, 111, 175
][..]
)
);
assert_eq!(super::encode_to_vec(decoded).unwrap(), encoded_bytes);
}
#[test]
fn no_children_no_storage_value() {
assert!(
super::encode(super::Decoded {
children: [None::<&'static [u8]>; 16],
storage_value: super::StorageValue::None,
partial_key: core::iter::empty()
})
.is_ok()
);
assert!(matches!(
super::encode(super::Decoded {
children: [None::<&'static [u8]>; 16],
storage_value: super::StorageValue::None,
partial_key: core::iter::once(nibble::Nibble::try_from(2).unwrap())
}),
Err(super::EncodeError::PartialKeyButNoChildrenNoStorageValue)
));
}
}