use crate::merkle::{
mmr::{Family, Position},
Family as _,
};
#[derive(Default)]
pub struct PeakIterator {
size: Position, node_pos: Position, two_h: u64, }
impl PeakIterator {
pub fn new(size: Position) -> Self {
if size == 0 {
return Self::default();
}
let start = u64::MAX >> size.leading_zeros();
assert_ne!(start, u64::MAX, "size overflow");
let two_h = 1 << start.trailing_ones();
Self {
size,
node_pos: Position::new(start - 1),
two_h,
}
}
pub fn to_nearest_size(size: Position) -> Position {
assert!(size <= Family::MAX_NODES, "size exceeds MAX_NODES");
if size == 0 {
return size;
}
let size_val = size.as_u64();
let mut low = 0u64;
let mut high = size_val;
while low < high {
let mid = (low + high).div_ceil(2);
let mmr_size = 2 * mid - mid.count_ones() as u64;
if mmr_size <= size_val {
low = mid;
} else {
high = mid - 1;
}
}
let result = 2 * low - low.count_ones() as u64;
Position::new(result)
}
}
impl Iterator for PeakIterator {
type Item = (Position, u32);
fn next(&mut self) -> Option<Self::Item> {
while self.two_h > 1 {
if self.node_pos < self.size {
let peak_item = (self.node_pos, self.two_h.trailing_zeros() - 1);
self.node_pos += self.two_h - 1;
assert!(self.node_pos >= self.size); return Some(peak_item);
}
self.two_h >>= 1;
self.node_pos -= self.two_h;
}
None
}
}
pub(crate) const fn pos_to_height(pos: Position) -> u32 {
let mut pos = pos.as_u64();
if pos == 0 {
return 0;
}
let mut size = u64::MAX >> pos.leading_zeros();
while size != 0 {
if pos >= size {
pos -= size;
}
size >>= 1;
}
pos as u32
}
#[cfg(test)]
mod tests {
use super::*;
use crate::merkle::mmr::{mem::Mmr, Location, StandardHasher as Standard};
use commonware_cryptography::Sha256;
#[test]
fn test_leaf_loc_calculation() {
let hasher = Standard::<Sha256>::new();
let mut mmr = Mmr::new(&hasher);
let digest = [1u8; 32];
let (batch, loc_to_pos) = {
let mut batch = mmr.new_batch();
let mut positions = Vec::with_capacity(1000);
for _ in 0..1000 {
let loc = batch.leaves();
batch = batch.add(&hasher, &digest);
positions.push(Position::try_from(loc).unwrap());
}
(batch.merkleize(&mmr, &hasher), positions)
};
mmr.apply_batch(&batch).unwrap();
let mut last_leaf_pos = 0;
for (leaf_loc_expected, leaf_pos) in loc_to_pos.into_iter().enumerate() {
let leaf_loc_got = Location::try_from(leaf_pos).unwrap();
assert_eq!(leaf_loc_got, Location::new(leaf_loc_expected as u64));
let leaf_pos_got = Position::try_from(leaf_loc_got).unwrap();
assert_eq!(leaf_pos_got, *leaf_pos);
for i in last_leaf_pos + 1..*leaf_pos {
assert!(Location::try_from(Position::new(i)).is_err());
}
last_leaf_pos = *leaf_pos;
}
}
#[test]
#[should_panic(expected = "size exceeds MAX_NODES")]
fn test_to_nearest_size_panic() {
PeakIterator::to_nearest_size(Family::MAX_NODES + 1);
}
#[test]
fn test_to_nearest_size() {
let hasher = Standard::<Sha256>::new();
let mut mmr = Mmr::new(&hasher);
let digest = [1u8; 32];
for _ in 0..1000 {
let current_size = mmr.size();
for test_pos in *current_size..=*current_size + 10 {
let rounded = PeakIterator::to_nearest_size(Position::new(test_pos));
assert!(
rounded.is_valid_size(),
"rounded size {rounded} should be valid (test_pos: {test_pos}, current: {current_size})",
);
assert!(
rounded <= test_pos,
"rounded {rounded} should be <= test_pos {test_pos} (current: {current_size})",
);
if rounded < test_pos {
assert!(
!(rounded + 1).is_valid_size(),
"rounded {rounded} should be largest valid size <= {test_pos} (current: {current_size})",
);
}
}
let batch = mmr
.new_batch()
.add(&hasher, &digest)
.merkleize(&mmr, &hasher);
mmr.apply_batch(&batch).unwrap();
}
}
#[test]
fn test_to_nearest_size_specific_cases() {
assert_eq!(PeakIterator::to_nearest_size(Position::new(0)), 0);
assert_eq!(PeakIterator::to_nearest_size(Position::new(1)), 1);
let mut expected = Position::new(0);
for size in 0..=20 {
let rounded = PeakIterator::to_nearest_size(Position::new(size));
assert_eq!(rounded, expected);
if Position::new(size + 1).is_valid_size() {
expected = Position::new(size + 1);
}
}
let large_size = Position::new(1_000_000);
let rounded = PeakIterator::to_nearest_size(large_size);
assert!(rounded.is_valid_size());
assert!(rounded <= large_size);
let largest_valid_size = Family::MAX_NODES;
let rounded = PeakIterator::to_nearest_size(largest_valid_size);
assert!(rounded.is_valid_size());
assert!(rounded <= largest_valid_size);
}
}