use moloch_core::proof::{Position, ProofNode};
use moloch_core::{hash_pair, Error, Hash, Result};
use serde::{Deserialize, Serialize};
use crate::store::MmrStore;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct MmrProof {
pub pos: u64,
pub leaf: Hash,
pub siblings: Vec<ProofNode>,
pub peak_pos: u64,
pub peaks: Vec<Hash>,
pub mmr_size: u64,
}
#[derive(Clone)]
pub struct Mmr<S: MmrStore> {
store: S,
}
impl<S: MmrStore> Mmr<S> {
pub fn new(store: S) -> Self {
Self { store }
}
pub fn from_store(store: S) -> Self {
Self { store }
}
pub fn leaf_count(&self) -> u64 {
size_to_leaf_count(self.store.size())
}
pub fn size(&self) -> u64 {
self.store.size()
}
pub fn append(&mut self, leaf: Hash) -> Result<u64> {
let leaf_pos = self.store.size();
self.store.insert(leaf_pos, leaf)?;
let mut current_hash = leaf;
let mut current_pos = leaf_pos;
let mut height = 0u32;
while let Some(sibling_pos) = left_sibling(current_pos, height) {
if sibling_pos >= leaf_pos {
break;
}
let sibling_hash = self.store.get(sibling_pos)?.ok_or_else(|| {
Error::internal(format!("missing sibling at pos {}", sibling_pos))
})?;
let parent_pos = current_pos + 1;
let parent_hash = hash_pair(sibling_hash, current_hash);
self.store.insert(parent_pos, parent_hash)?;
current_pos = parent_pos;
current_hash = parent_hash;
height += 1;
}
Ok(leaf_pos)
}
pub fn root(&self) -> Hash {
let peaks = self.peaks().unwrap_or_default();
bag_peaks(&peaks)
}
pub fn peaks(&self) -> Result<Vec<Hash>> {
let positions = peak_positions(self.store.size());
let mut peaks = Vec::with_capacity(positions.len());
for pos in positions {
let hash = self
.store
.get(pos)?
.ok_or_else(|| Error::internal(format!("missing peak at {}", pos)))?;
peaks.push(hash);
}
Ok(peaks)
}
pub fn peak_positions(&self) -> Vec<u64> {
peak_positions(self.store.size())
}
pub fn proof(&self, pos: u64) -> Result<MmrProof> {
let size = self.store.size();
if pos >= size {
return Err(Error::not_found(format!(
"position {} >= size {}",
pos, size
)));
}
let height = pos_height(pos);
if height != 0 {
return Err(Error::invalid_proof(format!("{} is not a leaf", pos)));
}
let leaf = self
.store
.get(pos)?
.ok_or_else(|| Error::not_found(format!("no node at {}", pos)))?;
let mut siblings = Vec::new();
let mut current_pos = pos;
let mut current_height = 0u32;
loop {
if let Some(sib_pos) = right_sibling(current_pos, current_height) {
if sib_pos < size {
let sib_hash = self.store.get(sib_pos)?.ok_or_else(|| {
Error::internal(format!("missing sibling at {}", sib_pos))
})?;
siblings.push(ProofNode {
hash: sib_hash,
position: Position::Right,
});
current_pos = sib_pos + 1;
current_height += 1;
continue;
}
}
if let Some(sib_pos) = left_sibling(current_pos, current_height) {
let sib_hash = self
.store
.get(sib_pos)?
.ok_or_else(|| Error::internal(format!("missing sibling at {}", sib_pos)))?;
siblings.push(ProofNode {
hash: sib_hash,
position: Position::Left,
});
current_pos += 1;
current_height += 1;
continue;
}
break;
}
let peak_pos = current_pos;
let peaks = self.peaks()?;
Ok(MmrProof {
pos,
leaf,
siblings,
peak_pos,
peaks,
mmr_size: size,
})
}
pub fn verify(&self, proof: &MmrProof) -> Result<bool> {
let mut current = proof.leaf;
for node in &proof.siblings {
current = match node.position {
Position::Left => hash_pair(node.hash, current),
Position::Right => hash_pair(current, node.hash),
};
}
let peak_positions = self.peak_positions();
let peak_idx = peak_positions.iter().position(|&p| p == proof.peak_pos);
match peak_idx {
Some(idx) => {
if proof.peaks.get(idx) != Some(¤t) {
return Ok(false);
}
let proof_root = bag_peaks(&proof.peaks);
Ok(proof_root == self.root())
}
None => Err(Error::invalid_proof("peak position not found")),
}
}
pub fn get(&self, pos: u64) -> Result<Option<Hash>> {
self.store.get(pos)
}
pub fn append_batch(&mut self, leaves: &[Hash]) -> Result<Vec<u64>> {
if leaves.is_empty() {
return Ok(vec![]);
}
let mut positions = Vec::with_capacity(leaves.len());
for &leaf in leaves {
let pos = self.append(leaf)?;
positions.push(pos);
}
Ok(positions)
}
pub fn proof_batch(&self, positions: &[u64]) -> Result<Vec<MmrProof>>
where
S: Sync,
{
use rayon::prelude::*;
if positions.len() < 16 {
return positions.iter().map(|&pos| self.proof(pos)).collect();
}
positions.par_iter().map(|&pos| self.proof(pos)).collect()
}
pub fn verify_batch(&self, proofs: &[MmrProof]) -> Result<bool>
where
S: Sync,
{
use rayon::prelude::*;
if proofs.len() < 16 {
for proof in proofs {
if !self.verify(proof)? {
return Ok(false);
}
}
return Ok(true);
}
let results: Result<Vec<bool>> =
proofs.par_iter().map(|proof| self.verify(proof)).collect();
results.map(|v| v.into_iter().all(|b| b))
}
pub fn proof_range(&self, start_leaf: u64, count: u64) -> Result<Vec<MmrProof>>
where
S: Sync,
{
let leaf_count = self.leaf_count();
if start_leaf + count > leaf_count {
return Err(Error::not_found(format!(
"range {}..{} exceeds leaf count {}",
start_leaf,
start_leaf + count,
leaf_count
)));
}
let positions: Vec<u64> = (0..count)
.filter_map(|i| leaf_to_pos(start_leaf + i))
.collect();
self.proof_batch(&positions)
}
}
fn pos_height(pos: u64) -> u32 {
pos_height_in_tree(pos, find_containing_tree_size(pos))
}
fn find_containing_tree_size(pos: u64) -> u64 {
let mut size = 1u64;
while size <= pos {
size = size * 2 + 1; }
size
}
fn pos_height_in_tree(pos: u64, tree_size: u64) -> u32 {
if tree_size == 1 {
return 0; }
let root_pos = tree_size - 1;
if pos == root_pos {
return (tree_size + 1).trailing_zeros() - 1;
}
let subtree_size = (tree_size - 1) / 2;
if pos < subtree_size {
pos_height_in_tree(pos, subtree_size)
} else {
let local_pos = pos - subtree_size;
pos_height_in_tree(local_pos, subtree_size)
}
}
fn left_sibling(pos: u64, height: u32) -> Option<u64> {
let offset = sibling_offset(height);
if pos >= offset {
let sib = pos - offset;
if pos_height(sib) == height {
return Some(sib);
}
}
None
}
fn right_sibling(pos: u64, height: u32) -> Option<u64> {
let offset = sibling_offset(height);
let sib = pos + offset;
if pos_height(sib) == height {
Some(sib)
} else {
None
}
}
fn sibling_offset(height: u32) -> u64 {
(1u64 << (height + 1)) - 1
}
fn size_to_leaf_count(size: u64) -> u64 {
if size == 0 {
return 0;
}
let mut count = 0u64;
let mut remaining = size;
while remaining > 0 {
let bits = 64 - remaining.leading_zeros();
let tree_leaves = 1u64 << (bits - 1);
let tree_size = (tree_leaves << 1) - 1;
if tree_size <= remaining {
count += tree_leaves;
remaining -= tree_size;
} else {
let tree_leaves = tree_leaves >> 1;
let tree_size = if tree_leaves > 0 {
(tree_leaves << 1) - 1
} else {
break;
};
if tree_size <= remaining {
count += tree_leaves;
remaining -= tree_size;
}
}
}
count
}
fn peak_positions(size: u64) -> Vec<u64> {
if size == 0 {
return vec![];
}
let mut peaks = Vec::new();
let mut pos = 0u64;
while pos < size {
let remaining = size - pos;
let bits = 64 - remaining.leading_zeros();
for h in (0..bits).rev() {
let tree_size = (1u64 << (h + 1)) - 1;
if tree_size <= remaining {
peaks.push(pos + tree_size - 1);
pos += tree_size;
break;
}
}
if pos == 0 {
break;
}
}
peaks
}
fn leaf_to_pos(leaf_index: u64) -> Option<u64> {
Some(2 * leaf_index - (leaf_index.count_ones() as u64))
}
fn bag_peaks(peaks: &[Hash]) -> Hash {
if peaks.is_empty() {
return Hash::ZERO;
}
if peaks.len() == 1 {
return peaks[0];
}
let mut root = *peaks.last().unwrap();
for peak in peaks.iter().rev().skip(1) {
root = hash_pair(*peak, root);
}
root
}
#[cfg(test)]
mod tests {
use super::*;
use crate::MemStore;
use moloch_core::hash;
fn make_leaf(s: &str) -> Hash {
hash(s.as_bytes())
}
#[test]
fn test_pos_height() {
assert_eq!(pos_height(0), 0); assert_eq!(pos_height(1), 0); assert_eq!(pos_height(2), 1); assert_eq!(pos_height(3), 0); assert_eq!(pos_height(6), 2); }
#[test]
fn test_single_leaf() {
let mut mmr = Mmr::new(MemStore::new());
let pos = mmr.append(make_leaf("leaf0")).unwrap();
assert_eq!(pos, 0);
assert_eq!(mmr.leaf_count(), 1);
assert_eq!(mmr.size(), 1);
assert_eq!(mmr.peak_positions(), vec![0]);
}
#[test]
fn test_two_leaves() {
let mut mmr = Mmr::new(MemStore::new());
let leaf0 = make_leaf("leaf0");
let leaf1 = make_leaf("leaf1");
let pos0 = mmr.append(leaf0).unwrap();
let pos1 = mmr.append(leaf1).unwrap();
assert_eq!(pos0, 0);
assert_eq!(pos1, 1);
assert_eq!(mmr.leaf_count(), 2);
assert_eq!(mmr.size(), 3);
let peaks = mmr.peak_positions();
assert_eq!(peaks, vec![2]);
let expected = hash_pair(leaf0, leaf1);
assert_eq!(mmr.get(2).unwrap(), Some(expected));
}
#[test]
fn test_three_leaves() {
let mut mmr = Mmr::new(MemStore::new());
mmr.append(make_leaf("leaf0")).unwrap();
mmr.append(make_leaf("leaf1")).unwrap();
mmr.append(make_leaf("leaf2")).unwrap();
assert_eq!(mmr.leaf_count(), 3);
assert_eq!(mmr.size(), 4);
assert_eq!(mmr.peak_positions(), vec![2, 3]);
}
#[test]
fn test_four_leaves() {
let mut mmr = Mmr::new(MemStore::new());
for i in 0..4 {
mmr.append(make_leaf(&format!("leaf{}", i))).unwrap();
}
assert_eq!(mmr.leaf_count(), 4);
assert_eq!(mmr.size(), 7);
assert_eq!(mmr.peak_positions(), vec![6]);
}
#[test]
fn test_five_leaves() {
let mut mmr = Mmr::new(MemStore::new());
for i in 0..5 {
mmr.append(make_leaf(&format!("leaf{}", i))).unwrap();
}
assert_eq!(mmr.leaf_count(), 5);
assert_eq!(mmr.size(), 8);
assert_eq!(mmr.peak_positions(), vec![6, 7]);
}
#[test]
fn test_many_leaves() {
let mut mmr = Mmr::new(MemStore::new());
for i in 0..7 {
mmr.append(make_leaf(&format!("leaf{}", i))).unwrap();
}
assert!(mmr.size() > 7); assert!(!mmr.root().is_zero());
let peaks = mmr.peak_positions();
assert!(!peaks.is_empty());
}
#[test]
fn test_proof_single() {
let mut mmr = Mmr::new(MemStore::new());
let leaf = make_leaf("only");
let pos = mmr.append(leaf).unwrap();
let proof = mmr.proof(pos).unwrap();
assert_eq!(proof.leaf, leaf);
assert!(proof.siblings.is_empty());
assert!(mmr.verify(&proof).unwrap());
}
#[test]
fn test_proof_two_leaves() {
let mut mmr = Mmr::new(MemStore::new());
let leaf0 = make_leaf("leaf0");
let leaf1 = make_leaf("leaf1");
mmr.append(leaf0).unwrap();
mmr.append(leaf1).unwrap();
let proof0 = mmr.proof(0).unwrap();
assert_eq!(proof0.siblings.len(), 1);
assert_eq!(proof0.siblings[0].hash, leaf1);
assert_eq!(proof0.siblings[0].position, Position::Right);
assert!(mmr.verify(&proof0).unwrap());
let proof1 = mmr.proof(1).unwrap();
assert_eq!(proof1.siblings.len(), 1);
assert_eq!(proof1.siblings[0].hash, leaf0);
assert_eq!(proof1.siblings[0].position, Position::Left);
assert!(mmr.verify(&proof1).unwrap());
}
#[test]
fn test_proof_multiple_peaks() {
let mut mmr = Mmr::new(MemStore::new());
for i in 0..5 {
mmr.append(make_leaf(&format!("leaf{}", i))).unwrap();
}
let proof = mmr.proof(7).unwrap();
assert!(mmr.verify(&proof).unwrap());
assert_eq!(proof.peaks.len(), 2);
}
#[test]
fn test_root_deterministic() {
let mut mmr1 = Mmr::new(MemStore::new());
let mut mmr2 = Mmr::new(MemStore::new());
for i in 0..5 {
let leaf = make_leaf(&format!("leaf{}", i));
mmr1.append(leaf).unwrap();
mmr2.append(leaf).unwrap();
}
assert_eq!(mmr1.root(), mmr2.root());
}
#[test]
fn test_root_changes() {
let mut mmr = Mmr::new(MemStore::new());
mmr.append(make_leaf("a")).unwrap();
let r1 = mmr.root();
mmr.append(make_leaf("b")).unwrap();
let r2 = mmr.root();
assert_ne!(r1, r2);
}
#[test]
fn test_bag_peaks() {
let p1 = hash(b"p1");
let p2 = hash(b"p2");
let p3 = hash(b"p3");
assert_eq!(bag_peaks(&[p1]), p1);
assert_eq!(bag_peaks(&[p1, p2]), hash_pair(p1, p2));
let h23 = hash_pair(p2, p3);
assert_eq!(bag_peaks(&[p1, p2, p3]), hash_pair(p1, h23));
}
#[test]
fn test_leaf_to_pos() {
assert_eq!(leaf_to_pos(0), Some(0));
assert_eq!(leaf_to_pos(1), Some(1));
assert_eq!(leaf_to_pos(2), Some(3)); assert_eq!(leaf_to_pos(3), Some(4));
assert_eq!(leaf_to_pos(4), Some(7)); assert_eq!(leaf_to_pos(5), Some(8));
assert_eq!(leaf_to_pos(6), Some(10)); assert_eq!(leaf_to_pos(7), Some(11));
}
#[test]
fn test_append_batch() {
let mut mmr = Mmr::new(MemStore::new());
let leaves: Vec<Hash> = (0..10)
.map(|i| make_leaf(&format!("batch-leaf-{}", i)))
.collect();
let positions = mmr.append_batch(&leaves).unwrap();
assert_eq!(positions.len(), 10);
assert_eq!(mmr.leaf_count(), 10);
for &pos in &positions {
let proof = mmr.proof(pos).unwrap();
assert!(mmr.verify(&proof).unwrap());
}
}
#[test]
fn test_proof_batch() {
let mut mmr = Mmr::new(MemStore::new());
let leaves: Vec<Hash> = (0..20)
.map(|i| make_leaf(&format!("batch-leaf-{}", i)))
.collect();
let positions = mmr.append_batch(&leaves).unwrap();
let proofs = mmr.proof_batch(&positions).unwrap();
assert_eq!(proofs.len(), 20);
for proof in &proofs {
assert!(mmr.verify(proof).unwrap());
}
}
#[test]
fn test_verify_batch() {
let mut mmr = Mmr::new(MemStore::new());
let leaves: Vec<Hash> = (0..20)
.map(|i| make_leaf(&format!("batch-leaf-{}", i)))
.collect();
let positions = mmr.append_batch(&leaves).unwrap();
let proofs = mmr.proof_batch(&positions).unwrap();
assert!(mmr.verify_batch(&proofs).unwrap());
}
#[test]
fn test_proof_range() {
let mut mmr = Mmr::new(MemStore::new());
let leaves: Vec<Hash> = (0..20)
.map(|i| make_leaf(&format!("range-leaf-{}", i)))
.collect();
mmr.append_batch(&leaves).unwrap();
let proofs = mmr.proof_range(5, 10).unwrap();
assert_eq!(proofs.len(), 10);
for proof in &proofs {
assert!(mmr.verify(proof).unwrap());
}
}
#[test]
fn test_proof_range_bounds_check() {
let mut mmr = Mmr::new(MemStore::new());
let leaves: Vec<Hash> = (0..10)
.map(|i| make_leaf(&format!("range-leaf-{}", i)))
.collect();
mmr.append_batch(&leaves).unwrap();
let result = mmr.proof_range(5, 10);
assert!(result.is_err());
}
}