use crate::error::{Error, Result};
use crate::merkle::{HASH_LEN, Hash, hash_children};
#[inline]
fn trailing_zeros(x: u64) -> u32 {
x.trailing_zeros()
}
#[inline]
fn bit_len(x: u64) -> u32 {
u64::BITS - x.leading_zeros()
}
#[inline]
fn ones_count(x: u64) -> u32 {
x.count_ones()
}
fn check_hash_len(bytes: &[u8]) -> Result<Hash> {
let arr: Hash = bytes.try_into().map_err(|_| Error::InvalidHashLength {
got: bytes.len(),
want: HASH_LEN,
})?;
Ok(arr)
}
#[inline]
fn inner_proof_size(index: u64, size: u64) -> u32 {
bit_len(index ^ (size - 1))
}
#[inline]
fn decomp_incl_proof(index: u64, size: u64) -> (u32, u32) {
let inner = inner_proof_size(index, size);
let border = ones_count(index >> inner);
(inner, border)
}
fn chain_inner(seed: Hash, proof: &[Hash], index: u64) -> Hash {
let mut acc = seed;
for (i, h) in proof.iter().enumerate() {
acc = if (index >> i) & 1 == 0 {
hash_children(&acc, h)
} else {
hash_children(h, &acc)
};
}
acc
}
fn chain_inner_right(seed: Hash, proof: &[Hash], index: u64) -> Hash {
let mut acc = seed;
for (i, h) in proof.iter().enumerate() {
if (index >> i) & 1 == 1 {
acc = hash_children(h, &acc);
}
}
acc
}
fn chain_border_right(seed: Hash, proof: &[Hash]) -> Hash {
let mut acc = seed;
for h in proof {
acc = hash_children(h, &acc);
}
acc
}
pub fn root_from_inclusion_proof(
index: u64,
size: u64,
leaf_hash: &[u8],
proof: &[Vec<u8>],
) -> Result<Hash> {
if index >= size {
return Err(Error::IndexBeyondSize { index, size });
}
let leaf = check_hash_len(leaf_hash)?;
let (inner, border) = decomp_incl_proof(index, size);
let want = (inner + border) as usize;
if proof.len() != want {
return Err(Error::WrongProofSize {
got: proof.len(),
want,
});
}
let nodes: Vec<Hash> = proof
.iter()
.map(|h| check_hash_len(h))
.collect::<Result<_>>()?;
let (inner_nodes, border_nodes) = nodes.split_at(inner as usize);
let res = chain_inner(leaf, inner_nodes, index);
Ok(chain_border_right(res, border_nodes))
}
pub fn verify_inclusion(
index: u64,
size: u64,
leaf_hash: &[u8],
proof: &[Vec<u8>],
root: &[u8],
) -> Result<()> {
let expected = check_hash_len(root)?;
let calc = root_from_inclusion_proof(index, size, leaf_hash, proof)?;
if calc == expected {
Ok(())
} else {
Err(Error::RootMismatch)
}
}
pub fn root_from_consistency_proof(
size1: u64,
size2: u64,
proof: &[Vec<u8>],
root1: &[u8],
) -> Result<Hash> {
if size2 < size1 {
return Err(Error::SizeRegression { size1, size2 });
}
if size1 == 0 {
return Err(Error::EmptyTreeConsistency);
}
let root1 = check_hash_len(root1)?;
if size1 == size2 {
if !proof.is_empty() {
return Err(Error::NonEmptyEqualSizeProof);
}
return Ok(root1);
}
if proof.is_empty() {
return Err(Error::WrongProofSize { got: 0, want: 1 });
}
let (inner_full, border) = decomp_incl_proof(size1 - 1, size2);
let shift = trailing_zeros(size1);
let inner = inner_full - shift;
let (seed, start): (Hash, usize) = if size1 == (1u64 << shift) {
(root1, 0)
} else {
(check_hash_len(&proof[0])?, 1)
};
let want = start + (inner + border) as usize;
if proof.len() != want {
return Err(Error::WrongProofSize {
got: proof.len(),
want,
});
}
let nodes: Vec<Hash> = proof[start..]
.iter()
.map(|h| check_hash_len(h))
.collect::<Result<_>>()?;
let (inner_nodes, border_nodes) = nodes.split_at(inner as usize);
let mask = (size1 - 1) >> shift;
let hash1 = chain_inner_right(seed, inner_nodes, mask);
let hash1 = chain_border_right(hash1, border_nodes);
if hash1 != root1 {
return Err(Error::RootMismatch);
}
let hash2 = chain_inner(seed, inner_nodes, mask);
Ok(chain_border_right(hash2, border_nodes))
}
pub fn verify_consistency(
size1: u64,
size2: u64,
proof: &[Vec<u8>],
root1: &[u8],
root2: &[u8],
) -> Result<()> {
let expected2 = check_hash_len(root2)?;
let calc2 = root_from_consistency_proof(size1, size2, proof, root1)?;
if calc2 == expected2 {
Ok(())
} else {
Err(Error::RootMismatch)
}
}