near_primitives/
merkle.rs

1use crate::hash::CryptoHash;
2use crate::types::MerkleHash;
3use borsh::{BorshDeserialize, BorshSerialize};
4use near_schema_checker_lib::ProtocolSchema;
5
6#[derive(
7    Debug,
8    Clone,
9    PartialEq,
10    Eq,
11    BorshSerialize,
12    BorshDeserialize,
13    serde::Serialize,
14    serde::Deserialize,
15    ProtocolSchema,
16)]
17#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
18pub struct MerklePathItem {
19    pub hash: MerkleHash,
20    pub direction: Direction,
21}
22
23pub type MerklePath = Vec<MerklePathItem>;
24
25#[derive(
26    Debug,
27    Clone,
28    PartialEq,
29    Eq,
30    BorshSerialize,
31    BorshDeserialize,
32    serde::Serialize,
33    serde::Deserialize,
34    ProtocolSchema,
35)]
36#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
37pub enum Direction {
38    Left,
39    Right,
40}
41
42pub fn combine_hash(hash1: &MerkleHash, hash2: &MerkleHash) -> MerkleHash {
43    CryptoHash::hash_borsh((hash1, hash2))
44}
45
46/// Merklize an array of items. If the array is empty, returns hash of 0
47pub fn merklize<T: BorshSerialize>(arr: &[T]) -> (MerkleHash, Vec<MerklePath>) {
48    if arr.is_empty() {
49        return (MerkleHash::default(), vec![]);
50    }
51    let mut len = arr.len().next_power_of_two();
52    let mut hashes = arr.iter().map(CryptoHash::hash_borsh).collect::<Vec<_>>();
53
54    // degenerate case
55    if len == 1 {
56        return (hashes[0], vec![vec![]]);
57    }
58    let mut arr_len = arr.len();
59    let mut paths: Vec<MerklePath> = (0..arr_len)
60        .map(|i| {
61            if i % 2 == 0 {
62                if i + 1 < arr_len {
63                    vec![MerklePathItem {
64                        hash: hashes[(i + 1) as usize],
65                        direction: Direction::Right,
66                    }]
67                } else {
68                    vec![]
69                }
70            } else {
71                vec![MerklePathItem { hash: hashes[(i - 1) as usize], direction: Direction::Left }]
72            }
73        })
74        .collect();
75
76    let mut counter = 1;
77    while len > 1 {
78        len /= 2;
79        counter *= 2;
80        for i in 0..len {
81            let hash = if 2 * i >= arr_len {
82                continue;
83            } else if 2 * i + 1 >= arr_len {
84                hashes[2 * i]
85            } else {
86                combine_hash(&hashes[2 * i], &hashes[2 * i + 1])
87            };
88            hashes[i] = hash;
89            if len > 1 {
90                if i % 2 == 0 {
91                    for j in 0..counter {
92                        let index = ((i + 1) * counter + j) as usize;
93                        if index < arr.len() {
94                            paths[index].push(MerklePathItem { hash, direction: Direction::Left });
95                        }
96                    }
97                } else {
98                    for j in 0..counter {
99                        let index = ((i - 1) * counter + j) as usize;
100                        if index < arr.len() {
101                            paths[index].push(MerklePathItem { hash, direction: Direction::Right });
102                        }
103                    }
104                }
105            }
106        }
107        arr_len = (arr_len + 1) / 2;
108    }
109    (hashes[0], paths)
110}
111
112/// Verify merkle path for given item and corresponding path.
113pub fn verify_path<T: BorshSerialize>(root: MerkleHash, path: &MerklePath, item: T) -> bool {
114    verify_hash(root, path, CryptoHash::hash_borsh(item))
115}
116
117pub fn verify_hash(root: MerkleHash, path: &MerklePath, item_hash: MerkleHash) -> bool {
118    compute_root_from_path(path, item_hash) == root
119}
120
121pub fn verify_path_with_index<T: BorshSerialize>(
122    root: MerkleHash,
123    path: &MerklePath,
124    item: T,
125    part_idx: u64,
126    num_merklized_parts: u64,
127) -> bool {
128    verify_path_matches_index(path, part_idx, num_merklized_parts) && verify_path(root, path, item)
129}
130
131pub fn compute_root_from_path(path: &MerklePath, item_hash: MerkleHash) -> MerkleHash {
132    let mut res = item_hash;
133    for item in path {
134        match item.direction {
135            Direction::Left => {
136                res = combine_hash(&item.hash, &res);
137            }
138            Direction::Right => {
139                res = combine_hash(&res, &item.hash);
140            }
141        }
142    }
143    res
144}
145
146pub fn compute_root_from_path_and_item<T: BorshSerialize>(
147    path: &MerklePath,
148    item: T,
149) -> MerkleHash {
150    compute_root_from_path(path, CryptoHash::hash_borsh(item))
151}
152
153/// Merkle tree that only maintains the path for the next leaf, i.e,
154/// when a new leaf is inserted, the existing `path` is its proof.
155/// The root can be computed by folding `path` from right but is not explicitly
156/// maintained to save space.
157/// The size of the object is O(log(n)) where n is the number of leaves in the tree, i.e, `size`.
158#[derive(
159    Default, Clone, BorshSerialize, BorshDeserialize, Eq, PartialEq, Debug, serde::Serialize,
160)]
161pub struct PartialMerkleTree {
162    /// Path for the next leaf.
163    path: Vec<MerkleHash>,
164    /// Number of leaves in the tree.
165    size: u64,
166}
167
168impl PartialMerkleTree {
169    /// A PartialMerkleTree is well formed iff the path would be a valid proof for the next block
170    /// of ordinal `size`. This means that the path contains exactly `size.count_ones()` elements.
171    ///
172    /// The <= direction of this statement is easy to prove, as the subtrees whose roots are being
173    /// combined to form the overall root correspond to the binary 1s in the size.
174    ///
175    /// The => direction is proven by observing that the root is computed as
176    /// hash(path[0], hash(path[1], hash(path[2], ... hash(path[n-1], path[n]) ...))
177    /// and there is only one way to provide an array of paths of the exact same size that would
178    /// produce the same result when combined in this way. (This would not have been true if we
179    /// could provide a path of a different size, e.g. if we could provide just one hash, we could
180    /// provide only the root).
181    pub fn is_well_formed(&self) -> bool {
182        self.path.len() == self.size.count_ones() as usize
183    }
184
185    pub fn root(&self) -> MerkleHash {
186        if self.path.is_empty() {
187            CryptoHash::default()
188        } else {
189            let mut res = *self.path.last().unwrap();
190            let len = self.path.len();
191            for i in (0..len - 1).rev() {
192                res = combine_hash(&self.path[i], &res);
193            }
194            res
195        }
196    }
197
198    pub fn insert(&mut self, elem: MerkleHash) {
199        let mut s = self.size;
200        let mut node = elem;
201        while s % 2 == 1 {
202            let last_path_elem = self.path.pop().unwrap();
203            node = combine_hash(&last_path_elem, &node);
204            s /= 2;
205        }
206        self.path.push(node);
207        self.size += 1;
208    }
209
210    pub fn size(&self) -> u64 {
211        self.size
212    }
213
214    pub fn get_path(&self) -> &[MerkleHash] {
215        &self.path
216    }
217
218    /// Iterate over the path from the bottom to the top, calling `f` with the hash and the level.
219    /// The level is 0 for the leaf and increases by 1 for each level in the actual tree.
220    pub fn iter_path_from_bottom(&self, mut f: impl FnMut(MerkleHash, u64)) {
221        let mut level = 0;
222        let mut index = self.size;
223        for node in self.path.iter().rev() {
224            if index == 0 {
225                // shouldn't happen
226                return;
227            }
228            let trailing_zeros = index.trailing_zeros();
229            level += trailing_zeros;
230            index >>= trailing_zeros;
231            index -= 1;
232            f(*node, level as u64);
233        }
234    }
235}
236
237fn verify_path_matches_index(path: &MerklePath, part_idx: u64, num_merklized_parts: u64) -> bool {
238    if part_idx >= num_merklized_parts {
239        return false;
240    }
241
242    let mut used = 0;
243
244    let height = num_merklized_parts.next_power_of_two().ilog2() as usize;
245    for k in 0..height {
246        let block = part_idx >> k;
247        let sibling_block = block ^ 1;
248        let sibling_leaf_start_index = sibling_block << k;
249        if sibling_leaf_start_index < num_merklized_parts {
250            let Some(item) = path.get(used) else {
251                return false;
252            };
253            let expected =
254                if (part_idx >> k) & 1 == 0 { Direction::Right } else { Direction::Left };
255            if item.direction != expected {
256                return false;
257            }
258            used += 1;
259        }
260    }
261    used == path.len()
262}
263
264#[cfg(test)]
265mod tests {
266    use rand::rngs::StdRng;
267    use rand::{Rng, SeedableRng};
268
269    use super::*;
270
271    fn test_with_len(n: u32, rng: &mut StdRng) {
272        let mut arr: Vec<u32> = vec![];
273        for _ in 0..n {
274            arr.push(rng.gen_range(0..1000));
275        }
276        let (root, paths) = merklize(&arr);
277        assert_eq!(paths.len() as u32, n);
278        for (i, item) in arr.iter().enumerate() {
279            assert!(verify_path(root, &paths[i], item));
280        }
281    }
282
283    #[test]
284    fn test_merkle_path() {
285        let mut rng: StdRng = SeedableRng::seed_from_u64(1);
286        for _ in 0..10 {
287            let len: u32 = rng.gen_range(1..100);
288            test_with_len(len, &mut rng);
289        }
290    }
291
292    #[test]
293    fn test_incorrect_path() {
294        let items = vec![111, 222, 333];
295        let (root, paths) = merklize(&items);
296        for i in 0..items.len() {
297            assert!(!verify_path(root, &paths[(i + 1) % 3], &items[i]))
298        }
299    }
300
301    #[test]
302    fn test_elements_order() {
303        let items = vec![1, 2];
304        let (root, _) = merklize(&items);
305        let items2 = vec![2, 1];
306        let (root2, _) = merklize(&items2);
307        assert_ne!(root, root2);
308    }
309
310    /// Compute the merkle root of a given array.
311    fn compute_root(hashes: &[CryptoHash]) -> CryptoHash {
312        if hashes.is_empty() {
313            CryptoHash::default()
314        } else if hashes.len() == 1 {
315            hashes[0]
316        } else {
317            let len = hashes.len();
318            let subtree_len = len.next_power_of_two() / 2;
319            let left_root = compute_root(&hashes[0..subtree_len]);
320            let right_root = compute_root(&hashes[subtree_len..len]);
321            combine_hash(&left_root, &right_root)
322        }
323    }
324
325    #[test]
326    fn test_merkle_tree() {
327        let mut tree = PartialMerkleTree::default();
328        let mut hashes = vec![];
329        for i in 0..50 {
330            assert_eq!(compute_root(&hashes), tree.root());
331            assert!(tree.is_well_formed());
332
333            let mut tree_copy = tree.clone();
334            tree_copy.path.push(CryptoHash::hash_bytes(&[i]));
335            assert!(!tree_copy.is_well_formed());
336            tree_copy.path.pop();
337            if !tree_copy.path.is_empty() {
338                tree_copy.path.pop();
339                assert!(!tree_copy.is_well_formed());
340            }
341
342            let cur_hash = CryptoHash::hash_bytes(&[i]);
343            hashes.push(cur_hash);
344            tree.insert(cur_hash);
345        }
346    }
347
348    #[test]
349    fn test_combine_hash_stability() {
350        let a = MerkleHash::default();
351        let b = MerkleHash::default();
352        let cc = combine_hash(&a, &b);
353        assert_eq!(
354            cc.0,
355            [
356                245, 165, 253, 66, 209, 106, 32, 48, 39, 152, 239, 110, 211, 9, 151, 155, 67, 0,
357                61, 35, 32, 217, 240, 232, 234, 152, 49, 169, 39, 89, 251, 75
358            ]
359        );
360    }
361}