use crate::merkle::{
hasher::Hasher,
proof::{self as merkle_proof, Blueprint},
storage::Storage,
Error, Family, Location, Position, Proof,
};
use commonware_cryptography::Digest;
use core::ops::Range;
use futures::future::try_join_all;
use std::collections::{BTreeSet, HashMap};
pub struct ProofStore<F: Family, D> {
digests: HashMap<Position<F>, D>,
size: Position<F>,
fold_acc: Option<D>,
num_fold_peaks: usize,
}
impl<F: Family, D: Digest> ProofStore<F, D> {
pub fn new<H, E>(
hasher: &H,
proof: &Proof<F, D>,
elements: &[E],
start_loc: Location<F>,
root: &D,
) -> Result<Self, Error<F>>
where
H: Hasher<F, Digest = D>,
E: AsRef<[u8]>,
{
let digests =
proof.verify_range_inclusion_and_extract_digests(hasher, elements, start_loc, root)?;
let map: HashMap<Position<F>, D> = digests.into_iter().collect();
let size = Position::try_from(proof.leaves)?;
let num_fold_peaks = Blueprint::<F>::fold_prefix(proof.leaves, start_loc)?.len();
let fold_acc = if num_fold_peaks > 0 {
Some(*proof.digests.first().ok_or(Error::InvalidProof)?)
} else {
None
};
Ok(Self {
size,
digests: map,
fold_acc,
num_fold_peaks,
})
}
pub fn range_proof<H: Hasher<F, Digest = D>>(
&self,
hasher: &H,
range: Range<Location<F>>,
) -> Result<Proof<F, D>, Error<F>> {
let leaves = Location::try_from(self.size)?;
let bp = Blueprint::new(leaves, range)?;
let mut digests: Vec<D> = Vec::new();
if !bp.fold_prefix.is_empty() {
let mut acc = self.fold_acc;
for &pos in bp.fold_prefix.iter().skip(self.num_fold_peaks) {
match self.digests.get(&pos) {
Some(d) => {
acc = Some(acc.map_or(*d, |a| hasher.fold(&a, d)));
}
None => return Err(Error::ElementPruned(pos)),
}
}
digests.push(acc.expect("fold_prefix is non-empty so acc must be set"));
}
for &pos in &bp.fetch_nodes {
match self.digests.get(&pos) {
Some(d) => digests.push(*d),
None => return Err(Error::ElementPruned(pos)),
}
}
Ok(Proof { leaves, digests })
}
pub fn multi_proof(
&self,
locations: &[Location<F>],
peaks: &[(Position<F>, D)],
) -> Result<Proof<F, D>, Error<F>> {
if locations.is_empty() {
return Err(Error::Empty);
}
let leaves = Location::try_from(self.size)?;
let node_positions: BTreeSet<_> =
merkle_proof::nodes_required_for_multi_proof(leaves, locations)?;
let peak_map: HashMap<Position<F>, D> = peaks.iter().copied().collect();
let mut digests = Vec::with_capacity(node_positions.len());
for &pos in &node_positions {
if let Some(d) = self.digests.get(&pos) {
digests.push(*d);
} else if let Some(d) = peak_map.get(&pos) {
digests.push(*d);
} else {
return Err(Error::ElementPruned(pos));
}
}
Ok(Proof { leaves, digests })
}
}
pub async fn range_proof<
F: Family,
D: Digest,
H: Hasher<F, Digest = D>,
S: Storage<F, Digest = D>,
>(
hasher: &H,
merkle: &S,
range: Range<Location<F>>,
) -> Result<Proof<F, D>, Error<F>> {
let leaves = Location::try_from(merkle.size().await)?;
historical_range_proof(hasher, merkle, leaves, range).await
}
pub async fn historical_range_proof<
F: Family,
D: Digest,
H: Hasher<F, Digest = D>,
S: Storage<F, Digest = D>,
>(
hasher: &H,
merkle: &S,
leaves: Location<F>,
range: Range<Location<F>>,
) -> Result<Proof<F, D>, Error<F>> {
let bp = Blueprint::new(leaves, range)?;
let mut digests: Vec<D> = Vec::new();
if !bp.fold_prefix.is_empty() {
let node_futures = bp.fold_prefix.iter().map(|&pos| merkle.get_node(pos));
let results = try_join_all(node_futures).await?;
let mut acc = results[0].ok_or(Error::ElementPruned(bp.fold_prefix[0]))?;
for (i, &result) in results.iter().enumerate().skip(1) {
let d = result.ok_or(Error::ElementPruned(bp.fold_prefix[i]))?;
acc = hasher.fold(&acc, &d);
}
digests.push(acc);
}
let node_futures = bp.fetch_nodes.iter().map(|&pos| merkle.get_node(pos));
let results = try_join_all(node_futures).await?;
for (i, result) in results.into_iter().enumerate() {
match result {
Some(d) => digests.push(d),
None => return Err(Error::ElementPruned(bp.fetch_nodes[i])),
}
}
Ok(Proof { leaves, digests })
}
pub async fn multi_proof<F: Family, D: Digest, S: Storage<F, Digest = D>>(
merkle: &S,
locations: &[Location<F>],
) -> Result<Proof<F, D>, Error<F>> {
if locations.is_empty() {
return Err(Error::Empty);
}
let size = merkle.size().await;
let leaves = Location::try_from(size)?;
let node_positions: BTreeSet<_> =
merkle_proof::nodes_required_for_multi_proof(leaves, locations)?;
let node_futures: Vec<_> = node_positions
.iter()
.map(|&pos| async move { merkle.get_node(pos).await.map(|digest| (pos, digest)) })
.collect();
let results = try_join_all(node_futures).await?;
let mut digests = Vec::with_capacity(results.len());
for (pos, digest) in results {
match digest {
Some(digest) => digests.push(digest),
None => return Err(Error::ElementPruned(pos)),
}
}
Ok(Proof { leaves, digests })
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
merkle::LocationRangeExt as _,
mmb::{mem::Mmb, Location as MmbLocation},
mmr::{mem::Mmr, StandardHasher as Standard},
};
use commonware_cryptography::{sha256::Digest, Hasher, Sha256};
use commonware_macros::test_traced;
use commonware_runtime::{deterministic, Runner};
fn test_digest(v: u8) -> Digest {
Sha256::hash(&[v])
}
#[test_traced]
fn test_verification_proof_store() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: Standard<Sha256> = Standard::new();
let mut mmr = Mmr::new(&hasher);
let elements: Vec<_> = (0..49).map(test_digest).collect();
let batch = {
let mut batch = mmr.new_batch();
for element in &elements {
batch = batch.add(&hasher, element);
}
batch.merkleize(&mmr, &hasher)
};
mmr.apply_batch(&batch).unwrap();
let root = mmr.root();
let mut range_start = Location::new(0);
let mut range_end = Location::new(49);
while range_start < range_end {
let range = range_start..range_end;
let range_proof = mmr.range_proof(&hasher, range.clone()).unwrap();
let proof_store = ProofStore::new(
&hasher,
&range_proof,
&elements[range.to_usize_range()],
range_start,
root,
)
.unwrap();
let mut subrange_start = range_start;
let mut subrange_end = range_end;
while subrange_start < subrange_end {
let sub_range = subrange_start..subrange_end;
let sub_range_proof =
proof_store.range_proof(&hasher, sub_range.clone()).unwrap();
assert!(sub_range_proof.verify_range_inclusion(
&hasher,
&elements[sub_range.to_usize_range()],
sub_range.start,
root
));
subrange_start += 1;
subrange_end -= 1;
}
range_start += 1;
range_end -= 1;
}
});
}
#[test_traced]
fn test_verification_proof_store_with_fold_prefix() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: Standard<Sha256> = Standard::new();
let mut mmr = Mmr::new(&hasher);
let elements: Vec<_> = (0..49).map(test_digest).collect();
let batch = {
let mut batch = mmr.new_batch();
for element in &elements {
batch = batch.add(&hasher, element);
}
batch.merkleize(&mmr, &hasher)
};
mmr.apply_batch(&batch).unwrap();
let root = mmr.root();
let range = Location::new(32)..Location::new(49);
let range_proof = mmr.range_proof(&hasher, range.clone()).unwrap();
let proof_store = ProofStore::new(
&hasher,
&range_proof,
&elements[range.to_usize_range()],
range.start,
root,
)
.unwrap();
for start in 32u64..49 {
for end in (start + 1)..=49 {
let sub_range = Location::new(start)..Location::new(end);
let sub_proof = proof_store.range_proof(&hasher, sub_range.clone()).unwrap();
assert!(
sub_proof.verify_range_inclusion(
&hasher,
&elements[sub_range.to_usize_range()],
sub_range.start,
root,
),
"sub-proof should verify for range {start}..{end}"
);
}
}
});
}
#[test_traced]
fn test_verification_proof_store_with_fold_prefix_mmb() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: Standard<Sha256> = Standard::new();
let mut mmb = Mmb::new(&hasher);
let elements: Vec<_> = (0..8).map(test_digest).collect();
let batch = {
let mut batch = mmb.new_batch();
for element in &elements {
batch = batch.add(&hasher, element);
}
batch.merkleize(&mmb, &hasher)
};
mmb.apply_batch(&batch).unwrap();
let root = mmb.root();
let range = MmbLocation::new(4)..MmbLocation::new(8);
let range_proof = mmb.range_proof(&hasher, range.clone()).unwrap();
let proof_store = ProofStore::new(
&hasher,
&range_proof,
&elements[range.to_usize_range()],
range.start,
root,
)
.unwrap();
for start in 4u64..8 {
for end in (start + 1)..=8 {
let sub_range = MmbLocation::new(start)..MmbLocation::new(end);
let sub_proof = proof_store.range_proof(&hasher, sub_range.clone()).unwrap();
assert!(
sub_proof.verify_range_inclusion(
&hasher,
&elements[sub_range.to_usize_range()],
sub_range.start,
root,
),
"sub-proof should verify for MMB range {start}..{end}"
);
}
}
});
}
}