commonware_storage/mmr/
iterator.rs

1//! Iterators for traversing MMRs of a given size, and functions for computing various MMR
2//! properties from their output.
3
4/// A PeakIterator returns a (position, height) tuple for each peak in an MMR with the given size,
5/// in decreasing order of height.
6///
7/// For the example MMR depicted at the top of this file, the PeakIterator would yield:
8/// ```text
9/// [(14, 3), (17, 1), (18, 0)]
10/// ```
11#[derive(Default)]
12pub struct PeakIterator {
13    size: u64,     // number of nodes in the MMR at the point the iterator was initialized
14    node_pos: u64, // position of the current node
15    two_h: u64,    // 2^(height+1) of the current node
16}
17
18impl PeakIterator {
19    /// Return a new PeakIterator over the peaks of a MMR with the given number of nodes.
20    pub fn new(size: u64) -> PeakIterator {
21        if size == 0 {
22            return PeakIterator::default();
23        }
24        // Compute the position at which to start the search for peaks. This starting position will
25        // not be in the MMR unless it happens to be a single perfect binary tree, but that's OK as
26        // we will descend leftward until we find the first peak.
27        let start = u64::MAX >> size.leading_zeros();
28        let two_h = 1 << start.trailing_ones();
29        PeakIterator {
30            size,
31            node_pos: start - 1,
32            two_h,
33        }
34    }
35
36    /// Return the position of the last leaf in an MMR of the given size.
37    ///
38    /// This is an O(log2(n)) operation.
39    pub fn last_leaf_pos(size: u64) -> u64 {
40        if size == 0 {
41            return 0;
42        }
43
44        let last_peak = PeakIterator::new(size).last().unwrap();
45        last_peak.0 - last_peak.1 as u64
46    }
47
48    /// Return if an MMR of the given `size` has a valid structure.
49    ///
50    /// The implementation verifies that peaks in the MMR of the given size have strictly decreasing
51    /// height, which is a necessary condition for MMR validity.
52    pub const fn check_validity(size: u64) -> bool {
53        if size == 0 {
54            return true;
55        }
56        let start = u64::MAX >> size.leading_zeros();
57        let mut two_h = 1 << start.trailing_ones();
58        let mut node_pos = start - 1;
59        while two_h > 1 {
60            if node_pos < size {
61                if two_h == 2 {
62                    // If this peak is a leaf yet there are more nodes remaining, then this MMR is
63                    // invalid.
64                    return node_pos == size - 1;
65                }
66                // move to the right sibling
67                node_pos += two_h - 1;
68                if node_pos < size {
69                    // If the right sibling is in the MMR, then it is invalid.
70                    return false;
71                }
72                continue;
73            }
74            // descend to the left child
75            two_h >>= 1;
76            node_pos -= two_h;
77        }
78        true
79    }
80
81    // Returns the largest valid MMR size that is no greater than the given size.
82    //
83    // TODO(https://github.com/commonwarexyz/monorepo/issues/820): This is an O(log2(n)^2)
84    // implementation but it's reasonably straightforward to make it O(log2(n)).
85    pub fn to_nearest_size(mut size: u64) -> u64 {
86        while !PeakIterator::check_validity(size) {
87            // A size-0 MMR is always valid so this loop must terminate before underflow.
88            size -= 1;
89        }
90        size
91    }
92}
93
94impl Iterator for PeakIterator {
95    type Item = (u64, u32); // (peak, height)
96
97    fn next(&mut self) -> Option<Self::Item> {
98        while self.two_h > 1 {
99            if self.node_pos < self.size {
100                // found a peak
101                let peak_item = (self.node_pos, self.two_h.trailing_zeros() - 1);
102                // move to the right sibling
103                self.node_pos += self.two_h - 1;
104                assert!(self.node_pos >= self.size); // sibling shouldn't be in the MMR if MMR is valid
105                return Some(peak_item);
106            }
107            // descend to the left child
108            self.two_h >>= 1;
109            self.node_pos -= self.two_h;
110        }
111        None
112    }
113}
114
115/// Returns the set of peaks that will require a new parent after adding the next leaf to an MMR
116/// with the given peaks. This set is non-empty only if there is a height-0 (leaf) peak in the MMR.
117/// The result will contain this leaf peak plus the other MMR peaks with contiguously increasing
118/// height. Nodes in the result are ordered by decreasing height.
119pub fn nodes_needing_parents(peak_iterator: PeakIterator) -> Vec<u64> {
120    let mut peaks = Vec::new();
121    let mut last_height = u32::MAX;
122
123    for (peak_pos, height) in peak_iterator {
124        assert!(last_height > 0);
125        assert!(height < last_height);
126        if height != last_height - 1 {
127            peaks.clear();
128        }
129        peaks.push(peak_pos);
130        last_height = height;
131    }
132    if last_height != 0 {
133        // there is no peak that is a leaf
134        peaks.clear();
135    }
136    peaks
137}
138
139/// Returns the number of the leaf at position `leaf_pos` in an MMR, or None if
140/// this is not a leaf.
141///
142/// This computation is O(log2(n)) in the given position.
143pub const fn leaf_pos_to_num(leaf_pos: u64) -> Option<u64> {
144    if leaf_pos == 0 {
145        return Some(0);
146    }
147
148    let start = u64::MAX >> (leaf_pos + 1).leading_zeros();
149    let height = start.trailing_ones();
150    let mut two_h = 1 << (height - 1);
151    let mut cur_node = start - 1;
152    let mut leaf_num_floor = 0u64;
153
154    while two_h > 1 {
155        if cur_node == leaf_pos {
156            return None;
157        }
158        let left_pos = cur_node - two_h;
159        two_h >>= 1;
160        if leaf_pos > left_pos {
161            // The leaf is in the right subtree, so we must account for the leaves in the left
162            // subtree all of which precede it.
163            leaf_num_floor += two_h;
164            cur_node -= 1; // move to the right child
165        } else {
166            // The node is in the left subtree
167            cur_node = left_pos;
168        }
169    }
170
171    Some(leaf_num_floor)
172}
173
174/// Returns the position of the leaf with number `leaf_num` in an MMR.
175pub const fn leaf_num_to_pos(leaf_num: u64) -> u64 {
176    // This will never underflow since 2*n >= count_ones(n).
177    leaf_num.checked_mul(2).expect("leaf_num overflow") - leaf_num.count_ones() as u64
178}
179
180/// Returns the height of the node at position `pos` in an MMR.
181pub const fn pos_to_height(mut pos: u64) -> u32 {
182    if pos == 0 {
183        return 0;
184    }
185
186    let mut size = u64::MAX >> pos.leading_zeros();
187    while size != 0 {
188        if pos >= size {
189            pos -= size;
190        }
191        size >>= 1;
192    }
193
194    pos as u32
195}
196
197/// A PathIterator returns a (parent_pos, sibling_pos) tuple for the sibling of each node along the
198/// path from a given perfect binary tree peak to a designated leaf, not including the peak itself.
199///
200/// For example, consider the tree below and the path from the peak to leaf node 3. Nodes on the
201/// path are [6, 5, 3] and tagged with '*' in the diagram):
202///
203/// ```text
204///
205///          6*
206///        /   \
207///       2     5*
208///      / \   / \
209///     0   1 3*  4
210///
211/// A PathIterator for this example yields:
212///    [(6, 2), (5, 4)]
213/// ```
214#[derive(Debug)]
215pub struct PathIterator {
216    leaf_pos: u64, // position of the leaf node in the path
217    node_pos: u64, // current node position in the path from peak to leaf
218    two_h: u64,    // 2^height of the current node
219}
220
221impl PathIterator {
222    /// Return a PathIterator over the siblings of nodes along the path from peak to leaf in the
223    /// perfect binary tree with peak `peak_pos` and having height `height`, not including the peak
224    /// itself.
225    pub fn new(leaf_pos: u64, peak_pos: u64, height: u32) -> PathIterator {
226        PathIterator {
227            leaf_pos,
228            node_pos: peak_pos,
229            two_h: 1 << height,
230        }
231    }
232}
233
234impl Iterator for PathIterator {
235    type Item = (u64, u64); // (parent_pos, sibling_pos)
236
237    fn next(&mut self) -> Option<Self::Item> {
238        if self.two_h <= 1 {
239            return None;
240        }
241
242        let left_pos = self.node_pos - self.two_h;
243        let right_pos = self.node_pos - 1;
244        self.two_h >>= 1;
245
246        if left_pos < self.leaf_pos {
247            let r = Some((self.node_pos, left_pos));
248            self.node_pos = right_pos;
249            return r;
250        }
251        let r = Some((self.node_pos, right_pos));
252        self.node_pos = left_pos;
253        r
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use crate::mmr::{hasher::Standard, mem::Mmr};
261    use commonware_cryptography::{Hasher, Sha256};
262    use commonware_runtime::{deterministic, Runner};
263
264    #[test]
265    fn test_leaf_num_calculation() {
266        let digest = Sha256::hash(b"testing");
267
268        let executor = deterministic::Runner::default();
269        executor.start(|_| async move {
270            // Build MMR with 1000 leaves and make sure we can correctly convert each leaf position to
271            // its number and back again.
272            let mut mmr: Mmr<Sha256> = Mmr::new();
273            let mut hasher = Standard::new();
274            let mut num_to_pos = Vec::new();
275            for _ in 0u64..1000 {
276                num_to_pos.push(mmr.add(&mut hasher, &digest));
277            }
278
279            let mut last_leaf_pos = 0;
280            for (leaf_num_expected, leaf_pos) in num_to_pos.iter().enumerate() {
281                let leaf_num_got = leaf_pos_to_num(*leaf_pos).unwrap();
282                assert_eq!(leaf_num_got, leaf_num_expected as u64);
283                let leaf_pos_got = leaf_num_to_pos(leaf_num_got);
284                assert_eq!(leaf_pos_got, *leaf_pos);
285                for i in last_leaf_pos + 1..*leaf_pos {
286                    assert!(leaf_pos_to_num(i).is_none());
287                }
288                last_leaf_pos = *leaf_pos;
289            }
290        });
291    }
292}