use std::mem;
use nimiq_collections::BitSet;
use nimiq_hash::{Blake2bHash, HashOutput, Hasher, SerializeContent};
use nimiq_serde::{Deserialize, Serialize};
pub struct IncrementalMerkleProofBuilder<H: HashOutput> {
tree: Vec<Vec<H>>,
chunk_size: usize,
}
impl<H: HashOutput> IncrementalMerkleProofBuilder<H> {
pub fn new(chunk_size: usize) -> Result<Self, IncrementalMerkleProofError> {
if chunk_size == 0 {
return Err(IncrementalMerkleProofError::InvalidChunkSize);
}
Ok(IncrementalMerkleProofBuilder {
tree: vec![vec![]],
chunk_size,
})
}
pub fn len(&self) -> usize {
self.tree[0].len()
}
pub fn is_empty(&self) -> bool {
self.tree[0].is_empty()
}
pub fn push_item<T: SerializeContent>(&mut self, value: &T) -> usize {
self.push(H::Builder::default().chain(value).finish())
}
pub fn push(&mut self, hash: H) -> usize {
self.tree[0].push(hash);
self.update();
self.tree[0].len() / self.chunk_size
}
pub fn pop(&mut self) -> Option<H> {
let result = self.tree[0].pop()?;
self.update();
Some(result)
}
fn update(&mut self) {
if self.tree[0].is_empty() {
return;
}
let mut current_level = 0;
let mut current_pos = self.tree[0].len() - 1;
while current_pos > 0 {
if current_pos + 1 != self.tree[current_level].len() {
self.tree[current_level].pop();
}
let level = &self.tree[current_level];
let hash = if current_pos % 2 == 1 {
H::Builder::default()
.chain(&level[current_pos - 1])
.chain(&level[current_pos])
.finish()
} else {
level[current_pos].clone()
};
current_level += 1;
current_pos /= 2;
if current_level >= self.tree.len() {
self.tree.push(vec![]);
}
if current_pos >= self.tree[current_level].len() {
self.tree[current_level].push(hash);
} else {
self.tree[current_level][current_pos] = hash;
}
}
if current_level + 1 != self.tree.len() {
self.tree.pop();
}
}
pub fn root(&self) -> Option<&H> {
self.tree.last()?.first()
}
pub fn chunks(&self) -> Vec<IncrementalMerkleProof<H>> {
let num_chunks = self.len().div_ceil(self.chunk_size);
let mut chunks = Vec::with_capacity(num_chunks);
for i in 0..num_chunks {
chunks.push(self.chunk(i).unwrap());
}
chunks
}
pub fn chunk(&self, index: usize) -> Option<IncrementalMerkleProof<H>> {
if index * self.chunk_size >= self.len() {
return None;
}
let mut proof = IncrementalMerkleProof::empty(self.len());
let mut current_pos = (index + 1) * self.chunk_size - 1;
let mut current_level = 0;
while current_pos + 1 < self.tree[current_level].len() {
if current_pos.is_multiple_of(2) {
proof
.nodes
.push(self.tree[current_level][current_pos + 1].clone());
}
current_level += 1;
current_pos /= 2;
if current_level >= self.tree.len() {
break;
}
}
Some(proof)
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(bound(deserialize = "H: HashOutput"))]
pub struct IncrementalMerkleProof<H: HashOutput> {
total_len: u32,
nodes: Vec<H>,
}
#[derive(Debug)]
pub struct IncrementalMerkleProofResult<H: HashOutput> {
root: H,
helper_nodes: Vec<H>,
next_index: usize,
}
impl<H: HashOutput> IncrementalMerkleProofResult<H> {
#[inline]
pub fn root(&self) -> &H {
&self.root
}
#[inline]
pub fn helper_nodes(&self) -> &[H] {
&self.helper_nodes
}
#[inline]
pub fn next_index(&self) -> usize {
self.next_index
}
}
impl<H: HashOutput> IncrementalMerkleProof<H> {
pub fn empty(total_len: usize) -> Self {
IncrementalMerkleProof {
total_len: total_len as u32,
nodes: Vec::new(),
}
}
pub fn compute_root_from_values<T: SerializeContent>(
&self,
chunk_values: &[T],
previous_result: Option<&IncrementalMerkleProofResult<H>>,
) -> Result<IncrementalMerkleProofResult<H>, IncrementalMerkleProofError> {
let hashes: Vec<H> = chunk_values
.iter()
.map(|v| H::Builder::default().chain(v).finish())
.collect();
self.compute_root(&hashes, previous_result)
}
pub fn compute_root(
&self,
chunk_hashes: &[H],
previous_result: Option<&IncrementalMerkleProofResult<H>>,
) -> Result<IncrementalMerkleProofResult<H>, IncrementalMerkleProofError> {
if chunk_hashes.is_empty() {
return Err(IncrementalMerkleProofError::InvalidChunkSize);
}
let index_offset = previous_result.map(|r| r.next_index).unwrap_or(0);
let empty_vec = Vec::new();
let helper_nodes = previous_result
.map(|r| &r.helper_nodes)
.unwrap_or(&empty_vec);
let mut helper_index = 0;
let mut proof_index = 0;
let mut helper_output = Vec::new();
let mut current_level = chunk_hashes;
let mut current_level_owned = vec![];
let mut current_proof_nodes = BitSet::new();
let mut level_leftmost = index_offset;
let mut level_rightmost = index_offset + chunk_hashes.len() - 1;
let mut current_position = (level_leftmost / 2) * 2;
let mut level_len = self.total_len;
let mut next_level = vec![];
let mut next_proof_nodes = BitSet::new();
let mut level_offset = current_position;
let depth = f64::from(self.total_len).log2().ceil() as u32;
for _ in 0..depth {
while current_position <= level_rightmost {
let left_hash = if current_position < level_leftmost {
helper_index += 1;
helper_nodes
.get(helper_index - 1)
.ok_or(IncrementalMerkleProofError::InvalidPreviousProof)?
} else {
current_level
.get(current_position - level_leftmost)
.ok_or(IncrementalMerkleProofError::InvalidChunkSize)?
};
if current_position + 1 >= level_len as usize {
if current_proof_nodes.contains(current_position - level_offset) {
let next_proof_offset = level_offset / 4 * 2;
next_proof_nodes.insert(current_position / 2 - next_proof_offset);
}
next_level.push(left_hash.clone());
break;
}
let right_hash = if current_position + 1 > level_rightmost {
proof_index += 1;
current_proof_nodes.insert(current_position + 1 - level_offset);
self.nodes
.get(proof_index - 1)
.ok_or(IncrementalMerkleProofError::InvalidProof)?
} else {
current_level
.get(current_position + 1 - level_leftmost)
.ok_or(IncrementalMerkleProofError::InvalidChunkSize)?
};
let next_hash = H::Builder::default()
.chain(left_hash)
.chain(right_hash)
.finish();
if !current_proof_nodes.contains(current_position - level_offset)
&& current_proof_nodes.contains(current_position + 1 - level_offset)
{
helper_output.push(left_hash.clone());
}
if current_proof_nodes.contains(current_position - level_offset)
|| current_proof_nodes.contains(current_position + 1 - level_offset)
{
let next_proof_offset = level_offset / 4 * 2;
next_proof_nodes.insert(current_position / 2 - next_proof_offset);
}
next_level.push(next_hash);
current_position += 2;
}
current_level_owned = mem::take(&mut next_level);
current_level = ¤t_level_owned;
current_proof_nodes = mem::replace(&mut next_proof_nodes, BitSet::new());
level_leftmost /= 2;
level_rightmost /= 2;
level_len = level_len.div_ceil(2);
current_position = (level_leftmost / 2) * 2;
level_offset = current_position;
}
if helper_index < helper_nodes.len() {
return Err(IncrementalMerkleProofError::InvalidPreviousProof);
}
if proof_index < self.nodes.len() {
return Err(IncrementalMerkleProofError::InvalidProof);
}
let root = current_level_owned
.pop()
.ok_or(IncrementalMerkleProofError::InvalidProof)?;
if !current_level_owned.is_empty() {
return Err(IncrementalMerkleProofError::InvalidProof);
}
if let Some(prev_root) = previous_result.map(|r| &r.root)
&& prev_root != &root
{
return Err(IncrementalMerkleProofError::InvalidProof);
}
Ok(IncrementalMerkleProofResult {
root,
helper_nodes: helper_output,
next_index: index_offset + chunk_hashes.len(),
})
}
#[inline]
pub fn len(&self) -> usize {
self.nodes.len()
}
#[inline]
pub fn total_len(&self) -> usize {
self.total_len as usize
}
#[inline]
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum IncrementalMerkleProofError {
InvalidPreviousProof,
InvalidProof,
InvalidChunkSize,
}
pub type Blake2bIncrementalMerkleProof = IncrementalMerkleProof<Blake2bHash>;