1use crate::hash::CryptoHash;
2use crate::types::MerkleHash;
3use borsh::{BorshDeserialize, BorshSerialize};
4use near_schema_checker_lib::ProtocolSchema;
5
6#[derive(
7 Debug,
8 Clone,
9 PartialEq,
10 Eq,
11 BorshSerialize,
12 BorshDeserialize,
13 serde::Serialize,
14 serde::Deserialize,
15 ProtocolSchema,
16)]
17#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
18pub struct MerklePathItem {
19 pub hash: MerkleHash,
20 pub direction: Direction,
21}
22
23pub type MerklePath = Vec<MerklePathItem>;
24
25#[derive(
26 Debug,
27 Clone,
28 PartialEq,
29 Eq,
30 BorshSerialize,
31 BorshDeserialize,
32 serde::Serialize,
33 serde::Deserialize,
34 ProtocolSchema,
35)]
36#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
37pub enum Direction {
38 Left,
39 Right,
40}
41
42pub fn combine_hash(hash1: &MerkleHash, hash2: &MerkleHash) -> MerkleHash {
43 CryptoHash::hash_borsh((hash1, hash2))
44}
45
46pub fn merklize<T: BorshSerialize>(arr: &[T]) -> (MerkleHash, Vec<MerklePath>) {
48 if arr.is_empty() {
49 return (MerkleHash::default(), vec![]);
50 }
51 let mut len = arr.len().next_power_of_two();
52 let mut hashes = arr.iter().map(CryptoHash::hash_borsh).collect::<Vec<_>>();
53
54 if len == 1 {
56 return (hashes[0], vec![vec![]]);
57 }
58 let mut arr_len = arr.len();
59 let mut paths: Vec<MerklePath> = (0..arr_len)
60 .map(|i| {
61 if i % 2 == 0 {
62 if i + 1 < arr_len {
63 vec![MerklePathItem {
64 hash: hashes[(i + 1) as usize],
65 direction: Direction::Right,
66 }]
67 } else {
68 vec![]
69 }
70 } else {
71 vec![MerklePathItem { hash: hashes[(i - 1) as usize], direction: Direction::Left }]
72 }
73 })
74 .collect();
75
76 let mut counter = 1;
77 while len > 1 {
78 len /= 2;
79 counter *= 2;
80 for i in 0..len {
81 let hash = if 2 * i >= arr_len {
82 continue;
83 } else if 2 * i + 1 >= arr_len {
84 hashes[2 * i]
85 } else {
86 combine_hash(&hashes[2 * i], &hashes[2 * i + 1])
87 };
88 hashes[i] = hash;
89 if len > 1 {
90 if i % 2 == 0 {
91 for j in 0..counter {
92 let index = ((i + 1) * counter + j) as usize;
93 if index < arr.len() {
94 paths[index].push(MerklePathItem { hash, direction: Direction::Left });
95 }
96 }
97 } else {
98 for j in 0..counter {
99 let index = ((i - 1) * counter + j) as usize;
100 if index < arr.len() {
101 paths[index].push(MerklePathItem { hash, direction: Direction::Right });
102 }
103 }
104 }
105 }
106 }
107 arr_len = (arr_len + 1) / 2;
108 }
109 (hashes[0], paths)
110}
111
112pub fn verify_path<T: BorshSerialize>(root: MerkleHash, path: &MerklePath, item: T) -> bool {
114 verify_hash(root, path, CryptoHash::hash_borsh(item))
115}
116
117pub fn verify_hash(root: MerkleHash, path: &MerklePath, item_hash: MerkleHash) -> bool {
118 compute_root_from_path(path, item_hash) == root
119}
120
121pub fn verify_path_with_index<T: BorshSerialize>(
122 root: MerkleHash,
123 path: &MerklePath,
124 item: T,
125 part_idx: u64,
126 num_merklized_parts: u64,
127) -> bool {
128 verify_path_matches_index(path, part_idx, num_merklized_parts) && verify_path(root, path, item)
129}
130
131pub fn compute_root_from_path(path: &MerklePath, item_hash: MerkleHash) -> MerkleHash {
132 let mut res = item_hash;
133 for item in path {
134 match item.direction {
135 Direction::Left => {
136 res = combine_hash(&item.hash, &res);
137 }
138 Direction::Right => {
139 res = combine_hash(&res, &item.hash);
140 }
141 }
142 }
143 res
144}
145
146pub fn compute_root_from_path_and_item<T: BorshSerialize>(
147 path: &MerklePath,
148 item: T,
149) -> MerkleHash {
150 compute_root_from_path(path, CryptoHash::hash_borsh(item))
151}
152
153#[derive(
159 Default, Clone, BorshSerialize, BorshDeserialize, Eq, PartialEq, Debug, serde::Serialize,
160)]
161pub struct PartialMerkleTree {
162 path: Vec<MerkleHash>,
164 size: u64,
166}
167
168impl PartialMerkleTree {
169 pub fn is_well_formed(&self) -> bool {
182 self.path.len() == self.size.count_ones() as usize
183 }
184
185 pub fn root(&self) -> MerkleHash {
186 if self.path.is_empty() {
187 CryptoHash::default()
188 } else {
189 let mut res = *self.path.last().unwrap();
190 let len = self.path.len();
191 for i in (0..len - 1).rev() {
192 res = combine_hash(&self.path[i], &res);
193 }
194 res
195 }
196 }
197
198 pub fn insert(&mut self, elem: MerkleHash) {
199 let mut s = self.size;
200 let mut node = elem;
201 while s % 2 == 1 {
202 let last_path_elem = self.path.pop().unwrap();
203 node = combine_hash(&last_path_elem, &node);
204 s /= 2;
205 }
206 self.path.push(node);
207 self.size += 1;
208 }
209
210 pub fn size(&self) -> u64 {
211 self.size
212 }
213
214 pub fn get_path(&self) -> &[MerkleHash] {
215 &self.path
216 }
217
218 pub fn iter_path_from_bottom(&self, mut f: impl FnMut(MerkleHash, u64)) {
221 let mut level = 0;
222 let mut index = self.size;
223 for node in self.path.iter().rev() {
224 if index == 0 {
225 return;
227 }
228 let trailing_zeros = index.trailing_zeros();
229 level += trailing_zeros;
230 index >>= trailing_zeros;
231 index -= 1;
232 f(*node, level as u64);
233 }
234 }
235}
236
237fn verify_path_matches_index(path: &MerklePath, part_idx: u64, num_merklized_parts: u64) -> bool {
238 if part_idx >= num_merklized_parts {
239 return false;
240 }
241
242 let mut used = 0;
243
244 let height = num_merklized_parts.next_power_of_two().ilog2() as usize;
245 for k in 0..height {
246 let block = part_idx >> k;
247 let sibling_block = block ^ 1;
248 let sibling_leaf_start_index = sibling_block << k;
249 if sibling_leaf_start_index < num_merklized_parts {
250 let Some(item) = path.get(used) else {
251 return false;
252 };
253 let expected =
254 if (part_idx >> k) & 1 == 0 { Direction::Right } else { Direction::Left };
255 if item.direction != expected {
256 return false;
257 }
258 used += 1;
259 }
260 }
261 used == path.len()
262}
263
264#[cfg(test)]
265mod tests {
266 use rand::rngs::StdRng;
267 use rand::{Rng, SeedableRng};
268
269 use super::*;
270
271 fn test_with_len(n: u32, rng: &mut StdRng) {
272 let mut arr: Vec<u32> = vec![];
273 for _ in 0..n {
274 arr.push(rng.gen_range(0..1000));
275 }
276 let (root, paths) = merklize(&arr);
277 assert_eq!(paths.len() as u32, n);
278 for (i, item) in arr.iter().enumerate() {
279 assert!(verify_path(root, &paths[i], item));
280 }
281 }
282
283 #[test]
284 fn test_merkle_path() {
285 let mut rng: StdRng = SeedableRng::seed_from_u64(1);
286 for _ in 0..10 {
287 let len: u32 = rng.gen_range(1..100);
288 test_with_len(len, &mut rng);
289 }
290 }
291
292 #[test]
293 fn test_incorrect_path() {
294 let items = vec![111, 222, 333];
295 let (root, paths) = merklize(&items);
296 for i in 0..items.len() {
297 assert!(!verify_path(root, &paths[(i + 1) % 3], &items[i]))
298 }
299 }
300
301 #[test]
302 fn test_elements_order() {
303 let items = vec![1, 2];
304 let (root, _) = merklize(&items);
305 let items2 = vec![2, 1];
306 let (root2, _) = merklize(&items2);
307 assert_ne!(root, root2);
308 }
309
310 fn compute_root(hashes: &[CryptoHash]) -> CryptoHash {
312 if hashes.is_empty() {
313 CryptoHash::default()
314 } else if hashes.len() == 1 {
315 hashes[0]
316 } else {
317 let len = hashes.len();
318 let subtree_len = len.next_power_of_two() / 2;
319 let left_root = compute_root(&hashes[0..subtree_len]);
320 let right_root = compute_root(&hashes[subtree_len..len]);
321 combine_hash(&left_root, &right_root)
322 }
323 }
324
325 #[test]
326 fn test_merkle_tree() {
327 let mut tree = PartialMerkleTree::default();
328 let mut hashes = vec![];
329 for i in 0..50 {
330 assert_eq!(compute_root(&hashes), tree.root());
331 assert!(tree.is_well_formed());
332
333 let mut tree_copy = tree.clone();
334 tree_copy.path.push(CryptoHash::hash_bytes(&[i]));
335 assert!(!tree_copy.is_well_formed());
336 tree_copy.path.pop();
337 if !tree_copy.path.is_empty() {
338 tree_copy.path.pop();
339 assert!(!tree_copy.is_well_formed());
340 }
341
342 let cur_hash = CryptoHash::hash_bytes(&[i]);
343 hashes.push(cur_hash);
344 tree.insert(cur_hash);
345 }
346 }
347
348 #[test]
349 fn test_combine_hash_stability() {
350 let a = MerkleHash::default();
351 let b = MerkleHash::default();
352 let cc = combine_hash(&a, &b);
353 assert_eq!(
354 cc.0,
355 [
356 245, 165, 253, 66, 209, 106, 32, 48, 39, 152, 239, 110, 211, 9, 151, 155, 67, 0,
357 61, 35, 32, 217, 240, 232, 234, 152, 49, 169, 39, 89, 251, 75
358 ]
359 );
360 }
361}