commonware_storage/merkle/
verification.rs1use crate::merkle::{
15 hasher::Hasher,
16 proof::{self as merkle_proof, Blueprint},
17 storage::Storage,
18 Error, Family, Location, Position, Proof,
19};
20use commonware_cryptography::Digest;
21use core::ops::Range;
22use futures::future::try_join_all;
23use std::collections::{BTreeSet, HashMap};
24
25pub struct ProofStore<F: Family, D> {
28 digests: HashMap<Position<F>, D>,
29 size: Position<F>,
30 fold_acc: Option<D>,
32 num_fold_peaks: usize,
34}
35
36impl<F: Family, D: Digest> ProofStore<F, D> {
37 pub fn new<H, E>(
46 hasher: &H,
47 proof: &Proof<F, D>,
48 elements: &[E],
49 start_loc: Location<F>,
50 root: &D,
51 ) -> Result<Self, Error<F>>
52 where
53 H: Hasher<F, Digest = D>,
54 E: AsRef<[u8]>,
55 {
56 let digests =
57 proof.verify_range_inclusion_and_extract_digests(hasher, elements, start_loc, root)?;
58 let map: HashMap<Position<F>, D> = digests.into_iter().collect();
59
60 let size = Position::try_from(proof.leaves)?;
61
62 let num_fold_peaks = Blueprint::<F>::fold_prefix(proof.leaves, start_loc)?.len();
65
66 let fold_acc = if num_fold_peaks > 0 {
67 Some(*proof.digests.first().ok_or(Error::InvalidProof)?)
68 } else {
69 None
70 };
71
72 Ok(Self {
73 size,
74 digests: map,
75 fold_acc,
76 num_fold_peaks,
77 })
78 }
79
80 pub fn range_proof<H: Hasher<F, Digest = D>>(
86 &self,
87 hasher: &H,
88 range: Range<Location<F>>,
89 ) -> Result<Proof<F, D>, Error<F>> {
90 let leaves = Location::try_from(self.size)?;
91 let bp = Blueprint::new(leaves, range)?;
92
93 let mut digests: Vec<D> = Vec::new();
94 if !bp.fold_prefix.is_empty() {
95 let mut acc = self.fold_acc;
97 for &pos in bp.fold_prefix.iter().skip(self.num_fold_peaks) {
99 match self.digests.get(&pos) {
100 Some(d) => {
101 acc = Some(acc.map_or(*d, |a| hasher.fold(&a, d)));
102 }
103 None => return Err(Error::ElementPruned(pos)),
104 }
105 }
106 digests.push(acc.expect("fold_prefix is non-empty so acc must be set"));
107 }
108
109 for &pos in &bp.fetch_nodes {
110 match self.digests.get(&pos) {
111 Some(d) => digests.push(*d),
112 None => return Err(Error::ElementPruned(pos)),
113 }
114 }
115
116 Ok(Proof { leaves, digests })
117 }
118
119 pub fn multi_proof(
126 &self,
127 locations: &[Location<F>],
128 peaks: &[(Position<F>, D)],
129 ) -> Result<Proof<F, D>, Error<F>> {
130 if locations.is_empty() {
131 return Err(Error::Empty);
132 }
133
134 let leaves = Location::try_from(self.size)?;
135 let node_positions: BTreeSet<_> =
136 merkle_proof::nodes_required_for_multi_proof(leaves, locations)?;
137
138 let peak_map: HashMap<Position<F>, D> = peaks.iter().copied().collect();
139
140 let mut digests = Vec::with_capacity(node_positions.len());
141 for &pos in &node_positions {
142 if let Some(d) = self.digests.get(&pos) {
143 digests.push(*d);
144 } else if let Some(d) = peak_map.get(&pos) {
145 digests.push(*d);
146 } else {
147 return Err(Error::ElementPruned(pos));
148 }
149 }
150
151 Ok(Proof { leaves, digests })
152 }
153}
154
155pub async fn range_proof<
164 F: Family,
165 D: Digest,
166 H: Hasher<F, Digest = D>,
167 S: Storage<F, Digest = D>,
168>(
169 hasher: &H,
170 merkle: &S,
171 range: Range<Location<F>>,
172) -> Result<Proof<F, D>, Error<F>> {
173 let leaves = Location::try_from(merkle.size().await)?;
174 historical_range_proof(hasher, merkle, leaves, range).await
175}
176
177pub async fn historical_range_proof<
187 F: Family,
188 D: Digest,
189 H: Hasher<F, Digest = D>,
190 S: Storage<F, Digest = D>,
191>(
192 hasher: &H,
193 merkle: &S,
194 leaves: Location<F>,
195 range: Range<Location<F>>,
196) -> Result<Proof<F, D>, Error<F>> {
197 let bp = Blueprint::new(leaves, range)?;
198
199 let mut digests: Vec<D> = Vec::new();
200 if !bp.fold_prefix.is_empty() {
201 let node_futures = bp.fold_prefix.iter().map(|&pos| merkle.get_node(pos));
202 let results = try_join_all(node_futures).await?;
203 let mut acc = results[0].ok_or(Error::ElementPruned(bp.fold_prefix[0]))?;
204 for (i, &result) in results.iter().enumerate().skip(1) {
205 let d = result.ok_or(Error::ElementPruned(bp.fold_prefix[i]))?;
206 acc = hasher.fold(&acc, &d);
207 }
208 digests.push(acc);
209 }
210
211 let node_futures = bp.fetch_nodes.iter().map(|&pos| merkle.get_node(pos));
212 let results = try_join_all(node_futures).await?;
213 for (i, result) in results.into_iter().enumerate() {
214 match result {
215 Some(d) => digests.push(d),
216 None => return Err(Error::ElementPruned(bp.fetch_nodes[i])),
217 }
218 }
219
220 Ok(Proof { leaves, digests })
221}
222
223pub async fn multi_proof<F: Family, D: Digest, S: Storage<F, Digest = D>>(
235 merkle: &S,
236 locations: &[Location<F>],
237) -> Result<Proof<F, D>, Error<F>> {
238 if locations.is_empty() {
239 return Err(Error::Empty);
241 }
242
243 let size = merkle.size().await;
245 let leaves = Location::try_from(size)?;
246 let node_positions: BTreeSet<_> =
247 merkle_proof::nodes_required_for_multi_proof(leaves, locations)?;
248
249 let node_futures: Vec<_> = node_positions
251 .iter()
252 .map(|&pos| async move { merkle.get_node(pos).await.map(|digest| (pos, digest)) })
253 .collect();
254 let results = try_join_all(node_futures).await?;
255
256 let mut digests = Vec::with_capacity(results.len());
258 for (pos, digest) in results {
259 match digest {
260 Some(digest) => digests.push(digest),
261 None => return Err(Error::ElementPruned(pos)),
262 }
263 }
264
265 Ok(Proof { leaves, digests })
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use crate::{
272 merkle::LocationRangeExt as _,
273 mmb::{mem::Mmb, Location as MmbLocation},
274 mmr::{mem::Mmr, StandardHasher as Standard},
275 };
276 use commonware_cryptography::{sha256::Digest, Hasher, Sha256};
277 use commonware_macros::test_traced;
278 use commonware_runtime::{deterministic, Runner};
279
280 fn test_digest(v: u8) -> Digest {
281 Sha256::hash(&[v])
282 }
283
284 #[test_traced]
285 fn test_verification_proof_store() {
286 let executor = deterministic::Runner::default();
287 executor.start(|_| async move {
288 let hasher: Standard<Sha256> = Standard::new();
290 let mut mmr = Mmr::new(&hasher);
291 let elements: Vec<_> = (0..49).map(test_digest).collect();
292 let batch = {
293 let mut batch = mmr.new_batch();
294 for element in &elements {
295 batch = batch.add(&hasher, element);
296 }
297 batch.merkleize(&mmr, &hasher)
298 };
299 mmr.apply_batch(&batch).unwrap();
300 let root = mmr.root();
301
302 let mut range_start = Location::new(0);
305 let mut range_end = Location::new(49);
306 while range_start < range_end {
307 let range = range_start..range_end;
308 let range_proof = mmr.range_proof(&hasher, range.clone()).unwrap();
309 let proof_store = ProofStore::new(
310 &hasher,
311 &range_proof,
312 &elements[range.to_usize_range()],
313 range_start,
314 root,
315 )
316 .unwrap();
317
318 let mut subrange_start = range_start;
322 let mut subrange_end = range_end;
323 while subrange_start < subrange_end {
324 let sub_range = subrange_start..subrange_end;
326 let sub_range_proof =
327 proof_store.range_proof(&hasher, sub_range.clone()).unwrap();
328 assert!(sub_range_proof.verify_range_inclusion(
329 &hasher,
330 &elements[sub_range.to_usize_range()],
331 sub_range.start,
332 root
333 ));
334 subrange_start += 1;
335 subrange_end -= 1;
336 }
337 range_start += 1;
338 range_end -= 1;
339 }
340 });
341 }
342
343 #[test_traced]
344 fn test_verification_proof_store_with_fold_prefix() {
345 let executor = deterministic::Runner::default();
346 executor.start(|_| async move {
347 let hasher: Standard<Sha256> = Standard::new();
350 let mut mmr = Mmr::new(&hasher);
351 let elements: Vec<_> = (0..49).map(test_digest).collect();
352 let batch = {
353 let mut batch = mmr.new_batch();
354 for element in &elements {
355 batch = batch.add(&hasher, element);
356 }
357 batch.merkleize(&mmr, &hasher)
358 };
359 mmr.apply_batch(&batch).unwrap();
360 let root = mmr.root();
361
362 let range = Location::new(32)..Location::new(49);
366 let range_proof = mmr.range_proof(&hasher, range.clone()).unwrap();
367 let proof_store = ProofStore::new(
368 &hasher,
369 &range_proof,
370 &elements[range.to_usize_range()],
371 range.start,
372 root,
373 )
374 .unwrap();
375
376 for start in 32u64..49 {
378 for end in (start + 1)..=49 {
379 let sub_range = Location::new(start)..Location::new(end);
380 let sub_proof = proof_store.range_proof(&hasher, sub_range.clone()).unwrap();
381 assert!(
382 sub_proof.verify_range_inclusion(
383 &hasher,
384 &elements[sub_range.to_usize_range()],
385 sub_range.start,
386 root,
387 ),
388 "sub-proof should verify for range {start}..{end}"
389 );
390 }
391 }
392 });
393 }
394
395 #[test_traced]
396 fn test_verification_proof_store_with_fold_prefix_mmb() {
397 let executor = deterministic::Runner::default();
398 executor.start(|_| async move {
399 let hasher: Standard<Sha256> = Standard::new();
400 let mut mmb = Mmb::new(&hasher);
401 let elements: Vec<_> = (0..8).map(test_digest).collect();
402 let batch = {
403 let mut batch = mmb.new_batch();
404 for element in &elements {
405 batch = batch.add(&hasher, element);
406 }
407 batch.merkleize(&mmb, &hasher)
408 };
409 mmb.apply_batch(&batch).unwrap();
410 let root = mmb.root();
411
412 let range = MmbLocation::new(4)..MmbLocation::new(8);
416 let range_proof = mmb.range_proof(&hasher, range.clone()).unwrap();
417 let proof_store = ProofStore::new(
418 &hasher,
419 &range_proof,
420 &elements[range.to_usize_range()],
421 range.start,
422 root,
423 )
424 .unwrap();
425
426 for start in 4u64..8 {
427 for end in (start + 1)..=8 {
428 let sub_range = MmbLocation::new(start)..MmbLocation::new(end);
429 let sub_proof = proof_store.range_proof(&hasher, sub_range.clone()).unwrap();
430 assert!(
431 sub_proof.verify_range_inclusion(
432 &hasher,
433 &elements[sub_range.to_usize_range()],
434 sub_range.start,
435 root,
436 ),
437 "sub-proof should verify for MMB range {start}..{end}"
438 );
439 }
440 }
441 });
442 }
443}