use super::{get_zero_hash, Hash256, BYTES_PER_CHUNK};
use ethereum_hashing::{hash32_concat, hash_fixed};
pub fn merkleize_padded(bytes: &[u8], min_leaves: usize) -> Hash256 {
if bytes.len() <= BYTES_PER_CHUNK && min_leaves <= 1 {
let mut o = bytes.to_vec();
o.resize(BYTES_PER_CHUNK, 0);
return Hash256::from_slice(&o);
}
assert!(
bytes.len() > BYTES_PER_CHUNK || min_leaves > 1,
"Merkle hashing only needs to happen if there is more than one chunk"
);
let leaves_with_values = bytes.len().div_ceil(BYTES_PER_CHUNK);
let initial_parents_with_values = std::cmp::max(1, next_even_number(leaves_with_values) / 2);
let num_leaves = std::cmp::max(leaves_with_values, min_leaves).next_power_of_two();
let height = num_leaves.trailing_zeros() as usize + 1;
assert!(height >= 2, "The tree should have two or more heights");
let mut chunks = ChunkStore::with_capacity(initial_parents_with_values);
for i in 0..initial_parents_with_values {
let start = i * BYTES_PER_CHUNK * 2;
let hash = match bytes.get(start..start + BYTES_PER_CHUNK * 2) {
Some(slice) => hash_fixed(slice),
None => {
let mut preimage = bytes
.get(start..)
.expect("`i` can only be larger than zero if there are bytes to read")
.to_vec();
preimage.resize(BYTES_PER_CHUNK * 2, 0);
hash_fixed(&preimage)
}
};
assert_eq!(
hash.len(),
BYTES_PER_CHUNK,
"Hashes should be exactly one chunk"
);
chunks
.set(i, &hash)
.expect("Buffer should always have capacity for parent nodes")
}
for height in 1..height - 1 {
let child_nodes = chunks.len();
let parent_nodes = next_even_number(child_nodes) / 2;
for i in 0..parent_nodes {
let (left, right) = match (chunks.get(i * 2), chunks.get(i * 2 + 1)) {
(Ok(left), Ok(right)) => (left, right),
(Ok(left), Err(_)) => (left, get_zero_hash(height)),
(Err(_), Err(_)) => unreachable!("Parent must have one child"),
(Err(_), Ok(_)) => unreachable!("Parent must have a left child"),
};
assert!(
left.len() == right.len() && right.len() == BYTES_PER_CHUNK,
"Both children should be `BYTES_PER_CHUNK` bytes."
);
let hash = hash32_concat(left, right);
chunks
.set(i, &hash)
.expect("Buf is adequate size for parent");
}
chunks.truncate(parent_nodes);
}
let root = chunks.into_vec();
assert_eq!(root.len(), BYTES_PER_CHUNK, "Only one chunk should remain");
Hash256::from_slice(&root)
}
#[derive(Debug)]
struct ChunkStore(Vec<u8>);
impl ChunkStore {
fn with_capacity(chunks: usize) -> Self {
Self(vec![0; chunks * BYTES_PER_CHUNK])
}
fn set(&mut self, i: usize, value: &[u8]) -> Result<(), ()> {
if i < self.len() && value.len() == BYTES_PER_CHUNK {
let slice = &mut self.0[i * BYTES_PER_CHUNK..i * BYTES_PER_CHUNK + BYTES_PER_CHUNK];
slice.copy_from_slice(value);
Ok(())
} else {
Err(())
}
}
fn get(&self, i: usize) -> Result<&[u8], ()> {
if i < self.len() {
Ok(&self.0[i * BYTES_PER_CHUNK..i * BYTES_PER_CHUNK + BYTES_PER_CHUNK])
} else {
Err(())
}
}
fn len(&self) -> usize {
self.0.len() / BYTES_PER_CHUNK
}
fn truncate(&mut self, num_chunks: usize) {
self.0.truncate(num_chunks * BYTES_PER_CHUNK)
}
fn into_vec(self) -> Vec<u8> {
self.0
}
}
fn next_even_number(n: usize) -> usize {
n + n % 2
}
#[cfg(test)]
mod test {
use super::*;
use crate::ZERO_HASHES_MAX_INDEX;
pub fn reference_root(bytes: &[u8]) -> Hash256 {
crate::merkleize_standard(bytes)
}
macro_rules! common_tests {
($get_bytes: ident) => {
#[test]
fn zero_value_0_nodes() {
test_against_reference(&$get_bytes(0 * BYTES_PER_CHUNK), 0);
}
#[test]
fn zero_value_1_nodes() {
test_against_reference(&$get_bytes(1 * BYTES_PER_CHUNK), 0);
}
#[test]
fn zero_value_2_nodes() {
test_against_reference(&$get_bytes(2 * BYTES_PER_CHUNK), 0);
}
#[test]
fn zero_value_3_nodes() {
test_against_reference(&$get_bytes(3 * BYTES_PER_CHUNK), 0);
}
#[test]
fn zero_value_4_nodes() {
test_against_reference(&$get_bytes(4 * BYTES_PER_CHUNK), 0);
}
#[test]
fn zero_value_8_nodes() {
test_against_reference(&$get_bytes(8 * BYTES_PER_CHUNK), 0);
}
#[test]
fn zero_value_9_nodes() {
test_against_reference(&$get_bytes(9 * BYTES_PER_CHUNK), 0);
}
#[test]
fn zero_value_8_nodes_varying_min_length() {
for i in 0..64 {
test_against_reference(&$get_bytes(8 * BYTES_PER_CHUNK), i);
}
}
#[test]
fn zero_value_range_of_nodes() {
for i in 0..32 * BYTES_PER_CHUNK {
test_against_reference(&$get_bytes(i), 0);
}
}
#[test]
fn max_tree_depth_min_nodes() {
let input = vec![0; 10 * BYTES_PER_CHUNK];
let min_nodes = 2usize.pow(ZERO_HASHES_MAX_INDEX as u32);
assert_eq!(
merkleize_padded(&input, min_nodes).as_slice(),
get_zero_hash(ZERO_HASHES_MAX_INDEX)
);
}
};
}
mod zero_value {
use super::*;
fn zero_bytes(bytes: usize) -> Vec<u8> {
vec![0; bytes]
}
common_tests!(zero_bytes);
}
mod random_value {
use super::*;
use rand::RngCore;
fn random_bytes(bytes: usize) -> Vec<u8> {
let mut bytes = Vec::with_capacity(bytes);
rand::rng().fill_bytes(&mut bytes);
bytes
}
common_tests!(random_bytes);
}
fn test_against_reference(input: &[u8], min_nodes: usize) {
let mut reference_input = input.to_vec();
reference_input.resize(
std::cmp::max(
reference_input.len(),
min_nodes.next_power_of_two() * BYTES_PER_CHUNK,
),
0,
);
assert_eq!(
reference_root(&reference_input),
merkleize_padded(input, min_nodes),
"input.len(): {:?}",
input.len()
);
}
}