commonware_storage/mmr/
iterator.rs

1//! Iterators for traversing MMRs of a given size, and functions for computing various MMR
2//! properties from their output. These are lower levels methods that are useful for implementing
3//! new MMR variants or extensions.
4
5use super::Position;
6use alloc::vec::Vec;
7
8/// A PeakIterator returns a (position, height) tuple for each peak in an MMR with the given size,
9/// in decreasing order of height.
10///
11/// For the example MMR depicted at the top of this file, the PeakIterator would yield:
12/// ```text
13/// [(14, 3), (17, 1), (18, 0)]
14/// ```
15#[derive(Default)]
16pub struct PeakIterator {
17    size: Position, // number of nodes in the MMR at the point the iterator was initialized
18    node_pos: Position, // position of the current node
19    two_h: u64,     // 2^(height+1) of the current node
20}
21
22impl PeakIterator {
23    /// Return a new PeakIterator over the peaks of a MMR with the given number of nodes.
24    ///
25    /// # Panics
26    ///
27    /// Iteration will panic if size is not a valid MMR size. If used on untrusted input, call
28    /// [Position::is_mmr_size] first.
29    pub fn new(size: Position) -> Self {
30        if size == 0 {
31            return Self::default();
32        }
33        // Compute the position at which to start the search for peaks. This starting position will
34        // not be in the MMR unless it happens to be a single perfect binary tree, but that's OK as
35        // we will descend leftward until we find the first peak.
36        let start = u64::MAX >> size.leading_zeros();
37        assert_ne!(start, u64::MAX, "size overflow");
38        let two_h = 1 << start.trailing_ones();
39        Self {
40            size,
41            node_pos: Position::new(start - 1),
42            two_h,
43        }
44    }
45
46    /// Return the position of the last leaf in an MMR of the given size.
47    ///
48    /// This is an O(log2(n)) operation.
49    ///
50    /// # Panics
51    ///
52    /// Panics if size is too large (specifically, the topmost bit should be 0).
53    pub fn last_leaf_pos(size: Position) -> Position {
54        if size == 0 {
55            return Position::new(0);
56        }
57
58        let last_peak = Self::new(size)
59            .last()
60            .expect("PeakIterator has at least one peak when size > 0");
61        last_peak.0.checked_sub(last_peak.1 as u64).unwrap()
62    }
63
64    /// Returns the largest valid MMR size that is no greater than the given size.
65    ///
66    /// This is an O(log2(n)) operation using binary search on the number of leaves.
67    ///
68    /// # Panics
69    ///
70    /// Panics if `size` exceeds [crate::mmr::MAX_POSITION].
71    pub fn to_nearest_size(size: Position) -> Position {
72        assert!(
73            size <= crate::mmr::MAX_POSITION,
74            "size exceeds MAX_POSITION"
75        );
76
77        // Algorithm: A valid MMR size corresponds to a specific number of leaves N, where:
78        // mmr_size(N) = 2*N - popcount(N)
79        // This formula comes from the fact that N leaves require N-1 internal nodes, but merging
80        // creates popcount(N)-1 additional nodes. We binary search for the largest N where
81        // mmr_size(N) <= size.
82
83        if size == 0 {
84            return size;
85        }
86
87        // Binary search for the largest N (number of leaves) such that
88        // mmr_size(N) = 2*N - popcount(N) <= size
89        let size_val = size.as_u64();
90        let mut low = 0u64;
91        let mut high = size_val; // MMR size >= leaf count, so N <= size
92
93        while low < high {
94            // Use div_ceil for upper-biased midpoint in binary search
95            let mid = (low + high).div_ceil(2);
96            let mmr_size = 2 * mid - mid.count_ones() as u64;
97
98            if mmr_size <= size_val {
99                low = mid;
100            } else {
101                high = mid - 1;
102            }
103        }
104
105        // low is the largest N where mmr_size(N) <= size
106        let result = 2 * low - low.count_ones() as u64;
107        Position::new(result)
108    }
109}
110
111impl Iterator for PeakIterator {
112    type Item = (Position, u32); // (peak, height)
113
114    fn next(&mut self) -> Option<Self::Item> {
115        while self.two_h > 1 {
116            if self.node_pos < self.size {
117                // found a peak
118                let peak_item = (self.node_pos, self.two_h.trailing_zeros() - 1);
119                // move to the right sibling
120                self.node_pos += self.two_h - 1;
121                assert!(self.node_pos >= self.size); // sibling shouldn't be in the MMR if MMR is valid
122                return Some(peak_item);
123            }
124            // descend to the left child
125            self.two_h >>= 1;
126            self.node_pos -= self.two_h;
127        }
128        None
129    }
130}
131
132/// Returns the set of peaks that will require a new parent after adding the next leaf to an MMR
133/// with the given peaks. This set is non-empty only if there is a height-0 (leaf) peak in the MMR.
134/// The result will contain this leaf peak plus the other MMR peaks with contiguously increasing
135/// height. Nodes in the result are ordered by decreasing height.
136pub(crate) fn nodes_needing_parents(peak_iterator: PeakIterator) -> Vec<Position> {
137    let mut peaks = Vec::new();
138    let mut last_height = u32::MAX;
139
140    for (peak_pos, height) in peak_iterator {
141        assert!(last_height > 0);
142        assert!(height < last_height);
143        if height != last_height - 1 {
144            peaks.clear();
145        }
146        peaks.push(peak_pos);
147        last_height = height;
148    }
149    if last_height != 0 {
150        // there is no peak that is a leaf
151        peaks.clear();
152    }
153    peaks
154}
155
156/// Returns the height of the node at position `pos` in an MMR.
157#[cfg(any(feature = "std", test))]
158pub(crate) const fn pos_to_height(pos: Position) -> u32 {
159    let mut pos = pos.as_u64();
160
161    if pos == 0 {
162        return 0;
163    }
164
165    let mut size = u64::MAX >> pos.leading_zeros();
166    while size != 0 {
167        if pos >= size {
168            pos -= size;
169        }
170        size >>= 1;
171    }
172
173    pos as u32
174}
175
176/// A PathIterator returns a (parent_pos, sibling_pos) tuple for the sibling of each node along the
177/// path from a given perfect binary tree peak to a designated leaf, not including the peak itself.
178///
179/// For example, consider the tree below and the path from the peak to leaf node 3. Nodes on the
180/// path are [6, 5, 3] and tagged with '*' in the diagram):
181///
182/// ```text
183///
184///          6*
185///        /   \
186///       2     5*
187///      / \   / \
188///     0   1 3*  4
189///
190/// A PathIterator for this example yields:
191///    [(6, 2), (5, 4)]
192/// ```
193#[derive(Debug)]
194pub struct PathIterator {
195    leaf_pos: Position, // position of the leaf node in the path
196    node_pos: Position, // current node position in the path from peak to leaf
197    two_h: u64,         // 2^height of the current node
198}
199
200impl PathIterator {
201    /// Return a PathIterator over the siblings of nodes along the path from peak to leaf in the
202    /// perfect binary tree with peak `peak_pos` and having height `height`, not including the peak
203    /// itself.
204    pub const fn new(leaf_pos: Position, peak_pos: Position, height: u32) -> Self {
205        Self {
206            leaf_pos,
207            node_pos: peak_pos,
208            two_h: 1 << height,
209        }
210    }
211}
212
213impl Iterator for PathIterator {
214    type Item = (Position, Position); // (parent_pos, sibling_pos)
215
216    fn next(&mut self) -> Option<Self::Item> {
217        if self.two_h <= 1 {
218            return None;
219        }
220
221        let left_pos = self.node_pos - self.two_h;
222        let right_pos = self.node_pos - 1;
223        self.two_h >>= 1;
224
225        if left_pos < self.leaf_pos {
226            let r = Some((self.node_pos, left_pos));
227            self.node_pos = right_pos;
228            return r;
229        }
230        let r = Some((self.node_pos, right_pos));
231        self.node_pos = left_pos;
232        r
233    }
234}
235
236/// Return the list of pruned (pos < `start_pos`) node positions that are still required for
237/// proving any retained node.
238///
239/// This set consists of every pruned node that is either (1) a peak, or (2) has no descendent
240/// in the retained section, but its immediate parent does. (A node meeting condition (2) can be
241/// shown to always be the left-child of its parent.)
242///
243/// This set of nodes does not change with the MMR's size, only the pruning boundary. For a
244/// given pruning boundary that happens to be a valid MMR size, one can prove that this set is
245/// exactly the set of peaks for an MMR whose size equals the pruning boundary. If the pruning
246/// boundary is not a valid MMR size, then the set corresponds to the peaks of the largest MMR
247/// whose size is less than the pruning boundary.
248pub(crate) fn nodes_to_pin(start_pos: Position) -> impl Iterator<Item = Position> {
249    PeakIterator::new(PeakIterator::to_nearest_size(start_pos)).map(|(pos, _)| pos)
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use crate::mmr::{hasher::Standard, mem::CleanMmr, Location};
256    use commonware_cryptography::Sha256;
257
258    #[test]
259    fn test_leaf_loc_calculation() {
260        // Build MMR with 1000 leaves and make sure we can correctly convert each leaf position to
261        // its number and back again.
262        let mut hasher = Standard::<Sha256>::new();
263        let mut mmr = CleanMmr::new(&mut hasher);
264        let mut loc_to_pos = Vec::new();
265        let digest = [1u8; 32];
266        for _ in 0u64..1000 {
267            loc_to_pos.push(mmr.add(&mut hasher, &digest));
268        }
269
270        let mut last_leaf_pos = 0;
271        for (leaf_loc_expected, leaf_pos) in loc_to_pos.into_iter().enumerate() {
272            let leaf_loc_got = Location::try_from(leaf_pos).unwrap();
273            assert_eq!(
274                leaf_loc_got,
275                Location::new_unchecked(leaf_loc_expected as u64)
276            );
277            let leaf_pos_got = Position::try_from(leaf_loc_got).unwrap();
278            assert_eq!(leaf_pos_got, *leaf_pos);
279            for i in last_leaf_pos + 1..*leaf_pos {
280                assert!(Location::try_from(Position::new(i)).is_err());
281            }
282            last_leaf_pos = *leaf_pos;
283        }
284    }
285
286    #[test]
287    #[should_panic(expected = "size exceeds MAX_POSITION")]
288    fn test_to_nearest_size_panic() {
289        PeakIterator::to_nearest_size(crate::mmr::MAX_POSITION + 1);
290    }
291
292    #[test]
293    fn test_to_nearest_size() {
294        // Build an MMR incrementally and verify to_nearest_size for all intermediate values
295        let mut hasher = Standard::<Sha256>::new();
296        let mut mmr = CleanMmr::new(&mut hasher);
297        let digest = [1u8; 32];
298
299        for _ in 0..1000 {
300            let current_size = mmr.size();
301
302            // Test positions from current size up to current size + 10
303            for test_pos in *current_size..=*current_size + 10 {
304                let rounded = PeakIterator::to_nearest_size(Position::new(test_pos));
305
306                // Verify rounded is a valid MMR size
307                assert!(
308                    rounded.is_mmr_size(),
309                    "rounded size {rounded} should be valid (test_pos: {test_pos}, current: {current_size})",
310                );
311
312                // Verify rounded <= test_pos
313                assert!(
314                    rounded <= test_pos,
315                    "rounded {rounded} should be <= test_pos {test_pos} (current: {current_size})",
316                );
317
318                // Verify rounded is the largest valid size <= test_pos
319                if rounded < test_pos {
320                    assert!(
321                        !(rounded + 1).is_mmr_size(),
322                        "rounded {rounded} should be largest valid size <= {test_pos} (current: {current_size})",
323                    );
324                }
325            }
326
327            mmr.add(&mut hasher, &digest);
328        }
329    }
330
331    #[test]
332    fn test_to_nearest_size_specific_cases() {
333        // Test edge cases
334        assert_eq!(PeakIterator::to_nearest_size(Position::new(0)), 0);
335        assert_eq!(PeakIterator::to_nearest_size(Position::new(1)), 1);
336
337        // Test consecutive values
338        let mut expected = Position::new(0);
339        for size in 0..=20 {
340            let rounded = PeakIterator::to_nearest_size(Position::new(size));
341            assert_eq!(rounded, expected);
342            if Position::new(size + 1).is_mmr_size() {
343                expected = Position::new(size + 1);
344            }
345        }
346
347        // Test with large value
348        let large_size = Position::new(1_000_000);
349        let rounded = PeakIterator::to_nearest_size(large_size);
350        assert!(rounded.is_mmr_size());
351        assert!(rounded <= large_size);
352
353        // Test maximum allowed input
354        let largest_valid_size = crate::mmr::MAX_POSITION;
355        let rounded = PeakIterator::to_nearest_size(largest_valid_size);
356        assert!(rounded.is_mmr_size());
357        assert!(rounded <= largest_valid_size);
358    }
359}