Skip to main content

miden_crypto/merkle/mmr/
forest.rs

1use core::{
2    fmt::{Binary, Display},
3    ops::{BitAnd, BitOr, BitXor, BitXorAssign},
4};
5
6use super::InOrderIndex;
7use crate::{
8    Felt,
9    field::PrimeField64,
10    utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
11};
12
13/// A compact representation of trees in a forest. Used in the Merkle forest (MMR).
14///
15/// Each active bit of the stored number represents a disjoint tree with number of leaves
16/// equal to the bit position.
17///
18/// The forest value has the following interpretations:
19/// - its value is the number of leaves in the forest
20/// - the version number (MMR is append only so the number of leaves always increases)
21/// - bit count corresponds to the number of trees (trees) in the forest
22/// - each true bit position determines the depth of a tree in the forest
23///
24/// Examples:
25/// - `Forest(0)` is a forest with no trees.
26/// - `Forest(0b01)` is a forest with a single leaf/node (the smallest tree possible).
27/// - `Forest(0b10)` is a forest with a single binary tree with 2 leaves (3 nodes).
28/// - `Forest(0b11)` is a forest with two trees: one with 1 leaf (1 node), and one with 2 leaves (3
29///   nodes).
30/// - `Forest(0b1010)` is a forest with two trees: one with 8 leaves (15 nodes), one with 2 leaves
31///   (3 nodes).
32/// - `Forest(0b1000)` is a forest with one tree, which has 8 leaves (15 nodes).
33#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord)]
34#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
35pub struct Forest(usize);
36
37impl Forest {
38    /// Creates an empty forest (no trees).
39    pub const fn empty() -> Self {
40        Self(0)
41    }
42
43    /// Creates a forest with `num_leaves` leaves.
44    pub const fn new(num_leaves: usize) -> Self {
45        Self(num_leaves)
46    }
47
48    /// Creates a forest with a given height.
49    ///
50    /// This is equivalent to `Forest::new(1 << height)`.
51    ///
52    /// # Panics
53    ///
54    /// This will panic if `height` is greater than `usize::BITS - 1`.
55    pub const fn with_height(height: usize) -> Self {
56        assert!(height < usize::BITS as usize);
57        Self::new(1 << height)
58    }
59
60    /// Returns true if there are no trees in the forest.
61    pub fn is_empty(self) -> bool {
62        self.0 == 0
63    }
64
65    /// Adds exactly one more leaf to the capacity of this forest.
66    ///
67    /// Some smaller trees might be merged together.
68    pub fn append_leaf(&mut self) {
69        self.0 += 1;
70    }
71
72    /// Returns a count of leaves in the entire underlying forest (MMR).
73    pub fn num_leaves(self) -> usize {
74        self.0
75    }
76
77    /// Return the total number of nodes of a given forest.
78    ///
79    /// # Panics
80    ///
81    /// This will panic if the forest has size greater than `usize::MAX / 2 + 1`.
82    pub const fn num_nodes(self) -> usize {
83        assert!(self.0 <= usize::MAX / 2 + 1);
84        if self.0 <= usize::MAX / 2 {
85            self.0 * 2 - self.num_trees()
86        } else {
87            // If `self.0 > usize::MAX / 2` then we need 128-bit math to double it.
88            let (inner, num_trees) = (self.0 as u128, self.num_trees() as u128);
89            (inner * 2 - num_trees) as usize
90        }
91    }
92
93    /// Return the total number of trees of a given forest (the number of active bits).
94    pub const fn num_trees(self) -> usize {
95        self.0.count_ones() as usize
96    }
97
98    /// Returns the height (bit position) of the largest tree in the forest.
99    ///
100    /// # Panics
101    ///
102    /// This will panic if the forest is empty.
103    pub fn largest_tree_height_unchecked(self) -> usize {
104        // ilog2 is computed with leading zeros, which itself is computed with the intrinsic ctlz.
105        // [Rust 1.67.0] x86 uses the `bsr` instruction. AArch64 uses the `clz` instruction.
106        self.0.ilog2() as usize
107    }
108
109    /// Returns the height (bit position) of the largest tree in the forest.
110    ///
111    /// If the forest cannot be empty, use [`largest_tree_height_unchecked`] for performance.
112    ///
113    /// [`largest_tree_height_unchecked`]: Self::largest_tree_height_unchecked
114    pub fn largest_tree_height(self) -> Option<usize> {
115        if self.is_empty() {
116            return None;
117        }
118
119        Some(self.largest_tree_height_unchecked())
120    }
121
122    /// Returns a forest with only the largest tree present.
123    ///
124    /// # Panics
125    ///
126    /// This will panic if the forest is empty.
127    pub fn largest_tree_unchecked(self) -> Self {
128        Self::with_height(self.largest_tree_height_unchecked())
129    }
130
131    /// Returns a forest with only the largest tree present.
132    ///
133    /// If forest cannot be empty, use `largest_tree` for better performance.
134    pub fn largest_tree(self) -> Self {
135        if self.is_empty() {
136            return Self::empty();
137        }
138
139        self.largest_tree_unchecked()
140    }
141
142    /// Returns the height (bit position) of the smallest tree in the forest.
143    ///
144    /// # Panics
145    ///
146    /// This will panic if the forest is empty.
147    pub fn smallest_tree_height_unchecked(self) -> usize {
148        // Trailing_zeros is computed with the intrinsic cttz. [Rust 1.67.0] x86 uses the `bsf`
149        // instruction. AArch64 uses the `rbit clz` instructions.
150        self.0.trailing_zeros() as usize
151    }
152
153    /// Returns the height (bit position) of the smallest tree in the forest.
154    ///
155    /// If the forest cannot be empty, use [`smallest_tree_height_unchecked`] for better
156    /// performance.
157    ///
158    /// [`smallest_tree_height_unchecked`]: Self::smallest_tree_height_unchecked
159    pub fn smallest_tree_height(self) -> Option<usize> {
160        if self.is_empty() {
161            return None;
162        }
163
164        Some(self.smallest_tree_height_unchecked())
165    }
166
167    /// Returns a forest with only the smallest tree present.
168    ///
169    /// # Panics
170    ///
171    /// This will panic if the forest is empty.
172    pub fn smallest_tree_unchecked(self) -> Self {
173        Self::with_height(self.smallest_tree_height_unchecked())
174    }
175
176    /// Returns a forest with only the smallest tree present.
177    ///
178    /// If forest cannot be empty, use `smallest_tree` for performance.
179    pub fn smallest_tree(self) -> Self {
180        if self.is_empty() {
181            return Self::empty();
182        }
183        self.smallest_tree_unchecked()
184    }
185
186    /// Keeps only trees larger than the reference tree.
187    ///
188    /// For example, if we start with the bit pattern `0b0101_0110`, and keep only the trees larger
189    /// than tree index 1, that targets this bit:
190    /// ```text
191    /// Forest(0b0101_0110).trees_larger_than(1)
192    ///                        ^
193    /// Becomes:      0b0101_0100
194    ///                        ^
195    /// ```
196    /// And keeps only trees *after* that bit, meaning that the tree at `tree_idx` is also removed,
197    /// resulting in `0b0101_0100`.
198    ///
199    /// ```
200    /// # use miden_crypto::merkle::mmr::Forest;
201    /// let range = Forest::new(0b0101_0110);
202    /// assert_eq!(range.trees_larger_than(1), Forest::new(0b0101_0100));
203    /// ```
204    pub fn trees_larger_than(self, tree_idx: u32) -> Self {
205        self & high_bitmask(tree_idx + 1)
206    }
207
208    /// Creates a new forest with all possible trees smaller than the smallest tree in this
209    /// forest.
210    ///
211    /// This forest must have exactly one tree.
212    ///
213    /// # Panics
214    /// With debug assertions enabled, this function panics if this forest does not have
215    /// exactly one tree.
216    ///
217    /// For a non-panicking version of this function, see [`Forest::all_smaller_trees()`].
218    pub fn all_smaller_trees_unchecked(self) -> Self {
219        debug_assert_eq!(self.num_trees(), 1);
220        Self::new(self.0 - 1)
221    }
222
223    /// Creates a new forest with all possible trees smaller than the smallest tree in this
224    /// forest, or returns `None` if this forest has more or less than one tree.
225    ///
226    /// If the forest cannot have more or less than one tree, use
227    /// [`Forest::all_smaller_trees_unchecked()`] for performance.
228    pub fn all_smaller_trees(self) -> Option<Forest> {
229        if self.num_trees() != 1 {
230            return None;
231        }
232        Some(self.all_smaller_trees_unchecked())
233    }
234
235    /// Returns a forest with exactly one tree, one size (depth) larger than the current one.
236    pub fn next_larger_tree(self) -> Self {
237        debug_assert_eq!(self.num_trees(), 1);
238        Forest(self.0 << 1)
239    }
240
241    /// Returns true if the forest contains a single-node tree.
242    pub fn has_single_leaf_tree(self) -> bool {
243        self.0 & 1 != 0
244    }
245
246    /// Add a single-node tree if not already present in the forest.
247    pub fn with_single_leaf(self) -> Self {
248        Self::new(self.0 | 1)
249    }
250
251    /// Remove the single-node tree if present in the forest.
252    pub fn without_single_leaf(self) -> Self {
253        Self::new(self.0 & (usize::MAX - 1))
254    }
255
256    /// Returns a new forest that does not have the trees that `other` has.
257    pub fn without_trees(self, other: Forest) -> Self {
258        self ^ other
259    }
260
261    /// Returns index of the forest tree for a specified leaf index.
262    pub fn tree_index(&self, leaf_idx: usize) -> usize {
263        let root = self
264            .leaf_to_corresponding_tree(leaf_idx)
265            .expect("position must be part of the forest");
266        let smaller_tree_mask = Self::new(2_usize.pow(root) - 1);
267        let num_smaller_trees = (*self & smaller_tree_mask).num_trees();
268        self.num_trees() - num_smaller_trees - 1
269    }
270
271    /// Returns the smallest tree's root element as an [InOrderIndex].
272    ///
273    /// This function takes the smallest tree in this forest, "pretends" that it is a subtree of a
274    /// fully balanced binary tree, and returns the the in-order index of that balanced tree's root
275    /// node.
276    pub fn root_in_order_index(&self) -> InOrderIndex {
277        // Count total size of all trees in the forest.
278        let nodes = self.num_nodes();
279
280        // Add the count for the parent nodes that separate each tree. These are allocated but
281        // currently empty, and correspond to the nodes that will be used once the trees are merged.
282        let open_trees = self.num_trees() - 1;
283
284        // Remove the leaf-count of the rightmost subtree. The target tree root index comes before
285        // the subtree, for the in-order tree walk.
286        let right_subtree_count = self.smallest_tree_unchecked().num_leaves() - 1;
287
288        let idx = nodes + open_trees - right_subtree_count;
289
290        InOrderIndex::new(idx.try_into().unwrap())
291    }
292
293    /// Returns the in-order index of the rightmost element (the smallest tree).
294    pub fn rightmost_in_order_index(&self) -> InOrderIndex {
295        // Count total size of all trees in the forest.
296        let nodes = self.num_nodes();
297
298        // Add the count for the parent nodes that separate each tree. These are allocated but
299        // currently empty, and correspond to the nodes that will be used once the trees are merged.
300        let open_trees = self.num_trees() - 1;
301
302        let idx = nodes + open_trees;
303
304        InOrderIndex::new(idx.try_into().unwrap())
305    }
306
307    /// Given a leaf index in the current forest, return the tree number responsible for the
308    /// leaf.
309    ///
310    /// Note:
311    /// The result is a tree position `p`, it has the following interpretations:
312    /// - `p+1` is the depth of the tree.
313    /// - Because the root element is not part of the proof, `p` is the length of the authentication
314    ///   path.
315    /// - `2^p` is equal to the number of leaves in this particular tree.
316    /// - And `2^(p+1)-1` corresponds to the size of the tree.
317    ///
318    /// For example, given a forest with 6 leaves whose forest is `0b110`:
319    /// ```text
320    ///       __ tree 2 __
321    ///      /            \
322    ///    ____          ____         _ tree 1 _
323    ///   /    \        /    \       /          \
324    ///  0      1      2      3     4            5
325    /// ```
326    ///
327    /// Leaf indices `0..=3` are in the tree at index 2 and leaf indices `4..=5` are in the tree at
328    /// index 1.
329    pub fn leaf_to_corresponding_tree(self, leaf_idx: usize) -> Option<u32> {
330        let forest = self.0;
331
332        if leaf_idx >= forest {
333            None
334        } else {
335            // - each bit in the forest is a unique tree and the bit position is its power-of-two
336            //   size
337            // - each tree is associated to a consecutive range of positions equal to its size from
338            //   left-to-right
339            // - this means the first tree owns from `0` up to the `2^k_0` first positions, where
340            //   `k_0` is the highest set bit position, the second tree from `2^k_0 + 1` up to
341            //   `2^k_1` where `k_1` is the second highest bit, so on.
342            // - this means the highest bits work as a category marker, and the position is owned by
343            //   the first tree which doesn't share a high bit with the position
344            let before = forest & leaf_idx;
345            let after = forest ^ before;
346            let tree_idx = after.ilog2();
347
348            Some(tree_idx)
349        }
350    }
351
352    /// Given a leaf index in the current forest, return the leaf index in the tree to which
353    /// the leaf belongs.
354    pub(super) fn leaf_relative_position(self, leaf_idx: usize) -> Option<usize> {
355        let tree_idx = self.leaf_to_corresponding_tree(leaf_idx)?;
356        let forest_before = self & high_bitmask(tree_idx + 1);
357        Some(leaf_idx - forest_before.0)
358    }
359}
360
361impl Display for Forest {
362    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
363        write!(f, "{}", self.0)
364    }
365}
366
367impl Binary for Forest {
368    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
369        write!(f, "{:b}", self.0)
370    }
371}
372
373impl BitAnd<Forest> for Forest {
374    type Output = Self;
375
376    fn bitand(self, rhs: Self) -> Self::Output {
377        Self::new(self.0 & rhs.0)
378    }
379}
380
381impl BitOr<Forest> for Forest {
382    type Output = Self;
383
384    fn bitor(self, rhs: Self) -> Self::Output {
385        Self::new(self.0 | rhs.0)
386    }
387}
388
389impl BitXor<Forest> for Forest {
390    type Output = Self;
391
392    fn bitxor(self, rhs: Self) -> Self::Output {
393        Self::new(self.0 ^ rhs.0)
394    }
395}
396
397impl BitXorAssign<Forest> for Forest {
398    fn bitxor_assign(&mut self, rhs: Self) {
399        self.0 ^= rhs.0;
400    }
401}
402
403impl From<Felt> for Forest {
404    fn from(value: Felt) -> Self {
405        Self::new(value.as_canonical_u64() as usize)
406    }
407}
408
409impl From<Forest> for Felt {
410    fn from(value: Forest) -> Self {
411        Felt::new(value.0 as u64)
412    }
413}
414
415/// Return a bitmask for the bits including and above the given position.
416pub(crate) const fn high_bitmask(bit: u32) -> Forest {
417    if bit > usize::BITS - 1 {
418        Forest::empty()
419    } else {
420        Forest::new(usize::MAX << bit)
421    }
422}
423
424// SERIALIZATION
425// ================================================================================================
426
427impl Serializable for Forest {
428    fn write_into<W: ByteWriter>(&self, target: &mut W) {
429        self.0.write_into(target);
430    }
431}
432
433impl Deserializable for Forest {
434    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
435        let value = source.read_usize()?;
436        Ok(Self::new(value))
437    }
438}
439
440// TREE SIZE ITERATOR
441// ================================================================================================
442
443/// Iterate over the trees within this `Forest`, from smallest to largest.
444///
445/// Each item is a "sub-forest", containing only one tree.
446pub struct TreeSizeIterator {
447    inner: Forest,
448}
449
450impl TreeSizeIterator {
451    pub fn new(value: Forest) -> TreeSizeIterator {
452        TreeSizeIterator { inner: value }
453    }
454}
455
456impl Iterator for TreeSizeIterator {
457    type Item = Forest;
458
459    fn next(&mut self) -> Option<<Self as Iterator>::Item> {
460        let tree = self.inner.smallest_tree();
461
462        if tree.is_empty() {
463            None
464        } else {
465            self.inner = self.inner.without_trees(tree);
466            Some(tree)
467        }
468    }
469}
470
471impl DoubleEndedIterator for TreeSizeIterator {
472    fn next_back(&mut self) -> Option<<Self as Iterator>::Item> {
473        let tree = self.inner.largest_tree();
474
475        if tree.is_empty() {
476            None
477        } else {
478            self.inner = self.inner.without_trees(tree);
479            Some(tree)
480        }
481    }
482}