restricted_sparse_merkle_tree/
merkle_proof.rs1use crate::{
2 collections::{BTreeMap, VecDeque},
3 error::{Error, Result},
4 merge::{hash_leaf, merge},
5 traits::Hasher,
6 vec::Vec,
7 H256,
8};
9
10type Range = core::ops::Range<usize>;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct MerkleProof {
14 leaves_path: Vec<Vec<u8>>,
15 proof: Vec<(H256, u8)>,
16}
17
18impl MerkleProof {
19 pub fn new(leaves_path: Vec<Vec<u8>>, proof: Vec<(H256, u8)>) -> Self {
23 MerkleProof { leaves_path, proof }
24 }
25
26 pub fn take(self) -> (Vec<Vec<u8>>, Vec<(H256, u8)>) {
28 let MerkleProof { leaves_path, proof } = self;
29 (leaves_path, proof)
30 }
31
32 pub fn leaves_count(&self) -> usize {
34 self.leaves_path.len()
35 }
36
37 pub fn leaves_path(&self) -> &Vec<Vec<u8>> {
39 &self.leaves_path
40 }
41
42 pub fn proof(&self) -> &Vec<(H256, u8)> {
44 &self.proof
45 }
46
47 pub fn compile(self, mut leaves: Vec<(H256, H256)>) -> Result<CompiledMerkleProof> {
49 if leaves.is_empty() {
50 return Err(Error::EmptyKeys);
51 } else if leaves.len() != self.leaves_count() {
52 return Err(Error::IncorrectNumberOfLeaves {
53 expected: self.leaves_count(),
54 actual: leaves.len(),
55 });
56 }
57
58 let (leaves_path, proof) = self.take();
59 let mut leaves_path: Vec<VecDeque<_>> = leaves_path.into_iter().map(Into::into).collect();
60 let mut proof: VecDeque<_> = proof.into();
61
62 leaves.sort_unstable_by_key(|(k, _v)| *k);
64 let mut tree_buf: BTreeMap<_, _> = leaves
66 .into_iter()
67 .enumerate()
68 .map(|(i, (k, _v))| ((0, k), (i, leaf_program(i))))
69 .collect();
70 while !tree_buf.is_empty() {
72 let &(mut height, key) = tree_buf.keys().next().unwrap();
74 let (leaf_index, program) = tree_buf.remove(&(height, key)).unwrap();
75
76 if proof.is_empty() && tree_buf.is_empty() {
77 return Ok(CompiledMerkleProof(program.0));
78 }
79
80 let mut sibling_key = key.parent_path(height);
81 if !key.get_bit(height) {
82 sibling_key.set_bit(height)
83 }
84
85 let (parent_key, parent_program, height) =
86 if Some(&(height, sibling_key)) == tree_buf.keys().next() {
87 let (_leaf_index, sibling_program) = tree_buf
88 .remove(&(height, sibling_key))
89 .expect("pop sibling");
90 let parent_key = key.parent_path(height);
91 let parent_program = merge_program(&program, &sibling_program, height)?;
92 (parent_key, parent_program, height)
93 } else {
94 let merge_height = leaves_path[leaf_index].front().copied().unwrap_or(height);
95 if height != merge_height {
96 let parent_key = key.copy_bits(merge_height);
97 tree_buf.insert((merge_height, parent_key), (leaf_index, program));
99 continue;
100 }
101 let (proof, proof_height) = proof.pop_front().ok_or(Error::CorruptedProof)?;
102 let proof_height = proof_height;
103 if height < proof_height {
104 height = proof_height;
105 }
106
107 let parent_key = key.parent_path(height);
108 let parent_program = proof_program(&program, proof, height);
109 (parent_key, parent_program, height)
110 };
111
112 if height == core::u8::MAX {
113 if proof.is_empty() {
114 return Ok(CompiledMerkleProof(parent_program.0));
115 } else {
116 return Err(Error::CorruptedProof);
117 }
118 }
119 leaves_path[leaf_index].pop_front();
120 tree_buf.insert((height + 1, parent_key), (leaf_index, parent_program));
121 }
122
123 Err(Error::CorruptedProof)
124 }
125
126 pub fn compute_root<H: Hasher + Default>(self, mut leaves: Vec<(H256, H256)>) -> Result<H256> {
132 if leaves.is_empty() {
133 return Err(Error::EmptyKeys);
134 } else if leaves.len() != self.leaves_count() {
135 return Err(Error::IncorrectNumberOfLeaves {
136 expected: self.leaves_count(),
137 actual: leaves.len(),
138 });
139 }
140
141 let (leaves_path, proof) = self.take();
142 let mut leaves_path: Vec<VecDeque<_>> = leaves_path.into_iter().map(Into::into).collect();
143 let mut proof: VecDeque<_> = proof.into();
144
145 for (_k, v) in &leaves {
147 if v.is_zero() {
148 return Err(Error::ForbidZeroValueLeaf);
149 }
150 }
151
152 leaves.sort_unstable_by_key(|(k, _v)| *k);
154 let mut tree_buf: BTreeMap<_, _> = leaves
156 .into_iter()
157 .enumerate()
158 .map(|(i, (k, v))| ((0, k), (i, hash_leaf::<H>(&k, &v))))
159 .collect();
160 while !tree_buf.is_empty() {
162 let (&(mut height, key), &(leaf_index, node)) = tree_buf.iter().next().unwrap();
164 tree_buf.remove(&(height, key));
165
166 if proof.is_empty() && tree_buf.is_empty() {
167 return Ok(node);
168 }
169
170 let mut sibling_key = key.parent_path(height);
171 if !key.get_bit(height) {
172 sibling_key.set_bit(height)
173 }
174 let (sibling, sibling_height) =
175 if Some(&(height, sibling_key)) == tree_buf.keys().next() {
176 let (_leaf_index, sibling) = tree_buf
177 .remove(&(height, sibling_key))
178 .expect("pop sibling");
179 (sibling, height)
180 } else {
181 let merge_height = leaves_path[leaf_index].front().copied().unwrap_or(height);
182 if height != merge_height {
183 let parent_key = key.copy_bits(merge_height);
184 tree_buf.insert((merge_height, parent_key), (leaf_index, node));
186 continue;
187 }
188 let (node, height) = proof.pop_front().ok_or(Error::CorruptedProof)?;
189 (node, height)
190 };
191 if height < sibling_height {
192 height = sibling_height;
193 }
194 let parent_key = key.parent_path(height);
196
197 let parent = if key.get_bit(height) {
198 merge::<H>(&sibling, &node)
199 } else {
200 merge::<H>(&node, &sibling)
201 };
202
203 if height == core::u8::MAX {
204 if proof.is_empty() {
205 return Ok(parent);
206 } else {
207 return Err(Error::CorruptedProof);
208 }
209 } else {
210 leaves_path[leaf_index].pop_front();
211 tree_buf.insert((height + 1, parent_key), (leaf_index, parent));
212 }
213 }
214
215 Err(Error::CorruptedProof)
216 }
217
218 pub fn verify<H: Hasher + Default>(
221 self,
222 root: &H256,
223 leaves: Vec<(H256, H256)>,
224 ) -> Result<bool> {
225 let calculated_root = self.compute_root::<H>(leaves)?;
226 Ok(&calculated_root == root)
227 }
228}
229
230fn leaf_program(leaf_index: usize) -> (Vec<u8>, Option<Range>) {
231 let mut program = Vec::with_capacity(1);
232 program.push(0x4C);
233 (
234 program,
235 Some(Range {
236 start: leaf_index,
237 end: leaf_index + 1,
238 }),
239 )
240}
241
242fn proof_program(
243 child: &(Vec<u8>, Option<Range>),
244 proof: H256,
245 height: u8,
246) -> (Vec<u8>, Option<Range>) {
247 let (child_program, child_range) = child;
248 let mut program = Vec::new();
249 program.resize(34 + child_program.len(), 0x50);
250 program[..child_program.len()].copy_from_slice(child_program);
251 program[child_program.len() + 1] = height;
252 program[child_program.len() + 2..].copy_from_slice(proof.as_slice());
253 (program, child_range.clone())
254}
255
256fn merge_program(
257 a: &(Vec<u8>, Option<Range>),
258 b: &(Vec<u8>, Option<Range>),
259 height: u8,
260) -> Result<(Vec<u8>, Option<Range>)> {
261 let (a_program, a_range) = a;
262 let (b_program, b_range) = b;
263 let (a_comes_first, range) = if a_range.is_none() || b_range.is_none() {
264 let range = if a_range.is_none() { b_range } else { a_range }
265 .clone()
266 .unwrap();
267 (true, range)
268 } else {
269 let a_range = a_range.clone().unwrap();
270 let b_range = b_range.clone().unwrap();
271 if a_range.end == b_range.start {
272 (
273 true,
274 Range {
275 start: a_range.start,
276 end: b_range.end,
277 },
278 )
279 } else {
280 return Err(Error::NonMergableRange);
281 }
282 };
283 let mut program = Vec::new();
284 program.resize(2 + a_program.len() + b_program.len(), 0x48);
285 if a_comes_first {
286 program[..a_program.len()].copy_from_slice(a_program);
287 program[a_program.len()..a_program.len() + b_program.len()].copy_from_slice(b_program);
288 } else {
289 program[..b_program.len()].copy_from_slice(b_program);
290 program[b_program.len()..a_program.len() + b_program.len()].copy_from_slice(a_program);
291 }
292 program[a_program.len() + b_program.len() + 1] = height;
293 Ok((program, Some(range)))
294}
295
296#[derive(Debug, Clone)]
298pub struct CompiledMerkleProof(pub Vec<u8>);
299
300impl CompiledMerkleProof {
301 pub fn compute_root<H: Hasher + Default>(&self, mut leaves: Vec<(H256, H256)>) -> Result<H256> {
302 leaves.sort_unstable_by_key(|(k, _v)| *k);
303 let mut program_index = 0;
304 let mut leave_index = 0;
305 let mut stack = Vec::new();
306 while program_index < self.0.len() {
307 let code = self.0[program_index];
308 program_index += 1;
309 match code {
310 0x4C => {
312 if leave_index >= leaves.len() {
313 return Err(Error::CorruptedStack);
314 }
315 let (k, v) = leaves[leave_index];
316
317 if v.is_zero() {
319 return Err(Error::ForbidZeroValueLeaf);
320 }
321
322 stack.push((k, hash_leaf::<H>(&k, &v)));
323 leave_index += 1;
324 }
325 0x50 => {
327 if stack.is_empty() {
328 return Err(Error::CorruptedStack);
329 }
330 if program_index + 33 > self.0.len() {
331 return Err(Error::CorruptedProof);
332 }
333 let height = self.0[program_index];
334 program_index += 1;
335 let mut data = [0u8; 32];
336 data.copy_from_slice(&self.0[program_index..program_index + 32]);
337 program_index += 32;
338 let proof = H256::from(data);
339 let (key, value) = stack.pop().unwrap();
340 let parent_key = key.parent_path(height);
341 let parent = if key.get_bit(height) {
342 merge::<H>(&proof, &value)
343 } else {
344 merge::<H>(&value, &proof)
345 };
346 stack.push((parent_key, parent));
347 }
348 0x48 => {
350 if stack.len() < 2 {
351 return Err(Error::CorruptedStack);
352 }
353 if program_index >= self.0.len() {
354 return Err(Error::CorruptedProof);
355 }
356 let height = self.0[program_index];
357 program_index += 1;
358 let (key_b, value_b) = stack.pop().unwrap();
359 let (key_a, value_a) = stack.pop().unwrap();
360 let parent_key_a = key_a.copy_bits(height);
361 let parent_key_b = key_b.copy_bits(height);
362 let a_set = key_a.get_bit(height);
363 let b_set = key_b.get_bit(height);
364 let mut sibling_key_a = parent_key_a;
365 if !a_set {
366 sibling_key_a.set_bit(height);
367 }
368 if !(sibling_key_a == parent_key_b && (a_set ^ b_set)) {
370 return Err(Error::NonSiblings);
371 }
372 let parent = if key_a.get_bit(height) {
373 merge::<H>(&value_b, &value_a)
374 } else {
375 merge::<H>(&value_a, &value_b)
376 };
377 stack.push((parent_key_a, parent));
378 }
379 _ => return Err(Error::InvalidCode(code)),
380 }
381 }
382 if stack.len() != 1 {
383 return Err(Error::CorruptedStack);
384 }
385 Ok(stack[0].1)
386 }
387
388 pub fn verify<H: Hasher + Default>(
389 &self,
390 root: &H256,
391 leaves: Vec<(H256, H256)>,
392 ) -> Result<bool> {
393 let calculated_root = self.compute_root::<H>(leaves)?;
394 Ok(&calculated_root == root)
395 }
396}
397
398impl Into<Vec<u8>> for CompiledMerkleProof {
399 fn into(self) -> Vec<u8> {
400 self.0
401 }
402}