miden_crypto/merkle/mmr/
forest.rs

1use core::{
2    fmt::{Binary, Display},
3    ops::{BitAnd, BitOr, BitXor, BitXorAssign},
4};
5
6use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
7
8use super::InOrderIndex;
9use crate::Felt;
10
11/// A compact representation of trees in a forest. Used in the Merkle forest (MMR).
12///
13/// Each active bit of the stored number represents a disjoint tree with number of leaves
14/// equal to the bit position.
15///
16/// The forest value has the following interpretations:
17/// - its value is the number of leaves in the forest
18/// - the version number (MMR is append only so the number of leaves always increases)
19/// - bit count corresponds to the number of trees (trees) in the forest
20/// - each true bit position determines the depth of a tree in the forest
21///
22/// Examples:
23/// - `Forest(0)` is a forest with no trees.
24/// - `Forest(0b01)` is a forest with a single leaf/node (the smallest tree possible).
25/// - `Forest(0b10)` is a forest with a single binary tree with 2 leaves (3 nodes).
26/// - `Forest(0b11)` is a forest with two trees: one with 1 leaf (1 node), and one with 2 leaves (3
27///   nodes).
28/// - `Forest(0b1010)` is a forest with two trees: one with 8 leaves (15 nodes), one with 2 leaves
29///   (3 nodes).
30/// - `Forest(0b1000)` is a forest with one tree, which has 8 leaves (15 nodes).
31#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord)]
32#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
33pub struct Forest(usize);
34
35impl Forest {
36    /// Creates an empty forest (no trees).
37    pub const fn empty() -> Self {
38        Self(0)
39    }
40
41    /// Creates a forest with `num_leaves` leaves.
42    pub const fn new(num_leaves: usize) -> Self {
43        Self(num_leaves)
44    }
45
46    /// Creates a forest with a given height.
47    ///
48    /// This is equivalent to `Forest::new(1 << height)`.
49    ///
50    /// # Panics
51    ///
52    /// This will panic if `height` is greater than `usize::BITS - 1`.
53    pub const fn with_height(height: usize) -> Self {
54        assert!(height < usize::BITS as usize);
55        Self::new(1 << height)
56    }
57
58    /// Returns true if there are no trees in the forest.
59    pub fn is_empty(self) -> bool {
60        self.0 == 0
61    }
62
63    /// Adds exactly one more leaf to the capacity of this forest.
64    ///
65    /// Some smaller trees might be merged together.
66    pub fn append_leaf(&mut self) {
67        self.0 += 1;
68    }
69
70    /// Returns a count of leaves in the entire underlying forest (MMR).
71    pub fn num_leaves(self) -> usize {
72        self.0
73    }
74
75    /// Return the total number of nodes of a given forest.
76    ///
77    /// # Panics
78    ///
79    /// This will panic if the forest has size greater than `usize::MAX / 2 + 1`.
80    pub const fn num_nodes(self) -> usize {
81        assert!(self.0 <= usize::MAX / 2 + 1);
82        if self.0 <= usize::MAX / 2 {
83            self.0 * 2 - self.num_trees()
84        } else {
85            // If `self.0 > usize::MAX / 2` then we need 128-bit math to double it.
86            let (inner, num_trees) = (self.0 as u128, self.num_trees() as u128);
87            (inner * 2 - num_trees) as usize
88        }
89    }
90
91    /// Return the total number of trees of a given forest (the number of active bits).
92    pub const fn num_trees(self) -> usize {
93        self.0.count_ones() as usize
94    }
95
96    /// Returns the height (bit position) of the largest tree in the forest.
97    ///
98    /// # Panics
99    ///
100    /// This will panic if the forest is empty.
101    pub fn largest_tree_height_unchecked(self) -> usize {
102        // ilog2 is computed with leading zeros, which itself is computed with the intrinsic ctlz.
103        // [Rust 1.67.0] x86 uses the `bsr` instruction. AArch64 uses the `clz` instruction.
104        self.0.ilog2() as usize
105    }
106
107    /// Returns the height (bit position) of the largest tree in the forest.
108    ///
109    /// If the forest cannot be empty, use [`largest_tree_height_unchecked`] for performance.
110    ///
111    /// [`largest_tree_height_unchecked`]: Self::largest_tree_height_unchecked
112    pub fn largest_tree_height(self) -> Option<usize> {
113        if self.is_empty() {
114            return None;
115        }
116
117        Some(self.largest_tree_height_unchecked())
118    }
119
120    /// Returns a forest with only the largest tree present.
121    ///
122    /// # Panics
123    ///
124    /// This will panic if the forest is empty.
125    pub fn largest_tree_unchecked(self) -> Self {
126        Self::with_height(self.largest_tree_height_unchecked())
127    }
128
129    /// Returns a forest with only the largest tree present.
130    ///
131    /// If forest cannot be empty, use `largest_tree` for better performance.
132    pub fn largest_tree(self) -> Self {
133        if self.is_empty() {
134            return Self::empty();
135        }
136
137        self.largest_tree_unchecked()
138    }
139
140    /// Returns the height (bit position) of the smallest tree in the forest.
141    ///
142    /// # Panics
143    ///
144    /// This will panic if the forest is empty.
145    pub fn smallest_tree_height_unchecked(self) -> usize {
146        // Trailing_zeros is computed with the intrinsic cttz. [Rust 1.67.0] x86 uses the `bsf`
147        // instruction. AArch64 uses the `rbit clz` instructions.
148        self.0.trailing_zeros() as usize
149    }
150
151    /// Returns the height (bit position) of the smallest tree in the forest.
152    ///
153    /// If the forest cannot be empty, use [`smallest_tree_height_unchecked`] for better
154    /// performance.
155    ///
156    /// [`smallest_tree_height_unchecked`]: Self::smallest_tree_height_unchecked
157    pub fn smallest_tree_height(self) -> Option<usize> {
158        if self.is_empty() {
159            return None;
160        }
161
162        Some(self.smallest_tree_height_unchecked())
163    }
164
165    /// Returns a forest with only the smallest tree present.
166    ///
167    /// # Panics
168    ///
169    /// This will panic if the forest is empty.
170    pub fn smallest_tree_unchecked(self) -> Self {
171        Self::with_height(self.smallest_tree_height_unchecked())
172    }
173
174    /// Returns a forest with only the smallest tree present.
175    ///
176    /// If forest cannot be empty, use `smallest_tree` for performance.
177    pub fn smallest_tree(self) -> Self {
178        if self.is_empty() {
179            return Self::empty();
180        }
181        self.smallest_tree_unchecked()
182    }
183
184    /// Keeps only trees larger than the reference tree.
185    ///
186    /// For example, if we start with the bit pattern `0b0101_0110`, and keep only the trees larger
187    /// than tree index 1, that targets this bit:
188    /// ```text
189    /// Forest(0b0101_0110).trees_larger_than(1)
190    ///                        ^
191    /// Becomes:      0b0101_0100
192    ///                        ^
193    /// ```
194    /// And keeps only trees *after* that bit, meaning that the tree at `tree_idx` is also removed,
195    /// resulting in `0b0101_0100`.
196    ///
197    /// ```
198    /// # use miden_crypto::merkle::Forest;
199    /// let range = Forest::new(0b0101_0110);
200    /// assert_eq!(range.trees_larger_than(1), Forest::new(0b0101_0100));
201    /// ```
202    pub fn trees_larger_than(self, tree_idx: u32) -> Self {
203        self & high_bitmask(tree_idx + 1)
204    }
205
206    /// Creates a new forest with all possible trees smaller than the smallest tree in this
207    /// forest.
208    ///
209    /// This forest must have exactly one tree.
210    ///
211    /// # Panics
212    /// With debug assertions enabled, this function panics if this forest does not have
213    /// exactly one tree.
214    ///
215    /// For a non-panicking version of this function, see [`Forest::all_smaller_trees()`].
216    pub fn all_smaller_trees_unchecked(self) -> Self {
217        debug_assert_eq!(self.num_trees(), 1);
218        Self::new(self.0 - 1)
219    }
220
221    /// Creates a new forest with all possible trees smaller than the smallest tree in this
222    /// forest, or returns `None` if this forest has more or less than one tree.
223    ///
224    /// If the forest cannot have more or less than one tree, use
225    /// [`Forest::all_smaller_trees_unchecked()`] for performance.
226    pub fn all_smaller_trees(self) -> Option<Forest> {
227        if self.num_trees() != 1 {
228            return None;
229        }
230        Some(self.all_smaller_trees_unchecked())
231    }
232
233    /// Returns a forest with exactly one tree, one size (depth) larger than the current one.
234    pub fn next_larger_tree(self) -> Self {
235        debug_assert_eq!(self.num_trees(), 1);
236        Forest(self.0 << 1)
237    }
238
239    /// Returns true if the forest contains a single-node tree.
240    pub fn has_single_leaf_tree(self) -> bool {
241        self.0 & 1 != 0
242    }
243
244    /// Add a single-node tree if not already present in the forest.
245    pub fn with_single_leaf(self) -> Self {
246        Self::new(self.0 | 1)
247    }
248
249    /// Remove the single-node tree if present in the forest.
250    pub fn without_single_leaf(self) -> Self {
251        Self::new(self.0 & (usize::MAX - 1))
252    }
253
254    /// Returns a new forest that does not have the trees that `other` has.
255    pub fn without_trees(self, other: Forest) -> Self {
256        self ^ other
257    }
258
259    /// Returns index of the forest tree for a specified leaf index.
260    pub fn tree_index(&self, leaf_idx: usize) -> usize {
261        let root = self
262            .leaf_to_corresponding_tree(leaf_idx)
263            .expect("position must be part of the forest");
264        let smaller_tree_mask = Self::new(2_usize.pow(root) - 1);
265        let num_smaller_trees = (*self & smaller_tree_mask).num_trees();
266        self.num_trees() - num_smaller_trees - 1
267    }
268
269    /// Returns the smallest tree's root element as an [InOrderIndex].
270    ///
271    /// This function takes the smallest tree in this forest, "pretends" that it is a subtree of a
272    /// fully balanced binary tree, and returns the the in-order index of that balanced tree's root
273    /// node.
274    pub fn root_in_order_index(&self) -> InOrderIndex {
275        // Count total size of all trees in the forest.
276        let nodes = self.num_nodes();
277
278        // Add the count for the parent nodes that separate each tree. These are allocated but
279        // currently empty, and correspond to the nodes that will be used once the trees are merged.
280        let open_trees = self.num_trees() - 1;
281
282        // Remove the leaf-count of the rightmost subtree. The target tree root index comes before
283        // the subtree, for the in-order tree walk.
284        let right_subtree_count = self.smallest_tree_unchecked().num_leaves() - 1;
285
286        let idx = nodes + open_trees - right_subtree_count;
287
288        InOrderIndex::new(idx.try_into().unwrap())
289    }
290
291    /// Returns the in-order index of the rightmost element (the smallest tree).
292    pub fn rightmost_in_order_index(&self) -> InOrderIndex {
293        // Count total size of all trees in the forest.
294        let nodes = self.num_nodes();
295
296        // Add the count for the parent nodes that separate each tree. These are allocated but
297        // currently empty, and correspond to the nodes that will be used once the trees are merged.
298        let open_trees = self.num_trees() - 1;
299
300        let idx = nodes + open_trees;
301
302        InOrderIndex::new(idx.try_into().unwrap())
303    }
304
305    /// Given a leaf index in the current forest, return the tree number responsible for the
306    /// leaf.
307    ///
308    /// Note:
309    /// The result is a tree position `p`, it has the following interpretations:
310    /// - `p+1` is the depth of the tree.
311    /// - Because the root element is not part of the proof, `p` is the length of the authentication
312    ///   path.
313    /// - `2^p` is equal to the number of leaves in this particular tree.
314    /// - And `2^(p+1)-1` corresponds to the size of the tree.
315    ///
316    /// For example, given a forest with 6 leaves whose forest is `0b110`:
317    /// ```text
318    ///       __ tree 2 __
319    ///      /            \
320    ///    ____          ____         _ tree 1 _
321    ///   /    \        /    \       /          \
322    ///  0      1      2      3     4            5
323    /// ```
324    ///
325    /// Leaf indices `0..=3` are in the tree at index 2 and leaf indices `4..=5` are in the tree at
326    /// index 1.
327    pub fn leaf_to_corresponding_tree(self, leaf_idx: usize) -> Option<u32> {
328        let forest = self.0;
329
330        if leaf_idx >= forest {
331            None
332        } else {
333            // - each bit in the forest is a unique tree and the bit position is its power-of-two
334            //   size
335            // - each tree is associated to a consecutive range of positions equal to its size from
336            //   left-to-right
337            // - this means the first tree owns from `0` up to the `2^k_0` first positions, where
338            //   `k_0` is the highest set bit position, the second tree from `2^k_0 + 1` up to
339            //   `2^k_1` where `k_1` is the second highest bit, so on.
340            // - this means the highest bits work as a category marker, and the position is owned by
341            //   the first tree which doesn't share a high bit with the position
342            let before = forest & leaf_idx;
343            let after = forest ^ before;
344            let tree_idx = after.ilog2();
345
346            Some(tree_idx)
347        }
348    }
349
350    /// Given a leaf index in the current forest, return the leaf index in the tree to which
351    /// the leaf belongs.
352    pub(super) fn leaf_relative_position(self, leaf_idx: usize) -> Option<usize> {
353        let tree_idx = self.leaf_to_corresponding_tree(leaf_idx)?;
354        let forest_before = self & high_bitmask(tree_idx + 1);
355        Some(leaf_idx - forest_before.0)
356    }
357}
358
359impl Display for Forest {
360    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
361        write!(f, "{}", self.0)
362    }
363}
364
365impl Binary for Forest {
366    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
367        write!(f, "{:b}", self.0)
368    }
369}
370
371impl BitAnd<Forest> for Forest {
372    type Output = Self;
373
374    fn bitand(self, rhs: Self) -> Self::Output {
375        Self::new(self.0 & rhs.0)
376    }
377}
378
379impl BitOr<Forest> for Forest {
380    type Output = Self;
381
382    fn bitor(self, rhs: Self) -> Self::Output {
383        Self::new(self.0 | rhs.0)
384    }
385}
386
387impl BitXor<Forest> for Forest {
388    type Output = Self;
389
390    fn bitxor(self, rhs: Self) -> Self::Output {
391        Self::new(self.0 ^ rhs.0)
392    }
393}
394
395impl BitXorAssign<Forest> for Forest {
396    fn bitxor_assign(&mut self, rhs: Self) {
397        self.0 ^= rhs.0;
398    }
399}
400
401impl From<Felt> for Forest {
402    fn from(value: Felt) -> Self {
403        Self::new(value.as_int() as usize)
404    }
405}
406
407impl From<Forest> for Felt {
408    fn from(value: Forest) -> Self {
409        Felt::new(value.0 as u64)
410    }
411}
412
413/// Return a bitmask for the bits including and above the given position.
414pub(crate) const fn high_bitmask(bit: u32) -> Forest {
415    if bit > usize::BITS - 1 {
416        Forest::empty()
417    } else {
418        Forest::new(usize::MAX << bit)
419    }
420}
421
422// SERIALIZATION
423// ================================================================================================
424
425impl Serializable for Forest {
426    fn write_into<W: ByteWriter>(&self, target: &mut W) {
427        self.0.write_into(target);
428    }
429}
430
431impl Deserializable for Forest {
432    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
433        let value = source.read_usize()?;
434        Ok(Self::new(value))
435    }
436}
437
438// TREE SIZE ITERATOR
439// ================================================================================================
440
441/// Iterate over the trees within this `Forest`, from smallest to largest.
442///
443/// Each item is a "sub-forest", containing only one tree.
444pub struct TreeSizeIterator {
445    inner: Forest,
446}
447
448impl TreeSizeIterator {
449    pub fn new(value: Forest) -> TreeSizeIterator {
450        TreeSizeIterator { inner: value }
451    }
452}
453
454impl Iterator for TreeSizeIterator {
455    type Item = Forest;
456
457    fn next(&mut self) -> Option<<Self as Iterator>::Item> {
458        let tree = self.inner.smallest_tree();
459
460        if tree.is_empty() {
461            None
462        } else {
463            self.inner = self.inner.without_trees(tree);
464            Some(tree)
465        }
466    }
467}
468
469impl DoubleEndedIterator for TreeSizeIterator {
470    fn next_back(&mut self) -> Option<<Self as Iterator>::Item> {
471        let tree = self.inner.largest_tree();
472
473        if tree.is_empty() {
474            None
475        } else {
476            self.inner = self.inner.without_trees(tree);
477            Some(tree)
478        }
479    }
480}