1use datasize::DataSize;
2use itertools::Itertools;
3use once_cell::sync::OnceCell;
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6
7use casper_types::bytesrepr::{self, FromBytes, ToBytes};
8
9use crate::{
10 error::{MerkleConstructionError, MerkleVerificationError},
11 Digest,
12};
13
14#[derive(DataSize, PartialEq, Eq, Debug, Clone, JsonSchema, Serialize, Deserialize)]
16#[serde(deny_unknown_fields)]
17pub struct IndexedMerkleProof {
18 index: u64,
19 count: u64,
20 merkle_proof: Vec<Digest>,
21 #[serde(skip)]
22 #[data_size(skip)]
23 root_hash: OnceCell<Digest>,
24}
25
26impl ToBytes for IndexedMerkleProof {
27 fn to_bytes(&self) -> Result<Vec<u8>, bytesrepr::Error> {
28 let mut result = bytesrepr::allocate_buffer(self)?;
29 result.append(&mut self.index.to_bytes()?);
30 result.append(&mut self.count.to_bytes()?);
31 result.append(&mut self.merkle_proof.to_bytes()?);
32 Ok(result)
33 }
34
35 fn serialized_length(&self) -> usize {
36 self.index.serialized_length()
37 + self.count.serialized_length()
38 + self.merkle_proof.serialized_length()
39 }
40}
41
42impl FromBytes for IndexedMerkleProof {
43 fn from_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), bytesrepr::Error> {
44 let (index, remainder) = FromBytes::from_bytes(bytes)?;
45 let (count, remainder) = FromBytes::from_bytes(remainder)?;
46 let (merkle_proof, remainder) = FromBytes::from_bytes(remainder)?;
47
48 Ok((
49 IndexedMerkleProof {
50 index,
51 count,
52 merkle_proof,
53 root_hash: OnceCell::new(),
54 },
55 remainder,
56 ))
57 }
58}
59
60impl IndexedMerkleProof {
61 pub(crate) fn new<I>(
62 leaves: I,
63 index: u64,
64 ) -> Result<IndexedMerkleProof, MerkleConstructionError>
65 where
66 I: IntoIterator<Item = Digest>,
67 I::IntoIter: ExactSizeIterator,
68 {
69 use HashOrProof::{Hash, Proof};
70
71 enum HashOrProof {
72 Hash(Digest),
73 Proof(Vec<Digest>),
74 }
75
76 let leaves = leaves.into_iter();
77 let count: u64 =
78 leaves
79 .len()
80 .try_into()
81 .map_err(|_| MerkleConstructionError::TooManyLeaves {
82 count: leaves.len().to_string(),
83 })?;
84
85 let maybe_proof = leaves
86 .enumerate()
87 .map(|(i, hash)| {
88 if i as u64 == index {
89 Proof(vec![hash])
90 } else {
91 Hash(hash)
92 }
93 })
94 .tree_fold1(|x, y| match (x, y) {
95 (Hash(hash_x), Hash(hash_y)) => Hash(Digest::hash_pair(hash_x, hash_y)),
96 (Hash(hash), Proof(mut proof)) | (Proof(mut proof), Hash(hash)) => {
97 proof.push(hash);
98 Proof(proof)
99 }
100 (Proof(_), Proof(_)) => unreachable!(),
101 });
102
103 match maybe_proof {
104 None | Some(Hash(_)) => Err(MerkleConstructionError::IndexOutOfBounds { count, index }),
105 Some(Proof(merkle_proof)) => Ok(IndexedMerkleProof {
106 index,
107 count,
108 merkle_proof,
109 root_hash: OnceCell::new(),
110 }),
111 }
112 }
113
114 pub fn index(&self) -> u64 {
116 self.index
117 }
118
119 pub fn count(&self) -> u64 {
121 self.count
122 }
123
124 pub fn root_hash(&self) -> Digest {
126 let IndexedMerkleProof {
127 index: _,
128 count,
129 merkle_proof,
130 root_hash,
131 } = self;
132
133 *root_hash.get_or_init(|| {
134 let mut hashes = merkle_proof.iter();
135 let raw_root = if let Some(leaf_hash) = hashes.next().cloned() {
136 let mut path: u64 = 0;
140 let mut n = self.count;
141 let mut i = self.index;
142 while n > 1 {
143 path <<= 1;
144 let pivot = 1u64 << (63 - (n - 1).leading_zeros());
145 if i < pivot {
146 n = pivot;
147 } else {
148 path |= 1;
149 n -= pivot;
150 i -= pivot;
151 }
152 }
153
154 hashes.fold(leaf_hash, |acc, hash| {
156 let digest = if (path & 1) == 1 {
157 Digest::hash_pair(hash, acc)
158 } else {
159 Digest::hash_pair(acc, hash)
160 };
161 path >>= 1;
162 digest
163 })
164 } else {
165 Digest::SENTINEL_MERKLE_TREE
166 };
167
168 Digest::hash_merkle_root(*count, raw_root)
170 })
171 }
172
173 pub fn merkle_proof(&self) -> &[Digest] {
175 &self.merkle_proof
176 }
177
178 fn compute_expected_proof_length(&self) -> u8 {
180 if self.count == 0 {
181 return 0;
182 }
183 let mut l = 1;
184 let mut n = self.count;
185 let mut i = self.index;
186 while n > 1 {
187 let pivot = 1u64 << (63 - (n - 1).leading_zeros());
188 if i < pivot {
189 n = pivot;
190 } else {
191 n -= pivot;
192 i -= pivot;
193 }
194 l += 1;
195 }
196 l
197 }
198
199 pub(crate) fn verify(&self) -> Result<(), MerkleVerificationError> {
200 if self.index >= self.count {
201 return Err(MerkleVerificationError::IndexOutOfBounds {
202 count: self.count,
203 index: self.index,
204 });
205 }
206 let expected_proof_length = self.compute_expected_proof_length();
207 if self.merkle_proof.len() != expected_proof_length as usize {
208 return Err(MerkleVerificationError::UnexpectedProofLength {
209 count: self.count,
210 index: self.index,
211 expected_proof_length,
212 actual_proof_length: self.merkle_proof.len(),
213 });
214 }
215 Ok(())
216 }
217
218 #[cfg(test)]
219 pub(crate) fn inject_merkle_proof(&mut self, merkle_proof: Vec<Digest>) {
220 self.merkle_proof = merkle_proof;
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use once_cell::sync::OnceCell;
227 use proptest::prelude::{prop_assert, prop_assert_eq};
228 use proptest_attr_macro::proptest;
229 use rand::{distributions::Standard, Rng};
230
231 use casper_types::bytesrepr::{self, FromBytes, ToBytes};
232
233 use crate::{error, indexed_merkle_proof::IndexedMerkleProof, Digest};
234
235 fn random_indexed_merkle_proof() -> IndexedMerkleProof {
236 let mut rng = rand::thread_rng();
237 let leaf_count: u64 = rng.gen_range(1..100);
238 let index = rng.gen_range(0..leaf_count);
239 let leaves: Vec<Digest> = (0..leaf_count)
240 .map(|i| Digest::hash(i.to_le_bytes()))
241 .collect();
242 IndexedMerkleProof::new(leaves.iter().cloned(), index)
243 .expect("should create indexed Merkle proof")
244 }
245
246 #[test]
247 fn test_merkle_proofs() {
248 let mut rng = rand::thread_rng();
249 for _ in 0..20 {
250 let leaf_count: u64 = rng.gen_range(1..100);
251 let index = rng.gen_range(0..leaf_count);
252 let leaves: Vec<Digest> = (0..leaf_count)
253 .map(|i| Digest::hash(i.to_le_bytes()))
254 .collect();
255 let root = Digest::hash_merkle_tree(leaves.clone());
256 let indexed_merkle_proof = IndexedMerkleProof::new(leaves.clone(), index).unwrap();
257 assert_eq!(
258 indexed_merkle_proof.compute_expected_proof_length(),
259 indexed_merkle_proof.merkle_proof().len() as u8
260 );
261 assert_eq!(indexed_merkle_proof.verify(), Ok(()));
262 assert_eq!(leaf_count, indexed_merkle_proof.count);
263 assert_eq!(leaves[index as usize], indexed_merkle_proof.merkle_proof[0]);
264 assert_eq!(root, indexed_merkle_proof.root_hash());
265 }
266 }
267
268 #[test]
269 fn out_of_bounds_index() {
270 let out_of_bounds_indexed_merkle_proof = IndexedMerkleProof {
271 index: 23,
272 count: 4,
273 merkle_proof: vec![Digest([0u8; 32]); 3],
274 root_hash: OnceCell::new(),
275 };
276 assert_eq!(
277 out_of_bounds_indexed_merkle_proof.verify(),
278 Err(error::MerkleVerificationError::IndexOutOfBounds {
279 count: 4,
280 index: 23
281 })
282 )
283 }
284
285 #[test]
286 fn unexpected_proof_length() {
287 let out_of_bounds_indexed_merkle_proof = IndexedMerkleProof {
288 index: 1235,
289 count: 5647,
290 merkle_proof: vec![Digest([0u8; 32]); 13],
291 root_hash: OnceCell::new(),
292 };
293 assert_eq!(
294 out_of_bounds_indexed_merkle_proof.verify(),
295 Err(error::MerkleVerificationError::UnexpectedProofLength {
296 count: 5647,
297 index: 1235,
298 expected_proof_length: 14,
299 actual_proof_length: 13
300 })
301 )
302 }
303
304 #[test]
305 fn empty_unexpected_proof_length() {
306 let out_of_bounds_indexed_merkle_proof = IndexedMerkleProof {
307 index: 0,
308 count: 0,
309 merkle_proof: vec![Digest([0u8; 32]); 3],
310 root_hash: OnceCell::new(),
311 };
312 assert_eq!(
313 out_of_bounds_indexed_merkle_proof.verify(),
314 Err(error::MerkleVerificationError::IndexOutOfBounds { count: 0, index: 0 })
315 )
316 }
317
318 #[test]
319 fn empty_out_of_bounds_index() {
320 let out_of_bounds_indexed_merkle_proof = IndexedMerkleProof {
321 index: 23,
322 count: 0,
323 merkle_proof: vec![],
324 root_hash: OnceCell::new(),
325 };
326 assert_eq!(
327 out_of_bounds_indexed_merkle_proof.verify(),
328 Err(error::MerkleVerificationError::IndexOutOfBounds {
329 count: 0,
330 index: 23
331 })
332 )
333 }
334
335 #[test]
336 fn deep_proof_doesnt_kill_stack() {
337 const PROOF_LENGTH: usize = 63;
338 let indexed_merkle_proof = IndexedMerkleProof {
339 index: 42,
340 count: 1 << (PROOF_LENGTH - 1),
341 merkle_proof: vec![Digest([0u8; Digest::LENGTH]); PROOF_LENGTH],
342 root_hash: OnceCell::new(),
343 };
344 let _hash = indexed_merkle_proof.root_hash();
345 }
346
347 #[test]
348 fn empty_proof() {
349 let empty_merkle_root = Digest::hash_merkle_tree(vec![]);
350 assert_eq!(empty_merkle_root, Digest::SENTINEL_MERKLE_TREE);
351 let indexed_merkle_proof = IndexedMerkleProof {
352 index: 0,
353 count: 0,
354 merkle_proof: vec![],
355 root_hash: OnceCell::new(),
356 };
357 assert!(indexed_merkle_proof.verify().is_err());
358 }
359
360 #[proptest]
361 fn expected_proof_length_le_65(index: u64, count: u64) {
362 let indexed_merkle_proof = IndexedMerkleProof {
363 index,
364 count,
365 merkle_proof: vec![],
366 root_hash: OnceCell::new(),
367 };
368 prop_assert!(indexed_merkle_proof.compute_expected_proof_length() <= 65);
369 }
370
371 fn reference_root_from_proof(index: u64, count: u64, proof: &[Digest]) -> Digest {
372 fn compute_raw_root_from_proof(index: u64, leaf_count: u64, proof: &[Digest]) -> Digest {
373 if leaf_count == 0 {
374 return Digest::SENTINEL_MERKLE_TREE;
375 }
376 if leaf_count == 1 {
377 return proof[0];
378 }
379 let half = 1u64 << (63 - (leaf_count - 1).leading_zeros());
380 let last = proof.len() - 1;
381 if index < half {
382 let left = compute_raw_root_from_proof(index, half, &proof[..last]);
383 Digest::hash_pair(left, proof[last])
384 } else {
385 let right =
386 compute_raw_root_from_proof(index - half, leaf_count - half, &proof[..last]);
387 Digest::hash_pair(proof[last], right)
388 }
389 }
390
391 let raw_root = compute_raw_root_from_proof(index, count, proof);
392 Digest::hash_merkle_root(count, raw_root)
393 }
394
395 fn test_indexed_merkle_proof(index: u64, count: u64) -> IndexedMerkleProof {
397 let mut indexed_merkle_proof = IndexedMerkleProof {
398 index,
399 count,
400 merkle_proof: vec![],
401 root_hash: OnceCell::new(),
402 };
403 let expected_proof_length = indexed_merkle_proof.compute_expected_proof_length();
404 indexed_merkle_proof.merkle_proof = rand::thread_rng()
405 .sample_iter(Standard)
406 .take(expected_proof_length as usize)
407 .collect();
408 indexed_merkle_proof
409 }
410
411 #[proptest]
412 fn root_from_proof_agrees_with_recursion(index: u64, count: u64) {
413 let indexed_merkle_proof = test_indexed_merkle_proof(index, count);
414 prop_assert_eq!(
415 indexed_merkle_proof.root_hash(),
416 reference_root_from_proof(
417 indexed_merkle_proof.index,
418 indexed_merkle_proof.count,
419 indexed_merkle_proof.merkle_proof(),
420 ),
421 "Result did not agree with reference implementation.",
422 );
423 }
424
425 #[test]
426 fn root_from_proof_agrees_with_recursion_2147483648_4294967297() {
427 let indexed_merkle_proof = test_indexed_merkle_proof(2147483648, 4294967297);
428 assert_eq!(
429 indexed_merkle_proof.root_hash(),
430 reference_root_from_proof(
431 indexed_merkle_proof.index,
432 indexed_merkle_proof.count,
433 indexed_merkle_proof.merkle_proof(),
434 ),
435 "Result did not agree with reference implementation.",
436 );
437 }
438
439 #[test]
440 fn serde_deserialization_of_malformed_proof_should_work() {
441 let indexed_merkle_proof = test_indexed_merkle_proof(10, 10);
442
443 let json = serde_json::to_string(&indexed_merkle_proof).unwrap();
444 assert_eq!(
445 indexed_merkle_proof,
446 serde_json::from_str::<IndexedMerkleProof>(&json)
447 .expect("should deserialize correctly")
448 );
449
450 let mut indexed_merkle_proof = test_indexed_merkle_proof(10, 10);
452 indexed_merkle_proof.index += 1;
453 let json = serde_json::to_string(&indexed_merkle_proof).unwrap();
454 serde_json::from_str::<IndexedMerkleProof>(&json).expect("should deserialize correctly");
455
456 let mut indexed_merkle_proof = test_indexed_merkle_proof(10, 10);
458 indexed_merkle_proof.merkle_proof.push(Digest::hash("XXX"));
459 let json = serde_json::to_string(&indexed_merkle_proof).unwrap();
460 serde_json::from_str::<IndexedMerkleProof>(&json).expect("should deserialize correctly");
461 }
462
463 #[test]
464 fn bytesrepr_deserialization_of_malformed_proof_should_work() {
465 let indexed_merkle_proof = test_indexed_merkle_proof(10, 10);
466
467 let bytes = indexed_merkle_proof
468 .to_bytes()
469 .expect("should serialize correctly");
470 IndexedMerkleProof::from_bytes(&bytes).expect("should deserialize correctly");
471
472 let mut indexed_merkle_proof = test_indexed_merkle_proof(10, 10);
474 indexed_merkle_proof.index += 1;
475 let bytes = indexed_merkle_proof
476 .to_bytes()
477 .expect("should serialize correctly");
478 IndexedMerkleProof::from_bytes(&bytes).expect("should deserialize correctly");
479
480 let mut indexed_merkle_proof = test_indexed_merkle_proof(10, 10);
482 indexed_merkle_proof.merkle_proof.push(Digest::hash("XXX"));
483 let bytes = indexed_merkle_proof
484 .to_bytes()
485 .expect("should serialize correctly");
486 IndexedMerkleProof::from_bytes(&bytes).expect("should deserialize correctly");
487 }
488
489 #[test]
490 fn bytesrepr_serialization() {
491 let indexed_merkle_proof = random_indexed_merkle_proof();
492 bytesrepr::test_serialization_roundtrip(&indexed_merkle_proof);
493 }
494}