jam_std_common/util/
merkle.rs1use std::ops::Index;
2
3use crate::hash_raw_concat;
4use jam_types::Hash;
5
6pub type CdMerkleProof = Vec<Hash>;
10
11#[derive(Debug, Clone)]
13pub enum MerkleNodeRef<'a> {
14 Hash([u8; 32]),
15 Data(&'a [u8]),
16}
17
18impl<'a> AsRef<[u8]> for MerkleNodeRef<'a> {
19 fn as_ref(&self) -> &[u8] {
20 match self {
21 Self::Hash(d) => &d[..],
22 Self::Data(d) => d,
23 }
24 }
25}
26
27pub fn merkle_node<'a>(
33 items: &'a (impl Index<usize, Output = impl AsRef<[u8]> + ?Sized + 'a> + ?Sized),
34 left: usize,
35 right: usize,
36 padded_len: usize,
37) -> MerkleNodeRef<'a> {
38 let len = right - left;
39 debug_assert!(len <= padded_len);
40 if padded_len == 0 {
41 return MerkleNodeRef::Hash([0u8; 32]);
42 }
43 if padded_len == 1 {
44 if len == 0 {
45 return MerkleNodeRef::Hash([0u8; 32])
46 }
47 let item = items.index(left);
48 return MerkleNodeRef::Data(item.as_ref())
49 }
50 let pivot = padded_len.div_ceil(2);
51 let clamped_pivot = pivot.min(len);
52 let left_node = merkle_node(items, left, left + clamped_pivot, pivot);
53 let right_node = merkle_node(items, left + clamped_pivot, right, padded_len - pivot);
54 MerkleNodeRef::Hash(hash_raw_concat([b"node", left_node.as_ref(), right_node.as_ref()]))
55}
56
57pub fn cd_merkle_proof(
68 items: &[Hash],
69 padded_len: usize,
70 mut index: usize,
71) -> (CdMerkleProof, Hash) {
72 debug_assert!(items.len() <= padded_len);
73 debug_assert!((padded_len == 0) || padded_len.is_power_of_two());
74
75 if index >= items.len() {
76 return (vec![], Default::default());
77 }
78 if padded_len == 1 {
79 return (vec![], items[0]);
80 }
81
82 let mut proof = vec![[0; 32]; padded_len.trailing_zeros() as usize];
84 let mut proof_iter = proof.iter_mut().rev();
85 *proof_iter.next().expect("proof is sized correctly") =
86 items.get(index ^ 1).cloned().unwrap_or_default();
87 let mut row = (0..padded_len)
88 .step_by(2)
89 .map(|x| {
90 hash_raw_concat([
91 &b"node"[..],
92 items.get(x).unwrap_or(&[0; 32]),
93 items.get(x + 1).unwrap_or(&[0; 32]),
94 ])
95 })
96 .collect::<Vec<_>>();
97 index >>= 1;
98
99 while row.len() > 1 {
100 *proof_iter.next().expect("proof is sized correctly") = row[index ^ 1];
101 row = row.chunks(2).map(|x| hash_raw_concat([&b"node"[..], &x[0], &x[1]])).collect();
102 index >>= 1;
103 }
104 debug_assert!(proof_iter.next().is_none());
105 (proof, row[0])
106}
107
108pub fn cd_merkle_root(items: &[Hash], padded_len: usize) -> Hash {
114 debug_assert!((padded_len == 0) || padded_len.is_power_of_two());
115 match merkle_node(items, 0, items.len(), padded_len) {
116 MerkleNodeRef::Hash(h) => h,
117 MerkleNodeRef::Data(d) => {
118 let mut h = Hash::default();
119 h.copy_from_slice(d);
120 h
121 },
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn cd_merkle_root_works() {
131 let items = [[1; 32], [2; 32], [3; 32], [4; 32]];
132 let root1 = cd_merkle_root(&items, 4);
133 let (_, root2) = cd_merkle_proof(&items, 4, 0);
134 assert_eq!(root1, root2);
135 }
136}