commonware_storage/mmr/iterator.rs
1//! Iterators for traversing MMRs of a given size, and functions for computing various MMR
2//! properties from their output. These are lower levels methods that are useful for implementing
3//! new MMR variants or extensions.
4
5use super::Position;
6use alloc::vec::Vec;
7
8/// A PeakIterator returns a (position, height) tuple for each peak in an MMR with the given size,
9/// in decreasing order of height.
10///
11/// For the example MMR depicted at the top of this file, the PeakIterator would yield:
12/// ```text
13/// [(14, 3), (17, 1), (18, 0)]
14/// ```
15#[derive(Default)]
16pub struct PeakIterator {
17 size: Position, // number of nodes in the MMR at the point the iterator was initialized
18 node_pos: Position, // position of the current node
19 two_h: u64, // 2^(height+1) of the current node
20}
21
22impl PeakIterator {
23 /// Return a new PeakIterator over the peaks of a MMR with the given number of nodes.
24 ///
25 /// # Panics
26 ///
27 /// Iteration will panic if size is not a valid MMR size. If used on untrusted input, call
28 /// [Position::is_mmr_size] first.
29 pub fn new(size: Position) -> Self {
30 if size == 0 {
31 return Self::default();
32 }
33 // Compute the position at which to start the search for peaks. This starting position will
34 // not be in the MMR unless it happens to be a single perfect binary tree, but that's OK as
35 // we will descend leftward until we find the first peak.
36 let start = u64::MAX >> size.leading_zeros();
37 assert_ne!(start, u64::MAX, "size overflow");
38 let two_h = 1 << start.trailing_ones();
39 Self {
40 size,
41 node_pos: Position::new(start - 1),
42 two_h,
43 }
44 }
45
46 /// Returns the largest valid MMR size that is no greater than the given size.
47 ///
48 /// This is an O(log2(n)) operation using binary search on the number of leaves.
49 ///
50 /// # Panics
51 ///
52 /// Panics if `size` exceeds [crate::mmr::MAX_POSITION].
53 pub fn to_nearest_size(size: Position) -> Position {
54 assert!(
55 size <= crate::mmr::MAX_POSITION,
56 "size exceeds MAX_POSITION"
57 );
58
59 // Algorithm: A valid MMR size corresponds to a specific number of leaves N, where:
60 // mmr_size(N) = 2*N - popcount(N)
61 // This formula comes from the fact that N leaves require N-1 internal nodes, but merging
62 // creates popcount(N)-1 additional nodes. We binary search for the largest N where
63 // mmr_size(N) <= size.
64
65 if size == 0 {
66 return size;
67 }
68
69 // Binary search for the largest N (number of leaves) such that
70 // mmr_size(N) = 2*N - popcount(N) <= size
71 let size_val = size.as_u64();
72 let mut low = 0u64;
73 let mut high = size_val; // MMR size >= leaf count, so N <= size
74
75 while low < high {
76 // Use div_ceil for upper-biased midpoint in binary search
77 let mid = (low + high).div_ceil(2);
78 let mmr_size = 2 * mid - mid.count_ones() as u64;
79
80 if mmr_size <= size_val {
81 low = mid;
82 } else {
83 high = mid - 1;
84 }
85 }
86
87 // low is the largest N where mmr_size(N) <= size
88 let result = 2 * low - low.count_ones() as u64;
89 Position::new(result)
90 }
91}
92
93impl Iterator for PeakIterator {
94 type Item = (Position, u32); // (peak, height)
95
96 fn next(&mut self) -> Option<Self::Item> {
97 while self.two_h > 1 {
98 if self.node_pos < self.size {
99 // found a peak
100 let peak_item = (self.node_pos, self.two_h.trailing_zeros() - 1);
101 // move to the right sibling
102 self.node_pos += self.two_h - 1;
103 assert!(self.node_pos >= self.size); // sibling shouldn't be in the MMR if MMR is valid
104 return Some(peak_item);
105 }
106 // descend to the left child
107 self.two_h >>= 1;
108 self.node_pos -= self.two_h;
109 }
110 None
111 }
112}
113
114/// Returns the set of peaks that will require a new parent after adding the next leaf to an MMR
115/// with the given peaks. This set is non-empty only if there is a height-0 (leaf) peak in the MMR.
116/// The result will contain this leaf peak plus the other MMR peaks with contiguously increasing
117/// height. Nodes in the result are ordered by decreasing height.
118pub(crate) fn nodes_needing_parents(peak_iterator: PeakIterator) -> Vec<Position> {
119 let mut peaks = Vec::new();
120 let mut last_height = u32::MAX;
121
122 for (peak_pos, height) in peak_iterator {
123 assert!(last_height > 0);
124 assert!(height < last_height);
125 if height != last_height - 1 {
126 peaks.clear();
127 }
128 peaks.push(peak_pos);
129 last_height = height;
130 }
131 if last_height != 0 {
132 // there is no peak that is a leaf
133 peaks.clear();
134 }
135 peaks
136}
137
138/// Returns the height of the node at position `pos` in an MMR.
139#[cfg(any(feature = "std", test))]
140pub(crate) const fn pos_to_height(pos: Position) -> u32 {
141 let mut pos = pos.as_u64();
142
143 if pos == 0 {
144 return 0;
145 }
146
147 let mut size = u64::MAX >> pos.leading_zeros();
148 while size != 0 {
149 if pos >= size {
150 pos -= size;
151 }
152 size >>= 1;
153 }
154
155 pos as u32
156}
157
158/// A PathIterator returns a (parent_pos, sibling_pos) tuple for the sibling of each node along the
159/// path from a given perfect binary tree peak to a designated leaf, not including the peak itself.
160///
161/// For example, consider the tree below and the path from the peak to leaf node 3. Nodes on the
162/// path are [6, 5, 3] and tagged with '*' in the diagram):
163///
164/// ```text
165///
166/// 6*
167/// / \
168/// 2 5*
169/// / \ / \
170/// 0 1 3* 4
171///
172/// A PathIterator for this example yields:
173/// [(6, 2), (5, 4)]
174/// ```
175#[derive(Debug)]
176pub struct PathIterator {
177 leaf_pos: Position, // position of the leaf node in the path
178 node_pos: Position, // current node position in the path from peak to leaf
179 two_h: u64, // 2^height of the current node
180}
181
182impl PathIterator {
183 /// Return a PathIterator over the siblings of nodes along the path from peak to leaf in the
184 /// perfect binary tree with peak `peak_pos` and having height `height`, not including the peak
185 /// itself.
186 pub const fn new(leaf_pos: Position, peak_pos: Position, height: u32) -> Self {
187 Self {
188 leaf_pos,
189 node_pos: peak_pos,
190 two_h: 1 << height,
191 }
192 }
193}
194
195impl Iterator for PathIterator {
196 type Item = (Position, Position); // (parent_pos, sibling_pos)
197
198 fn next(&mut self) -> Option<Self::Item> {
199 if self.two_h <= 1 {
200 return None;
201 }
202
203 let left_pos = self.node_pos - self.two_h;
204 let right_pos = self.node_pos - 1;
205 self.two_h >>= 1;
206
207 if left_pos < self.leaf_pos {
208 let r = Some((self.node_pos, left_pos));
209 self.node_pos = right_pos;
210 return r;
211 }
212 let r = Some((self.node_pos, right_pos));
213 self.node_pos = left_pos;
214 r
215 }
216}
217
218/// Return the list of pruned (pos < `start_pos`) node positions that are still required for
219/// proving any retained node.
220///
221/// This set consists of every pruned node that is either (1) a peak, or (2) has no descendent
222/// in the retained section, but its immediate parent does. (A node meeting condition (2) can be
223/// shown to always be the left-child of its parent.)
224///
225/// This set of nodes does not change with the MMR's size, only the pruning boundary. For a
226/// given pruning boundary that happens to be a valid MMR size, one can prove that this set is
227/// exactly the set of peaks for an MMR whose size equals the pruning boundary. If the pruning
228/// boundary is not a valid MMR size, then the set corresponds to the peaks of the largest MMR
229/// whose size is less than the pruning boundary.
230pub(crate) fn nodes_to_pin(start_pos: Position) -> impl Iterator<Item = Position> {
231 PeakIterator::new(PeakIterator::to_nearest_size(start_pos)).map(|(pos, _)| pos)
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use crate::mmr::{hasher::Standard, mem::Mmr, Location};
238 use commonware_cryptography::Sha256;
239
240 #[test]
241 fn test_leaf_loc_calculation() {
242 // Build MMR with 1000 leaves and make sure we can correctly convert each leaf position to
243 // its number and back again.
244 let mut hasher = Standard::<Sha256>::new();
245 let mut mmr = Mmr::new(&mut hasher);
246 let digest = [1u8; 32];
247 let (changeset, loc_to_pos) = {
248 let mut batch = mmr.new_batch();
249 let positions: Vec<_> = (0..1000).map(|_| batch.add(&mut hasher, &digest)).collect();
250 (batch.merkleize(&mut hasher).finalize(), positions)
251 };
252 mmr.apply(changeset).unwrap();
253
254 let mut last_leaf_pos = 0;
255 for (leaf_loc_expected, leaf_pos) in loc_to_pos.into_iter().enumerate() {
256 let leaf_loc_got = Location::try_from(leaf_pos).unwrap();
257 assert_eq!(leaf_loc_got, Location::new(leaf_loc_expected as u64));
258 let leaf_pos_got = Position::try_from(leaf_loc_got).unwrap();
259 assert_eq!(leaf_pos_got, *leaf_pos);
260 for i in last_leaf_pos + 1..*leaf_pos {
261 assert!(Location::try_from(Position::new(i)).is_err());
262 }
263 last_leaf_pos = *leaf_pos;
264 }
265 }
266
267 #[test]
268 #[should_panic(expected = "size exceeds MAX_POSITION")]
269 fn test_to_nearest_size_panic() {
270 PeakIterator::to_nearest_size(crate::mmr::MAX_POSITION + 1);
271 }
272
273 #[test]
274 fn test_to_nearest_size() {
275 // Build an MMR incrementally and verify to_nearest_size for all intermediate values
276 let mut hasher = Standard::<Sha256>::new();
277 let mut mmr = Mmr::new(&mut hasher);
278 let digest = [1u8; 32];
279
280 for _ in 0..1000 {
281 let current_size = mmr.size();
282
283 // Test positions from current size up to current size + 10
284 for test_pos in *current_size..=*current_size + 10 {
285 let rounded = PeakIterator::to_nearest_size(Position::new(test_pos));
286
287 // Verify rounded is a valid MMR size
288 assert!(
289 rounded.is_mmr_size(),
290 "rounded size {rounded} should be valid (test_pos: {test_pos}, current: {current_size})",
291 );
292
293 // Verify rounded <= test_pos
294 assert!(
295 rounded <= test_pos,
296 "rounded {rounded} should be <= test_pos {test_pos} (current: {current_size})",
297 );
298
299 // Verify rounded is the largest valid size <= test_pos
300 if rounded < test_pos {
301 assert!(
302 !(rounded + 1).is_mmr_size(),
303 "rounded {rounded} should be largest valid size <= {test_pos} (current: {current_size})",
304 );
305 }
306 }
307
308 let changeset = {
309 let mut batch = mmr.new_batch();
310 batch.add(&mut hasher, &digest);
311 batch.merkleize(&mut hasher).finalize()
312 };
313 mmr.apply(changeset).unwrap();
314 }
315 }
316
317 #[test]
318 fn test_to_nearest_size_specific_cases() {
319 // Test edge cases
320 assert_eq!(PeakIterator::to_nearest_size(Position::new(0)), 0);
321 assert_eq!(PeakIterator::to_nearest_size(Position::new(1)), 1);
322
323 // Test consecutive values
324 let mut expected = Position::new(0);
325 for size in 0..=20 {
326 let rounded = PeakIterator::to_nearest_size(Position::new(size));
327 assert_eq!(rounded, expected);
328 if Position::new(size + 1).is_mmr_size() {
329 expected = Position::new(size + 1);
330 }
331 }
332
333 // Test with large value
334 let large_size = Position::new(1_000_000);
335 let rounded = PeakIterator::to_nearest_size(large_size);
336 assert!(rounded.is_mmr_size());
337 assert!(rounded <= large_size);
338
339 // Test maximum allowed input
340 let largest_valid_size = crate::mmr::MAX_POSITION;
341 let rounded = PeakIterator::to_nearest_size(largest_valid_size);
342 assert!(rounded.is_mmr_size());
343 assert!(rounded <= largest_valid_size);
344 }
345}