use core::marker::PhantomData;
use super::{SparseMerkleInternalNode, SparseMerkleLeafNode, SparseMerkleNode};
use crate::{
storage::Node,
types::nibble::nibble_path::{skip_common_prefix, NibblePath},
Bytes32Ext, KeyHash, RootHash, SimpleHasher, ValueHash, SPARSE_MERKLE_PLACEHOLDER_HASH,
};
use alloc::vec::Vec;
use anyhow::{bail, ensure, format_err, Result};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, borsh::BorshSerialize, borsh::BorshDeserialize)]
pub struct SparseMerkleProof<H: SimpleHasher> {
#[serde(bound(serialize = "", deserialize = ""))]
leaf: Option<SparseMerkleLeafNode>,
siblings: Vec<SparseMerkleNode>,
#[borsh(bound(serialize = "", deserialize = ""))]
phantom_hasher: PhantomData<H>,
}
impl<H: SimpleHasher> core::fmt::Debug for SparseMerkleProof<H> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("SparseMerkleProof")
.field("leaf", &self.leaf)
.field("siblings", &self.siblings)
.field("phantom_hasher", &self.phantom_hasher)
.finish()
}
}
impl<H: SimpleHasher> PartialEq for SparseMerkleProof<H> {
fn eq(&self, other: &Self) -> bool {
self.leaf == other.leaf && self.siblings == other.siblings
}
}
impl<H: SimpleHasher> Clone for SparseMerkleProof<H> {
fn clone(&self) -> Self {
Self {
leaf: self.leaf.clone(),
siblings: self.siblings.clone(),
phantom_hasher: Default::default(),
}
}
}
impl<H: SimpleHasher> SparseMerkleProof<H> {
pub(crate) fn new(leaf: Option<SparseMerkleLeafNode>, siblings: Vec<SparseMerkleNode>) -> Self {
SparseMerkleProof {
leaf,
siblings,
phantom_hasher: Default::default(),
}
}
pub fn leaf(&self) -> Option<SparseMerkleLeafNode> {
self.leaf.clone()
}
pub(crate) fn siblings(&self) -> &[SparseMerkleNode] {
&self.siblings
}
pub(crate) fn take_siblings(self) -> Vec<SparseMerkleNode> {
self.siblings
}
pub fn verify_existence<V: AsRef<[u8]>>(
&self,
expected_root_hash: RootHash,
element_key: KeyHash,
element_value: V,
) -> Result<()> {
self.verify(expected_root_hash, element_key, Some(element_value))
}
pub fn verify_nonexistence(
&self,
expected_root_hash: RootHash,
element_key: KeyHash,
) -> Result<()> {
self.verify(expected_root_hash, element_key, None::<&[u8]>)
}
pub fn verify<V: AsRef<[u8]>>(
&self,
expected_root_hash: RootHash,
element_key: KeyHash,
element_value: Option<V>,
) -> Result<()> {
ensure!(
self.siblings.len() <= 256,
"Sparse Merkle Tree proof has more than {} ({}) siblings.",
256,
self.siblings.len(),
);
match (element_value, self.leaf.clone()) {
(Some(value), Some(leaf)) => {
ensure!(
element_key == leaf.key_hash,
"Keys do not match. Key in proof: {:?}. Expected key: {:?}.",
leaf.key_hash,
element_key
);
let hash: ValueHash = ValueHash::with::<H>(value);
ensure!(
hash == leaf.value_hash,
"Value hashes do not match. Value hash in proof: {:?}. \
Expected value hash: {:?}",
leaf.value_hash,
hash,
);
}
(Some(_value), None) => bail!("Expected inclusion proof. Found non-inclusion proof."),
(None, Some(leaf)) => {
ensure!(
element_key != leaf.key_hash,
"Expected non-inclusion proof, but key exists in proof.",
);
ensure!(
element_key.0.common_prefix_bits_len(&leaf.key_hash.0) >= self.siblings.len(),
"Key would not have ended up in the subtree where the provided key in proof \
is the only existing key, if it existed. So this is not a valid \
non-inclusion proof.",
);
}
(None, None) => {
}
}
let current_hash = self
.leaf
.clone()
.map_or(SPARSE_MERKLE_PLACEHOLDER_HASH, |leaf| leaf.hash::<H>());
let actual_root_hash = self
.siblings
.iter()
.zip(
element_key
.0
.iter_bits()
.rev()
.skip(256 - self.siblings.len()),
)
.fold(current_hash, |hash, (sibling_node, bit)| {
if bit {
SparseMerkleInternalNode::new(sibling_node.hash::<H>(), hash).hash::<H>()
} else {
SparseMerkleInternalNode::new(hash, sibling_node.hash::<H>()).hash::<H>()
}
});
ensure!(
actual_root_hash == expected_root_hash.0,
"Root hashes do not match. Actual root hash: {:?}. Expected root hash: {:?}.",
actual_root_hash,
expected_root_hash.0,
);
Ok(())
}
fn compute_new_merkle_path_on_split<V: AsRef<[u8]>>(
mut self,
leaf_node: SparseMerkleLeafNode,
new_element_key: KeyHash,
new_element_value: V,
) -> SparseMerkleProof<H> {
let new_key_path = NibblePath::new(new_element_key.0.to_vec());
let old_key_path = NibblePath::new(leaf_node.key_hash.0.to_vec());
let mut new_key_nibbles = new_key_path.nibbles();
let mut old_key_nibbles = old_key_path.nibbles();
let common_prefix_len = skip_common_prefix(&mut new_key_nibbles, &mut old_key_nibbles);
let num_siblings = self.siblings().len();
let default_leaves_to_add_to_the_path =
((4 * (common_prefix_len + 1) - num_siblings) / 4) * 4;
let mut default_siblings_leaf_nibble = 0;
let mut new_key_bits = new_key_nibbles.bits();
let mut old_key_bits = old_key_nibbles.bits();
while new_key_bits.next() == old_key_bits.next() {
default_siblings_leaf_nibble += 1;
}
let default_siblings_prev_root = (4 - (num_siblings % 4)) % 4;
let num_default_siblings = default_siblings_prev_root
+ default_leaves_to_add_to_the_path
+ default_siblings_leaf_nibble
- 4;
let mut new_siblings: Vec<SparseMerkleNode> = Vec::with_capacity(
num_default_siblings + 1 + self.siblings.len(),
);
new_siblings.push(SparseMerkleNode::Leaf(SparseMerkleLeafNode {
key_hash: leaf_node.key_hash,
value_hash: leaf_node.value_hash,
}));
new_siblings.resize(num_default_siblings + 1, SparseMerkleNode::Null);
new_siblings.append(&mut self.siblings);
SparseMerkleProof::new(
Some(SparseMerkleLeafNode::new(
new_element_key,
ValueHash::with::<H>(new_element_value),
)),
new_siblings,
)
}
fn check_compute_new_root<V: AsRef<[u8]>>(
self,
old_root_hash: RootHash,
new_element_key: KeyHash,
new_element_value: Option<V>,
) -> Result<RootHash> {
if let Some(new_element_value) = new_element_value {
match self.leaf {
Some(leaf_node) => {
ensure!(self.root_hash() == old_root_hash);
if new_element_key == leaf_node.key_hash {
let new_merkle_path: SparseMerkleProof<H> = SparseMerkleProof::new(
Some(SparseMerkleLeafNode::new(
new_element_key,
ValueHash::with::<H>(new_element_value),
)),
self.siblings,
);
Ok(new_merkle_path.root_hash())
} else {
let new_merkle_path = self.compute_new_merkle_path_on_split(
leaf_node,
new_element_key,
new_element_value,
);
Ok(new_merkle_path.root_hash())
}
}
None => {
ensure!(self
.verify_nonexistence(old_root_hash, new_element_key)
.is_ok());
let new_merkle_path: SparseMerkleProof<H> = SparseMerkleProof::new(
Some(SparseMerkleLeafNode::new(
new_element_key,
ValueHash::with::<H>(new_element_value),
)),
self.siblings,
);
Ok(new_merkle_path.root_hash())
}
}
} else {
if let Some(leaf_node) = self.leaf {
ensure!(self.root_hash() == old_root_hash);
ensure!(
new_element_key == leaf_node.key_hash,
"Key {:?} to remove doesn't match the leaf key {:?} supplied with the proof",
new_element_key,
leaf_node.key_hash
);
let mut siblings_it = self.siblings.into_iter().peekable();
let mut next_non_default_sib = SparseMerkleNode::Null;
while let Some(next_sibling) = siblings_it.peek() {
if *next_sibling != SparseMerkleNode::Null {
next_non_default_sib = *next_sibling;
break;
}
siblings_it.next();
}
let new_merkle_hash = match next_non_default_sib {
SparseMerkleNode::Internal(_) => {
let remaining_siblings_len = siblings_it.len();
RootHash(
siblings_it
.zip(
new_element_key
.0
.iter_bits()
.rev()
.skip(256 - remaining_siblings_len),
)
.fold(Node::new_null().hash::<H>(), |hash, (sibling_node, bit)| {
if bit {
SparseMerkleInternalNode::new(
sibling_node.hash::<H>(),
hash,
)
.hash::<H>()
} else {
SparseMerkleInternalNode::new(
hash,
sibling_node.hash::<H>(),
)
.hash::<H>()
}
}),
)
}
SparseMerkleNode::Leaf(_) => {
siblings_it.next();
while let Some(next_sibling) = siblings_it.peek() {
if *next_sibling != SparseMerkleNode::Null {
break;
}
siblings_it.next();
}
let remaining_siblings_len = siblings_it.len();
RootHash(
siblings_it
.zip(
new_element_key
.0
.iter_bits()
.rev()
.skip(256 - remaining_siblings_len),
)
.fold(
next_non_default_sib.hash::<H>(),
|hash, (sibling_node, bit)| {
if bit {
SparseMerkleInternalNode::new(
sibling_node.hash::<H>(),
hash,
)
.hash::<H>()
} else {
SparseMerkleInternalNode::new(
hash,
sibling_node.hash::<H>(),
)
.hash::<H>()
}
},
),
)
}
SparseMerkleNode::Null => RootHash(SPARSE_MERKLE_PLACEHOLDER_HASH),
};
Ok(new_merkle_hash)
} else {
Ok(old_root_hash)
}
}
}
pub fn root_hash(&self) -> RootHash {
let current_hash = self
.leaf
.clone()
.map_or(SPARSE_MERKLE_PLACEHOLDER_HASH, |leaf| leaf.hash::<H>());
let actual_root_hash = self
.siblings
.iter()
.zip(
self.leaf()
.expect("need leaf hash for root_hash")
.key_hash
.0
.iter_bits()
.rev()
.skip(256 - self.siblings.len()),
)
.fold(current_hash, |hash, (sibling_node, bit)| {
if bit {
SparseMerkleInternalNode::new(sibling_node.hash::<H>(), hash).hash::<H>()
} else {
SparseMerkleInternalNode::new(hash, sibling_node.hash::<H>()).hash::<H>()
}
});
RootHash(actual_root_hash)
}
}
#[derive(Debug, Serialize, Deserialize, borsh::BorshSerialize, borsh::BorshDeserialize)]
pub struct UpdateMerkleProof<H: SimpleHasher>(
#[borsh(bound(serialize = "", deserialize = ""))] Vec<SparseMerkleProof<H>>,
);
impl<H: SimpleHasher> UpdateMerkleProof<H> {
pub fn new(merkle_proofs: Vec<SparseMerkleProof<H>>) -> Self {
UpdateMerkleProof(merkle_proofs)
}
pub fn verify_update<V: AsRef<[u8]>>(
self,
old_root_hash: RootHash,
new_root_hash: RootHash,
updates: impl AsRef<[(KeyHash, Option<V>)]>,
) -> Result<()> {
let updates = updates.as_ref();
ensure!(
updates.len() == self.0.len(),
"Mismatched number of updates and proofs. Received {} proofs for {} updates",
self.0.len(),
updates.len()
);
let mut curr_root_hash = old_root_hash;
for (merkle_proof, (new_element_key, new_element_value)) in
self.0.into_iter().zip(updates.iter())
{
curr_root_hash = merkle_proof.check_compute_new_root(
curr_root_hash,
*new_element_key,
new_element_value.as_ref(),
)?;
}
ensure!(
curr_root_hash == new_root_hash,
"Root hashes do not match. Actual root hash: {:?}. Expected root hash: {:?}.",
curr_root_hash,
new_root_hash,
);
Ok(())
}
}
#[derive(Eq, Serialize, Deserialize, borsh::BorshSerialize, borsh::BorshDeserialize)]
pub struct SparseMerkleRangeProof<H: SimpleHasher> {
right_siblings: Vec<SparseMerkleNode>,
_phantom: PhantomData<H>,
}
impl<H: SimpleHasher> PartialEq for SparseMerkleRangeProof<H> {
fn eq(&self, other: &Self) -> bool {
self.right_siblings == other.right_siblings
}
}
impl<H: SimpleHasher> Clone for SparseMerkleRangeProof<H> {
fn clone(&self) -> Self {
Self {
right_siblings: self.right_siblings.clone(),
_phantom: self._phantom.clone(),
}
}
}
impl<H: SimpleHasher> core::fmt::Debug for SparseMerkleRangeProof<H> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("SparseMerkleRangeProof")
.field("right_siblings", &self.right_siblings)
.field("_phantom", &self._phantom)
.finish()
}
}
impl<H: SimpleHasher> SparseMerkleRangeProof<H> {
pub(crate) fn new(right_siblings: Vec<SparseMerkleNode>) -> Self {
Self {
right_siblings,
_phantom: Default::default(),
}
}
pub(crate) fn right_siblings(&self) -> &[SparseMerkleNode] {
&self.right_siblings
}
pub fn verify(
&self,
expected_root_hash: RootHash,
rightmost_known_leaf: SparseMerkleLeafNode,
left_siblings: Vec<[u8; 32]>,
) -> Result<()> {
let num_siblings = left_siblings.len() + self.right_siblings.len();
let mut left_sibling_iter = left_siblings.iter();
let mut right_sibling_iter = self.right_siblings().iter();
let mut current_hash = rightmost_known_leaf.hash::<H>();
for bit in rightmost_known_leaf
.key_hash()
.0
.iter_bits()
.rev()
.skip(256 - num_siblings)
{
let (left_hash, right_hash) = if bit {
(
*left_sibling_iter
.next()
.ok_or_else(|| format_err!("Missing left sibling."))?,
current_hash,
)
} else {
(
current_hash,
right_sibling_iter
.next()
.ok_or_else(|| format_err!("Missing right sibling."))?
.hash::<H>(),
)
};
current_hash = SparseMerkleInternalNode::new(left_hash, right_hash).hash::<H>();
}
ensure!(
current_hash == expected_root_hash.0,
"Root hashes do not match. Actual root hash: {:?}. Expected root hash: {:?}.",
current_hash,
expected_root_hash,
);
Ok(())
}
}
#[cfg(test)]
mod serialization_tests {
use sha2::Sha256;
use crate::{
proof::{SparseMerkleInternalNode, SparseMerkleLeafNode, SparseMerkleNode},
KeyHash, ValueHash,
};
use super::{SparseMerkleProof, SparseMerkleRangeProof};
fn get_test_proof() -> SparseMerkleProof<Sha256> {
SparseMerkleProof {
leaf: Some(SparseMerkleLeafNode::new(
KeyHash([1u8; 32]),
ValueHash([2u8; 32]),
)),
siblings: alloc::vec![SparseMerkleNode::Internal(SparseMerkleInternalNode::new(
[3u8; 32], [4u8; 32]
))],
phantom_hasher: Default::default(),
}
}
fn get_test_range_proof() -> SparseMerkleRangeProof<Sha256> {
SparseMerkleRangeProof {
right_siblings: alloc::vec![SparseMerkleNode::Internal(SparseMerkleInternalNode::new(
[3u8; 32], [4u8; 32]
))],
_phantom: Default::default(),
}
}
#[test]
fn test_sparse_merkle_proof_roundtrip_serde() {
let proof = get_test_proof();
let serialized_proof = serde_json::to_string(&proof).expect("serialization is infallible");
let deserialized =
serde_json::from_str(&serialized_proof).expect("serialized proof is valid");
assert_eq!(proof, deserialized);
}
#[test]
fn test_sparse_merkle_proof_roundtrip_borsh() {
use borsh::BorshDeserialize;
let proof = get_test_proof();
let serialized_proof = borsh::to_vec(&proof).expect("serialization is infallible");
let deserialized =
SparseMerkleProof::<Sha256>::deserialize(&mut serialized_proof.as_slice())
.expect("serialized proof is valid");
assert_eq!(proof, deserialized);
}
#[test]
fn test_sparse_merkle_range_proof_roundtrip_serde() {
let proof = get_test_range_proof();
let serialized_proof = serde_json::to_string(&proof).expect("serialization is infallible");
let deserialized =
serde_json::from_str(&serialized_proof).expect("serialized proof is valid");
assert_eq!(proof, deserialized);
}
#[test]
fn test_sparse_merkle_range_proof_roundtrip_borsh() {
use borsh::BorshDeserialize;
let proof = get_test_range_proof();
let serialized_proof = borsh::to_vec(&proof).expect("serialization is infallible");
let deserialized =
SparseMerkleRangeProof::<Sha256>::deserialize(&mut serialized_proof.as_slice())
.expect("serialized proof is valid");
assert_eq!(proof, deserialized);
}
}