1use crate::{
2 collections::{BTreeMap, VecDeque},
3 error::{Error, Result},
4 merge::{hash_leaf, merge},
5 merkle_proof::MerkleProof,
6 proof_ics23,
7 traits::{Hasher, Store, Value},
8 vec::Vec,
9 Key, InternalKey, EXPECTED_PATH_SIZE, H256,
10};
11#[cfg(feature = "borsh")]
12use borsh::{BorshDeserialize, BorshSerialize};
13use core::{cmp::max, marker::PhantomData};
14use ics23::commitment_proof::Proof;
15use ics23::{CommitmentProof, NonExistenceProof};
16
17#[derive(Debug, Eq, PartialEq, Clone)]
19#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
20pub struct BranchNode<K, const N: usize>
21where
22 K: Key<N>,
23{
24 pub fork_height: usize,
25 pub key: K,
26 pub node: H256,
27 pub sibling: H256,
28}
29
30impl<K, const N: usize> BranchNode<K, N>
31where
32 K: Key<N>,
33{
34 fn branch(&self, height: usize) -> (&H256, &H256) {
35 let is_right = self.key.get_bit(height);
36 if is_right {
37 (&self.sibling, &self.node)
38 } else {
39 (&self.node, &self.sibling)
40 }
41 }
42}
43
44#[derive(Debug, Eq, PartialEq, Clone)]
46#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
47pub struct LeafNode<K, V, const N: usize>
48where
49 K: Key<N>,
50{
51 pub key: K,
52 pub value: V,
53}
54
55#[derive(Debug)]
57pub struct SparseMerkleTree<H, K, V, S, const N: usize>
58where
59 H: Hasher + Default,
60 K: Key<N>,
61 V: Value,
62 S: Store<K, V, N>,
63{
64 store: S,
65 root: H256,
66 phantom: PhantomData<(H, K, V)>,
67}
68
69impl<H, K, V, S, const N: usize> Default for SparseMerkleTree<H, K, V, S, N>
70where
71 H: Hasher + Default,
72 K: Key<N>,
73 V: Value + core::cmp::PartialEq,
74 S: Store<K, V, N>,
75{
76 fn default() -> Self {
77 Self::new(H256::default(), S::default())
78 }
79}
80
81impl<H, K, V, S, const N: usize> SparseMerkleTree<H, K, V, S, N>
82where
83 H: Hasher + Default,
84 K: Key<N>,
85 V: Value + core::cmp::PartialEq,
86 S: Store<K, V, N>,
87{
88 pub fn new(root: H256, store: S) -> SparseMerkleTree<H, K, V, S, N> {
90 SparseMerkleTree {
91 root,
92 store,
93 phantom: PhantomData,
94 }
95 }
96
97 pub fn root(&self) -> &H256 {
99 &self.root
100 }
101
102 pub fn is_empty(&self) -> bool {
104 self.root.is_zero()
105 }
106
107 pub fn take_store(self) -> S {
109 self.store
110 }
111
112 pub fn store(&self) -> &S {
114 &self.store
115 }
116
117 pub fn store_mut(&mut self) -> &mut S {
119 &mut self.store
120 }
121
122 pub fn update(&mut self, key: K, value: V) -> Result<&H256> {
125 let mut path: BTreeMap<_, _> = Default::default();
127 let mut node = self.root;
129 let mut branch = self.store.get_branch(&node)?;
130 let mut height = branch
131 .as_ref()
132 .map(|b| max(b.key.fork_height(&key), b.fork_height))
133 .unwrap_or(0);
134 while branch.is_some() {
137 let branch_node = branch.unwrap();
138 let fork_height = max(key.fork_height(&branch_node.key), branch_node.fork_height);
139 if height > branch_node.fork_height {
140 path.insert(fork_height, node);
143 break;
144 }
145 if branch_node.fork_height > 0 {
148 self.store.remove_branch(&node)?;
149 }
150 let (left, right) = branch_node.branch(height);
151 let is_right = key.get_bit(height);
152 let sibling = if is_right {
153 if &node == right {
154 break;
155 }
156 node = *right;
157 *left
158 } else {
159 if &node == left {
160 break;
161 }
162 node = *left;
163 *right
164 };
165 path.insert(height, sibling);
166 branch = self.store.get_branch(&node)?;
168 if let Some(branch_node) = branch.as_ref() {
169 height = max(key.fork_height(&branch_node.key), branch_node.fork_height);
170 }
171 }
172 if let Some(leaf) = self.store.get_leaf(&node)? {
174 if leaf.key == key {
175 self.store.remove_leaf(&node)?;
176 self.store.remove_branch(&node)?;
177 }
178 }
179
180 let mut node = hash_leaf::<H, K, V, N>(&key, &value);
182 if !node.is_zero() {
184 self.store.insert_leaf(node, LeafNode { key, value })?;
185
186 self.store.insert_branch(
188 node,
189 BranchNode {
190 key,
191 fork_height: 0,
192 node,
193 sibling: H256::zero(),
194 },
195 )?;
196 }
197
198 while !path.is_empty() {
200 let height = path.iter().next().map(|(height, _)| *height).unwrap();
202 let sibling = path.remove(&height).unwrap();
203
204 let is_right = key.get_bit(height);
205 let parent = if is_right {
206 merge::<H>(&sibling, &node)
207 } else {
208 merge::<H>(&node, &sibling)
209 };
210
211 if !node.is_zero() {
212 let branch_node = BranchNode {
214 fork_height: height,
215 sibling,
216 node,
217 key,
218 };
219 self.store.insert_branch(parent, branch_node)?;
220 }
221 node = parent;
222 }
223 self.root = node;
224 Ok(&self.root)
225 }
226
227 pub fn get(&self, key: &K) -> Result<V> {
230 let mut node = self.root;
231 while !node.is_zero() {
233 let branch_node = match self.store.get_branch(&node)? {
234 Some(branch_node) => branch_node,
235 None => {
236 break;
237 }
238 };
239 let is_right = key.get_bit(branch_node.fork_height);
240 let (left, right) = branch_node.branch(branch_node.fork_height);
241 node = if is_right { *right } else { *left };
242 if branch_node.fork_height == 0 {
243 break;
244 }
245 }
246
247 if node.is_zero() {
249 return Ok(V::zero());
250 }
251 match self.store.get_leaf(&node)? {
253 Some(leaf) if &leaf.key == key => Ok(leaf.value),
254 _ => Ok(V::zero()),
255 }
256 }
257
258 fn fetch_merkle_path(
261 &self,
262 key: &K,
263 cache: &mut BTreeMap<(usize, InternalKey<N>), H256>,
264 ) -> Result<()> {
265 let mut node = self.root;
266 let mut height = self
267 .store
268 .get_branch(&node)?
269 .map(|b| max(b.key.fork_height(key), b.fork_height))
270 .unwrap_or(0);
271 while !node.is_zero() {
272 if node.is_zero() {
274 break;
275 }
276 match self.store.get_branch(&node)? {
277 Some(branch_node) => {
278 if height > branch_node.fork_height {
279 let fork_height =
280 max(key.fork_height(&branch_node.key), branch_node.fork_height);
281
282 let is_right = key.get_bit(fork_height);
283 let mut sibling_key = key.parent_path(fork_height);
284 if !is_right {
285 sibling_key.set_bit(height);
287 };
288 if !node.is_zero() {
289 cache
290 .entry((fork_height as usize, sibling_key))
291 .or_insert(node);
292 }
293 break;
294 }
295 let (left, right) = branch_node.branch(height);
296 let is_right = key.get_bit(height);
297 let sibling = if is_right {
298 if &node == right {
299 break;
300 }
301 node = *right;
302 *left
303 } else {
304 if &node == left {
305 break;
306 }
307 node = *left;
308 *right
309 };
310 let mut sibling_key = key.parent_path(height);
311 if !is_right {
312 sibling_key.set_bit(height);
314 };
315 cache.insert((height as usize, sibling_key), sibling);
316 if let Some(branch_node) = self.store.get_branch(&node)? {
317 let fork_height =
318 max(key.fork_height(&branch_node.key), branch_node.fork_height);
319 height = fork_height;
320 }
321 }
322 None => break,
323 };
324 }
325 Ok(())
326 }
327
328 pub fn merkle_proof(&self, mut keys: Vec<K>) -> Result<MerkleProof> {
330 if keys.is_empty() {
331 return Err(Error::EmptyKeys);
332 }
333
334 keys.sort_unstable_by_key(|k| **k);
336
337 let mut cache: BTreeMap<(usize, _), H256> = Default::default();
339 for k in &keys {
340 self.fetch_merkle_path(k, &mut cache)?;
341 }
342
343 let mut proof: Vec<(H256, usize)> = Vec::with_capacity(EXPECTED_PATH_SIZE * keys.len());
345 let mut leaves_path: Vec<Vec<usize>> = Vec::with_capacity(keys.len());
347 leaves_path.resize_with(keys.len(), Default::default);
348
349 let keys_len = keys.len();
350 let mut queue: VecDeque<(_, usize, usize)> = keys
353 .into_iter()
354 .enumerate()
355 .map(|(i, k)| (*k, 0, i))
356 .collect();
357
358 while let Some((key, height, leaf_index)) = queue.pop_front() {
359 if queue.is_empty() && cache.is_empty() || height == 8 * N {
360 if leaves_path[leaf_index].is_empty() {
362 leaves_path[leaf_index].push((8 * N) - 1);
363 }
364 break;
365 }
366 let mut sibling_key = key.parent_path(height);
368
369 let is_right = key.get_bit(height);
370 if is_right {
371 sibling_key.clear_bit(height);
373 } else {
374 sibling_key.set_bit(height);
376 }
377 if Some((&sibling_key, &height))
378 == queue
379 .front()
380 .map(|(sibling_key, height, _leaf_index)| (sibling_key, height))
381 {
382 let (_sibling_key, height, leaf_index) = queue.pop_front().unwrap();
384 leaves_path[leaf_index].push(height);
385 } else {
386 match cache.remove(&(height, sibling_key)) {
387 Some(sibling) => {
388 debug_assert!(height < 8 * N);
389 proof.push((sibling, height));
391 }
392 None => {
393 if !is_right {
395 sibling_key.clear_bit(height);
396 }
397 let parent_key = sibling_key;
398 queue.push_back((parent_key, height + 1, leaf_index));
399 continue;
400 }
401 }
402 }
403 leaves_path[leaf_index].push(height);
405 if height < 8 * N {
406 let parent_key = if is_right { sibling_key } else { key };
408 queue.push_back((parent_key, height + 1, leaf_index));
409 }
410 }
411 debug_assert_eq!(leaves_path.len(), keys_len);
412 Ok(MerkleProof::new(leaves_path, proof))
413 }
414
415 pub fn membership_proof(&self, key: &K) -> Result<CommitmentProof> {
417 let value = self.get(key)?;
418 if value == V::zero() {
419 return Err(Error::ExistenceProof);
420 }
421 let merkle_proof = self.merkle_proof(vec![*key])?;
422 let existence_proof =
423 proof_ics23::convert(merkle_proof, key, &value, H::hash_op())?;
424 Ok(CommitmentProof {
425 proof: Some(Proof::Exist(existence_proof)),
426 })
427 }
428
429 pub fn non_membership_proof(&self, key: &K) -> Result<CommitmentProof> {
431 let value = self.get(key)?;
432 if value != V::zero() {
433 return Err(Error::NonExistenceProof);
434 }
435
436 let mut cache: BTreeMap<(usize, _), H256> = Default::default();
438 self.fetch_merkle_path(key, &mut cache)?;
439 let mut left = None;
440 let mut right = None;
441 for (_, node) in cache.iter() {
442 let branch = self
443 .store
444 .get_branch(node)?
445 .expect("the forked branch should exist");
446 let fork_height = key.fork_height(&branch.key);
447 let is_right = key.get_bit(fork_height);
448 if is_right && left.is_none() {
449 let mut n = *node;
451 while let Some(branch) = self.store.get_branch(&n)? {
452 if branch.fork_height == 0 {
453 break;
454 }
455 let (left_node, right_node) = branch.branch(branch.fork_height);
456 n = if right_node.is_zero() {
457 *left_node
458 } else {
459 *right_node
460 };
461 }
462 let leaf = self.store.get_leaf(&n)?.expect("the leaf should exist");
463 let merkle_proof = self.merkle_proof(vec![leaf.key])?;
464 left = Some(proof_ics23::convert(
465 merkle_proof,
466 &leaf.key,
467 &leaf.value,
468 H::hash_op(),
469 )?);
470 } else if !is_right && right.is_none() {
471 let mut n = *node;
473 while let Some(branch) = self.store.get_branch(&n)? {
474 if branch.fork_height == 0 {
475 break;
476 }
477 let (left_node, right_node) = branch.branch(branch.fork_height);
478 n = if left_node.is_zero() {
479 *right_node
480 } else {
481 *left_node
482 };
483 }
484 let leaf = self.store.get_leaf(&n)?.expect("the leaf should exist");
485 let merkle_proof = self.merkle_proof(vec![leaf.key])?;
486 right = Some(proof_ics23::convert(
487 merkle_proof,
488 &leaf.key,
489 &leaf.value,
490 H::hash_op(),
491 )?);
492 }
493 if left.is_some() && right.is_some() {
494 break;
495 }
496 }
497 let proof = NonExistenceProof {
498 key: key.to_vec(),
499 left,
500 right,
501 };
502 Ok(CommitmentProof {
503 proof: Some(Proof::Nonexist(proof)),
504 })
505 }
506
507 pub fn validate(&self) -> bool {
510 let pairs = {
512 let sorted_leaves = self.store.sorted_leaves();
513 let mut other = self.store.sorted_leaves();
514 _ = other.next();
515 sorted_leaves.zip(other)
516 };
517
518 if self.store.size() == 0 {
520 return self.root == H256::zero()
521 }
522
523 let mut leaves = Vec::with_capacity(self.store.size());
525 for ((k1, v1), (k2, _)) in pairs {
526 let height = k1.fork_height(&k2);
527 let hash = hash_leaf::<H, K, V, N>(&k1, &v1);
528 leaves.push((hash, height));
529 }
530 let (last_k, last_v) = self.store
531 .sorted_leaves()
532 .last()
533 .map(|(k, v)| (k, v))
534 .unwrap();
535 let last = hash_leaf::<H, K, V, N>(&last_k, last_v);
536 if leaves.is_empty() {
537 return self.root == last;
538 }
539 leaves.push((last, usize::MAX));
540
541 let find_next = |leaves: &[(H256, usize)]| {
544 for ix in 0..leaves.len() - 1 {
545 if leaves[ix].1 < leaves[ix + 1].1 {
546 return ix;
547 }
548 }
549 unreachable!()
550 };
551
552 loop {
553 let next_left = find_next(&leaves);
555 let next_right = next_left + 1;
556 let merged = merge::<H>(&leaves[next_left].0, &leaves[next_right].0);
557 let (_, dist) = leaves.remove(next_right);
559 leaves[next_left].0 = merged;
560 leaves[next_left].1 = dist;
561 if leaves.len() == 1 {
563 break;
564 }
565 }
566
567 leaves[0].0 == self.root
569 }
570}