nam_sparse_merkle_tree/
tree.rs

1use crate::{
2    collections::{BTreeMap, VecDeque},
3    error::{Error, Result},
4    merge::{hash_leaf, merge},
5    merkle_proof::MerkleProof,
6    proof_ics23,
7    traits::{Hasher, Store, Value},
8    vec::Vec,
9    Key, InternalKey, EXPECTED_PATH_SIZE, H256,
10};
11#[cfg(feature = "borsh")]
12use borsh::{BorshDeserialize, BorshSerialize};
13use core::{cmp::max, marker::PhantomData};
14use ics23::commitment_proof::Proof;
15use ics23::{CommitmentProof, NonExistenceProof};
16
17/// A branch in the SMT
18#[derive(Debug, Eq, PartialEq, Clone)]
19#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
20pub struct BranchNode<K, const N: usize>
21where
22    K: Key<N>,
23{
24    pub fork_height: usize,
25    pub key: K,
26    pub node: H256,
27    pub sibling: H256,
28}
29
30impl<K, const N: usize> BranchNode<K, N>
31where
32    K: Key<N>,
33{
34    fn branch(&self, height: usize) -> (&H256, &H256) {
35        let is_right = self.key.get_bit(height);
36        if is_right {
37            (&self.sibling, &self.node)
38        } else {
39            (&self.node, &self.sibling)
40        }
41    }
42}
43
44/// A leaf in the SMT
45#[derive(Debug, Eq, PartialEq, Clone)]
46#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
47pub struct LeafNode<K, V, const N: usize>
48where
49    K: Key<N>,
50{
51    pub key: K,
52    pub value: V,
53}
54
55/// Sparse merkle tree
56#[derive(Debug)]
57pub struct SparseMerkleTree<H, K, V, S, const N: usize>
58where
59    H: Hasher + Default,
60    K: Key<N>,
61    V: Value,
62    S: Store<K, V, N>,
63{
64    store: S,
65    root: H256,
66    phantom: PhantomData<(H, K, V)>,
67}
68
69impl<H, K, V, S, const N: usize> Default for SparseMerkleTree<H, K, V, S, N>
70where
71    H: Hasher + Default,
72    K: Key<N>,
73    V: Value + core::cmp::PartialEq,
74    S: Store<K, V, N>,
75{
76    fn default() -> Self {
77        Self::new(H256::default(), S::default())
78    }
79}
80
81impl<H, K, V, S, const N: usize> SparseMerkleTree<H, K, V, S, N>
82where
83    H: Hasher + Default,
84    K: Key<N>,
85    V: Value + core::cmp::PartialEq,
86    S: Store<K, V, N>,
87{
88    /// Build a merkle tree from root and store
89    pub fn new(root: H256, store: S) -> SparseMerkleTree<H, K, V, S, N> {
90        SparseMerkleTree {
91            root,
92            store,
93            phantom: PhantomData,
94        }
95    }
96
97    /// Merkle root
98    pub fn root(&self) -> &H256 {
99        &self.root
100    }
101
102    /// Check empty of the tree
103    pub fn is_empty(&self) -> bool {
104        self.root.is_zero()
105    }
106
107    /// Destroy current tree and retake store
108    pub fn take_store(self) -> S {
109        self.store
110    }
111
112    /// Get backend store
113    pub fn store(&self) -> &S {
114        &self.store
115    }
116
117    /// Get mutable backend store
118    pub fn store_mut(&mut self) -> &mut S {
119        &mut self.store
120    }
121
122    /// Update a leaf, return new merkle root
123    /// set to zero value to delete a key
124    pub fn update(&mut self, key: K, value: V) -> Result<&H256> {
125        // store the path, sparse index will ignore zero members
126        let mut path: BTreeMap<_, _> = Default::default();
127        // walk path from root to leaf
128        let mut node = self.root;
129        let mut branch = self.store.get_branch(&node)?;
130        let mut height = branch
131            .as_ref()
132            .map(|b| max(b.key.fork_height(&key), b.fork_height))
133            .unwrap_or(0);
134        // branch.is_none() represents the descendants are zeros, so we can stop the
135        // loop
136        while branch.is_some() {
137            let branch_node = branch.unwrap();
138            let fork_height = max(key.fork_height(&branch_node.key), branch_node.fork_height);
139            if height > branch_node.fork_height {
140                // the merge height is higher than node, so we do not need to remove node's
141                // branch
142                path.insert(fork_height, node);
143                break;
144            }
145            // branch node is parent if height is less than branch_node's height
146            // remove it from store
147            if branch_node.fork_height > 0 {
148                self.store.remove_branch(&node)?;
149            }
150            let (left, right) = branch_node.branch(height);
151            let is_right = key.get_bit(height);
152            let sibling = if is_right {
153                if &node == right {
154                    break;
155                }
156                node = *right;
157                *left
158            } else {
159                if &node == left {
160                    break;
161                }
162                node = *left;
163                *right
164            };
165            path.insert(height, sibling);
166            // get next branch and fork_height
167            branch = self.store.get_branch(&node)?;
168            if let Some(branch_node) = branch.as_ref() {
169                height = max(key.fork_height(&branch_node.key), branch_node.fork_height);
170            }
171        }
172        // delete previous leaf
173        if let Some(leaf) = self.store.get_leaf(&node)? {
174            if leaf.key == key {
175                self.store.remove_leaf(&node)?;
176                self.store.remove_branch(&node)?;
177            }
178        }
179
180        // compute and store new leaf
181        let mut node = hash_leaf::<H, K, V, N>(&key, &value);
182        // notice when value is zero the leaf is deleted, so we do not need to store it
183        if !node.is_zero() {
184            self.store.insert_leaf(node, LeafNode { key, value })?;
185
186            // build at least one branch for leaf
187            self.store.insert_branch(
188                node,
189                BranchNode {
190                    key,
191                    fork_height: 0,
192                    node,
193                    sibling: H256::zero(),
194                },
195            )?;
196        }
197
198        // recompute the tree from top to bottom
199        while !path.is_empty() {
200            // pop from path
201            let height = path.iter().next().map(|(height, _)| *height).unwrap();
202            let sibling = path.remove(&height).unwrap();
203
204            let is_right = key.get_bit(height);
205            let parent = if is_right {
206                merge::<H>(&sibling, &node)
207            } else {
208                merge::<H>(&node, &sibling)
209            };
210
211            if !node.is_zero() {
212                // node exists
213                let branch_node = BranchNode {
214                    fork_height: height,
215                    sibling,
216                    node,
217                    key,
218                };
219                self.store.insert_branch(parent, branch_node)?;
220            }
221            node = parent;
222        }
223        self.root = node;
224        Ok(&self.root)
225    }
226
227    /// Get value of a leaf
228    /// return zero value if leaf not exists
229    pub fn get(&self, key: &K) -> Result<V> {
230        let mut node = self.root;
231        // children must equal zero when parent equals zero
232        while !node.is_zero() {
233            let branch_node = match self.store.get_branch(&node)? {
234                Some(branch_node) => branch_node,
235                None => {
236                    break;
237                }
238            };
239            let is_right = key.get_bit(branch_node.fork_height);
240            let (left, right) = branch_node.branch(branch_node.fork_height);
241            node = if is_right { *right } else { *left };
242            if branch_node.fork_height == 0 {
243                break;
244            }
245        }
246
247        // return zero is leaf_key is zero
248        if node.is_zero() {
249            return Ok(V::zero());
250        }
251        // get leaf node
252        match self.store.get_leaf(&node)? {
253            Some(leaf) if &leaf.key == key => Ok(leaf.value),
254            _ => Ok(V::zero()),
255        }
256    }
257
258    /// fetch merkle path of key into cache
259    /// cache: (height, key) -> node
260    fn fetch_merkle_path(
261        &self,
262        key: &K,
263        cache: &mut BTreeMap<(usize, InternalKey<N>), H256>,
264    ) -> Result<()> {
265        let mut node = self.root;
266        let mut height = self
267            .store
268            .get_branch(&node)?
269            .map(|b| max(b.key.fork_height(key), b.fork_height))
270            .unwrap_or(0);
271        while !node.is_zero() {
272            // the descendants are zeros, so we can break the loop
273            if node.is_zero() {
274                break;
275            }
276            match self.store.get_branch(&node)? {
277                Some(branch_node) => {
278                    if height > branch_node.fork_height {
279                        let fork_height =
280                            max(key.fork_height(&branch_node.key), branch_node.fork_height);
281
282                        let is_right = key.get_bit(fork_height);
283                        let mut sibling_key = key.parent_path(fork_height);
284                        if !is_right {
285                            // mark sibling's index, sibling on the right path.
286                            sibling_key.set_bit(height);
287                        };
288                        if !node.is_zero() {
289                            cache
290                                .entry((fork_height as usize, sibling_key))
291                                .or_insert(node);
292                        }
293                        break;
294                    }
295                    let (left, right) = branch_node.branch(height);
296                    let is_right = key.get_bit(height);
297                    let sibling = if is_right {
298                        if &node == right {
299                            break;
300                        }
301                        node = *right;
302                        *left
303                    } else {
304                        if &node == left {
305                            break;
306                        }
307                        node = *left;
308                        *right
309                    };
310                    let mut sibling_key = key.parent_path(height);
311                    if !is_right {
312                        // mark sibling's index, sibling on the right path.
313                        sibling_key.set_bit(height);
314                    };
315                    cache.insert((height as usize, sibling_key), sibling);
316                    if let Some(branch_node) = self.store.get_branch(&node)? {
317                        let fork_height =
318                            max(key.fork_height(&branch_node.key), branch_node.fork_height);
319                        height = fork_height;
320                    }
321                }
322                None => break,
323            };
324        }
325        Ok(())
326    }
327
328    /// Generate merkle proof
329    pub fn merkle_proof(&self, mut keys: Vec<K>) -> Result<MerkleProof> {
330        if keys.is_empty() {
331            return Err(Error::EmptyKeys);
332        }
333
334        // sort keys
335        keys.sort_unstable_by_key(|k| **k);
336
337        // fetch all merkle path
338        let mut cache: BTreeMap<(usize, _), H256> = Default::default();
339        for k in &keys {
340            self.fetch_merkle_path(k, &mut cache)?;
341        }
342
343        // (node, height)
344        let mut proof: Vec<(H256, usize)> = Vec::with_capacity(EXPECTED_PATH_SIZE * keys.len());
345        // key_index -> merkle path height
346        let mut leaves_path: Vec<Vec<usize>> = Vec::with_capacity(keys.len());
347        leaves_path.resize_with(keys.len(), Default::default);
348
349        let keys_len = keys.len();
350        // build merkle proofs from bottom to up
351        // (key, height, key_index)
352        let mut queue: VecDeque<(_, usize, usize)> = keys
353            .into_iter()
354            .enumerate()
355            .map(|(i, k)| (*k, 0, i))
356            .collect();
357
358        while let Some((key, height, leaf_index)) = queue.pop_front() {
359            if queue.is_empty() && cache.is_empty() || height == 8 * N {
360                // tree only contains one leaf
361                if leaves_path[leaf_index].is_empty() {
362                    leaves_path[leaf_index].push((8 * N) - 1);
363                }
364                break;
365            }
366            // compute sibling key
367            let mut sibling_key = key.parent_path(height);
368
369            let is_right = key.get_bit(height);
370            if is_right {
371                // sibling on left
372                sibling_key.clear_bit(height);
373            } else {
374                // sibling on right
375                sibling_key.set_bit(height);
376            }
377            if Some((&sibling_key, &height))
378                == queue
379                    .front()
380                    .map(|(sibling_key, height, _leaf_index)| (sibling_key, height))
381            {
382                // drop the sibling, mark sibling's merkle path
383                let (_sibling_key, height, leaf_index) = queue.pop_front().unwrap();
384                leaves_path[leaf_index].push(height);
385            } else {
386                match cache.remove(&(height, sibling_key)) {
387                    Some(sibling) => {
388                        debug_assert!(height < 8 * N);
389                        // save first non-zero sibling's height for leaves
390                        proof.push((sibling, height));
391                    }
392                    None => {
393                        // skip zero siblings
394                        if !is_right {
395                            sibling_key.clear_bit(height);
396                        }
397                        let parent_key = sibling_key;
398                        queue.push_back((parent_key, height + 1, leaf_index));
399                        continue;
400                    }
401                }
402            }
403            // find new non-zero sibling, append to leaf's path
404            leaves_path[leaf_index].push(height);
405            if height < 8 * N {
406                // get parent_key, which k.get_bit(height) is false
407                let parent_key = if is_right { sibling_key } else { key };
408                queue.push_back((parent_key, height + 1, leaf_index));
409            }
410        }
411        debug_assert_eq!(leaves_path.len(), keys_len);
412        Ok(MerkleProof::new(leaves_path, proof))
413    }
414
415    /// Generate ICS 23 commitment proof for the existing key
416    pub fn membership_proof(&self, key: &K) -> Result<CommitmentProof> {
417        let value = self.get(key)?;
418        if value == V::zero() {
419            return Err(Error::ExistenceProof);
420        }
421        let merkle_proof = self.merkle_proof(vec![*key])?;
422        let existence_proof =
423            proof_ics23::convert(merkle_proof, key, &value, H::hash_op())?;
424        Ok(CommitmentProof {
425            proof: Some(Proof::Exist(existence_proof)),
426        })
427    }
428
429    /// Generate ICS 23 commitment proof for the non-existing key
430    pub fn non_membership_proof(&self, key: &K) -> Result<CommitmentProof> {
431        let value = self.get(key)?;
432        if value != V::zero() {
433            return Err(Error::NonExistenceProof);
434        }
435
436        // fetch all merkle path
437        let mut cache: BTreeMap<(usize, _), H256> = Default::default();
438        self.fetch_merkle_path(key, &mut cache)?;
439        let mut left = None;
440        let mut right = None;
441        for (_, node) in cache.iter() {
442            let branch = self
443                .store
444                .get_branch(node)?
445                .expect("the forked branch should exist");
446            let fork_height = key.fork_height(&branch.key);
447            let is_right = key.get_bit(fork_height);
448            if is_right && left.is_none() {
449                // get the left which is the most right in the left subtree
450                let mut n = *node;
451                while let Some(branch) = self.store.get_branch(&n)? {
452                    if branch.fork_height == 0 {
453                        break;
454                    }
455                    let (left_node, right_node) = branch.branch(branch.fork_height);
456                    n = if right_node.is_zero() {
457                        *left_node
458                    } else {
459                        *right_node
460                    };
461                }
462                let leaf = self.store.get_leaf(&n)?.expect("the leaf should exist");
463                let merkle_proof = self.merkle_proof(vec![leaf.key])?;
464                left = Some(proof_ics23::convert(
465                    merkle_proof,
466                    &leaf.key,
467                    &leaf.value,
468                    H::hash_op(),
469                )?);
470            } else if !is_right && right.is_none() {
471                // get the right which is the most left in the right subtree
472                let mut n = *node;
473                while let Some(branch) = self.store.get_branch(&n)? {
474                    if branch.fork_height == 0 {
475                        break;
476                    }
477                    let (left_node, right_node) = branch.branch(branch.fork_height);
478                    n = if left_node.is_zero() {
479                        *right_node
480                    } else {
481                        *left_node
482                    };
483                }
484                let leaf = self.store.get_leaf(&n)?.expect("the leaf should exist");
485                let merkle_proof = self.merkle_proof(vec![leaf.key])?;
486                right = Some(proof_ics23::convert(
487                    merkle_proof,
488                    &leaf.key,
489                    &leaf.value,
490                    H::hash_op(),
491                )?);
492            }
493            if left.is_some() && right.is_some() {
494                break;
495            }
496        }
497        let proof = NonExistenceProof {
498            key: key.to_vec(),
499            left,
500            right,
501        };
502        Ok(CommitmentProof {
503            proof: Some(Proof::Nonexist(proof)),
504        })
505    }
506
507    /// Recompute the root of the merkle tree from the store. Check if it agrees with the
508    /// root in `self`.
509    pub fn validate(&self) -> bool {
510        // create an iterator over consecutive pairs of leaves
511        let pairs = {
512            let sorted_leaves = self.store.sorted_leaves();
513            let mut other = self.store.sorted_leaves();
514            _ = other.next();
515            sorted_leaves.zip(other)
516        };
517
518        // handle case when tree is empty
519        if self.store.size() == 0 {
520            return self.root == H256::zero()
521        }
522
523        // construct a vector of nodes and distance to next node
524        let mut leaves = Vec::with_capacity(self.store.size());
525        for ((k1, v1), (k2, _)) in pairs {
526            let height = k1.fork_height(&k2);
527            let hash = hash_leaf::<H, K, V, N>(&k1, &v1);
528            leaves.push((hash, height));
529        }
530        let (last_k, last_v) = self.store
531            .sorted_leaves()
532            .last()
533            .map(|(k, v)| (k, v))
534            .unwrap();
535        let last = hash_leaf::<H, K, V, N>(&last_k, last_v);
536        if leaves.is_empty() {
537            return self.root == last;
538        }
539        leaves.push((last, usize::MAX));
540
541        // find the next node `n` such that `n+1` is the closest neighbor
542        // of `n` and vice versa. These can be merged
543        let find_next = |leaves: &[(H256, usize)]| {
544            for ix in 0..leaves.len() - 1 {
545                if leaves[ix].1 < leaves[ix + 1].1 {
546                    return ix;
547                }
548            }
549            unreachable!()
550        };
551
552        loop {
553            // find the next pair of nodes to merge
554            let next_left = find_next(&leaves);
555            let next_right = next_left + 1;
556            let merged = merge::<H>(&leaves[next_left].0, &leaves[next_right].0);
557            // perform the merge
558            let (_, dist) = leaves.remove(next_right);
559            leaves[next_left].0 = merged;
560            leaves[next_left].1 = dist;
561            // we have recovered the root
562            if leaves.len() == 1 {
563                break;
564            }
565        }
566
567        // check that the computed root is the same as the precomputed one
568        leaves[0].0 == self.root
569    }
570}