1use crate::{
2 collections::{BTreeMap, VecDeque},
3 error::{Error, Result},
4 merge::{hash_leaf, merge},
5 traits::{Hasher, Value},
6 vec::Vec,
7 Key, H256, TREE_HEIGHT,
8};
9use core::convert::TryInto;
10
11type Range = core::ops::Range<usize>;
12
13#[derive(Debug, Clone)]
14pub struct MerkleProof {
15 leaves_path: Vec<Vec<usize>>,
16 proof: Vec<(H256, usize)>,
17}
18
19impl MerkleProof {
20 pub fn new(leaves_path: Vec<Vec<usize>>, proof: Vec<(H256, usize)>) -> Self {
24 MerkleProof { leaves_path, proof }
25 }
26
27 pub fn take(self) -> (Vec<Vec<usize>>, Vec<(H256, usize)>) {
29 let MerkleProof { leaves_path, proof } = self;
30 (leaves_path, proof)
31 }
32
33 pub fn leaves_count(&self) -> usize {
35 self.leaves_path.len()
36 }
37
38 pub fn leaves_path(&self) -> &Vec<Vec<usize>> {
40 &self.leaves_path
41 }
42
43 pub fn proof(&self) -> &Vec<(H256, usize)> {
45 &self.proof
46 }
47
48 pub fn compile<K, const N: usize>(
50 self,
51 mut leaves: Vec<(K, H256)>,
52 ) -> Result<CompiledMerkleProof>
53 where
54 K: Key<N>,
55 {
56 if leaves.is_empty() {
57 return Err(Error::EmptyKeys);
58 } else if leaves.len() != self.leaves_count() {
59 return Err(Error::IncorrectNumberOfLeaves {
60 expected: self.leaves_count(),
61 actual: leaves.len(),
62 });
63 }
64
65 let (leaves_path, proof) = self.take();
66 let mut leaves_path: Vec<VecDeque<_>> = leaves_path.into_iter().map(Into::into).collect();
67 let mut proof: VecDeque<_> = proof.into();
68
69 leaves.sort_unstable_by_key(|(k, _v)| **k);
71 let mut tree_buf: BTreeMap<_, _> = leaves
73 .into_iter()
74 .enumerate()
75 .map(|(i, (k, _v))| ((0, *k), (i, leaf_program(i))))
76 .collect();
77 while !tree_buf.is_empty() {
79 let &(mut height, key) = tree_buf.keys().next().unwrap();
81 let (leaf_index, program) = tree_buf.remove(&(height, key)).unwrap();
82
83 if proof.is_empty() && tree_buf.is_empty() {
84 return Ok(CompiledMerkleProof(program.0));
85 } else if height == TREE_HEIGHT {
86 if !proof.is_empty() {
87 return Err(Error::CorruptedProof);
88 }
89 return Ok(CompiledMerkleProof(program.0));
90 }
91
92 let mut sibling_key = key.parent_path(height);
93 if !key.get_bit(height) {
94 sibling_key.set_bit(height)
95 }
96
97 let (parent_key, parent_program, height) =
98 if Some(&(height, sibling_key)) == tree_buf.keys().next() {
99 let (_leaf_index, sibling_program) = tree_buf
100 .remove(&(height, sibling_key))
101 .expect("pop sibling");
102 let parent_key = key.parent_path(height);
103 let parent_program = merge_program(&program, &sibling_program, height)?;
104 (parent_key, parent_program, height)
105 } else {
106 let merge_height = leaves_path[leaf_index]
107 .front()
108 .map(|h| *h as usize)
109 .unwrap_or(height);
110 if height != merge_height {
111 debug_assert!(height < merge_height);
112 let parent_key = key.copy_bits(merge_height..);
113 tree_buf.insert((merge_height, parent_key), (leaf_index, program));
115 continue;
116 }
117 let (proof, proof_height) = proof.pop_front().expect("pop proof");
118 debug_assert_eq!(proof_height, leaves_path[leaf_index][0]);
119 let proof_height = proof_height as usize;
120 debug_assert!(height <= proof_height);
121 if height < proof_height {
122 height = proof_height;
123 }
124
125 let parent_key = key.parent_path(height);
126 let parent_program = proof_program(&program, proof, height);
127 (parent_key, parent_program, height)
128 };
129
130 leaves_path[leaf_index].pop_front();
131 tree_buf.insert((height + 1, parent_key), (leaf_index, parent_program));
132 }
133
134 Err(Error::CorruptedProof)
135 }
136
137 pub fn compute_root<H: Hasher + Default, K, V, const N: usize>(
143 self,
144 mut leaves: Vec<(K, V)>,
145 ) -> Result<H256>
146 where
147 K: Key<N>,
148 V: Value,
149 {
150 if leaves.is_empty() {
151 return Err(Error::EmptyKeys);
152 } else if leaves.len() != self.leaves_count() {
153 return Err(Error::IncorrectNumberOfLeaves {
154 expected: self.leaves_count(),
155 actual: leaves.len(),
156 });
157 }
158
159 let (leaves_path, proof) = self.take();
160 let mut leaves_path: Vec<VecDeque<_>> = leaves_path.into_iter().map(Into::into).collect();
161 let mut proof: VecDeque<_> = proof.into();
162
163 leaves.sort_unstable_by_key(|(k, _v)| **k);
165 let mut tree_buf: BTreeMap<_, _> = leaves
167 .into_iter()
168 .enumerate()
169 .map(|(i, (k, v))| ((0, *k), (i, hash_leaf::<H, K, V, N>(&k, &v))))
170 .collect();
171 while !tree_buf.is_empty() {
173 let (&(mut height, key), &(leaf_index, node)) = tree_buf.iter().next().unwrap();
175 tree_buf.remove(&(height, key));
176
177 if proof.is_empty() && tree_buf.is_empty() {
178 return Ok(node);
179 } else if height == 8 * N {
180 if !proof.is_empty() {
181 return Err(Error::CorruptedProof);
182 }
183 return Ok(node);
184 }
185
186 let mut sibling_key = key.parent_path(height);
187 if !key.get_bit(height) {
188 sibling_key.set_bit(height)
189 }
190 let (sibling, sibling_height) =
191 if Some(&(height, sibling_key)) == tree_buf.keys().next() {
192 let (_leaf_index, sibling) = tree_buf
193 .remove(&(height, sibling_key))
194 .expect("pop sibling");
195 (sibling, height)
196 } else {
197 let merge_height = leaves_path[leaf_index]
198 .front()
199 .map(|h| *h as usize)
200 .unwrap_or(height);
201 if height != merge_height {
202 debug_assert!(height < merge_height);
203 let parent_key = key.copy_bits(merge_height..);
204 tree_buf.insert((merge_height, parent_key), (leaf_index, node));
206 continue;
207 }
208 let (node, height) = proof.pop_front().expect("pop proof");
209 debug_assert_eq!(height, leaves_path[leaf_index][0]);
210 (node, height as usize)
211 };
212 debug_assert!(height <= sibling_height);
213 if height < sibling_height {
214 height = sibling_height;
215 }
216 let parent_key = key.parent_path(height);
218
219 let parent = if key.get_bit(height) {
220 merge::<H>(&sibling, &node)
221 } else {
222 merge::<H>(&node, &sibling)
223 };
224 leaves_path[leaf_index].pop_front();
225 tree_buf.insert((height + 1, parent_key), (leaf_index, parent));
226 }
227
228 Err(Error::CorruptedProof)
229 }
230
231 pub fn verify<H: Hasher + Default, K, V, const N: usize>(
234 self,
235 root: &H256,
236 leaves: Vec<(K, V)>,
237 ) -> Result<bool>
238 where
239 K: Key<N>,
240 V: Value
241 {
242 let calculated_root = self.compute_root::<H, K, V, N>(leaves)?;
243 Ok(&calculated_root == root)
244 }
245}
246
247fn leaf_program(leaf_index: usize) -> (Vec<u8>, Option<Range>) {
248 let program = vec![0x4C];
249 (
250 program,
251 Some(Range {
252 start: leaf_index,
253 end: leaf_index + 1,
254 }),
255 )
256}
257
258fn proof_program(
259 child: &(Vec<u8>, Option<Range>),
260 proof: H256,
261 height: usize,
262) -> (Vec<u8>, Option<Range>) {
263 let (child_program, child_range) = child;
264 let mut program = Vec::new();
265 let height = height as u64;
266 program.resize(41 + child_program.len(), 0x50);
267 program[..child_program.len()].copy_from_slice(child_program);
268 program[child_program.len() + 1..child_program.len() + 9]
269 .copy_from_slice(&height.to_be_bytes());
270 program[child_program.len() + 9..].copy_from_slice(proof.as_slice());
271 (program, child_range.clone())
272}
273
274fn merge_program(
275 a: &(Vec<u8>, Option<Range>),
276 b: &(Vec<u8>, Option<Range>),
277 height: usize,
278) -> Result<(Vec<u8>, Option<Range>)> {
279 let (a_program, a_range) = a;
280 let (b_program, b_range) = b;
281 let (a_comes_first, range) = if a_range.is_none() || b_range.is_none() {
282 let range = if a_range.is_none() { b_range } else { a_range }
283 .clone()
284 .unwrap();
285 (true, range)
286 } else {
287 let a_range = a_range.clone().unwrap();
288 let b_range = b_range.clone().unwrap();
289 if a_range.end == b_range.start {
290 (
291 true,
292 Range {
293 start: a_range.start,
294 end: b_range.end,
295 },
296 )
297 } else {
298 return Err(Error::NonMergableRange);
299 }
300 };
301 let mut program = Vec::new();
302 program.resize(9 + a_program.len() + b_program.len(), 0x48);
303 if a_comes_first {
304 program[..a_program.len()].copy_from_slice(a_program);
305 program[a_program.len()..a_program.len() + b_program.len()].copy_from_slice(b_program);
306 } else {
307 program[..b_program.len()].copy_from_slice(b_program);
308 program[b_program.len()..a_program.len() + b_program.len()].copy_from_slice(a_program);
309 }
310 let height = height as u64;
311 let height_pos = a_program.len() + b_program.len() + 1;
312 program[height_pos..height_pos + 8].copy_from_slice(&height.to_be_bytes());
313 Ok((program, Some(range)))
314}
315
316#[derive(Debug, Clone)]
318pub struct CompiledMerkleProof(pub Vec<u8>);
319
320impl CompiledMerkleProof {
321 pub fn compute_root<H: Hasher + Default, K, V, const N: usize>(
322 &self,
323 mut leaves: Vec<(K, V)>,
324 ) -> Result<H256>
325 where
326 K: Key<N>,
327 V: Value,
328 {
329 leaves.sort_unstable_by_key(|(k, _v)| **k);
330 let mut program_index = 0;
331 let mut leave_index = 0;
332 let mut stack = Vec::new();
333 while program_index < self.0.len() {
334 let code = self.0[program_index];
335 program_index += 1;
336 match code {
337 0x4C => {
339 if leave_index >= leaves.len() {
340 return Err(Error::CorruptedStack);
341 }
342 let (k, v) = leaves[leave_index].clone();
343 stack.push((*k, hash_leaf::<H, K, V, N>(&k, &v)));
344 leave_index += 1;
345 }
346 0x50 => {
348 if stack.is_empty() {
349 return Err(Error::CorruptedStack);
350 }
351 if program_index + 40 > self.0.len() {
352 return Err(Error::CorruptedProof);
353 }
354 let height: [u8; 8] = self.0[program_index..program_index + 8]
355 .try_into()
356 .expect("8 bytes should fit in an 8 byte array");
357 let height = u64::from_be_bytes(height) as usize;
358 program_index += 8;
359 let mut data = [0u8; 32];
360 data.copy_from_slice(&self.0[program_index..program_index + 32]);
361 program_index += 32;
362 let proof = H256::from(data);
363 let (key, value) = stack.pop().unwrap();
364 let parent_key = key.parent_path(height);
365 let parent = if key.get_bit(height) {
366 merge::<H>(&proof, &value)
367 } else {
368 merge::<H>(&value, &proof)
369 };
370 stack.push((parent_key, parent));
371 }
372 0x48 => {
374 if stack.len() < 2 {
375 return Err(Error::CorruptedStack);
376 }
377 if program_index >= self.0.len() {
378 return Err(Error::CorruptedProof);
379 }
380 let height: [u8; 8] = self.0[program_index..program_index + 8]
381 .try_into()
382 .expect("8 bytes should fit in an 8 byte array");
383 let height = u64::from_be_bytes(height) as usize;
384 program_index += 8;
385 let (key_b, value_b) = stack.pop().unwrap();
386 let (key_a, value_a) = stack.pop().unwrap();
387 let parent_key_a = key_a.copy_bits(height..);
388 let parent_key_b = key_b.copy_bits(height..);
389 let a_set = key_a.get_bit(height);
390 let b_set = key_b.get_bit(height);
391 let mut sibling_key_a = parent_key_a;
392 if !a_set {
393 sibling_key_a.set_bit(height);
394 }
395 if !(sibling_key_a == parent_key_b && (a_set ^ b_set)) {
397 return Err(Error::NonSiblings);
398 }
399 let parent = if key_a.get_bit(height) {
400 merge::<H>(&value_b, &value_a)
401 } else {
402 merge::<H>(&value_a, &value_b)
403 };
404 stack.push((parent_key_a, parent));
405 }
406 _ => return Err(Error::InvalidCode(code)),
407 }
408 }
409 if stack.len() != 1 {
410 return Err(Error::CorruptedStack);
411 }
412 Ok(stack[0].1)
413 }
414
415 pub fn verify<H: Hasher + Default, K, V, const N: usize>(
416 &self,
417 root: &H256,
418 leaves: Vec<(K, V)>,
419 ) -> Result<bool>
420 where
421 K: Key<N>,
422 V: Value,
423 {
424 let calculated_root = self.compute_root::<H, K, V, N>(leaves)?;
425 Ok(&calculated_root == root)
426 }
427}