use crate::borrow::Cow;
use crate::collections::{btree_map::Entry, BTreeMap};
use crate::helper::{get_peaks, parent_offset, pos_height_in_tree, sibling_offset};
use crate::mmr_store::{MMRBatch, MMRStore};
use crate::vec;
use crate::vec::Vec;
use crate::{Error, Merge, Result};
use core::fmt::Debug;
use core::marker::PhantomData;
pub struct MMR<T, M, S: MMRStore<T>> {
mmr_size: u64,
batch: MMRBatch<T, S>,
merge: PhantomData<M>,
}
impl<'a, T: Clone + PartialEq + Debug, M: Merge<Item = T>, S: MMRStore<T>> MMR<T, M, S> {
pub fn new(mmr_size: u64, store: S) -> Self {
MMR {
mmr_size,
batch: MMRBatch::new(store),
merge: PhantomData,
}
}
fn find_elem<'b>(&self, pos: u64, hashes: &'b [T]) -> Result<Cow<'b, T>> {
let pos_offset = pos.checked_sub(self.mmr_size);
if let Some(elem) = pos_offset.and_then(|i| hashes.get(i as usize)) {
return Ok(Cow::Borrowed(elem));
}
let elem = self.batch.get_elem(pos)?.ok_or(Error::InconsistentStore)?;
Ok(Cow::Owned(elem))
}
pub fn mmr_size(&self) -> u64 {
self.mmr_size
}
pub fn is_empty(&self) -> bool {
self.mmr_size == 0
}
pub fn push(&mut self, elem: T) -> Result<u64> {
let mut elems: Vec<T> = Vec::new();
let elem_pos = self.mmr_size;
elems.push(elem);
let mut height = 0u32;
let mut pos = elem_pos;
while pos_height_in_tree(pos + 1) > height {
pos += 1;
let left_pos = pos - parent_offset(height);
let right_pos = left_pos + sibling_offset(height);
let left_elem = self.find_elem(left_pos, &elems)?;
let right_elem = self.find_elem(right_pos, &elems)?;
let parent_elem = M::merge(&left_elem, &right_elem);
elems.push(parent_elem);
height += 1
}
self.batch.append(elem_pos, elems);
self.mmr_size = pos + 1;
Ok(elem_pos)
}
pub fn get_root(&self) -> Result<T> {
if self.mmr_size == 0 {
return Err(Error::GetRootOnEmpty);
} else if self.mmr_size == 1 {
return self.batch.get_elem(0)?.ok_or(Error::InconsistentStore);
}
let peaks: Vec<T> = get_peaks(self.mmr_size)
.into_iter()
.map(|peak_pos| {
self.batch
.get_elem(peak_pos)
.and_then(|elem| elem.ok_or(Error::InconsistentStore))
})
.collect::<Result<Vec<T>>>()?;
self.bag_rhs_peaks(peaks)?.ok_or(Error::InconsistentStore)
}
fn bag_rhs_peaks(&self, mut rhs_peaks: Vec<T>) -> Result<Option<T>> {
while rhs_peaks.len() > 1 {
let right_peak = rhs_peaks.pop().expect("pop");
let left_peak = rhs_peaks.pop().expect("pop");
rhs_peaks.push(M::merge(&right_peak, &left_peak));
}
Ok(rhs_peaks.pop())
}
fn build_sub_merkle_path(
&self,
mut pos: u64,
mut height: u32,
peak_pos: u64,
stop_pos: u64,
tree_buf: &BTreeMap<u64, u32>,
proof: &mut Vec<T>,
) -> Result<(u64, u32)> {
while pos < peak_pos {
let pos_height = pos_height_in_tree(pos);
let next_height = pos_height_in_tree(pos + 1);
let sib_pos = if next_height > pos_height {
let sib_pos = pos - sibling_offset(height);
pos += 1;
sib_pos
} else {
let sib_pos = pos + sibling_offset(height);
pos += parent_offset(height);
sib_pos
};
height += 1;
if pos > stop_pos || tree_buf.contains_key(&pos) {
break;
}
proof.push(
self.batch
.get_elem(sib_pos)?
.ok_or(Error::InconsistentStore)?,
);
}
Ok((pos, height))
}
fn gen_proof_for_peak(
&self,
proof: &mut Vec<T>,
pos_list: Vec<u64>,
peak_pos: u64,
) -> Result<()> {
if pos_list.len() == 1 && pos_list == [peak_pos] {
return Ok(());
}
if pos_list.is_empty() {
proof.push(
self.batch
.get_elem(peak_pos)?
.ok_or(Error::InconsistentStore)?,
);
return Ok(());
}
let mut tree_buf: BTreeMap<u64, u32> =
pos_list.into_iter().map(|pos| (pos, 0u32)).collect();
loop {
let (&pos, &height) = tree_buf.iter().next().unwrap();
tree_buf.remove(&pos);
debug_assert!(pos <= peak_pos);
if pos == peak_pos {
break;
}
let next_pos = *tree_buf
.iter()
.next()
.map(|(pos, _height)| pos)
.unwrap_or(&peak_pos);
let (pos, height) =
self.build_sub_merkle_path(pos, height, peak_pos, next_pos, &tree_buf, proof)?;
tree_buf.entry(pos).or_insert(height);
}
Ok(())
}
pub fn gen_proof(&self, mut pos_list: Vec<u64>) -> Result<MerkleProof<T, M>> {
if pos_list.is_empty() {
return Err(Error::GenProofForInvalidLeaves);
}
if self.mmr_size == 1 && pos_list == [0] {
return Ok(MerkleProof::new(self.mmr_size, Vec::new()));
}
pos_list.sort_unstable();
let peaks = get_peaks(self.mmr_size);
let mut proof: Vec<T> = Vec::new();
let mut bagging_track = 0;
for peak_pos in peaks {
let pos_list: Vec<_> = take_while_vec(&mut pos_list, |&pos| pos <= peak_pos);
if pos_list.is_empty() {
bagging_track += 1;
} else {
bagging_track = 0;
}
self.gen_proof_for_peak(&mut proof, pos_list, peak_pos)?;
}
if !pos_list.is_empty() {
return Err(Error::GenProofForInvalidLeaves);
}
if bagging_track > 1 {
let rhs_peaks = proof.split_off(proof.len() - bagging_track);
proof.push(self.bag_rhs_peaks(rhs_peaks)?.expect("bagging rhs peaks"));
}
Ok(MerkleProof::new(self.mmr_size, proof))
}
pub fn commit(self) -> Result<()> {
self.batch.commit()
}
}
#[derive(Debug)]
pub struct MerkleProof<T, M> {
mmr_size: u64,
proof: Vec<T>,
merge: PhantomData<M>,
}
impl<T: PartialEq + Debug + Clone, M: Merge<Item = T>> MerkleProof<T, M> {
pub fn new(mmr_size: u64, proof: Vec<T>) -> Self {
MerkleProof {
mmr_size,
proof,
merge: PhantomData,
}
}
pub fn mmr_size(&self) -> u64 {
self.mmr_size
}
pub fn proof_items(&self) -> &[T] {
&self.proof
}
pub fn calculate_root(&self, leaves: Vec<(u64, T)>) -> Result<T> {
calculate_root::<_, M, _>(leaves, self.mmr_size, self.proof.iter())
}
pub fn calculate_root_with_new_leaf(
&self,
mut leaves: Vec<(u64, T)>,
new_pos: u64,
new_elem: T,
new_mmr_size: u64,
) -> Result<T> {
let pos_height = pos_height_in_tree(new_pos);
let next_height = pos_height_in_tree(new_pos + 1);
if next_height > pos_height {
let mut peaks_hashes =
calculate_peaks_hashes::<_, M, _>(leaves, self.mmr_size, self.proof.iter())?;
let peaks_pos = get_peaks(new_mmr_size);
let mut i = 0;
while peaks_pos[i] < new_pos {
i += 1
}
peaks_hashes[i..].reverse();
calculate_root::<_, M, _>(vec![(new_pos, new_elem)], new_mmr_size, peaks_hashes.iter())
} else {
leaves.push((new_pos, new_elem));
calculate_root::<_, M, _>(leaves, new_mmr_size, self.proof.iter())
}
}
pub fn verify(&self, root: T, leaves: Vec<(u64, T)>) -> Result<bool> {
self.calculate_root(leaves)
.map(|calculated_root| calculated_root == root)
}
}
fn calculate_peak_root<
'a,
T: 'a + PartialEq + Debug + Clone,
M: Merge<Item = T>,
I: Iterator<Item = &'a T>,
>(
leaves: Vec<(u64, T)>,
peak_pos: u64,
proof_iter: &mut I,
) -> Result<T> {
debug_assert!(!leaves.is_empty(), "can't be empty");
let mut tree_buf: BTreeMap<u64, (T, u32)> = leaves
.into_iter()
.map(|(pos, item)| (pos, (item, 0u32)))
.collect();
while !tree_buf.is_empty() {
let (pos, _item) = tree_buf.iter().next().unwrap();
let mut pos = *pos;
let (item, mut height) = tree_buf.remove(&pos).unwrap();
if pos == peak_pos {
return Ok(item);
}
let next_pos = tree_buf
.iter()
.next()
.map(|(pos, _item)| *pos)
.unwrap_or(peak_pos);
let mut item = item.clone();
while pos < peak_pos {
let pos_height = pos_height_in_tree(pos);
let next_height = pos_height_in_tree(pos + 1);
let is_right_side = next_height > pos_height;
if is_right_side {
pos += 1;
} else {
pos += parent_offset(height);
}
height += 1;
if pos > next_pos || tree_buf.contains_key(&pos) {
break;
}
let proof = proof_iter.next().ok_or(Error::CorruptedProof)?;
item = if is_right_side {
M::merge(proof, &item)
} else {
M::merge(&item, proof)
};
}
match tree_buf.entry(pos) {
Entry::Vacant(entry) => {
entry.insert((item, height));
}
Entry::Occupied(mut entry) => {
item = M::merge(&entry.get().0, &item);
entry.insert((item, height));
}
}
}
Err(Error::CorruptedProof)
}
fn calculate_peaks_hashes<
'a,
T: 'a + PartialEq + Debug + Clone,
M: Merge<Item = T>,
I: Iterator<Item = &'a T>,
>(
mut leaves: Vec<(u64, T)>,
mmr_size: u64,
mut proof_iter: I,
) -> Result<Vec<T>> {
if mmr_size == 1 && leaves.len() == 1 && leaves[0].0 == 0 {
return Ok(leaves.into_iter().map(|(_pos, item)| item).collect());
}
leaves.sort_by_key(|(pos, _)| *pos);
let peaks = get_peaks(mmr_size);
let mut peaks_hashes: Vec<T> = Vec::with_capacity(peaks.len() + 1);
for peak_pos in peaks {
let mut leaves: Vec<_> = take_while_vec(&mut leaves, |(pos, _)| *pos <= peak_pos);
let peak_root = if leaves.len() == 1 && leaves[0].0 == peak_pos {
leaves.remove(0).1
} else if leaves.is_empty() {
if let Some(peak_root) = proof_iter.next() {
peak_root.clone()
} else {
break;
}
} else {
calculate_peak_root::<_, M, _>(leaves, peak_pos, &mut proof_iter)?
};
peaks_hashes.push(peak_root.clone());
}
if !leaves.is_empty() {
return Err(Error::CorruptedProof);
}
if let Some(rhs_peaks_hashes) = proof_iter.next() {
peaks_hashes.push(rhs_peaks_hashes.clone());
}
if proof_iter.next().is_some() {
return Err(Error::CorruptedProof);
}
Ok(peaks_hashes)
}
fn bagging_peaks_hashes<'a, T: 'a + PartialEq + Debug + Clone, M: Merge<Item = T>>(
mut peaks_hashes: Vec<T>,
) -> Result<T> {
while peaks_hashes.len() > 1 {
let right_peak = peaks_hashes.pop().expect("pop");
let left_peak = peaks_hashes.pop().expect("pop");
peaks_hashes.push(M::merge(&right_peak, &left_peak));
}
peaks_hashes.pop().ok_or(Error::CorruptedProof)
}
fn calculate_root<
'a,
T: 'a + PartialEq + Debug + Clone,
M: Merge<Item = T>,
I: Iterator<Item = &'a T>,
>(
leaves: Vec<(u64, T)>,
mmr_size: u64,
proof_iter: I,
) -> Result<T> {
let peaks_hashes = calculate_peaks_hashes::<_, M, _>(leaves, mmr_size, proof_iter)?;
bagging_peaks_hashes::<_, M>(peaks_hashes)
}
fn take_while_vec<T, P: Fn(&T) -> bool>(v: &mut Vec<T>, p: P) -> Vec<T> {
for i in 0..v.len() {
if !p(&v[i]) {
return v.drain(..i).collect();
}
}
v.drain(..).collect()
}