Skip to main content

miden_crypto/merkle/mmr/
forest.rs

1use core::{
2    fmt::{Binary, Display},
3    ops::{BitAnd, BitOr, BitXor, BitXorAssign},
4};
5
6use super::InOrderIndex;
7use crate::{
8    Felt,
9    utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
10};
11
12/// A compact representation of trees in a forest. Used in the Merkle forest (MMR).
13///
14/// Each active bit of the stored number represents a disjoint tree with number of leaves
15/// equal to the bit position.
16///
17/// The forest value has the following interpretations:
18/// - its value is the number of leaves in the forest
19/// - the version number (MMR is append only so the number of leaves always increases)
20/// - bit count corresponds to the number of trees (trees) in the forest
21/// - each true bit position determines the depth of a tree in the forest
22///
23/// Examples:
24/// - `Forest(0)` is a forest with no trees.
25/// - `Forest(0b01)` is a forest with a single leaf/node (the smallest tree possible).
26/// - `Forest(0b10)` is a forest with a single binary tree with 2 leaves (3 nodes).
27/// - `Forest(0b11)` is a forest with two trees: one with 1 leaf (1 node), and one with 2 leaves (3
28///   nodes).
29/// - `Forest(0b1010)` is a forest with two trees: one with 8 leaves (15 nodes), one with 2 leaves
30///   (3 nodes).
31/// - `Forest(0b1000)` is a forest with one tree, which has 8 leaves (15 nodes).
32#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord)]
33#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
34pub struct Forest(usize);
35
36impl Forest {
37    /// Creates an empty forest (no trees).
38    pub const fn empty() -> Self {
39        Self(0)
40    }
41
42    /// Creates a forest with `num_leaves` leaves.
43    pub const fn new(num_leaves: usize) -> Self {
44        Self(num_leaves)
45    }
46
47    /// Creates a forest with a given height.
48    ///
49    /// This is equivalent to `Forest::new(1 << height)`.
50    ///
51    /// # Panics
52    ///
53    /// This will panic if `height` is greater than `usize::BITS - 1`.
54    pub const fn with_height(height: usize) -> Self {
55        assert!(height < usize::BITS as usize);
56        Self::new(1 << height)
57    }
58
59    /// Returns true if there are no trees in the forest.
60    pub fn is_empty(self) -> bool {
61        self.0 == 0
62    }
63
64    /// Adds exactly one more leaf to the capacity of this forest.
65    ///
66    /// Some smaller trees might be merged together.
67    pub fn append_leaf(&mut self) {
68        self.0 += 1;
69    }
70
71    /// Returns a count of leaves in the entire underlying forest (MMR).
72    pub fn num_leaves(self) -> usize {
73        self.0
74    }
75
76    /// Return the total number of nodes of a given forest.
77    ///
78    /// # Panics
79    ///
80    /// This will panic if the forest has size greater than `usize::MAX / 2 + 1`.
81    pub const fn num_nodes(self) -> usize {
82        assert!(self.0 <= usize::MAX / 2 + 1);
83        if self.0 <= usize::MAX / 2 {
84            self.0 * 2 - self.num_trees()
85        } else {
86            // If `self.0 > usize::MAX / 2` then we need 128-bit math to double it.
87            let (inner, num_trees) = (self.0 as u128, self.num_trees() as u128);
88            (inner * 2 - num_trees) as usize
89        }
90    }
91
92    /// Return the total number of trees of a given forest (the number of active bits).
93    pub const fn num_trees(self) -> usize {
94        self.0.count_ones() as usize
95    }
96
97    /// Returns the height (bit position) of the largest tree in the forest.
98    ///
99    /// # Panics
100    ///
101    /// This will panic if the forest is empty.
102    pub fn largest_tree_height_unchecked(self) -> usize {
103        // ilog2 is computed with leading zeros, which itself is computed with the intrinsic ctlz.
104        // [Rust 1.67.0] x86 uses the `bsr` instruction. AArch64 uses the `clz` instruction.
105        self.0.ilog2() as usize
106    }
107
108    /// Returns the height (bit position) of the largest tree in the forest.
109    ///
110    /// If the forest cannot be empty, use [`largest_tree_height_unchecked`] for performance.
111    ///
112    /// [`largest_tree_height_unchecked`]: Self::largest_tree_height_unchecked
113    pub fn largest_tree_height(self) -> Option<usize> {
114        if self.is_empty() {
115            return None;
116        }
117
118        Some(self.largest_tree_height_unchecked())
119    }
120
121    /// Returns a forest with only the largest tree present.
122    ///
123    /// # Panics
124    ///
125    /// This will panic if the forest is empty.
126    pub fn largest_tree_unchecked(self) -> Self {
127        Self::with_height(self.largest_tree_height_unchecked())
128    }
129
130    /// Returns a forest with only the largest tree present.
131    ///
132    /// If forest cannot be empty, use `largest_tree` for better performance.
133    pub fn largest_tree(self) -> Self {
134        if self.is_empty() {
135            return Self::empty();
136        }
137
138        self.largest_tree_unchecked()
139    }
140
141    /// Returns the height (bit position) of the smallest tree in the forest.
142    ///
143    /// # Panics
144    ///
145    /// This will panic if the forest is empty.
146    pub fn smallest_tree_height_unchecked(self) -> usize {
147        // Trailing_zeros is computed with the intrinsic cttz. [Rust 1.67.0] x86 uses the `bsf`
148        // instruction. AArch64 uses the `rbit clz` instructions.
149        self.0.trailing_zeros() as usize
150    }
151
152    /// Returns the height (bit position) of the smallest tree in the forest.
153    ///
154    /// If the forest cannot be empty, use [`smallest_tree_height_unchecked`] for better
155    /// performance.
156    ///
157    /// [`smallest_tree_height_unchecked`]: Self::smallest_tree_height_unchecked
158    pub fn smallest_tree_height(self) -> Option<usize> {
159        if self.is_empty() {
160            return None;
161        }
162
163        Some(self.smallest_tree_height_unchecked())
164    }
165
166    /// Returns a forest with only the smallest tree present.
167    ///
168    /// # Panics
169    ///
170    /// This will panic if the forest is empty.
171    pub fn smallest_tree_unchecked(self) -> Self {
172        Self::with_height(self.smallest_tree_height_unchecked())
173    }
174
175    /// Returns a forest with only the smallest tree present.
176    ///
177    /// If forest cannot be empty, use `smallest_tree` for performance.
178    pub fn smallest_tree(self) -> Self {
179        if self.is_empty() {
180            return Self::empty();
181        }
182        self.smallest_tree_unchecked()
183    }
184
185    /// Keeps only trees larger than the reference tree.
186    ///
187    /// For example, if we start with the bit pattern `0b0101_0110`, and keep only the trees larger
188    /// than tree index 1, that targets this bit:
189    /// ```text
190    /// Forest(0b0101_0110).trees_larger_than(1)
191    ///                        ^
192    /// Becomes:      0b0101_0100
193    ///                        ^
194    /// ```
195    /// And keeps only trees *after* that bit, meaning that the tree at `tree_idx` is also removed,
196    /// resulting in `0b0101_0100`.
197    ///
198    /// ```
199    /// # use miden_crypto::merkle::mmr::Forest;
200    /// let range = Forest::new(0b0101_0110);
201    /// assert_eq!(range.trees_larger_than(1), Forest::new(0b0101_0100));
202    /// ```
203    pub fn trees_larger_than(self, tree_idx: u32) -> Self {
204        self & high_bitmask(tree_idx + 1)
205    }
206
207    /// Creates a new forest with all possible trees smaller than the smallest tree in this
208    /// forest.
209    ///
210    /// This forest must have exactly one tree.
211    ///
212    /// # Panics
213    /// With debug assertions enabled, this function panics if this forest does not have
214    /// exactly one tree.
215    ///
216    /// For a non-panicking version of this function, see [`Forest::all_smaller_trees()`].
217    pub fn all_smaller_trees_unchecked(self) -> Self {
218        debug_assert_eq!(self.num_trees(), 1);
219        Self::new(self.0 - 1)
220    }
221
222    /// Creates a new forest with all possible trees smaller than the smallest tree in this
223    /// forest, or returns `None` if this forest has more or less than one tree.
224    ///
225    /// If the forest cannot have more or less than one tree, use
226    /// [`Forest::all_smaller_trees_unchecked()`] for performance.
227    pub fn all_smaller_trees(self) -> Option<Forest> {
228        if self.num_trees() != 1 {
229            return None;
230        }
231        Some(self.all_smaller_trees_unchecked())
232    }
233
234    /// Returns a forest with exactly one tree, one size (depth) larger than the current one.
235    pub fn next_larger_tree(self) -> Self {
236        debug_assert_eq!(self.num_trees(), 1);
237        Forest(self.0 << 1)
238    }
239
240    /// Returns true if the forest contains a single-node tree.
241    pub fn has_single_leaf_tree(self) -> bool {
242        self.0 & 1 != 0
243    }
244
245    /// Add a single-node tree if not already present in the forest.
246    pub fn with_single_leaf(self) -> Self {
247        Self::new(self.0 | 1)
248    }
249
250    /// Remove the single-node tree if present in the forest.
251    pub fn without_single_leaf(self) -> Self {
252        Self::new(self.0 & (usize::MAX - 1))
253    }
254
255    /// Returns a new forest that does not have the trees that `other` has.
256    pub fn without_trees(self, other: Forest) -> Self {
257        self ^ other
258    }
259
260    /// Returns index of the forest tree for a specified leaf index.
261    pub fn tree_index(&self, leaf_idx: usize) -> usize {
262        let root = self
263            .leaf_to_corresponding_tree(leaf_idx)
264            .expect("position must be part of the forest");
265        let smaller_tree_mask = Self::new(2_usize.pow(root) - 1);
266        let num_smaller_trees = (*self & smaller_tree_mask).num_trees();
267        self.num_trees() - num_smaller_trees - 1
268    }
269
270    /// Returns the smallest tree's root element as an [InOrderIndex].
271    ///
272    /// This function takes the smallest tree in this forest, "pretends" that it is a subtree of a
273    /// fully balanced binary tree, and returns the the in-order index of that balanced tree's root
274    /// node.
275    pub fn root_in_order_index(&self) -> InOrderIndex {
276        // Count total size of all trees in the forest.
277        let nodes = self.num_nodes();
278
279        // Add the count for the parent nodes that separate each tree. These are allocated but
280        // currently empty, and correspond to the nodes that will be used once the trees are merged.
281        let open_trees = self.num_trees() - 1;
282
283        // Remove the leaf-count of the rightmost subtree. The target tree root index comes before
284        // the subtree, for the in-order tree walk.
285        let right_subtree_count = self.smallest_tree_unchecked().num_leaves() - 1;
286
287        let idx = nodes + open_trees - right_subtree_count;
288
289        InOrderIndex::new(idx.try_into().unwrap())
290    }
291
292    /// Returns the in-order index of the rightmost element (the smallest tree).
293    pub fn rightmost_in_order_index(&self) -> InOrderIndex {
294        // Count total size of all trees in the forest.
295        let nodes = self.num_nodes();
296
297        // Add the count for the parent nodes that separate each tree. These are allocated but
298        // currently empty, and correspond to the nodes that will be used once the trees are merged.
299        let open_trees = self.num_trees() - 1;
300
301        let idx = nodes + open_trees;
302
303        InOrderIndex::new(idx.try_into().unwrap())
304    }
305
306    /// Checks if an in-order index corresponds to a valid node in the forest.
307    ///
308    /// Returns `true` if the index points to an actual node within one of the trees,
309    /// `false` if the index is:
310    /// - Zero (invalid, as `InOrderIndex` is 1-indexed)
311    /// - Beyond the forest bounds
312    /// - A separator position between trees (these positions are reserved for future parent nodes
313    ///   when trees are merged, but don't correspond to actual nodes yet)
314    ///
315    /// # Example
316    /// For a forest with 7 leaves (0b111 = trees of 4, 2, and 1 leaves):
317    /// - Valid indices: 1-7 (first tree), 9-11 (second tree), 13 (third tree)
318    /// - Invalid separator indices: 8 (between first and second), 12 (between second and third)
319    pub fn is_valid_in_order_index(&self, idx: &InOrderIndex) -> bool {
320        // Index 0 is never valid (InOrderIndex is 1-indexed)
321        if idx.inner() == 0 {
322            return false;
323        }
324
325        // Empty forest has no valid indices
326        if self.is_empty() {
327            return false;
328        }
329
330        let idx_val = idx.inner();
331        let mut offset = 0usize;
332
333        // Iterate through trees from largest to smallest
334        for tree in TreeSizeIterator::new(*self).rev() {
335            let tree_nodes = tree.num_nodes();
336            let tree_start = offset + 1;
337            let tree_end = offset + tree_nodes;
338
339            if idx_val >= tree_start && idx_val <= tree_end {
340                return true;
341            }
342
343            // Move offset past this tree and the separator position
344            offset = tree_end + 1;
345        }
346
347        false
348    }
349
350    /// Given a leaf index in the current forest, return the tree number responsible for the
351    /// leaf.
352    ///
353    /// Note:
354    /// The result is a tree position `p`, it has the following interpretations:
355    /// - `p+1` is the depth of the tree.
356    /// - Because the root element is not part of the proof, `p` is the length of the authentication
357    ///   path.
358    /// - `2^p` is equal to the number of leaves in this particular tree.
359    /// - And `2^(p+1)-1` corresponds to the size of the tree.
360    ///
361    /// For example, given a forest with 6 leaves whose forest is `0b110`:
362    /// ```text
363    ///       __ tree 2 __
364    ///      /            \
365    ///    ____          ____         _ tree 1 _
366    ///   /    \        /    \       /          \
367    ///  0      1      2      3     4            5
368    /// ```
369    ///
370    /// Leaf indices `0..=3` are in the tree at index 2 and leaf indices `4..=5` are in the tree at
371    /// index 1.
372    pub fn leaf_to_corresponding_tree(self, leaf_idx: usize) -> Option<u32> {
373        let forest = self.0;
374
375        if leaf_idx >= forest {
376            None
377        } else {
378            // - each bit in the forest is a unique tree and the bit position is its power-of-two
379            //   size
380            // - each tree is associated to a consecutive range of positions equal to its size from
381            //   left-to-right
382            // - this means the first tree owns from `0` up to the `2^k_0` first positions, where
383            //   `k_0` is the highest set bit position, the second tree from `2^k_0 + 1` up to
384            //   `2^k_1` where `k_1` is the second highest bit, so on.
385            // - this means the highest bits work as a category marker, and the position is owned by
386            //   the first tree which doesn't share a high bit with the position
387            let before = forest & leaf_idx;
388            let after = forest ^ before;
389            let tree_idx = after.ilog2();
390
391            Some(tree_idx)
392        }
393    }
394
395    /// Given a leaf index in the current forest, return the leaf index in the tree to which
396    /// the leaf belongs.
397    pub(super) fn leaf_relative_position(self, leaf_idx: usize) -> Option<usize> {
398        let tree_idx = self.leaf_to_corresponding_tree(leaf_idx)?;
399        let forest_before = self & high_bitmask(tree_idx + 1);
400        Some(leaf_idx - forest_before.0)
401    }
402}
403
404impl Display for Forest {
405    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
406        write!(f, "{}", self.0)
407    }
408}
409
410impl Binary for Forest {
411    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
412        write!(f, "{:b}", self.0)
413    }
414}
415
416impl BitAnd<Forest> for Forest {
417    type Output = Self;
418
419    fn bitand(self, rhs: Self) -> Self::Output {
420        Self::new(self.0 & rhs.0)
421    }
422}
423
424impl BitOr<Forest> for Forest {
425    type Output = Self;
426
427    fn bitor(self, rhs: Self) -> Self::Output {
428        Self::new(self.0 | rhs.0)
429    }
430}
431
432impl BitXor<Forest> for Forest {
433    type Output = Self;
434
435    fn bitxor(self, rhs: Self) -> Self::Output {
436        Self::new(self.0 ^ rhs.0)
437    }
438}
439
440impl BitXorAssign<Forest> for Forest {
441    fn bitxor_assign(&mut self, rhs: Self) {
442        self.0 ^= rhs.0;
443    }
444}
445
446impl From<Felt> for Forest {
447    fn from(value: Felt) -> Self {
448        Self::new(value.as_canonical_u64() as usize)
449    }
450}
451
452impl From<Forest> for Felt {
453    fn from(value: Forest) -> Self {
454        Felt::new(value.0 as u64)
455    }
456}
457
458/// Return a bitmask for the bits including and above the given position.
459pub(crate) const fn high_bitmask(bit: u32) -> Forest {
460    if bit > usize::BITS - 1 {
461        Forest::empty()
462    } else {
463        Forest::new(usize::MAX << bit)
464    }
465}
466
467// SERIALIZATION
468// ================================================================================================
469
470impl Serializable for Forest {
471    fn write_into<W: ByteWriter>(&self, target: &mut W) {
472        self.0.write_into(target);
473    }
474}
475
476impl Deserializable for Forest {
477    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
478        let value = source.read_usize()?;
479        Ok(Self::new(value))
480    }
481}
482
483// TREE SIZE ITERATOR
484// ================================================================================================
485
486/// Iterate over the trees within this `Forest`, from smallest to largest.
487///
488/// Each item is a "sub-forest", containing only one tree.
489pub struct TreeSizeIterator {
490    inner: Forest,
491}
492
493impl TreeSizeIterator {
494    pub fn new(value: Forest) -> TreeSizeIterator {
495        TreeSizeIterator { inner: value }
496    }
497}
498
499impl Iterator for TreeSizeIterator {
500    type Item = Forest;
501
502    fn next(&mut self) -> Option<<Self as Iterator>::Item> {
503        let tree = self.inner.smallest_tree();
504
505        if tree.is_empty() {
506            None
507        } else {
508            self.inner = self.inner.without_trees(tree);
509            Some(tree)
510        }
511    }
512}
513
514impl DoubleEndedIterator for TreeSizeIterator {
515    fn next_back(&mut self) -> Option<<Self as Iterator>::Item> {
516        let tree = self.inner.largest_tree();
517
518        if tree.is_empty() {
519            None
520        } else {
521            self.inner = self.inner.without_trees(tree);
522            Some(tree)
523        }
524    }
525}