Skip to main content

commonware_storage/mmr/
iterator.rs

1//! Iterators for traversing MMRs of a given size, and functions for computing various MMR
2//! properties from their output. These are lower levels methods that are useful for implementing
3//! new MMR variants or extensions.
4
5use super::Position;
6use alloc::vec::Vec;
7
8/// A PeakIterator returns a (position, height) tuple for each peak in an MMR with the given size,
9/// in decreasing order of height.
10///
11/// For the example MMR depicted at the top of this file, the PeakIterator would yield:
12/// ```text
13/// [(14, 3), (17, 1), (18, 0)]
14/// ```
15#[derive(Default)]
16pub struct PeakIterator {
17    size: Position, // number of nodes in the MMR at the point the iterator was initialized
18    node_pos: Position, // position of the current node
19    two_h: u64,     // 2^(height+1) of the current node
20}
21
22impl PeakIterator {
23    /// Return a new PeakIterator over the peaks of a MMR with the given number of nodes.
24    ///
25    /// # Panics
26    ///
27    /// Iteration will panic if size is not a valid MMR size. If used on untrusted input, call
28    /// [Position::is_mmr_size] first.
29    pub fn new(size: Position) -> Self {
30        if size == 0 {
31            return Self::default();
32        }
33        // Compute the position at which to start the search for peaks. This starting position will
34        // not be in the MMR unless it happens to be a single perfect binary tree, but that's OK as
35        // we will descend leftward until we find the first peak.
36        let start = u64::MAX >> size.leading_zeros();
37        assert_ne!(start, u64::MAX, "size overflow");
38        let two_h = 1 << start.trailing_ones();
39        Self {
40            size,
41            node_pos: Position::new(start - 1),
42            two_h,
43        }
44    }
45
46    /// Returns the largest valid MMR size that is no greater than the given size.
47    ///
48    /// This is an O(log2(n)) operation using binary search on the number of leaves.
49    ///
50    /// # Panics
51    ///
52    /// Panics if `size` exceeds [crate::mmr::MAX_POSITION].
53    pub fn to_nearest_size(size: Position) -> Position {
54        assert!(
55            size <= crate::mmr::MAX_POSITION,
56            "size exceeds MAX_POSITION"
57        );
58
59        // Algorithm: A valid MMR size corresponds to a specific number of leaves N, where:
60        // mmr_size(N) = 2*N - popcount(N)
61        // This formula comes from the fact that N leaves require N-1 internal nodes, but merging
62        // creates popcount(N)-1 additional nodes. We binary search for the largest N where
63        // mmr_size(N) <= size.
64
65        if size == 0 {
66            return size;
67        }
68
69        // Binary search for the largest N (number of leaves) such that
70        // mmr_size(N) = 2*N - popcount(N) <= size
71        let size_val = size.as_u64();
72        let mut low = 0u64;
73        let mut high = size_val; // MMR size >= leaf count, so N <= size
74
75        while low < high {
76            // Use div_ceil for upper-biased midpoint in binary search
77            let mid = (low + high).div_ceil(2);
78            let mmr_size = 2 * mid - mid.count_ones() as u64;
79
80            if mmr_size <= size_val {
81                low = mid;
82            } else {
83                high = mid - 1;
84            }
85        }
86
87        // low is the largest N where mmr_size(N) <= size
88        let result = 2 * low - low.count_ones() as u64;
89        Position::new(result)
90    }
91}
92
93impl Iterator for PeakIterator {
94    type Item = (Position, u32); // (peak, height)
95
96    fn next(&mut self) -> Option<Self::Item> {
97        while self.two_h > 1 {
98            if self.node_pos < self.size {
99                // found a peak
100                let peak_item = (self.node_pos, self.two_h.trailing_zeros() - 1);
101                // move to the right sibling
102                self.node_pos += self.two_h - 1;
103                assert!(self.node_pos >= self.size); // sibling shouldn't be in the MMR if MMR is valid
104                return Some(peak_item);
105            }
106            // descend to the left child
107            self.two_h >>= 1;
108            self.node_pos -= self.two_h;
109        }
110        None
111    }
112}
113
114/// Returns the set of peaks that will require a new parent after adding the next leaf to an MMR
115/// with the given peaks. This set is non-empty only if there is a height-0 (leaf) peak in the MMR.
116/// The result will contain this leaf peak plus the other MMR peaks with contiguously increasing
117/// height. Nodes in the result are ordered by decreasing height.
118pub(crate) fn nodes_needing_parents(peak_iterator: PeakIterator) -> Vec<Position> {
119    let mut peaks = Vec::new();
120    let mut last_height = u32::MAX;
121
122    for (peak_pos, height) in peak_iterator {
123        assert!(last_height > 0);
124        assert!(height < last_height);
125        if height != last_height - 1 {
126            peaks.clear();
127        }
128        peaks.push(peak_pos);
129        last_height = height;
130    }
131    if last_height != 0 {
132        // there is no peak that is a leaf
133        peaks.clear();
134    }
135    peaks
136}
137
138/// Returns the height of the node at position `pos` in an MMR.
139#[cfg(any(feature = "std", test))]
140pub(crate) const fn pos_to_height(pos: Position) -> u32 {
141    let mut pos = pos.as_u64();
142
143    if pos == 0 {
144        return 0;
145    }
146
147    let mut size = u64::MAX >> pos.leading_zeros();
148    while size != 0 {
149        if pos >= size {
150            pos -= size;
151        }
152        size >>= 1;
153    }
154
155    pos as u32
156}
157
158/// A PathIterator returns a (parent_pos, sibling_pos) tuple for the sibling of each node along the
159/// path from a given perfect binary tree peak to a designated leaf, not including the peak itself.
160///
161/// For example, consider the tree below and the path from the peak to leaf node 3. Nodes on the
162/// path are [6, 5, 3] and tagged with '*' in the diagram):
163///
164/// ```text
165///
166///          6*
167///        /   \
168///       2     5*
169///      / \   / \
170///     0   1 3*  4
171///
172/// A PathIterator for this example yields:
173///    [(6, 2), (5, 4)]
174/// ```
175#[derive(Debug)]
176pub struct PathIterator {
177    leaf_pos: Position, // position of the leaf node in the path
178    node_pos: Position, // current node position in the path from peak to leaf
179    two_h: u64,         // 2^height of the current node
180}
181
182impl PathIterator {
183    /// Return a PathIterator over the siblings of nodes along the path from peak to leaf in the
184    /// perfect binary tree with peak `peak_pos` and having height `height`, not including the peak
185    /// itself.
186    pub const fn new(leaf_pos: Position, peak_pos: Position, height: u32) -> Self {
187        Self {
188            leaf_pos,
189            node_pos: peak_pos,
190            two_h: 1 << height,
191        }
192    }
193}
194
195impl Iterator for PathIterator {
196    type Item = (Position, Position); // (parent_pos, sibling_pos)
197
198    fn next(&mut self) -> Option<Self::Item> {
199        if self.two_h <= 1 {
200            return None;
201        }
202
203        let left_pos = self.node_pos - self.two_h;
204        let right_pos = self.node_pos - 1;
205        self.two_h >>= 1;
206
207        if left_pos < self.leaf_pos {
208            let r = Some((self.node_pos, left_pos));
209            self.node_pos = right_pos;
210            return r;
211        }
212        let r = Some((self.node_pos, right_pos));
213        self.node_pos = left_pos;
214        r
215    }
216}
217
218/// Return the list of pruned (pos < `start_pos`) node positions that are still required for
219/// proving any retained node.
220///
221/// This set consists of every pruned node that is either (1) a peak, or (2) has no descendent
222/// in the retained section, but its immediate parent does. (A node meeting condition (2) can be
223/// shown to always be the left-child of its parent.)
224///
225/// This set of nodes does not change with the MMR's size, only the pruning boundary. For a
226/// given pruning boundary that happens to be a valid MMR size, one can prove that this set is
227/// exactly the set of peaks for an MMR whose size equals the pruning boundary. If the pruning
228/// boundary is not a valid MMR size, then the set corresponds to the peaks of the largest MMR
229/// whose size is less than the pruning boundary.
230pub(crate) fn nodes_to_pin(start_pos: Position) -> impl Iterator<Item = Position> {
231    PeakIterator::new(PeakIterator::to_nearest_size(start_pos)).map(|(pos, _)| pos)
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use crate::mmr::{hasher::Standard, mem::Mmr, Location};
238    use commonware_cryptography::Sha256;
239
240    #[test]
241    fn test_leaf_loc_calculation() {
242        // Build MMR with 1000 leaves and make sure we can correctly convert each leaf position to
243        // its number and back again.
244        let mut hasher = Standard::<Sha256>::new();
245        let mut mmr = Mmr::new(&mut hasher);
246        let digest = [1u8; 32];
247        let (changeset, loc_to_pos) = {
248            let mut batch = mmr.new_batch();
249            let positions: Vec<_> = (0..1000).map(|_| batch.add(&mut hasher, &digest)).collect();
250            (batch.merkleize(&mut hasher).finalize(), positions)
251        };
252        mmr.apply(changeset).unwrap();
253
254        let mut last_leaf_pos = 0;
255        for (leaf_loc_expected, leaf_pos) in loc_to_pos.into_iter().enumerate() {
256            let leaf_loc_got = Location::try_from(leaf_pos).unwrap();
257            assert_eq!(leaf_loc_got, Location::new(leaf_loc_expected as u64));
258            let leaf_pos_got = Position::try_from(leaf_loc_got).unwrap();
259            assert_eq!(leaf_pos_got, *leaf_pos);
260            for i in last_leaf_pos + 1..*leaf_pos {
261                assert!(Location::try_from(Position::new(i)).is_err());
262            }
263            last_leaf_pos = *leaf_pos;
264        }
265    }
266
267    #[test]
268    #[should_panic(expected = "size exceeds MAX_POSITION")]
269    fn test_to_nearest_size_panic() {
270        PeakIterator::to_nearest_size(crate::mmr::MAX_POSITION + 1);
271    }
272
273    #[test]
274    fn test_to_nearest_size() {
275        // Build an MMR incrementally and verify to_nearest_size for all intermediate values
276        let mut hasher = Standard::<Sha256>::new();
277        let mut mmr = Mmr::new(&mut hasher);
278        let digest = [1u8; 32];
279
280        for _ in 0..1000 {
281            let current_size = mmr.size();
282
283            // Test positions from current size up to current size + 10
284            for test_pos in *current_size..=*current_size + 10 {
285                let rounded = PeakIterator::to_nearest_size(Position::new(test_pos));
286
287                // Verify rounded is a valid MMR size
288                assert!(
289                    rounded.is_mmr_size(),
290                    "rounded size {rounded} should be valid (test_pos: {test_pos}, current: {current_size})",
291                );
292
293                // Verify rounded <= test_pos
294                assert!(
295                    rounded <= test_pos,
296                    "rounded {rounded} should be <= test_pos {test_pos} (current: {current_size})",
297                );
298
299                // Verify rounded is the largest valid size <= test_pos
300                if rounded < test_pos {
301                    assert!(
302                        !(rounded + 1).is_mmr_size(),
303                        "rounded {rounded} should be largest valid size <= {test_pos} (current: {current_size})",
304                    );
305                }
306            }
307
308            let changeset = {
309                let mut batch = mmr.new_batch();
310                batch.add(&mut hasher, &digest);
311                batch.merkleize(&mut hasher).finalize()
312            };
313            mmr.apply(changeset).unwrap();
314        }
315    }
316
317    #[test]
318    fn test_to_nearest_size_specific_cases() {
319        // Test edge cases
320        assert_eq!(PeakIterator::to_nearest_size(Position::new(0)), 0);
321        assert_eq!(PeakIterator::to_nearest_size(Position::new(1)), 1);
322
323        // Test consecutive values
324        let mut expected = Position::new(0);
325        for size in 0..=20 {
326            let rounded = PeakIterator::to_nearest_size(Position::new(size));
327            assert_eq!(rounded, expected);
328            if Position::new(size + 1).is_mmr_size() {
329                expected = Position::new(size + 1);
330            }
331        }
332
333        // Test with large value
334        let large_size = Position::new(1_000_000);
335        let rounded = PeakIterator::to_nearest_size(large_size);
336        assert!(rounded.is_mmr_size());
337        assert!(rounded <= large_size);
338
339        // Test maximum allowed input
340        let largest_valid_size = crate::mmr::MAX_POSITION;
341        let rounded = PeakIterator::to_nearest_size(largest_valid_size);
342        assert!(rounded.is_mmr_size());
343        assert!(rounded <= largest_valid_size);
344    }
345}