use crate::{
collections::{BTreeMap, VecDeque},
error::{Error, Result},
traits::{Hasher, Store, Value},
vec::Vec,
H256,
};
use core::{cmp::max, marker::PhantomData};
pub const EXPECTED_PATH_SIZE: usize = 16;
const TREE_HEIGHT: usize = 256;
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct BranchNode {
pub fork_height: u8,
pub key: H256,
pub node: H256,
pub sibling: H256,
}
impl BranchNode {
fn branch(&self, height: u8) -> (&H256, &H256) {
let is_right = self.key.get_bit(height);
if is_right {
(&self.sibling, &self.node)
} else {
(&self.node, &self.sibling)
}
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct LeafNode<V> {
pub key: H256,
pub value: V,
}
fn merge<H: Hasher + Default>(lhs: &H256, rhs: &H256) -> H256 {
if lhs.is_zero() {
return *rhs;
} else if rhs.is_zero() {
return *lhs;
}
let mut hasher = H::default();
hasher.write_h256(lhs);
hasher.write_h256(rhs);
hasher.finish()
}
fn hash_leaf<H: Hasher + Default>(key: &H256, value: &H256) -> H256 {
if value.is_zero() {
return H256::zero();
}
let mut hasher = H::default();
hasher.write_h256(key);
hasher.write_h256(value);
hasher.finish()
}
#[derive(Default, Debug)]
pub struct SparseMerkleTree<H, V, S> {
store: S,
root: H256,
phantom: PhantomData<(H, V)>,
}
impl<H: Hasher + Default, V: Value, S: Store<V>> SparseMerkleTree<H, V, S> {
pub fn new(root: H256, store: S) -> SparseMerkleTree<H, V, S> {
SparseMerkleTree {
root,
store,
phantom: PhantomData,
}
}
pub fn root(&self) -> &H256 {
&self.root
}
pub fn is_empty(&self) -> bool {
self.root.is_zero()
}
pub fn store(&self) -> &S {
&self.store
}
pub fn update(&mut self, key: H256, value: V) -> Result<&H256> {
let mut path: BTreeMap<_, _> = Default::default();
let mut node = self.root;
let mut branch = self.store.get_branch(&node)?;
let mut height = branch
.as_ref()
.map(|b| max(b.key.fork_height(&key), b.fork_height))
.unwrap_or(0);
while branch.is_some() {
let branch_node = branch.unwrap();
let fork_height = max(key.fork_height(&branch_node.key), branch_node.fork_height);
if height > branch_node.fork_height {
path.insert(fork_height, node);
break;
}
self.store.remove_branch(&node)?;
let (left, right) = branch_node.branch(height);
let is_right = key.get_bit(height);
let sibling;
if is_right {
if &node == right {
break;
}
sibling = *left;
node = *right;
} else {
if &node == left {
break;
}
sibling = *right;
node = *left;
}
path.insert(height, sibling);
branch = self.store.get_branch(&node)?;
if let Some(branch_node) = branch.as_ref() {
height = max(key.fork_height(&branch_node.key), branch_node.fork_height);
}
}
if let Some(leaf) = self.store.get_leaf(&node)? {
if leaf.key == key {
self.store.remove_leaf(&node)?;
}
}
let mut node = hash_leaf::<H>(&key, &value.to_h256());
if !node.is_zero() {
self.store.insert_leaf(node, LeafNode { key, value })?;
}
self.store.insert_branch(
node,
BranchNode {
key,
fork_height: 0,
node,
sibling: H256::zero(),
},
)?;
while !path.is_empty() {
let height = path.iter().next().map(|(height, _)| *height).unwrap();
let sibling = path.remove(&height).unwrap();
let is_right = key.get_bit(height as u8);
let parent = if is_right {
merge::<H>(&sibling, &node)
} else {
merge::<H>(&node, &sibling)
};
let branch_node = BranchNode {
fork_height: height as u8,
sibling,
node,
key,
};
self.store.insert_branch(parent, branch_node)?;
node = parent;
}
self.root = node;
Ok(&self.root)
}
pub fn get(&self, key: &H256) -> Result<V> {
let mut node = self.root;
while !node.is_zero() {
let branch_node = match self.store.get_branch(&node)? {
Some(branch_node) => branch_node,
None => {
break;
}
};
let is_right = key.get_bit(branch_node.fork_height as u8);
let (left, right) = branch_node.branch(branch_node.fork_height as u8);
if is_right {
node = *right;
} else {
node = *left;
}
if branch_node.fork_height == 0 {
break;
}
}
if node.is_zero() {
return Ok(V::zero());
}
match self.store.get_leaf(&node)? {
Some(leaf) if &leaf.key == key => Ok(leaf.value),
_ => Ok(V::zero()),
}
}
fn fetch_merkle_path(
&self,
key: &H256,
cache: &mut BTreeMap<(usize, H256), H256>,
) -> Result<()> {
let mut node = self.root;
let mut height = self
.store
.get_branch(&node)?
.map(|b| max(b.key.fork_height(&key), b.fork_height))
.unwrap_or(0);
while !node.is_zero() {
if node.is_zero() {
break;
}
match self.store.get_branch(&node)? {
Some(branch_node) => {
if height <= branch_node.fork_height {
} else {
let fork_height =
max(key.fork_height(&branch_node.key), branch_node.fork_height);
let is_right = key.get_bit(fork_height as u8);
let mut sibling_key = key.parent_path(fork_height as u8);
if is_right {
} else {
sibling_key.set_bit(height as u8);
};
if !node.is_zero() {
cache
.entry((fork_height as usize, sibling_key))
.or_insert(node);
}
break;
}
let (left, right) = branch_node.branch(height);
let is_right = key.get_bit(height);
let sibling;
if is_right {
if &node == right {
break;
}
sibling = *left;
node = *right;
} else {
if &node == left {
break;
}
sibling = *right;
node = *left;
}
let mut sibling_key = key.parent_path(height as u8);
if is_right {
} else {
sibling_key.set_bit(height as u8);
};
cache.insert((height as usize, sibling_key), sibling);
if let Some(branch_node) = self.store.get_branch(&node)? {
let fork_height =
max(key.fork_height(&branch_node.key), branch_node.fork_height);
height = fork_height;
}
}
None => break,
};
}
Ok(())
}
pub fn merkle_proof(&self, mut keys: Vec<H256>) -> Result<MerkleProof> {
if keys.is_empty() {
return Err(Error::EmptyKeys);
}
keys.sort_unstable();
let mut cache: BTreeMap<(usize, H256), H256> = Default::default();
for k in &keys {
self.fetch_merkle_path(k, &mut cache)?;
}
let mut proof: Vec<(H256, u8)> = Vec::with_capacity(EXPECTED_PATH_SIZE * keys.len());
let mut leaves_path: Vec<Vec<u8>> = Vec::with_capacity(keys.len());
leaves_path.resize_with(keys.len(), Default::default);
let keys_len = keys.len();
let mut queue: VecDeque<(H256, usize, usize)> = keys
.into_iter()
.enumerate()
.map(|(i, k)| (k, 0, i))
.collect();
while let Some((key, height, leaf_index)) = queue.pop_front() {
if queue.is_empty() && cache.is_empty() || height == TREE_HEIGHT {
if leaves_path[leaf_index].is_empty() {
leaves_path[leaf_index].push(core::u8::MAX);
}
break;
}
let mut sibling_key = key.parent_path(height as u8);
let is_right = key.get_bit(height as u8);
if is_right {
sibling_key.clear_bit(height as u8);
} else {
sibling_key.set_bit(height as u8);
}
if Some((&sibling_key, &height))
== queue
.front()
.map(|(sibling_key, height, _leaf_index)| (sibling_key, height))
{
let (_sibling_key, height, leaf_index) = queue.pop_front().unwrap();
leaves_path[leaf_index].push(height as u8);
} else {
match cache.remove(&(height, sibling_key)) {
Some(sibling) => {
debug_assert!(height <= core::u8::MAX as usize);
proof.push((sibling, height as u8));
}
None => {
if !is_right {
sibling_key.clear_bit(height as u8);
}
let parent_key = sibling_key;
queue.push_back((parent_key, height + 1, leaf_index));
continue;
}
}
}
leaves_path[leaf_index].push(height as u8);
if height < TREE_HEIGHT {
let parent_key = if is_right { sibling_key } else { key };
queue.push_back((parent_key, height + 1, leaf_index));
}
}
debug_assert_eq!(leaves_path.len(), keys_len);
Ok(MerkleProof::new(leaves_path, proof))
}
}
#[derive(Debug, Clone)]
pub struct MerkleProof {
leaves_path: Vec<Vec<u8>>,
proof: Vec<(H256, u8)>,
}
impl MerkleProof {
pub fn new(leaves_path: Vec<Vec<u8>>, proof: Vec<(H256, u8)>) -> Self {
MerkleProof { leaves_path, proof }
}
pub fn take(self) -> (Vec<Vec<u8>>, Vec<(H256, u8)>) {
let MerkleProof { leaves_path, proof } = self;
(leaves_path, proof)
}
pub fn leaves_count(&self) -> usize {
self.leaves_path.len()
}
pub fn leaves_path(&self) -> &Vec<Vec<u8>> {
&self.leaves_path
}
pub fn proof(&self) -> &Vec<(H256, u8)> {
&self.proof
}
pub fn compute_root<H: Hasher + Default>(self, mut leaves: Vec<(H256, H256)>) -> Result<H256> {
if leaves.is_empty() {
return Err(Error::EmptyKeys);
} else if leaves.len() != self.leaves_count() {
return Err(Error::IncorrectNumberOfLeaves {
expected: self.leaves_count(),
actual: leaves.len(),
});
}
let (leaves_path, proof) = self.take();
let mut leaves_path: Vec<VecDeque<_>> = leaves_path.into_iter().map(Into::into).collect();
let mut proof: VecDeque<_> = proof.into();
leaves.sort_unstable_by_key(|(k, _v)| *k);
let mut tree_buf: BTreeMap<_, _> = leaves
.into_iter()
.enumerate()
.map(|(i, (k, v))| ((0, k), (i, hash_leaf::<H>(&k, &v))))
.collect();
while !tree_buf.is_empty() {
let (&(mut height, key), &(leaf_index, node)) = tree_buf.iter().next().unwrap();
tree_buf.remove(&(height, key));
if proof.is_empty() && tree_buf.is_empty() {
return Ok(node);
} else if height == TREE_HEIGHT {
if !proof.is_empty() {
return Err(Error::CorruptedProof);
}
return Ok(node);
}
let mut sibling_key = key.parent_path(height as u8);
if !key.get_bit(height as u8) {
sibling_key.set_bit(height as u8)
}
let (sibling, sibling_height) =
if Some(&(height, sibling_key)) == tree_buf.keys().next() {
let (_leaf_index, sibling) = tree_buf
.remove(&(height, sibling_key))
.expect("pop sibling");
(sibling, height)
} else {
let merge_height = leaves_path[leaf_index]
.front()
.map(|h| *h as usize)
.unwrap_or(height);
if height != merge_height {
debug_assert!(height < merge_height);
let parent_key = key.copy_bits(merge_height as u8..);
tree_buf.insert((merge_height, parent_key), (leaf_index, node));
continue;
}
let (node, height) = proof.pop_front().expect("pop proof");
debug_assert_eq!(height, leaves_path[leaf_index][0]);
(node, height as usize)
};
debug_assert!(height <= sibling_height);
if height < sibling_height {
height = sibling_height;
}
let parent_key = key.parent_path(height as u8);
let parent = if key.get_bit(height as u8) {
merge::<H>(&sibling, &node)
} else {
merge::<H>(&node, &sibling)
};
leaves_path[leaf_index].pop_front();
tree_buf.insert((height + 1, parent_key), (leaf_index, parent));
}
Err(Error::CorruptedProof)
}
pub fn verify<H: Hasher + Default>(
self,
root: &H256,
leaves: Vec<(H256, H256)>,
) -> Result<bool> {
let calculated_root = self.compute_root::<H>(leaves)?;
Ok(&calculated_root == root)
}
}