use alloc::vec::Vec;
use super::{
super::{InnerNodeInfo, MerklePath},
MmrDelta, MmrError, MmrPath, MmrPeaks, MmrProof,
forest::{Forest, TreeSizeIterator},
nodes_from_mask,
};
use crate::{
Word,
hash::poseidon2::Poseidon2,
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct Mmr {
pub(super) forest: Forest,
pub(super) nodes: Vec<Word>,
}
impl Default for Mmr {
fn default() -> Self {
Self::new()
}
}
impl Mmr {
pub fn new() -> Mmr {
Mmr {
forest: Forest::empty(),
nodes: Vec::new(),
}
}
pub fn try_from_iter<T: IntoIterator<Item = Word>>(values: T) -> Result<Self, MmrError> {
Self::try_from_iter_with_limit(values, Forest::MAX_LEAVES)
}
pub(crate) fn try_from_iter_with_limit<T: IntoIterator<Item = Word>>(
values: T,
max_leaves: usize,
) -> Result<Self, MmrError> {
let mut mmr = Mmr::new();
let iter = values.into_iter();
let (lower, _) = iter.size_hint();
if lower > max_leaves {
return Err(MmrError::ForestSizeExceeded { requested: lower, max: max_leaves });
}
let mut count = 0usize;
for v in iter {
count += 1;
if count > max_leaves {
return Err(MmrError::ForestSizeExceeded { requested: count, max: max_leaves });
}
mmr.add(v)?;
}
Ok(mmr)
}
pub const fn forest(&self) -> Forest {
self.forest
}
pub fn open(&self, pos: usize) -> Result<MmrProof, MmrError> {
self.open_at(pos, self.forest)
}
pub fn open_at(&self, pos: usize, forest: Forest) -> Result<MmrProof, MmrError> {
if forest > self.forest {
return Err(MmrError::ForestOutOfBounds(forest.num_leaves(), self.forest.num_leaves()));
}
let (leaf, path) = self.collect_merkle_path_and_value(pos, forest)?;
let path = MmrPath::new(forest, pos, MerklePath::new(path));
Ok(MmrProof::new(path, leaf))
}
pub fn get(&self, pos: usize) -> Result<Word, MmrError> {
let (value, _) = self.collect_merkle_path_and_value(pos, self.forest)?;
Ok(value)
}
pub fn add(&mut self, el: Word) -> Result<(), MmrError> {
let old_forest = self.forest;
self.forest.append_leaf()?;
self.nodes.push(el);
let mut left_offset = self.nodes.len().saturating_sub(2);
let mut right = el;
let mut left_tree = 1usize;
while (old_forest.num_leaves() & left_tree) != 0 {
right = Poseidon2::merge(&[self.nodes[left_offset], right]);
self.nodes.push(right);
debug_assert!(left_tree <= Forest::MAX_LEAVES);
let left_nodes = left_tree * 2 - 1;
left_offset = left_offset.saturating_sub(left_nodes);
match left_tree.checked_shl(1) {
Some(next) => left_tree = next,
None => break,
}
}
Ok(())
}
pub fn peaks(&self) -> MmrPeaks {
self.peaks_at(self.forest).expect("failed to get peaks at current forest")
}
pub fn peaks_at(&self, forest: Forest) -> Result<MmrPeaks, MmrError> {
if forest > self.forest {
return Err(MmrError::ForestOutOfBounds(forest.num_leaves(), self.forest.num_leaves()));
}
let peaks: Vec<Word> = TreeSizeIterator::new(forest)
.rev()
.map(Forest::num_nodes)
.scan(0, |offset, el| {
*offset += el;
Some(*offset)
})
.map(|offset| self.nodes[offset - 1])
.collect();
let peaks = MmrPeaks::new(forest, peaks)?;
Ok(peaks)
}
pub fn get_delta(&self, from_forest: Forest, to_forest: Forest) -> Result<MmrDelta, MmrError> {
if to_forest > self.forest {
return Err(MmrError::ForestOutOfBounds(
to_forest.num_leaves(),
self.forest.num_leaves(),
));
}
if from_forest > to_forest {
return Err(MmrError::ForestOutOfBounds(
from_forest.num_leaves(),
to_forest.num_leaves(),
));
}
if from_forest == to_forest {
return Ok(MmrDelta { forest: to_forest, data: Vec::new() });
}
let mut result = Vec::new();
let candidate_mask = to_forest.num_leaves() ^ from_forest.num_leaves();
let mut new_high = super::forest::largest_tree_from_mask(candidate_mask);
let mut merges = from_forest & new_high.all_smaller_trees_unchecked();
let common_trees = from_forest ^ merges;
if !merges.is_empty() {
let mut target = merges.smallest_tree_unchecked();
while target < new_high {
let known_mask =
common_trees.num_leaves() | merges.num_leaves() | target.num_leaves();
let known = nodes_from_mask(known_mask);
let sibling = target.num_nodes();
result.push(self.nodes[known + sibling - 1]);
target = target.next_larger_tree()?;
while !(merges & target).is_empty() {
target = target.next_larger_tree()?;
}
merges ^= merges & target.all_smaller_trees_unchecked();
}
} else {
new_high = Forest::empty();
}
let mut new_peaks = to_forest ^ common_trees ^ new_high;
let old_peaks = to_forest ^ new_peaks;
let mut offset = old_peaks.num_nodes();
while !new_peaks.is_empty() {
let target = new_peaks.largest_tree_unchecked();
offset += target.num_nodes();
result.push(self.nodes[offset - 1]);
new_peaks ^= target;
}
Ok(MmrDelta { forest: to_forest, data: result })
}
pub fn inner_nodes(&self) -> MmrNodes<'_> {
MmrNodes {
mmr: self,
forest: 0,
last_right: 0,
index: 0,
}
}
fn collect_merkle_path_and_value(
&self,
leaf_idx: usize,
forest: Forest,
) -> Result<(Word, Vec<Word>), MmrError> {
let tree_bit = forest
.leaf_to_corresponding_tree(leaf_idx)
.ok_or(MmrError::PositionNotFound(leaf_idx))?;
let forest_before = forest.trees_larger_than(tree_bit);
let index_offset = forest_before.num_nodes();
let relative_pos = leaf_idx - forest_before.num_leaves();
let tree_depth = (tree_bit + 1) as usize;
let mut path = Vec::with_capacity(tree_depth);
let mut forest_target: usize = 1usize << tree_bit;
let mut index = nodes_from_mask(forest_target) - 1;
while forest_target > 1 {
forest_target >>= 1;
let right_offset = index - 1;
let left_offset = right_offset - nodes_from_mask(forest_target);
let left_or_right = relative_pos & forest_target;
let sibling = if left_or_right != 0 {
index = right_offset;
self.nodes[index_offset + left_offset]
} else {
index = left_offset;
self.nodes[index_offset + right_offset]
};
path.push(sibling);
}
debug_assert!(path.len() == tree_depth - 1);
path.reverse();
let value = self.nodes[index_offset + index];
Ok((value, path))
}
}
impl Serializable for Mmr {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.forest.write_into(target);
self.nodes.write_into(target);
}
}
impl Deserializable for Mmr {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let forest = Forest::read_from(source)?;
let nodes = Vec::<Word>::read_from(source)?;
Ok(Self { forest, nodes })
}
}
pub struct MmrNodes<'a> {
mmr: &'a Mmr,
forest: usize,
last_right: usize,
index: usize,
}
impl Iterator for MmrNodes<'_> {
type Item = InnerNodeInfo;
fn next(&mut self) -> Option<Self::Item> {
debug_assert!(self.last_right.count_ones() <= 1, "last_right tracks zero or one element");
let target = self.mmr.forest.without_single_leaf().num_leaves();
if self.forest < target {
if self.last_right == 0 {
debug_assert!(self.last_right == 0, "left must be before right");
self.forest |= 1;
self.index += 1;
debug_assert!((self.forest & 1) == 1, "right must be after left");
self.last_right |= 1;
self.index += 1;
};
debug_assert!(
self.forest & self.last_right != 0,
"parent requires both a left and right",
);
let right_nodes = Forest::new(self.last_right).unwrap().num_nodes();
let parent = self.last_right << 1;
self.forest ^= self.last_right;
if self.forest & parent == 0 {
debug_assert!(self.forest & 1 == 0, "next iteration yields a left leaf");
self.last_right = 0;
self.forest ^= parent;
} else {
self.last_right = parent;
}
let value = self.mmr.nodes[self.index];
let right = self.mmr.nodes[self.index - 1];
let left = self.mmr.nodes[self.index - 1 - right_nodes];
self.index += 1;
let node = InnerNodeInfo { value, left, right };
Some(node)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use alloc::vec::Vec;
use super::super::nodes_from_mask;
use crate::{
Felt, Word, ZERO,
merkle::mmr::{Forest, Mmr},
utils::{Deserializable, DeserializationError, Serializable},
};
#[test]
fn test_serialization() {
let nodes = (0u64..128u64)
.map(|value| Word::new([ZERO, ZERO, ZERO, Felt::new_unchecked(value)]))
.collect::<Vec<_>>();
let mmr = Mmr::try_from_iter(nodes).unwrap();
let serialized = mmr.to_bytes();
let deserialized = Mmr::read_from_bytes(&serialized).unwrap();
assert_eq!(mmr.forest, deserialized.forest);
assert_eq!(mmr.nodes, deserialized.nodes);
}
#[test]
fn test_deserialization_rejects_large_forest() {
let mut bytes = (Forest::MAX_LEAVES + 1).to_bytes();
bytes.extend_from_slice(&0usize.to_bytes());
let result = Mmr::read_from_bytes(&bytes);
assert!(matches!(result, Err(DeserializationError::InvalidValue(_))));
}
#[test]
fn test_nodes_from_mask_at_max_leaves() {
let expected = (Forest::MAX_LEAVES as u128)
.saturating_mul(2)
.saturating_sub(Forest::MAX_LEAVES.count_ones() as u128);
assert!(expected <= usize::MAX as u128);
assert_eq!(nodes_from_mask(Forest::MAX_LEAVES), expected as usize);
}
}