use std::marker::PhantomData;
use anyhow::{Context, anyhow};
use digest::{Digest, FixedOutput};
use serde::{Deserialize, Serialize};
use crate::StmResult;
use crate::codec;
#[cfg(feature = "future_snark")]
#[allow(dead_code)]
use super::MerklePath;
use super::{MerkleBatchPath, MerkleTreeError, MerkleTreeLeaf, parent, sibling};
#[cfg(feature = "future_snark")]
#[allow(dead_code)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct MerkleTreeCommitment<D: Digest, L: MerkleTreeLeaf> {
pub root: Vec<u8>,
hasher: PhantomData<D>,
#[serde(skip)]
leaf_type: PhantomData<L>,
}
#[cfg(feature = "future_snark")]
#[allow(dead_code)]
impl<D: Digest + FixedOutput, L: MerkleTreeLeaf> MerkleTreeCommitment<D, L> {
pub(crate) fn new(root: Vec<u8>) -> Self {
MerkleTreeCommitment {
root,
hasher: PhantomData,
leaf_type: PhantomData,
}
}
pub(crate) fn verify_leaf_membership_from_path(
&self,
val: &L,
proof: &MerklePath<D>,
) -> StmResult<()>
where
D: FixedOutput + Clone,
L: MerkleTreeLeaf,
{
let mut idx = proof.index;
let mut h = D::digest(val.as_bytes_for_merkle_tree()).to_vec();
for p in &proof.values {
if (idx & 0b1) == 0 {
h = D::new().chain_update(h).chain_update(p).finalize().to_vec();
} else {
h = D::new().chain_update(p).chain_update(h).finalize().to_vec();
}
idx >>= 1;
}
if h == self.root {
return Ok(());
}
Err(anyhow!(MerkleTreeError::PathInvalid(proof.to_bytes()?)))
}
pub fn to_bytes(&self) -> StmResult<Vec<u8>> {
codec::to_cbor_bytes(self)
}
pub fn from_bytes(bytes: &[u8]) -> StmResult<MerkleTreeCommitment<D, L>> {
if codec::has_cbor_v1_prefix(bytes) {
codec::from_cbor_bytes::<Self>(&bytes[1..]).or_else(|_| Self::from_bytes_legacy(bytes))
} else {
Self::from_bytes_legacy(bytes)
}
}
fn from_bytes_legacy(bytes: &[u8]) -> StmResult<MerkleTreeCommitment<D, L>> {
let root = bytes.to_vec();
Ok(Self {
root,
hasher: PhantomData,
leaf_type: PhantomData,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MerkleTreeBatchCommitment<D: Digest, L: MerkleTreeLeaf> {
pub root: Vec<u8>,
nr_leaves: usize,
hasher: PhantomData<D>,
#[serde(skip)]
leaf_type: PhantomData<L>,
}
impl<D: Digest + FixedOutput, L: MerkleTreeLeaf> MerkleTreeBatchCommitment<D, L> {
pub(crate) fn new(root: Vec<u8>, nr_leaves: usize) -> Self {
Self {
root,
nr_leaves,
hasher: Default::default(),
leaf_type: PhantomData,
}
}
#[cfg(feature = "future_snark")]
#[allow(dead_code)]
pub(crate) fn get_number_of_leaves(&self) -> usize {
self.nr_leaves
}
pub(crate) fn concatenate_with_message(&self, msg: &[u8]) -> Vec<u8> {
let mut msgp = msg.to_vec();
let mut bytes = self.root.clone();
msgp.append(&mut bytes);
msgp
}
pub(crate) fn verify_leaves_membership_from_batch_path(
&self,
batch_val: &[L],
proof: &MerkleBatchPath<D>,
) -> StmResult<()>
where
D: FixedOutput + Clone,
L: MerkleTreeLeaf,
{
if batch_val.len() != proof.indices.len() {
return Err(anyhow!(MerkleTreeError::BatchPathInvalid(
proof.to_bytes()?
)));
}
let mut ordered_indices: Vec<usize> = proof.indices.clone();
ordered_indices.sort_unstable();
if ordered_indices != proof.indices {
return Err(anyhow!(MerkleTreeError::BatchPathInvalid(
proof.to_bytes()?
)));
}
let nr_nodes = self.nr_leaves + self.nr_leaves.next_power_of_two() - 1;
ordered_indices = ordered_indices
.into_iter()
.map(|i| i + self.nr_leaves.next_power_of_two() - 1)
.collect();
let mut idx = ordered_indices[0];
let mut leaves: Vec<Vec<u8>> = batch_val
.iter()
.map(|val| D::digest(val.as_bytes_for_merkle_tree()).to_vec())
.collect();
let mut values = proof.values.clone();
while idx > 0 {
let mut new_hashes = Vec::with_capacity(ordered_indices.len());
let mut new_indices = Vec::with_capacity(ordered_indices.len());
let mut i = 0;
idx = parent(idx);
while i < ordered_indices.len() {
new_indices.push(parent(ordered_indices[i]));
if ordered_indices[i] & 1 == 0 {
new_hashes.push(
D::new()
.chain(
values
.first()
.ok_or(MerkleTreeError::SerializationError)
.with_context(|| {
format!("Could not verify leave membership from batch path for idx = {} and ordered_indices[{}]", idx, i)
})?,
)
.chain(&leaves[i])
.finalize()
.to_vec(),
);
values.remove(0);
} else {
let sibling = sibling(ordered_indices[i]);
if i < ordered_indices.len() - 1 && ordered_indices[i + 1] == sibling {
new_hashes.push(
D::new().chain(&leaves[i]).chain(&leaves[i + 1]).finalize().to_vec(),
);
i += 1;
} else if sibling < nr_nodes {
new_hashes.push(
D::new()
.chain(&leaves[i])
.chain(
values
.first()
.ok_or(MerkleTreeError::SerializationError)
.with_context(|| {
format!(
"Could not verify leave membership from batch path for idx = {} where sibling < nr_nodes", idx
)
})?,
)
.finalize()
.to_vec(),
);
values.remove(0);
} else {
new_hashes.push(
D::new().chain(&leaves[i]).chain(D::digest([0u8])).finalize().to_vec(),
);
}
}
i += 1;
}
leaves.clone_from(&new_hashes);
ordered_indices.clone_from(&new_indices);
}
if leaves.len() == 1 && leaves[0] == self.root {
return Ok(());
}
Err(anyhow!(MerkleTreeError::BatchPathInvalid(
proof.to_bytes()?
)))
}
#[allow(dead_code)]
pub fn to_bytes(&self) -> StmResult<Vec<u8>> {
codec::to_cbor_bytes(self)
}
pub fn from_bytes(bytes: &[u8]) -> StmResult<MerkleTreeBatchCommitment<D, L>> {
codec::from_versioned_bytes(bytes, Self::from_bytes_legacy)
}
fn from_bytes_legacy(bytes: &[u8]) -> StmResult<MerkleTreeBatchCommitment<D, L>> {
let mut u64_bytes = [0u8; 8];
u64_bytes.copy_from_slice(bytes.get(..8).ok_or(MerkleTreeError::SerializationError)?);
let nr_leaves = usize::try_from(u64::from_be_bytes(u64_bytes))
.map_err(|_| MerkleTreeError::SerializationError)?;
let mut root = Vec::new();
root.extend_from_slice(bytes.get(8..).ok_or(MerkleTreeError::SerializationError)?);
Ok(MerkleTreeBatchCommitment {
root,
nr_leaves,
hasher: PhantomData,
leaf_type: Default::default(),
})
}
}
impl<D: Digest, L: MerkleTreeLeaf> PartialEq for MerkleTreeBatchCommitment<D, L> {
fn eq(&self, other: &Self) -> bool {
self.root == other.root && self.nr_leaves == other.nr_leaves
}
}
impl<D: Digest, L: MerkleTreeLeaf> Eq for MerkleTreeBatchCommitment<D, L> {}
#[cfg(test)]
mod tests {
use crate::{
MembershipDigest, MithrilMembershipDigest, VerificationKeyForConcatenation,
membership_commitment::merkle_tree::tree::MerkleTree,
membership_commitment::{MerkleTreeBatchCommitment, MerkleTreeConcatenationLeaf},
};
type ConcatenationHash = <MithrilMembershipDigest as MembershipDigest>::ConcatenationHash;
fn build_batch_commitment(
nr: usize,
) -> MerkleTreeBatchCommitment<ConcatenationHash, MerkleTreeConcatenationLeaf> {
let pks = vec![VerificationKeyForConcatenation::default(); nr];
let leaves = pks
.into_iter()
.enumerate()
.map(|(i, k)| MerkleTreeConcatenationLeaf(k, i as u64))
.collect::<Vec<_>>();
MerkleTree::<ConcatenationHash, MerkleTreeConcatenationLeaf>::new(&leaves)
.to_merkle_tree_batch_commitment()
}
#[cfg(feature = "future_snark")]
mod bytes_codec_ambiguity {
use crate::{
MembershipDigest, MithrilMembershipDigest,
membership_commitment::{MerkleTreeCommitment, MerkleTreeSnarkLeaf},
};
type SnarkHash = <MithrilMembershipDigest as MembershipDigest>::SnarkHash;
#[test]
fn legacy_data_starting_with_0x01_falls_back_correctly() {
let mut legacy_root = vec![0x01];
legacy_root.extend_from_slice(&[0xAA; 31]);
let decoded =
MerkleTreeCommitment::<SnarkHash, MerkleTreeSnarkLeaf>::from_bytes(&legacy_root)
.expect("Legacy data starting with 0x01 should fall back to legacy decoder");
assert_eq!(legacy_root, decoded.root);
}
}
mod golden_cbor {
use super::*;
const GOLDEN_CBOR_BYTES: &[u8; 91] = &[
1, 163, 100, 114, 111, 111, 116, 152, 32, 24, 178, 24, 30, 24, 231, 24, 127, 24, 65,
24, 247, 24, 162, 24, 149, 24, 33, 24, 29, 24, 147, 24, 148, 24, 224, 24, 156, 24, 96,
24, 113, 24, 140, 24, 42, 24, 98, 24, 166, 24, 137, 14, 24, 69, 24, 29, 24, 28, 24,
244, 24, 161, 24, 145, 24, 207, 24, 146, 24, 236, 24, 249, 105, 110, 114, 95, 108, 101,
97, 118, 101, 115, 4, 102, 104, 97, 115, 104, 101, 114, 246,
];
fn golden_value()
-> MerkleTreeBatchCommitment<ConcatenationHash, MerkleTreeConcatenationLeaf> {
build_batch_commitment(4)
}
#[test]
fn cbor_golden_bytes_can_be_decoded() {
let decoded = MerkleTreeBatchCommitment::<
ConcatenationHash,
MerkleTreeConcatenationLeaf,
>::from_bytes(GOLDEN_CBOR_BYTES)
.expect("CBOR golden bytes deserialization should not fail");
assert_eq!(golden_value(), decoded);
}
#[test]
fn cbor_encoding_is_stable() {
let bytes = MerkleTreeBatchCommitment::to_bytes(&golden_value())
.expect("CBOR serialization should not fail");
assert_eq!(GOLDEN_CBOR_BYTES.as_slice(), bytes.as_slice());
}
}
}