use std::ops::Index;
use crate::hash_raw_concat;
use jam_types::Hash;
pub type CdMerkleProof = Vec<Hash>;
#[derive(Debug, Clone)]
pub enum MerkleNodeRef<'a> {
Hash([u8; 32]),
Data(&'a [u8]),
}
impl<'a> AsRef<[u8]> for MerkleNodeRef<'a> {
fn as_ref(&self) -> &[u8] {
match self {
Self::Hash(d) => &d[..],
Self::Data(d) => d,
}
}
}
pub fn merkle_node<'a>(
items: &'a (impl Index<usize, Output = impl AsRef<[u8]> + ?Sized + 'a> + ?Sized),
left: usize,
right: usize,
padded_len: usize,
) -> MerkleNodeRef<'a> {
let len = right - left;
debug_assert!(len <= padded_len);
if padded_len == 0 {
return MerkleNodeRef::Hash([0u8; 32]);
}
if padded_len == 1 {
if len == 0 {
return MerkleNodeRef::Hash([0u8; 32])
}
let item = items.index(left);
return MerkleNodeRef::Data(item.as_ref())
}
let pivot = padded_len.div_ceil(2);
let clamped_pivot = pivot.min(len);
let left_node = merkle_node(items, left, left + clamped_pivot, pivot);
let right_node = merkle_node(items, left + clamped_pivot, right, padded_len - pivot);
MerkleNodeRef::Hash(hash_raw_concat([b"node", left_node.as_ref(), right_node.as_ref()]))
}
pub fn cd_merkle_proof(
items: &[Hash],
padded_len: usize,
mut index: usize,
) -> (CdMerkleProof, Hash) {
debug_assert!(items.len() <= padded_len);
debug_assert!((padded_len == 0) || padded_len.is_power_of_two());
if index >= items.len() {
return (vec![], Default::default());
}
if padded_len == 1 {
return (vec![], items[0]);
}
let mut proof = vec![[0; 32]; padded_len.trailing_zeros() as usize];
let mut proof_iter = proof.iter_mut().rev();
*proof_iter.next().expect("proof is sized correctly") =
items.get(index ^ 1).cloned().unwrap_or_default();
let mut row = (0..padded_len)
.step_by(2)
.map(|x| {
hash_raw_concat([
&b"node"[..],
items.get(x).unwrap_or(&[0; 32]),
items.get(x + 1).unwrap_or(&[0; 32]),
])
})
.collect::<Vec<_>>();
index >>= 1;
while row.len() > 1 {
*proof_iter.next().expect("proof is sized correctly") = row[index ^ 1];
row = row.chunks(2).map(|x| hash_raw_concat([&b"node"[..], &x[0], &x[1]])).collect();
index >>= 1;
}
debug_assert!(proof_iter.next().is_none());
(proof, row[0])
}
pub fn cd_merkle_root(items: &[Hash], padded_len: usize) -> Hash {
debug_assert!((padded_len == 0) || padded_len.is_power_of_two());
match merkle_node(items, 0, items.len(), padded_len) {
MerkleNodeRef::Hash(h) => h,
MerkleNodeRef::Data(d) => {
let mut h = Hash::default();
h.copy_from_slice(d);
h
},
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cd_merkle_root_works() {
let items = [[1; 32], [2; 32], [3; 32], [4; 32]];
let root1 = cd_merkle_root(&items, 4);
let (_, root2) = cd_merkle_proof(&items, 4, 0);
assert_eq!(root1, root2);
}
}