use {
crate::shred::Error, solana_hash::Hash, solana_sha256_hasher::hashv,
static_assertions::const_assert_eq, std::iter::successors,
};
pub(crate) const SIZE_OF_MERKLE_ROOT: usize = std::mem::size_of::<Hash>();
const_assert_eq!(SIZE_OF_MERKLE_ROOT, 32);
const_assert_eq!(SIZE_OF_MERKLE_PROOF_ENTRY, 20);
pub(crate) const SIZE_OF_MERKLE_PROOF_ENTRY: usize = std::mem::size_of::<MerkleProofEntry>();
pub(crate) const PROOF_ENTRIES_FOR_32_32_BATCH: u8 = 6;
pub(crate) const MERKLE_HASH_PREFIX_LEAF: &[u8] = b"\x00SOLANA_MERKLE_SHREDS_LEAF";
pub(crate) const MERKLE_HASH_PREFIX_NODE: &[u8] = b"\x01SOLANA_MERKLE_SHREDS_NODE";
pub(crate) type MerkleProofEntry = [u8; 20];
pub fn make_merkle_tree<I>(shreds: I) -> Result<Vec<Hash>, Error>
where
I: IntoIterator<Item = Result<Hash, Error>>,
<I as IntoIterator>::IntoIter: ExactSizeIterator,
{
let shreds = shreds.into_iter();
let num_shreds = shreds.len();
let capacity = get_merkle_tree_size(num_shreds);
let mut nodes = Vec::with_capacity(capacity);
for shred in shreds {
nodes.push(shred?);
}
let init = (num_shreds > 1).then_some(num_shreds);
for size in successors(init, |&k| (k > 2).then_some((k + 1) >> 1)) {
let offset = nodes.len() - size;
for index in (offset..offset + size).step_by(2) {
let node = &nodes[index];
let other = &nodes[(index + 1).min(offset + size - 1)];
let parent = join_nodes(node, other);
nodes.push(parent);
}
}
debug_assert_eq!(nodes.len(), capacity);
Ok(nodes)
}
pub fn make_merkle_proof(
mut index: usize, mut size: usize, tree: &[Hash],
) -> impl Iterator<Item = Result<&MerkleProofEntry, Error>> {
let mut offset = 0;
if index >= size {
(size, offset) = (0, tree.len());
}
std::iter::from_fn(move || {
if size > 1 {
let Some(node) = tree.get(offset + (index ^ 1).min(size - 1)) else {
return Some(Err(Error::InvalidMerkleProof));
};
offset += size;
size = (size + 1) >> 1;
index >>= 1;
let entry = &node.as_ref()[..SIZE_OF_MERKLE_PROOF_ENTRY];
let entry = <&MerkleProofEntry>::try_from(entry).unwrap();
Some(Ok(entry))
} else if offset + 1 == tree.len() {
None
} else {
Some(Err(Error::InvalidMerkleProof))
}
})
}
fn join_nodes<S: AsRef<[u8]>, T: AsRef<[u8]>>(node: S, other: T) -> Hash {
let node = &node.as_ref()[..SIZE_OF_MERKLE_PROOF_ENTRY];
let other = &other.as_ref()[..SIZE_OF_MERKLE_PROOF_ENTRY];
hashv(&[MERKLE_HASH_PREFIX_NODE, node, other])
}
pub fn get_merkle_root<'a, I>(index: usize, node: Hash, proof: I) -> Result<Hash, Error>
where
I: IntoIterator<Item = &'a MerkleProofEntry>,
{
let (index, root) = proof
.into_iter()
.fold((index, node), |(index, node), other| {
let parent = if index % 2 == 0 {
join_nodes(node, other)
} else {
join_nodes(other, node)
};
(index >> 1, parent)
});
(index == 0)
.then_some(root)
.ok_or(Error::InvalidMerkleProof)
}
pub fn get_merkle_tree_size(num_shreds: usize) -> usize {
successors(Some(num_shreds), |&k| (k > 1).then_some((k + 1) >> 1)).sum()
}
#[cfg(test)]
pub(crate) const fn get_proof_size(num_shreds: usize) -> u8 {
let bits = usize::BITS - num_shreds.leading_zeros();
let proof_size = if num_shreds.is_power_of_two() {
bits.saturating_sub(1)
} else {
bits
};
proof_size as u8
}
#[cfg(test)]
mod tests {
use {
crate::shred::merkle_tree::*, assert_matches::assert_matches, rand::Rng,
std::iter::repeat_with,
};
#[test]
fn test_merkle_proof_entry_from_hash() {
let mut rng = rand::thread_rng();
let bytes: [u8; 32] = rng.gen();
let hash = Hash::from(bytes);
let entry = &hash.as_ref()[..SIZE_OF_MERKLE_PROOF_ENTRY];
let entry = MerkleProofEntry::try_from(entry).unwrap();
assert_eq!(entry, &bytes[..SIZE_OF_MERKLE_PROOF_ENTRY]);
}
#[test]
fn test_get_merkle_tree_size() {
const TREE_SIZE: [usize; 15] = [0, 1, 3, 6, 7, 11, 12, 14, 15, 20, 21, 23, 24, 27, 28];
for (num_shreds, size) in TREE_SIZE.into_iter().enumerate() {
assert_eq!(get_merkle_tree_size(num_shreds), size);
}
}
#[test]
fn test_make_merkle_proof_error() {
let mut rng = rand::thread_rng();
let nodes = repeat_with(|| rng.gen::<[u8; 32]>()).map(Hash::from);
let nodes: Vec<_> = nodes.take(5).collect();
let size = nodes.len();
let tree = make_merkle_tree(nodes.into_iter().map(Ok)).unwrap();
for index in size..size + 3 {
assert_matches!(
make_merkle_proof(index, size, &tree).next(),
Some(Err(Error::InvalidMerkleProof))
);
}
}
fn run_merkle_tree_round_trip<R: Rng>(rng: &mut R, size: usize) {
let nodes = repeat_with(|| rng.gen::<[u8; 32]>()).map(Hash::from);
let nodes: Vec<_> = nodes.take(size).collect();
let tree = make_merkle_tree(nodes.iter().cloned().map(Ok)).unwrap();
let root = tree.last().copied().unwrap();
for index in 0..size {
for (k, &node) in nodes.iter().enumerate() {
let proof = make_merkle_proof(index, size, &tree).map(Result::unwrap);
if k == index {
assert_eq!(root, get_merkle_root(k, node, proof).unwrap());
} else {
assert_ne!(root, get_merkle_root(k, node, proof).unwrap());
}
}
}
}
#[test]
fn test_merkle_tree_round_trip_small() {
let mut rng = rand::thread_rng();
for size in 1..=110 {
run_merkle_tree_round_trip(&mut rng, size);
}
}
#[test]
fn test_merkle_tree_round_trip_big() {
let mut rng = rand::thread_rng();
for size in 110..=143 {
run_merkle_tree_round_trip(&mut rng, size);
}
}
#[test]
fn test_get_proof_size() {
assert_eq!(get_proof_size(0), 0);
assert_eq!(get_proof_size(1), 0);
assert_eq!(get_proof_size(2), 1);
assert_eq!(get_proof_size(3), 2);
assert_eq!(get_proof_size(4), 2);
assert_eq!(get_proof_size(5), 3);
assert_eq!(get_proof_size(63), 6);
assert_eq!(get_proof_size(64), 6);
assert_eq!(get_proof_size(65), 7);
assert_eq!(get_proof_size(usize::MAX - 1), 64);
assert_eq!(get_proof_size(usize::MAX), 64);
for proof_size in 1u8..9 {
let max_num_shreds = 1usize << u32::from(proof_size);
let min_num_shreds = (max_num_shreds >> 1) + 1;
for num_shreds in min_num_shreds..=max_num_shreds {
assert_eq!(get_proof_size(num_shreds), proof_size);
}
}
}
}