nam_sparse_merkle_tree/
tree.rs

1use crate::{
2    collections::{BTreeMap, VecDeque},
3    error::{Error, Result},
4    merge::{hash_leaf, merge},
5    merkle_proof::MerkleProof,
6    proof_ics23,
7    traits::{Hasher, Store, Value},
8    vec::Vec,
9    Key, InternalKey, EXPECTED_PATH_SIZE, H256,
10};
11#[cfg(feature = "borsh")]
12use borsh::{BorshDeserialize, BorshSerialize};
13use core::{cmp::max, marker::PhantomData};
14use ics23::commitment_proof::Proof;
15use ics23::{CommitmentProof, NonExistenceProof};
16use itertools::Itertools;
17
18/// A branch in the SMT
19#[derive(Debug, Eq, PartialEq, Clone)]
20#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
21pub struct BranchNode<K, const N: usize>
22where
23    K: Key<N>,
24{
25    pub fork_height: usize,
26    pub key: K,
27    pub node: H256,
28    pub sibling: H256,
29}
30
31impl<K, const N: usize> BranchNode<K, N>
32where
33    K: Key<N>,
34{
35    fn branch(&self, height: usize) -> (&H256, &H256) {
36        let is_right = self.key.get_bit(height);
37        if is_right {
38            (&self.sibling, &self.node)
39        } else {
40            (&self.node, &self.sibling)
41        }
42    }
43}
44
45/// A leaf in the SMT
46#[derive(Debug, Eq, PartialEq, Clone)]
47#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
48pub struct LeafNode<K, V, const N: usize>
49where
50    K: Key<N>,
51{
52    pub key: K,
53    pub value: V,
54}
55
56/// Sparse merkle tree
57#[derive(Debug)]
58pub struct SparseMerkleTree<H, K, V, S, const N: usize>
59where
60    H: Hasher + Default,
61    K: Key<N>,
62    V: Value,
63    S: Store<K, V, N>,
64{
65    store: S,
66    root: H256,
67    phantom: PhantomData<(H, K, V)>,
68}
69
70impl<H, K, V, S, const N: usize> Default for SparseMerkleTree<H, K, V, S, N>
71where
72    H: Hasher + Default,
73    K: Key<N>,
74    V: Value + core::cmp::PartialEq,
75    S: Store<K, V, N>,
76{
77    fn default() -> Self {
78        Self::new(H256::default(), S::default())
79    }
80}
81
82impl<H, K, V, S, const N: usize> SparseMerkleTree<H, K, V, S, N>
83where
84    H: Hasher + Default,
85    K: Key<N>,
86    V: Value + core::cmp::PartialEq,
87    S: Store<K, V, N>,
88{
89    /// Build a merkle tree from root and store
90    pub fn new(root: H256, store: S) -> SparseMerkleTree<H, K, V, S, N> {
91        SparseMerkleTree {
92            root,
93            store,
94            phantom: PhantomData,
95        }
96    }
97
98    /// Merkle root
99    pub fn root(&self) -> &H256 {
100        &self.root
101    }
102
103    /// Check empty of the tree
104    pub fn is_empty(&self) -> bool {
105        self.root.is_zero()
106    }
107
108    /// Destroy current tree and retake store
109    pub fn take_store(self) -> S {
110        self.store
111    }
112
113    /// Get backend store
114    pub fn store(&self) -> &S {
115        &self.store
116    }
117
118    /// Get mutable backend store
119    pub fn store_mut(&mut self) -> &mut S {
120        &mut self.store
121    }
122
123    /// Update a leaf, return new merkle root
124    /// set to zero value to delete a key
125    pub fn update(&mut self, key: K, value: V) -> Result<&H256> {
126        // store the path, sparse index will ignore zero members
127        let mut path: BTreeMap<_, _> = Default::default();
128        // walk path from root to leaf
129        let mut node = self.root;
130        let mut branch = self.store.get_branch(&node)?;
131        let mut height = branch
132            .as_ref()
133            .map(|b| max(b.key.fork_height(&key), b.fork_height))
134            .unwrap_or(0);
135        // branch.is_none() represents the descendants are zeros, so we can stop the
136        // loop
137        while branch.is_some() {
138            let branch_node = branch.unwrap();
139            let fork_height = max(key.fork_height(&branch_node.key), branch_node.fork_height);
140            if height > branch_node.fork_height {
141                // the merge height is higher than node, so we do not need to remove node's
142                // branch
143                path.insert(fork_height, node);
144                break;
145            }
146            // branch node is parent if height is less than branch_node's height
147            // remove it from store
148            if branch_node.fork_height > 0 {
149                self.store.remove_branch(&node)?;
150            }
151            let (left, right) = branch_node.branch(height);
152            let is_right = key.get_bit(height);
153            let sibling = if is_right {
154                if &node == right {
155                    break;
156                }
157                node = *right;
158                *left
159            } else {
160                if &node == left {
161                    break;
162                }
163                node = *left;
164                *right
165            };
166            path.insert(height, sibling);
167            // get next branch and fork_height
168            branch = self.store.get_branch(&node)?;
169            if let Some(branch_node) = branch.as_ref() {
170                height = max(key.fork_height(&branch_node.key), branch_node.fork_height);
171            }
172        }
173        // delete previous leaf
174        if let Some(leaf) = self.store.get_leaf(&node)? {
175            if leaf.key == key {
176                self.store.remove_leaf(&node)?;
177                self.store.remove_branch(&node)?;
178            }
179        }
180
181        // compute and store new leaf
182        let mut node = hash_leaf::<H, K, V, N>(&key, &value);
183        // notice when value is zero the leaf is deleted, so we do not need to store it
184        if !node.is_zero() {
185            self.store.insert_leaf(node, LeafNode { key, value })?;
186
187            // build at least one branch for leaf
188            self.store.insert_branch(
189                node,
190                BranchNode {
191                    key,
192                    fork_height: 0,
193                    node,
194                    sibling: H256::zero(),
195                },
196            )?;
197        }
198
199        // recompute the tree from top to bottom
200        while !path.is_empty() {
201            // pop from path
202            let height = path.iter().next().map(|(height, _)| *height).unwrap();
203            let sibling = path.remove(&height).unwrap();
204
205            let is_right = key.get_bit(height);
206            let parent = if is_right {
207                merge::<H>(&sibling, &node)
208            } else {
209                merge::<H>(&node, &sibling)
210            };
211
212            if !node.is_zero() {
213                // node exists
214                let branch_node = BranchNode {
215                    fork_height: height,
216                    sibling,
217                    node,
218                    key,
219                };
220                self.store.insert_branch(parent, branch_node)?;
221            }
222            node = parent;
223        }
224        self.root = node;
225        Ok(&self.root)
226    }
227
228    /// Get value of a leaf
229    /// return zero value if leaf not exists
230    pub fn get(&self, key: &K) -> Result<V> {
231        let mut node = self.root;
232        // children must equal zero when parent equals zero
233        while !node.is_zero() {
234            let branch_node = match self.store.get_branch(&node)? {
235                Some(branch_node) => branch_node,
236                None => {
237                    break;
238                }
239            };
240            let is_right = key.get_bit(branch_node.fork_height);
241            let (left, right) = branch_node.branch(branch_node.fork_height);
242            node = if is_right { *right } else { *left };
243            if branch_node.fork_height == 0 {
244                break;
245            }
246        }
247
248        // return zero is leaf_key is zero
249        if node.is_zero() {
250            return Ok(V::zero());
251        }
252        // get leaf node
253        match self.store.get_leaf(&node)? {
254            Some(leaf) if &leaf.key == key => Ok(leaf.value),
255            _ => Ok(V::zero()),
256        }
257    }
258
259    /// fetch merkle path of key into cache
260    /// cache: (height, key) -> node
261    fn fetch_merkle_path(
262        &self,
263        key: &K,
264        cache: &mut BTreeMap<(usize, InternalKey<N>), H256>,
265    ) -> Result<()> {
266        let mut node = self.root;
267        let mut height = self
268            .store
269            .get_branch(&node)?
270            .map(|b| max(b.key.fork_height(key), b.fork_height))
271            .unwrap_or(0);
272        while !node.is_zero() {
273            // the descendants are zeros, so we can break the loop
274            if node.is_zero() {
275                break;
276            }
277            match self.store.get_branch(&node)? {
278                Some(branch_node) => {
279                    if height > branch_node.fork_height {
280                        let fork_height =
281                            max(key.fork_height(&branch_node.key), branch_node.fork_height);
282
283                        let is_right = key.get_bit(fork_height);
284                        let mut sibling_key = key.parent_path(fork_height);
285                        if !is_right {
286                            // mark sibling's index, sibling on the right path.
287                            sibling_key.set_bit(height);
288                        };
289                        if !node.is_zero() {
290                            cache
291                                .entry((fork_height as usize, sibling_key))
292                                .or_insert(node);
293                        }
294                        break;
295                    }
296                    let (left, right) = branch_node.branch(height);
297                    let is_right = key.get_bit(height);
298                    let sibling = if is_right {
299                        if &node == right {
300                            break;
301                        }
302                        node = *right;
303                        *left
304                    } else {
305                        if &node == left {
306                            break;
307                        }
308                        node = *left;
309                        *right
310                    };
311                    let mut sibling_key = key.parent_path(height);
312                    if !is_right {
313                        // mark sibling's index, sibling on the right path.
314                        sibling_key.set_bit(height);
315                    };
316                    cache.insert((height as usize, sibling_key), sibling);
317                    if let Some(branch_node) = self.store.get_branch(&node)? {
318                        let fork_height =
319                            max(key.fork_height(&branch_node.key), branch_node.fork_height);
320                        height = fork_height;
321                    }
322                }
323                None => break,
324            };
325        }
326        Ok(())
327    }
328
329    /// Generate merkle proof
330    pub fn merkle_proof(&self, mut keys: Vec<K>) -> Result<MerkleProof> {
331        if keys.is_empty() {
332            return Err(Error::EmptyKeys);
333        }
334
335        // sort keys
336        keys.sort_unstable_by_key(|k| **k);
337
338        // fetch all merkle path
339        let mut cache: BTreeMap<(usize, _), H256> = Default::default();
340        for k in &keys {
341            self.fetch_merkle_path(k, &mut cache)?;
342        }
343
344        // (node, height)
345        let mut proof: Vec<(H256, usize)> = Vec::with_capacity(EXPECTED_PATH_SIZE * keys.len());
346        // key_index -> merkle path height
347        let mut leaves_path: Vec<Vec<usize>> = Vec::with_capacity(keys.len());
348        leaves_path.resize_with(keys.len(), Default::default);
349
350        let keys_len = keys.len();
351        // build merkle proofs from bottom to up
352        // (key, height, key_index)
353        let mut queue: VecDeque<(_, usize, usize)> = keys
354            .into_iter()
355            .enumerate()
356            .map(|(i, k)| (*k, 0, i))
357            .collect();
358
359        while let Some((key, height, leaf_index)) = queue.pop_front() {
360            if queue.is_empty() && cache.is_empty() || height == 8 * N {
361                // tree only contains one leaf
362                if leaves_path[leaf_index].is_empty() {
363                    leaves_path[leaf_index].push((8 * N) - 1);
364                }
365                break;
366            }
367            // compute sibling key
368            let mut sibling_key = key.parent_path(height);
369
370            let is_right = key.get_bit(height);
371            if is_right {
372                // sibling on left
373                sibling_key.clear_bit(height);
374            } else {
375                // sibling on right
376                sibling_key.set_bit(height);
377            }
378            if Some((&sibling_key, &height))
379                == queue
380                    .front()
381                    .map(|(sibling_key, height, _leaf_index)| (sibling_key, height))
382            {
383                // drop the sibling, mark sibling's merkle path
384                let (_sibling_key, height, leaf_index) = queue.pop_front().unwrap();
385                leaves_path[leaf_index].push(height);
386            } else {
387                match cache.remove(&(height, sibling_key)) {
388                    Some(sibling) => {
389                        debug_assert!(height < 8 * N);
390                        // save first non-zero sibling's height for leaves
391                        proof.push((sibling, height));
392                    }
393                    None => {
394                        // skip zero siblings
395                        if !is_right {
396                            sibling_key.clear_bit(height);
397                        }
398                        let parent_key = sibling_key;
399                        queue.push_back((parent_key, height + 1, leaf_index));
400                        continue;
401                    }
402                }
403            }
404            // find new non-zero sibling, append to leaf's path
405            leaves_path[leaf_index].push(height);
406            if height < 8 * N {
407                // get parent_key, which k.get_bit(height) is false
408                let parent_key = if is_right { sibling_key } else { key };
409                queue.push_back((parent_key, height + 1, leaf_index));
410            }
411        }
412        debug_assert_eq!(leaves_path.len(), keys_len);
413        Ok(MerkleProof::new(leaves_path, proof))
414    }
415
416    /// Generate ICS 23 commitment proof for the existing key
417    pub fn membership_proof(&self, key: &K) -> Result<CommitmentProof> {
418        let value = self.get(key)?;
419        if value == V::zero() {
420            return Err(Error::ExistenceProof);
421        }
422        let merkle_proof = self.merkle_proof(vec![*key])?;
423        let existence_proof =
424            proof_ics23::convert(merkle_proof, key, &value, H::hash_op())?;
425        Ok(CommitmentProof {
426            proof: Some(Proof::Exist(existence_proof)),
427        })
428    }
429
430    /// Generate ICS 23 commitment proof for the non-existing key
431    pub fn non_membership_proof(&self, key: &K) -> Result<CommitmentProof> {
432        let value = self.get(key)?;
433        if value != V::zero() {
434            return Err(Error::NonExistenceProof);
435        }
436
437        // fetch all merkle path
438        let mut cache: BTreeMap<(usize, _), H256> = Default::default();
439        self.fetch_merkle_path(key, &mut cache)?;
440        let mut left = None;
441        let mut right = None;
442        for (_, node) in cache.iter() {
443            let branch = self
444                .store
445                .get_branch(node)?
446                .expect("the forked branch should exist");
447            let fork_height = key.fork_height(&branch.key);
448            let is_right = key.get_bit(fork_height);
449            if is_right && left.is_none() {
450                // get the left which is the most right in the left subtree
451                let mut n = *node;
452                while let Some(branch) = self.store.get_branch(&n)? {
453                    if branch.fork_height == 0 {
454                        break;
455                    }
456                    let (left_node, right_node) = branch.branch(branch.fork_height);
457                    n = if right_node.is_zero() {
458                        *left_node
459                    } else {
460                        *right_node
461                    };
462                }
463                let leaf = self.store.get_leaf(&n)?.expect("the leaf should exist");
464                let merkle_proof = self.merkle_proof(vec![leaf.key])?;
465                left = Some(proof_ics23::convert(
466                    merkle_proof,
467                    &leaf.key,
468                    &leaf.value,
469                    H::hash_op(),
470                )?);
471            } else if !is_right && right.is_none() {
472                // get the right which is the most left in the right subtree
473                let mut n = *node;
474                while let Some(branch) = self.store.get_branch(&n)? {
475                    if branch.fork_height == 0 {
476                        break;
477                    }
478                    let (left_node, right_node) = branch.branch(branch.fork_height);
479                    n = if left_node.is_zero() {
480                        *right_node
481                    } else {
482                        *left_node
483                    };
484                }
485                let leaf = self.store.get_leaf(&n)?.expect("the leaf should exist");
486                let merkle_proof = self.merkle_proof(vec![leaf.key])?;
487                right = Some(proof_ics23::convert(
488                    merkle_proof,
489                    &leaf.key,
490                    &leaf.value,
491                    H::hash_op(),
492                )?);
493            }
494            if left.is_some() && right.is_some() {
495                break;
496            }
497        }
498        let proof = NonExistenceProof {
499            key: key.to_vec(),
500            left,
501            right,
502        };
503        Ok(CommitmentProof {
504            proof: Some(Proof::Nonexist(proof)),
505        })
506    }
507
508    /// Recompute the root of the merkle tree from the store. Check if it agrees with the
509    /// root in `self`.
510    pub fn validate(&self) -> bool {
511        // handle case when tree is empty
512        if self.store.size() == 0 {
513            return self.root == H256::zero()
514        }
515
516        let sorted_leaves = self.store
517            .sorted_leaves()
518            .map(|(k, v)| (k, v.clone()))
519            .collect::<Vec<_>>();
520        // iterator over consecutive pairs of leaves
521        let pairs = sorted_leaves
522            .iter()
523            .tuple_windows::<(_, _)>();
524
525        // construct a vector of nodes and distance to next node
526        let mut leaves = Vec::with_capacity(self.store.size());
527        for ((k1, v1), (k2, _)) in pairs {
528            let height = k1.fork_height(k2);
529            let hash = hash_leaf::<H, K, V, N>(k1, v1);
530            leaves.push((hash, height));
531        }
532        let (last_k, last_v) = sorted_leaves
533            .last()
534            .map(|(k, v)| (k, v))
535            .unwrap();
536        let last = hash_leaf::<H, K, V, N>(last_k, last_v);
537        if leaves.is_empty() {
538            return self.root == last;
539        }
540        leaves.push((last, usize::MAX));
541
542        let mut left: usize = 0;
543        let mut right: usize = 1;
544        let mut merged = Default::default();
545
546        // stack of previous `left` indexes that are yet to be merged
547        let mut prev: Vec<usize> = Vec::with_capacity(leaves.len() / 2);
548
549        // Iterate finding the first node `left` such that `left+1` (`right`) is
550        // its closest neighbor and vice versa, merging them until a single node
551        // remains.
552        while right < leaves.len() {
553            if leaves[left].1 < leaves[right].1 {
554                loop {
555                    // perform merge
556                    merged = merge::<H>(&leaves[left].0, &leaves[right].0);
557                    leaves[right].0 = merged;
558
559                    // check previous `left` node next (if present)
560                    match prev.last() {
561                        Some(&idx) if leaves[idx].1 < leaves[right].1 => {
562                            left = idx;
563                            _ = prev.pop();
564                            continue;
565                        }
566                        _ => {
567                            break;
568                        }
569                    }
570                }
571            } else {
572                prev.push(left);
573            }
574            left = right;
575            right += 1;
576        }
577        // check that the recovered root matches the precomputed one
578        merged == self.root
579    }
580}