restricted_sparse_merkle_tree/
tree.rs

1use crate::{
2    collections::{BTreeMap, VecDeque},
3    error::{Error, Result},
4    merge::{hash_leaf, merge},
5    merkle_proof::MerkleProof,
6    traits::{Hasher, Store, Value},
7    vec::Vec,
8    EXPECTED_PATH_SIZE, H256,
9};
10use core::{cmp::max, marker::PhantomData};
11
12/// A branch in the SMT
13#[derive(Debug, Eq, PartialEq, Clone)]
14pub struct BranchNode {
15    pub fork_height: u8,
16    pub key: H256,
17    pub node_type: NodeType,
18}
19
20impl BranchNode {
21    // get node at a specific height
22    fn node_at(&self, height: u8) -> NodeType {
23        match self.node_type {
24            NodeType::Pair(node, sibling) => {
25                let is_right = self.key.get_bit(height);
26                if is_right {
27                    NodeType::Pair(sibling, node)
28                } else {
29                    NodeType::Pair(node, sibling)
30                }
31            }
32            NodeType::Single(node) => NodeType::Single(node),
33        }
34    }
35
36    fn key(&self) -> &H256 {
37        &self.key
38    }
39}
40
41#[derive(Debug, Eq, PartialEq, Clone)]
42pub enum NodeType {
43    Single(H256),
44    Pair(H256, H256),
45}
46
47/// A leaf in the SMT
48#[derive(Debug, Eq, PartialEq, Clone)]
49pub struct LeafNode<V> {
50    pub key: H256,
51    pub value: V,
52}
53
54/// Sparse merkle tree
55#[derive(Default, Debug)]
56pub struct SparseMerkleTree<H, V, S> {
57    store: S,
58    root: H256,
59    phantom: PhantomData<(H, V)>,
60}
61
62impl<H: Hasher + Default, V: Value, S: Store<V>> SparseMerkleTree<H, V, S> {
63    /// Build a merkle tree from root and store
64    pub fn new(root: H256, store: S) -> SparseMerkleTree<H, V, S> {
65        SparseMerkleTree {
66            root,
67            store,
68            phantom: PhantomData,
69        }
70    }
71
72    /// Merkle root
73    pub fn root(&self) -> &H256 {
74        &self.root
75    }
76
77    /// Check empty of the tree
78    pub fn is_empty(&self) -> bool {
79        self.root.is_zero()
80    }
81
82    /// Destroy current tree and retake store
83    pub fn take_store(self) -> S {
84        self.store
85    }
86
87    /// Get backend store
88    pub fn store(&self) -> &S {
89        &self.store
90    }
91
92    /// Get mutable backend store
93    pub fn store_mut(&mut self) -> &mut S {
94        &mut self.store
95    }
96
97    /// Update a leaf, return new merkle root
98    /// set to zero value to delete a key
99    pub fn update(&mut self, key: H256, value: V) -> Result<&H256> {
100        // store the path, sparse index will ignore zero members
101        let mut path = Vec::new();
102        if !self.is_empty() {
103            let mut node = self.root;
104            loop {
105                let branch_node = self
106                    .store
107                    .get_branch(&node)?
108                    .ok_or_else(|| Error::MissingBranch(node))?;
109                let height = max(key.fork_height(branch_node.key()), branch_node.fork_height);
110                match branch_node.node_at(height) {
111                    NodeType::Pair(left, right) => {
112                        if height > branch_node.fork_height {
113                            // the merge height is higher than node, so we do not need to remove node's branch
114                            path.push((height, node));
115                            break;
116                        } else {
117                            self.store.remove_branch(&node)?;
118                            let is_right = key.get_bit(height);
119                            if is_right {
120                                node = right;
121                                path.push((height, left));
122                            } else {
123                                node = left;
124                                path.push((height, right));
125                            }
126                        }
127                    }
128                    NodeType::Single(node) => {
129                        if &key == branch_node.key() {
130                            self.store.remove_leaf(&node)?;
131                            self.store.remove_branch(&node)?;
132                        } else {
133                            path.push((height, node));
134                        }
135                        break;
136                    }
137                }
138            }
139        }
140
141        // compute and store new leaf
142        let mut node = hash_leaf::<H>(&key, &value.to_h256());
143        // notice when value is zero the leaf is deleted, so we do not need to store it
144        if !node.is_zero() {
145            self.store.insert_leaf(node, LeafNode { key, value })?;
146
147            // build at least one branch for leaf
148            self.store.insert_branch(
149                node,
150                BranchNode {
151                    key,
152                    fork_height: 0,
153                    node_type: NodeType::Single(node),
154                },
155            )?;
156        }
157
158        // recompute the tree from bottom to top
159        for (height, sibling) in path.into_iter().rev() {
160            let is_right = key.get_bit(height);
161            let parent = if is_right {
162                merge::<H>(&sibling, &node)
163            } else {
164                merge::<H>(&node, &sibling)
165            };
166
167            if !node.is_zero() {
168                // node is exists
169                let branch_node = BranchNode {
170                    key,
171                    fork_height: height,
172                    node_type: NodeType::Pair(node, sibling),
173                };
174                self.store.insert_branch(parent, branch_node)?;
175            }
176            node = parent;
177        }
178        self.root = node;
179        Ok(&self.root)
180    }
181
182    /// Get value of a leaf
183    /// return zero value if leaf not exists
184    pub fn get(&self, key: &H256) -> Result<V> {
185        if self.is_empty() {
186            return Ok(V::zero());
187        }
188
189        let mut node = self.root;
190        loop {
191            let branch_node = self
192                .store
193                .get_branch(&node)?
194                .ok_or_else(|| Error::MissingBranch(node))?;
195
196            match branch_node.node_at(branch_node.fork_height) {
197                NodeType::Pair(left, right) => {
198                    let is_right = key.get_bit(branch_node.fork_height);
199                    node = if is_right { right } else { left };
200                }
201                NodeType::Single(node) => {
202                    if key == branch_node.key() {
203                        return Ok(self
204                            .store
205                            .get_leaf(&node)?
206                            .ok_or_else(|| Error::MissingLeaf(node))?
207                            .value);
208                    } else {
209                        return Ok(V::zero());
210                    }
211                }
212            }
213        }
214    }
215
216    /// fetch merkle path of key into cache
217    /// cache: (height, key) -> node
218    fn fetch_merkle_path(&self, key: &H256, cache: &mut BTreeMap<(u8, H256), H256>) -> Result<()> {
219        let mut node = self.root;
220        loop {
221            let branch_node = self
222                .store
223                .get_branch(&node)?
224                .ok_or_else(|| Error::MissingBranch(node))?;
225            let height = max(key.fork_height(branch_node.key()), branch_node.fork_height);
226            let is_right = key.get_bit(height);
227            let mut sibling_key = key.parent_path(height);
228            if !is_right {
229                // mark sibling's index, sibling on the right path.
230                sibling_key.set_bit(height);
231            };
232
233            match branch_node.node_at(height) {
234                NodeType::Pair(left, right) => {
235                    if height > branch_node.fork_height {
236                        cache.entry((height, sibling_key)).or_insert(node);
237                        break;
238                    } else {
239                        let sibling;
240                        if is_right {
241                            if node == right {
242                                break;
243                            }
244                            sibling = left;
245                            node = right;
246                        } else {
247                            if node == left {
248                                break;
249                            }
250                            sibling = right;
251                            node = left;
252                        }
253                        cache.insert((height, sibling_key), sibling);
254                    }
255                }
256                NodeType::Single(node) => {
257                    if key != branch_node.key() {
258                        cache.insert((height, sibling_key), node);
259                    }
260                    break;
261                }
262            }
263        }
264
265        Ok(())
266    }
267
268    /// Generate merkle proof
269    pub fn merkle_proof(&self, mut keys: Vec<H256>) -> Result<MerkleProof> {
270        if keys.is_empty() {
271            return Err(Error::EmptyKeys);
272        }
273
274        // sort keys
275        keys.sort_unstable();
276
277        // fetch all merkle path
278        let mut cache: BTreeMap<(u8, H256), H256> = Default::default();
279        if !self.is_empty() {
280            for k in &keys {
281                self.fetch_merkle_path(k, &mut cache)?;
282            }
283        }
284
285        // (node, height)
286        let mut proof: Vec<(H256, u8)> = Vec::with_capacity(EXPECTED_PATH_SIZE * keys.len());
287        // key_index -> merkle path height
288        let mut leaves_path: Vec<Vec<u8>> = Vec::with_capacity(keys.len());
289        leaves_path.resize_with(keys.len(), Default::default);
290
291        let keys_len = keys.len();
292        // build merkle proofs from bottom to up
293        // (key, height, key_index)
294        let mut queue: VecDeque<(H256, u8, usize)> = keys
295            .into_iter()
296            .enumerate()
297            .map(|(i, k)| (k, 0, i))
298            .collect();
299
300        while let Some((key, height, leaf_index)) = queue.pop_front() {
301            if queue.is_empty() && cache.is_empty() {
302                // tree only contains one leaf
303                if leaves_path[leaf_index].is_empty() {
304                    leaves_path[leaf_index].push(core::u8::MAX);
305                }
306                break;
307            }
308            // compute sibling key
309            let mut sibling_key = key.parent_path(height);
310
311            let is_right = key.get_bit(height);
312            if is_right {
313                // sibling on left
314                sibling_key.clear_bit(height);
315            } else {
316                // sibling on right
317                sibling_key.set_bit(height);
318            }
319            if Some((&sibling_key, &height))
320                == queue
321                    .front()
322                    .map(|(sibling_key, height, _leaf_index)| (sibling_key, height))
323            {
324                // drop the sibling, mark sibling's merkle path
325                let (_sibling_key, height, leaf_index) = queue.pop_front().unwrap();
326                leaves_path[leaf_index].push(height);
327            } else {
328                match cache.remove(&(height, sibling_key)) {
329                    Some(sibling) => {
330                        // save first non-zero sibling's height for leaves
331                        proof.push((sibling, height));
332                    }
333                    None => {
334                        // skip zero siblings
335                        if !is_right {
336                            sibling_key.clear_bit(height);
337                        }
338                        if height == core::u8::MAX {
339                            if leaves_path[leaf_index].is_empty() {
340                                leaves_path[leaf_index].push(height);
341                            }
342                            break;
343                        } else {
344                            let parent_key = sibling_key;
345                            queue.push_back((parent_key, height + 1, leaf_index));
346                            continue;
347                        }
348                    }
349                }
350            }
351            // find new non-zero sibling, append to leaf's path
352            leaves_path[leaf_index].push(height);
353            if height == core::u8::MAX {
354                break;
355            } else {
356                // get parent_key, which k.get_bit(height) is false
357                let parent_key = if is_right { sibling_key } else { key };
358                queue.push_back((parent_key, height + 1, leaf_index));
359            }
360        }
361        debug_assert_eq!(leaves_path.len(), keys_len);
362        Ok(MerkleProof::new(leaves_path, proof))
363    }
364}