use super::{
EMPTY_BLAKE2_TRIE_MERKLE_VALUE, EMPTY_KECCAK256_TRIE_MERKLE_VALUE, HashFunction,
TrieEntryVersion, branch_search,
nibble::{Nibble, nibbles_to_bytes_suffix_extend},
trie_node,
};
use alloc::vec::Vec;
use core::array;
pub fn root_merkle_value(hash_function: HashFunction) -> RootMerkleValueCalculation {
CalcInner {
hash_function,
stack: Vec::with_capacity(8),
}
.next()
}
#[must_use]
pub enum RootMerkleValueCalculation {
Finished {
hash: [u8; 32],
},
NextKey(NextKey),
StorageValue(StorageValue),
}
struct CalcInner {
hash_function: HashFunction,
stack: Vec<Node>,
}
#[derive(Debug)]
struct Node {
partial_key: Vec<Nibble>,
children: arrayvec::ArrayVec<Option<trie_node::MerkleValueOutput>, 16>,
}
impl CalcInner {
fn current_iter_node_full_key(&self) -> impl Iterator<Item = Nibble> {
self.stack.iter().flat_map(|node| {
let child_nibble = if node.children.len() == 16 {
None
} else {
Some(Nibble::try_from(u8::try_from(node.children.len()).unwrap()).unwrap())
};
node.partial_key.iter().copied().chain(child_nibble)
})
}
fn next(mut self) -> RootMerkleValueCalculation {
loop {
if self
.stack
.last()
.map_or(false, |node| node.children.len() == 16)
{
if self.current_iter_node_full_key().count() % 2 == 0 {
break RootMerkleValueCalculation::StorageValue(StorageValue {
calculation: self,
});
}
let calculated_elem = self.stack.pop().unwrap();
let merkle_value = trie_node::calculate_merkle_value(
trie_node::Decoded {
children: array::from_fn(|n| calculated_elem.children[n].as_ref()),
partial_key: calculated_elem.partial_key.iter().copied(),
storage_value: trie_node::StorageValue::None,
},
self.hash_function,
self.stack.is_empty(),
)
.unwrap_or_else(|_| unreachable!());
if let Some(parent) = self.stack.last_mut() {
parent.children.push(Some(merkle_value));
} else {
let hash = *<&[u8; 32]>::try_from(merkle_value.as_ref()).unwrap();
break RootMerkleValueCalculation::Finished { hash };
}
} else {
break RootMerkleValueCalculation::NextKey(NextKey {
branch_search: branch_search::start_branch_search(branch_search::Config {
key_before: self.current_iter_node_full_key(),
or_equal: true,
prefix: self.current_iter_node_full_key(),
no_branch_search: false,
}),
calculation: self,
});
}
}
}
}
#[must_use]
pub struct NextKey {
calculation: CalcInner,
branch_search: branch_search::NextKey,
}
impl NextKey {
pub fn key_before(&self) -> impl Iterator<Item = u8> {
self.branch_search.key_before()
}
pub fn or_equal(&self) -> bool {
self.branch_search.or_equal()
}
pub fn prefix(&self) -> impl Iterator<Item = u8> {
self.branch_search.prefix()
}
pub fn inject_key(
mut self,
key: Option<impl Iterator<Item = u8>>,
) -> RootMerkleValueCalculation {
match self.branch_search.inject(key) {
branch_search::BranchSearch::NextKey(next_key) => {
RootMerkleValueCalculation::NextKey(NextKey {
calculation: self.calculation,
branch_search: next_key,
})
}
branch_search::BranchSearch::Found {
branch_trie_node_key,
} => {
if let Some(branch_trie_node_key) = branch_trie_node_key {
let partial_key = branch_trie_node_key
.skip(self.calculation.current_iter_node_full_key().count())
.collect();
self.calculation.stack.push(Node {
partial_key,
children: arrayvec::ArrayVec::new(),
});
self.calculation.next()
} else if let Some(stack_top) = self.calculation.stack.last_mut() {
stack_top.children.push(None);
self.calculation.next()
} else {
RootMerkleValueCalculation::Finished {
hash: match self.calculation.hash_function {
HashFunction::Blake2 => EMPTY_BLAKE2_TRIE_MERKLE_VALUE,
HashFunction::Keccak256 => EMPTY_KECCAK256_TRIE_MERKLE_VALUE,
},
}
}
}
}
}
}
#[must_use]
pub struct StorageValue {
calculation: CalcInner,
}
impl StorageValue {
pub fn key(&self) -> impl Iterator<Item = u8> {
debug_assert_eq!(self.calculation.current_iter_node_full_key().count() % 2, 0);
nibbles_to_bytes_suffix_extend(self.calculation.current_iter_node_full_key())
}
pub fn inject(
mut self,
storage_value: Option<(impl AsRef<[u8]>, TrieEntryVersion)>,
) -> RootMerkleValueCalculation {
let calculated_elem = self.calculation.stack.pop().unwrap();
let storage_value_hash = if let Some((value, TrieEntryVersion::V1)) = storage_value.as_ref()
{
if value.as_ref().len() >= 33 {
Some(blake2_rfc::blake2b::blake2b(32, &[], value.as_ref()))
} else {
None
}
} else {
None
};
let merkle_value = trie_node::calculate_merkle_value(
trie_node::Decoded {
children: array::from_fn(|n| calculated_elem.children[n].as_ref()),
partial_key: calculated_elem.partial_key.iter().copied(),
storage_value: match (storage_value.as_ref(), storage_value_hash.as_ref()) {
(_, Some(storage_value_hash)) => trie_node::StorageValue::Hashed(
<&[u8; 32]>::try_from(storage_value_hash.as_bytes())
.unwrap_or_else(|_| unreachable!()),
),
(Some((value, _)), _) => trie_node::StorageValue::Unhashed(value.as_ref()),
(None, _) => trie_node::StorageValue::None,
},
},
self.calculation.hash_function,
self.calculation.stack.is_empty(),
)
.unwrap_or_else(|_| unreachable!());
if let Some(parent) = self.calculation.stack.last_mut() {
parent.children.push(Some(merkle_value));
self.calculation.next()
} else {
let hash = *<&[u8; 32]>::try_from(merkle_value.as_ref()).unwrap();
RootMerkleValueCalculation::Finished { hash }
}
}
}
#[cfg(test)]
mod tests {
use crate::trie::{HashFunction, TrieEntryVersion};
use alloc::collections::BTreeMap;
use core::ops::Bound;
fn calculate_root(version: TrieEntryVersion, trie: &BTreeMap<Vec<u8>, Vec<u8>>) -> [u8; 32] {
let mut calculation = super::root_merkle_value(HashFunction::Blake2);
loop {
match calculation {
super::RootMerkleValueCalculation::Finished { hash } => {
return hash;
}
super::RootMerkleValueCalculation::NextKey(next_key) => {
let lower_bound = if next_key.or_equal() {
Bound::Included(next_key.key_before().collect::<Vec<_>>())
} else {
Bound::Excluded(next_key.key_before().collect::<Vec<_>>())
};
let k = trie
.range((lower_bound, Bound::Unbounded))
.next()
.filter(|(k, _)| {
k.iter()
.copied()
.zip(next_key.prefix())
.all(|(a, b)| a == b)
})
.map(|(k, _)| k);
calculation = next_key.inject_key(k.map(|k| k.iter().copied()));
}
super::RootMerkleValueCalculation::StorageValue(value) => {
let key = value.key().collect::<Vec<u8>>();
calculation = value.inject(trie.get(&key).map(|v| (v, version)));
}
}
}
}
#[test]
fn trie_root_one_node() {
let mut trie = BTreeMap::new();
trie.insert(b"abcd".to_vec(), b"hello world".to_vec());
let expected = [
122, 177, 134, 89, 211, 178, 120, 158, 242, 64, 13, 16, 113, 4, 199, 212, 251, 147,
208, 109, 154, 182, 168, 182, 65, 165, 222, 124, 63, 236, 200, 81,
];
assert_eq!(calculate_root(TrieEntryVersion::V0, &trie), &expected[..]);
assert_eq!(calculate_root(TrieEntryVersion::V1, &trie), &expected[..]);
}
#[test]
fn trie_root_empty() {
let trie = BTreeMap::new();
let expected = blake2_rfc::blake2b::blake2b(32, &[], &[0x0]);
assert_eq!(
calculate_root(TrieEntryVersion::V0, &trie),
expected.as_bytes()
);
assert_eq!(
calculate_root(TrieEntryVersion::V1, &trie),
expected.as_bytes()
);
}
#[test]
fn trie_root_single_tuple() {
let mut trie = BTreeMap::new();
trie.insert([0xaa].to_vec(), [0xbb].to_vec());
let expected = blake2_rfc::blake2b::blake2b(
32,
&[],
&[
0x42, 0xaa, 1 << 2, 0xbb, ],
);
assert_eq!(
calculate_root(TrieEntryVersion::V0, &trie),
expected.as_bytes()
);
assert_eq!(
calculate_root(TrieEntryVersion::V1, &trie),
expected.as_bytes()
);
}
#[test]
fn trie_root_example() {
let mut trie = BTreeMap::new();
trie.insert([0x48, 0x19].to_vec(), [0xfe].to_vec());
trie.insert([0x13, 0x14].to_vec(), [0xff].to_vec());
let ex = vec![
0x80, 0x12, 0x00, 0x05 << 2, 0x43, 0x03, 0x14, 0x01 << 2, 0xff, 0x05 << 2, 0x43, 0x08, 0x19, 0x01 << 2, 0xfe, ];
let expected = blake2_rfc::blake2b::blake2b(32, &[], &ex);
assert_eq!(
calculate_root(TrieEntryVersion::V0, &trie),
expected.as_bytes()
);
assert_eq!(
calculate_root(TrieEntryVersion::V1, &trie),
expected.as_bytes()
);
}
}