restricted_sparse_merkle_tree/
tree.rs1use crate::{
2 collections::{BTreeMap, VecDeque},
3 error::{Error, Result},
4 merge::{hash_leaf, merge},
5 merkle_proof::MerkleProof,
6 traits::{Hasher, Store, Value},
7 vec::Vec,
8 EXPECTED_PATH_SIZE, H256,
9};
10use core::{cmp::max, marker::PhantomData};
11
12#[derive(Debug, Eq, PartialEq, Clone)]
14pub struct BranchNode {
15 pub fork_height: u8,
16 pub key: H256,
17 pub node_type: NodeType,
18}
19
20impl BranchNode {
21 fn node_at(&self, height: u8) -> NodeType {
23 match self.node_type {
24 NodeType::Pair(node, sibling) => {
25 let is_right = self.key.get_bit(height);
26 if is_right {
27 NodeType::Pair(sibling, node)
28 } else {
29 NodeType::Pair(node, sibling)
30 }
31 }
32 NodeType::Single(node) => NodeType::Single(node),
33 }
34 }
35
36 fn key(&self) -> &H256 {
37 &self.key
38 }
39}
40
41#[derive(Debug, Eq, PartialEq, Clone)]
42pub enum NodeType {
43 Single(H256),
44 Pair(H256, H256),
45}
46
47#[derive(Debug, Eq, PartialEq, Clone)]
49pub struct LeafNode<V> {
50 pub key: H256,
51 pub value: V,
52}
53
54#[derive(Default, Debug)]
56pub struct SparseMerkleTree<H, V, S> {
57 store: S,
58 root: H256,
59 phantom: PhantomData<(H, V)>,
60}
61
62impl<H: Hasher + Default, V: Value, S: Store<V>> SparseMerkleTree<H, V, S> {
63 pub fn new(root: H256, store: S) -> SparseMerkleTree<H, V, S> {
65 SparseMerkleTree {
66 root,
67 store,
68 phantom: PhantomData,
69 }
70 }
71
72 pub fn root(&self) -> &H256 {
74 &self.root
75 }
76
77 pub fn is_empty(&self) -> bool {
79 self.root.is_zero()
80 }
81
82 pub fn take_store(self) -> S {
84 self.store
85 }
86
87 pub fn store(&self) -> &S {
89 &self.store
90 }
91
92 pub fn store_mut(&mut self) -> &mut S {
94 &mut self.store
95 }
96
97 pub fn update(&mut self, key: H256, value: V) -> Result<&H256> {
100 let mut path = Vec::new();
102 if !self.is_empty() {
103 let mut node = self.root;
104 loop {
105 let branch_node = self
106 .store
107 .get_branch(&node)?
108 .ok_or_else(|| Error::MissingBranch(node))?;
109 let height = max(key.fork_height(branch_node.key()), branch_node.fork_height);
110 match branch_node.node_at(height) {
111 NodeType::Pair(left, right) => {
112 if height > branch_node.fork_height {
113 path.push((height, node));
115 break;
116 } else {
117 self.store.remove_branch(&node)?;
118 let is_right = key.get_bit(height);
119 if is_right {
120 node = right;
121 path.push((height, left));
122 } else {
123 node = left;
124 path.push((height, right));
125 }
126 }
127 }
128 NodeType::Single(node) => {
129 if &key == branch_node.key() {
130 self.store.remove_leaf(&node)?;
131 self.store.remove_branch(&node)?;
132 } else {
133 path.push((height, node));
134 }
135 break;
136 }
137 }
138 }
139 }
140
141 let mut node = hash_leaf::<H>(&key, &value.to_h256());
143 if !node.is_zero() {
145 self.store.insert_leaf(node, LeafNode { key, value })?;
146
147 self.store.insert_branch(
149 node,
150 BranchNode {
151 key,
152 fork_height: 0,
153 node_type: NodeType::Single(node),
154 },
155 )?;
156 }
157
158 for (height, sibling) in path.into_iter().rev() {
160 let is_right = key.get_bit(height);
161 let parent = if is_right {
162 merge::<H>(&sibling, &node)
163 } else {
164 merge::<H>(&node, &sibling)
165 };
166
167 if !node.is_zero() {
168 let branch_node = BranchNode {
170 key,
171 fork_height: height,
172 node_type: NodeType::Pair(node, sibling),
173 };
174 self.store.insert_branch(parent, branch_node)?;
175 }
176 node = parent;
177 }
178 self.root = node;
179 Ok(&self.root)
180 }
181
182 pub fn get(&self, key: &H256) -> Result<V> {
185 if self.is_empty() {
186 return Ok(V::zero());
187 }
188
189 let mut node = self.root;
190 loop {
191 let branch_node = self
192 .store
193 .get_branch(&node)?
194 .ok_or_else(|| Error::MissingBranch(node))?;
195
196 match branch_node.node_at(branch_node.fork_height) {
197 NodeType::Pair(left, right) => {
198 let is_right = key.get_bit(branch_node.fork_height);
199 node = if is_right { right } else { left };
200 }
201 NodeType::Single(node) => {
202 if key == branch_node.key() {
203 return Ok(self
204 .store
205 .get_leaf(&node)?
206 .ok_or_else(|| Error::MissingLeaf(node))?
207 .value);
208 } else {
209 return Ok(V::zero());
210 }
211 }
212 }
213 }
214 }
215
216 fn fetch_merkle_path(&self, key: &H256, cache: &mut BTreeMap<(u8, H256), H256>) -> Result<()> {
219 let mut node = self.root;
220 loop {
221 let branch_node = self
222 .store
223 .get_branch(&node)?
224 .ok_or_else(|| Error::MissingBranch(node))?;
225 let height = max(key.fork_height(branch_node.key()), branch_node.fork_height);
226 let is_right = key.get_bit(height);
227 let mut sibling_key = key.parent_path(height);
228 if !is_right {
229 sibling_key.set_bit(height);
231 };
232
233 match branch_node.node_at(height) {
234 NodeType::Pair(left, right) => {
235 if height > branch_node.fork_height {
236 cache.entry((height, sibling_key)).or_insert(node);
237 break;
238 } else {
239 let sibling;
240 if is_right {
241 if node == right {
242 break;
243 }
244 sibling = left;
245 node = right;
246 } else {
247 if node == left {
248 break;
249 }
250 sibling = right;
251 node = left;
252 }
253 cache.insert((height, sibling_key), sibling);
254 }
255 }
256 NodeType::Single(node) => {
257 if key != branch_node.key() {
258 cache.insert((height, sibling_key), node);
259 }
260 break;
261 }
262 }
263 }
264
265 Ok(())
266 }
267
268 pub fn merkle_proof(&self, mut keys: Vec<H256>) -> Result<MerkleProof> {
270 if keys.is_empty() {
271 return Err(Error::EmptyKeys);
272 }
273
274 keys.sort_unstable();
276
277 let mut cache: BTreeMap<(u8, H256), H256> = Default::default();
279 if !self.is_empty() {
280 for k in &keys {
281 self.fetch_merkle_path(k, &mut cache)?;
282 }
283 }
284
285 let mut proof: Vec<(H256, u8)> = Vec::with_capacity(EXPECTED_PATH_SIZE * keys.len());
287 let mut leaves_path: Vec<Vec<u8>> = Vec::with_capacity(keys.len());
289 leaves_path.resize_with(keys.len(), Default::default);
290
291 let keys_len = keys.len();
292 let mut queue: VecDeque<(H256, u8, usize)> = keys
295 .into_iter()
296 .enumerate()
297 .map(|(i, k)| (k, 0, i))
298 .collect();
299
300 while let Some((key, height, leaf_index)) = queue.pop_front() {
301 if queue.is_empty() && cache.is_empty() {
302 if leaves_path[leaf_index].is_empty() {
304 leaves_path[leaf_index].push(core::u8::MAX);
305 }
306 break;
307 }
308 let mut sibling_key = key.parent_path(height);
310
311 let is_right = key.get_bit(height);
312 if is_right {
313 sibling_key.clear_bit(height);
315 } else {
316 sibling_key.set_bit(height);
318 }
319 if Some((&sibling_key, &height))
320 == queue
321 .front()
322 .map(|(sibling_key, height, _leaf_index)| (sibling_key, height))
323 {
324 let (_sibling_key, height, leaf_index) = queue.pop_front().unwrap();
326 leaves_path[leaf_index].push(height);
327 } else {
328 match cache.remove(&(height, sibling_key)) {
329 Some(sibling) => {
330 proof.push((sibling, height));
332 }
333 None => {
334 if !is_right {
336 sibling_key.clear_bit(height);
337 }
338 if height == core::u8::MAX {
339 if leaves_path[leaf_index].is_empty() {
340 leaves_path[leaf_index].push(height);
341 }
342 break;
343 } else {
344 let parent_key = sibling_key;
345 queue.push_back((parent_key, height + 1, leaf_index));
346 continue;
347 }
348 }
349 }
350 }
351 leaves_path[leaf_index].push(height);
353 if height == core::u8::MAX {
354 break;
355 } else {
356 let parent_key = if is_right { sibling_key } else { key };
358 queue.push_back((parent_key, height + 1, leaf_index));
359 }
360 }
361 debug_assert_eq!(leaves_path.len(), keys_len);
362 Ok(MerkleProof::new(leaves_path, proof))
363 }
364}