1use crate::{
2 collections::{BTreeMap, VecDeque},
3 error::{Error, Result},
4 merge::{hash_leaf, merge},
5 merkle_proof::MerkleProof,
6 proof_ics23,
7 traits::{Hasher, Store, Value},
8 vec::Vec,
9 Key, InternalKey, EXPECTED_PATH_SIZE, H256,
10};
11#[cfg(feature = "borsh")]
12use borsh::{BorshDeserialize, BorshSerialize};
13use core::{cmp::max, marker::PhantomData};
14use ics23::commitment_proof::Proof;
15use ics23::{CommitmentProof, NonExistenceProof};
16use itertools::Itertools;
17
18#[derive(Debug, Eq, PartialEq, Clone)]
20#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
21pub struct BranchNode<K, const N: usize>
22where
23 K: Key<N>,
24{
25 pub fork_height: usize,
26 pub key: K,
27 pub node: H256,
28 pub sibling: H256,
29}
30
31impl<K, const N: usize> BranchNode<K, N>
32where
33 K: Key<N>,
34{
35 fn branch(&self, height: usize) -> (&H256, &H256) {
36 let is_right = self.key.get_bit(height);
37 if is_right {
38 (&self.sibling, &self.node)
39 } else {
40 (&self.node, &self.sibling)
41 }
42 }
43}
44
45#[derive(Debug, Eq, PartialEq, Clone)]
47#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
48pub struct LeafNode<K, V, const N: usize>
49where
50 K: Key<N>,
51{
52 pub key: K,
53 pub value: V,
54}
55
56#[derive(Debug)]
58pub struct SparseMerkleTree<H, K, V, S, const N: usize>
59where
60 H: Hasher + Default,
61 K: Key<N>,
62 V: Value,
63 S: Store<K, V, N>,
64{
65 store: S,
66 root: H256,
67 phantom: PhantomData<(H, K, V)>,
68}
69
70impl<H, K, V, S, const N: usize> Default for SparseMerkleTree<H, K, V, S, N>
71where
72 H: Hasher + Default,
73 K: Key<N>,
74 V: Value + core::cmp::PartialEq,
75 S: Store<K, V, N>,
76{
77 fn default() -> Self {
78 Self::new(H256::default(), S::default())
79 }
80}
81
82impl<H, K, V, S, const N: usize> SparseMerkleTree<H, K, V, S, N>
83where
84 H: Hasher + Default,
85 K: Key<N>,
86 V: Value + core::cmp::PartialEq,
87 S: Store<K, V, N>,
88{
89 pub fn new(root: H256, store: S) -> SparseMerkleTree<H, K, V, S, N> {
91 SparseMerkleTree {
92 root,
93 store,
94 phantom: PhantomData,
95 }
96 }
97
98 pub fn root(&self) -> &H256 {
100 &self.root
101 }
102
103 pub fn is_empty(&self) -> bool {
105 self.root.is_zero()
106 }
107
108 pub fn take_store(self) -> S {
110 self.store
111 }
112
113 pub fn store(&self) -> &S {
115 &self.store
116 }
117
118 pub fn store_mut(&mut self) -> &mut S {
120 &mut self.store
121 }
122
123 pub fn update(&mut self, key: K, value: V) -> Result<&H256> {
126 let mut path: BTreeMap<_, _> = Default::default();
128 let mut node = self.root;
130 let mut branch = self.store.get_branch(&node)?;
131 let mut height = branch
132 .as_ref()
133 .map(|b| max(b.key.fork_height(&key), b.fork_height))
134 .unwrap_or(0);
135 while branch.is_some() {
138 let branch_node = branch.unwrap();
139 let fork_height = max(key.fork_height(&branch_node.key), branch_node.fork_height);
140 if height > branch_node.fork_height {
141 path.insert(fork_height, node);
144 break;
145 }
146 if branch_node.fork_height > 0 {
149 self.store.remove_branch(&node)?;
150 }
151 let (left, right) = branch_node.branch(height);
152 let is_right = key.get_bit(height);
153 let sibling = if is_right {
154 if &node == right {
155 break;
156 }
157 node = *right;
158 *left
159 } else {
160 if &node == left {
161 break;
162 }
163 node = *left;
164 *right
165 };
166 path.insert(height, sibling);
167 branch = self.store.get_branch(&node)?;
169 if let Some(branch_node) = branch.as_ref() {
170 height = max(key.fork_height(&branch_node.key), branch_node.fork_height);
171 }
172 }
173 if let Some(leaf) = self.store.get_leaf(&node)? {
175 if leaf.key == key {
176 self.store.remove_leaf(&node)?;
177 self.store.remove_branch(&node)?;
178 }
179 }
180
181 let mut node = hash_leaf::<H, K, V, N>(&key, &value);
183 if !node.is_zero() {
185 self.store.insert_leaf(node, LeafNode { key, value })?;
186
187 self.store.insert_branch(
189 node,
190 BranchNode {
191 key,
192 fork_height: 0,
193 node,
194 sibling: H256::zero(),
195 },
196 )?;
197 }
198
199 while !path.is_empty() {
201 let height = path.iter().next().map(|(height, _)| *height).unwrap();
203 let sibling = path.remove(&height).unwrap();
204
205 let is_right = key.get_bit(height);
206 let parent = if is_right {
207 merge::<H>(&sibling, &node)
208 } else {
209 merge::<H>(&node, &sibling)
210 };
211
212 if !node.is_zero() {
213 let branch_node = BranchNode {
215 fork_height: height,
216 sibling,
217 node,
218 key,
219 };
220 self.store.insert_branch(parent, branch_node)?;
221 }
222 node = parent;
223 }
224 self.root = node;
225 Ok(&self.root)
226 }
227
228 pub fn get(&self, key: &K) -> Result<V> {
231 let mut node = self.root;
232 while !node.is_zero() {
234 let branch_node = match self.store.get_branch(&node)? {
235 Some(branch_node) => branch_node,
236 None => {
237 break;
238 }
239 };
240 let is_right = key.get_bit(branch_node.fork_height);
241 let (left, right) = branch_node.branch(branch_node.fork_height);
242 node = if is_right { *right } else { *left };
243 if branch_node.fork_height == 0 {
244 break;
245 }
246 }
247
248 if node.is_zero() {
250 return Ok(V::zero());
251 }
252 match self.store.get_leaf(&node)? {
254 Some(leaf) if &leaf.key == key => Ok(leaf.value),
255 _ => Ok(V::zero()),
256 }
257 }
258
259 fn fetch_merkle_path(
262 &self,
263 key: &K,
264 cache: &mut BTreeMap<(usize, InternalKey<N>), H256>,
265 ) -> Result<()> {
266 let mut node = self.root;
267 let mut height = self
268 .store
269 .get_branch(&node)?
270 .map(|b| max(b.key.fork_height(key), b.fork_height))
271 .unwrap_or(0);
272 while !node.is_zero() {
273 if node.is_zero() {
275 break;
276 }
277 match self.store.get_branch(&node)? {
278 Some(branch_node) => {
279 if height > branch_node.fork_height {
280 let fork_height =
281 max(key.fork_height(&branch_node.key), branch_node.fork_height);
282
283 let is_right = key.get_bit(fork_height);
284 let mut sibling_key = key.parent_path(fork_height);
285 if !is_right {
286 sibling_key.set_bit(height);
288 };
289 if !node.is_zero() {
290 cache
291 .entry((fork_height as usize, sibling_key))
292 .or_insert(node);
293 }
294 break;
295 }
296 let (left, right) = branch_node.branch(height);
297 let is_right = key.get_bit(height);
298 let sibling = if is_right {
299 if &node == right {
300 break;
301 }
302 node = *right;
303 *left
304 } else {
305 if &node == left {
306 break;
307 }
308 node = *left;
309 *right
310 };
311 let mut sibling_key = key.parent_path(height);
312 if !is_right {
313 sibling_key.set_bit(height);
315 };
316 cache.insert((height as usize, sibling_key), sibling);
317 if let Some(branch_node) = self.store.get_branch(&node)? {
318 let fork_height =
319 max(key.fork_height(&branch_node.key), branch_node.fork_height);
320 height = fork_height;
321 }
322 }
323 None => break,
324 };
325 }
326 Ok(())
327 }
328
329 pub fn merkle_proof(&self, mut keys: Vec<K>) -> Result<MerkleProof> {
331 if keys.is_empty() {
332 return Err(Error::EmptyKeys);
333 }
334
335 keys.sort_unstable_by_key(|k| **k);
337
338 let mut cache: BTreeMap<(usize, _), H256> = Default::default();
340 for k in &keys {
341 self.fetch_merkle_path(k, &mut cache)?;
342 }
343
344 let mut proof: Vec<(H256, usize)> = Vec::with_capacity(EXPECTED_PATH_SIZE * keys.len());
346 let mut leaves_path: Vec<Vec<usize>> = Vec::with_capacity(keys.len());
348 leaves_path.resize_with(keys.len(), Default::default);
349
350 let keys_len = keys.len();
351 let mut queue: VecDeque<(_, usize, usize)> = keys
354 .into_iter()
355 .enumerate()
356 .map(|(i, k)| (*k, 0, i))
357 .collect();
358
359 while let Some((key, height, leaf_index)) = queue.pop_front() {
360 if queue.is_empty() && cache.is_empty() || height == 8 * N {
361 if leaves_path[leaf_index].is_empty() {
363 leaves_path[leaf_index].push((8 * N) - 1);
364 }
365 break;
366 }
367 let mut sibling_key = key.parent_path(height);
369
370 let is_right = key.get_bit(height);
371 if is_right {
372 sibling_key.clear_bit(height);
374 } else {
375 sibling_key.set_bit(height);
377 }
378 if Some((&sibling_key, &height))
379 == queue
380 .front()
381 .map(|(sibling_key, height, _leaf_index)| (sibling_key, height))
382 {
383 let (_sibling_key, height, leaf_index) = queue.pop_front().unwrap();
385 leaves_path[leaf_index].push(height);
386 } else {
387 match cache.remove(&(height, sibling_key)) {
388 Some(sibling) => {
389 debug_assert!(height < 8 * N);
390 proof.push((sibling, height));
392 }
393 None => {
394 if !is_right {
396 sibling_key.clear_bit(height);
397 }
398 let parent_key = sibling_key;
399 queue.push_back((parent_key, height + 1, leaf_index));
400 continue;
401 }
402 }
403 }
404 leaves_path[leaf_index].push(height);
406 if height < 8 * N {
407 let parent_key = if is_right { sibling_key } else { key };
409 queue.push_back((parent_key, height + 1, leaf_index));
410 }
411 }
412 debug_assert_eq!(leaves_path.len(), keys_len);
413 Ok(MerkleProof::new(leaves_path, proof))
414 }
415
416 pub fn membership_proof(&self, key: &K) -> Result<CommitmentProof> {
418 let value = self.get(key)?;
419 if value == V::zero() {
420 return Err(Error::ExistenceProof);
421 }
422 let merkle_proof = self.merkle_proof(vec![*key])?;
423 let existence_proof =
424 proof_ics23::convert(merkle_proof, key, &value, H::hash_op())?;
425 Ok(CommitmentProof {
426 proof: Some(Proof::Exist(existence_proof)),
427 })
428 }
429
430 pub fn non_membership_proof(&self, key: &K) -> Result<CommitmentProof> {
432 let value = self.get(key)?;
433 if value != V::zero() {
434 return Err(Error::NonExistenceProof);
435 }
436
437 let mut cache: BTreeMap<(usize, _), H256> = Default::default();
439 self.fetch_merkle_path(key, &mut cache)?;
440 let mut left = None;
441 let mut right = None;
442 for (_, node) in cache.iter() {
443 let branch = self
444 .store
445 .get_branch(node)?
446 .expect("the forked branch should exist");
447 let fork_height = key.fork_height(&branch.key);
448 let is_right = key.get_bit(fork_height);
449 if is_right && left.is_none() {
450 let mut n = *node;
452 while let Some(branch) = self.store.get_branch(&n)? {
453 if branch.fork_height == 0 {
454 break;
455 }
456 let (left_node, right_node) = branch.branch(branch.fork_height);
457 n = if right_node.is_zero() {
458 *left_node
459 } else {
460 *right_node
461 };
462 }
463 let leaf = self.store.get_leaf(&n)?.expect("the leaf should exist");
464 let merkle_proof = self.merkle_proof(vec![leaf.key])?;
465 left = Some(proof_ics23::convert(
466 merkle_proof,
467 &leaf.key,
468 &leaf.value,
469 H::hash_op(),
470 )?);
471 } else if !is_right && right.is_none() {
472 let mut n = *node;
474 while let Some(branch) = self.store.get_branch(&n)? {
475 if branch.fork_height == 0 {
476 break;
477 }
478 let (left_node, right_node) = branch.branch(branch.fork_height);
479 n = if left_node.is_zero() {
480 *right_node
481 } else {
482 *left_node
483 };
484 }
485 let leaf = self.store.get_leaf(&n)?.expect("the leaf should exist");
486 let merkle_proof = self.merkle_proof(vec![leaf.key])?;
487 right = Some(proof_ics23::convert(
488 merkle_proof,
489 &leaf.key,
490 &leaf.value,
491 H::hash_op(),
492 )?);
493 }
494 if left.is_some() && right.is_some() {
495 break;
496 }
497 }
498 let proof = NonExistenceProof {
499 key: key.to_vec(),
500 left,
501 right,
502 };
503 Ok(CommitmentProof {
504 proof: Some(Proof::Nonexist(proof)),
505 })
506 }
507
508 pub fn validate(&self) -> bool {
511 if self.store.size() == 0 {
513 return self.root == H256::zero()
514 }
515
516 let sorted_leaves = self.store
517 .sorted_leaves()
518 .map(|(k, v)| (k, v.clone()))
519 .collect::<Vec<_>>();
520 let pairs = sorted_leaves
522 .iter()
523 .tuple_windows::<(_, _)>();
524
525 let mut leaves = Vec::with_capacity(self.store.size());
527 for ((k1, v1), (k2, _)) in pairs {
528 let height = k1.fork_height(k2);
529 let hash = hash_leaf::<H, K, V, N>(k1, v1);
530 leaves.push((hash, height));
531 }
532 let (last_k, last_v) = sorted_leaves
533 .last()
534 .map(|(k, v)| (k, v))
535 .unwrap();
536 let last = hash_leaf::<H, K, V, N>(last_k, last_v);
537 if leaves.is_empty() {
538 return self.root == last;
539 }
540 leaves.push((last, usize::MAX));
541
542 let mut left: usize = 0;
543 let mut right: usize = 1;
544 let mut merged = Default::default();
545
546 let mut prev: Vec<usize> = Vec::with_capacity(leaves.len() / 2);
548
549 while right < leaves.len() {
553 if leaves[left].1 < leaves[right].1 {
554 loop {
555 merged = merge::<H>(&leaves[left].0, &leaves[right].0);
557 leaves[right].0 = merged;
558
559 match prev.last() {
561 Some(&idx) if leaves[idx].1 < leaves[right].1 => {
562 left = idx;
563 _ = prev.pop();
564 continue;
565 }
566 _ => {
567 break;
568 }
569 }
570 }
571 } else {
572 prev.push(left);
573 }
574 left = right;
575 right += 1;
576 }
577 merged == self.root
579 }
580}