Skip to main content

miden_crypto/merkle/
index.rs

1use core::fmt::Display;
2
3use super::{Felt, MerkleError, Word};
4use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
5
6// NODE INDEX
7// ================================================================================================
8
9/// Address to an arbitrary node in a binary tree using level order form.
10///
11/// The position is represented by the pair `(depth, pos)`, where for a given depth `d` elements
12/// are numbered from $0..(2^d)-1$. Example:
13///
14/// ```text
15/// depth
16/// 0             0
17/// 1         0        1
18/// 2      0    1    2    3
19/// 3     0 1  2 3  4 5  6 7
20/// ```
21///
22/// The root is represented by the pair $(0, 0)$, its left child is $(1, 0)$ and its right child
23/// $(1, 1)$.
24#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
25#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
26pub struct NodeIndex {
27    depth: u8,
28    position: u64,
29}
30
31impl NodeIndex {
32    // CONSTRUCTORS
33    // --------------------------------------------------------------------------------------------
34
35    /// Creates a new node index.
36    ///
37    /// # Errors
38    /// Returns an error if:
39    /// - `depth` is greater than 64.
40    /// - `position` is greater than or equal to 2^{depth}.
41    pub const fn new(depth: u8, position: u64) -> Result<Self, MerkleError> {
42        if depth > 64 {
43            Err(MerkleError::DepthTooBig(depth as u64))
44        } else if (64 - position.leading_zeros()) > depth as u32 {
45            Err(MerkleError::InvalidNodeIndex { depth, position })
46        } else {
47            Ok(Self { depth, position })
48        }
49    }
50
51    /// Creates a new node index without checking its validity.
52    pub const fn new_unchecked(depth: u8, position: u64) -> Self {
53        debug_assert!(depth <= 64);
54        debug_assert!((64 - position.leading_zeros()) <= depth as u32);
55        Self { depth, position }
56    }
57
58    /// Creates a new node index for testing purposes.
59    ///
60    /// # Panics
61    /// Panics if the `position` is greater than or equal to 2^{depth}.
62    #[cfg(test)]
63    pub fn make(depth: u8, position: u64) -> Self {
64        Self::new(depth, position).unwrap()
65    }
66
67    /// Creates a node index from a pair of field elements representing the depth and position.
68    ///
69    /// # Errors
70    /// Returns an error if:
71    /// - `depth` is greater than 64.
72    /// - `position` is greater than or equal to 2^{depth}.
73    pub fn from_elements(depth: &Felt, position: &Felt) -> Result<Self, MerkleError> {
74        let depth = depth.as_canonical_u64();
75        let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?;
76        let position = position.as_canonical_u64();
77        Self::new(depth, position)
78    }
79
80    /// Creates a new node index pointing to the root of the tree.
81    pub const fn root() -> Self {
82        Self { depth: 0, position: 0 }
83    }
84
85    /// Computes sibling index of the current node.
86    pub const fn sibling(mut self) -> Self {
87        self.position ^= 1;
88        self
89    }
90
91    /// Returns left child index of the current node.
92    pub const fn left_child(mut self) -> Self {
93        self.depth += 1;
94        self.position <<= 1;
95        self
96    }
97
98    /// Returns right child index of the current node.
99    pub const fn right_child(mut self) -> Self {
100        self.depth += 1;
101        self.position = (self.position << 1) + 1;
102        self
103    }
104
105    /// Returns the parent of the current node. This is the same as [`Self::move_up()`], but returns
106    /// a new value instead of mutating `self`.
107    pub const fn parent(mut self) -> Self {
108        self.depth = self.depth.saturating_sub(1);
109        self.position >>= 1;
110        self
111    }
112
113    // PROVIDERS
114    // --------------------------------------------------------------------------------------------
115
116    /// Builds a node to be used as input of a hash function when computing a Merkle path.
117    ///
118    /// Will evaluate the parity of the current instance to define the result.
119    pub const fn build_node(&self, slf: Word, sibling: Word) -> [Word; 2] {
120        if self.is_position_odd() {
121            [sibling, slf]
122        } else {
123            [slf, sibling]
124        }
125    }
126
127    /// Returns the scalar representation of the depth/position pair.
128    ///
129    /// It is computed as `2^depth + position`.
130    ///
131    /// # Errors
132    ///
133    /// - [`MerkleError::DepthTooBig`] if the depth is 64 or greater, as the resulting index would
134    ///   overflow.
135    pub const fn to_scalar_index(&self) -> Result<u64, MerkleError> {
136        if self.depth >= 64 {
137            return Err(MerkleError::DepthTooBig(self.depth as u64));
138        }
139        Ok((1u64 << self.depth as u64) + self.position)
140    }
141
142    /// Returns the depth of the current instance.
143    pub const fn depth(&self) -> u8 {
144        self.depth
145    }
146
147    /// Returns the position of this index within its depth layer.
148    pub const fn position(&self) -> u64 {
149        self.position
150    }
151
152    /// Returns `true` if the current instance points to a right sibling node.
153    pub const fn is_position_odd(&self) -> bool {
154        (self.position & 1) == 1
155    }
156
157    /// Returns `true` if the n-th node on the path points to a right child.
158    pub const fn is_nth_bit_odd(&self, n: u8) -> bool {
159        (self.position >> n) & 1 == 1
160    }
161
162    /// Returns `true` if the depth is `0`.
163    pub const fn is_root(&self) -> bool {
164        self.depth == 0
165    }
166
167    // STATE MUTATORS
168    // --------------------------------------------------------------------------------------------
169
170    /// Traverses one level towards the root, decrementing the depth by `1`.
171    pub fn move_up(&mut self) {
172        self.depth = self.depth.saturating_sub(1);
173        self.position >>= 1;
174    }
175
176    /// Traverses towards the root until the specified depth is reached.
177    ///
178    /// Assumes that the specified depth is smaller than the current depth.
179    pub fn move_up_to(&mut self, depth: u8) {
180        debug_assert!(depth < self.depth);
181        let delta = self.depth.saturating_sub(depth);
182        self.depth = self.depth.saturating_sub(delta);
183        self.position >>= delta as u32;
184    }
185
186    // ITERATORS
187    // --------------------------------------------------------------------------------------------
188
189    /// Return an iterator of the indices required for a Merkle proof of inclusion of a node at
190    /// `self`.
191    ///
192    /// This is *exclusive* on both ends: neither `self` nor the root index are included in the
193    /// returned iterator.
194    pub fn proof_indices(&self) -> impl ExactSizeIterator<Item = NodeIndex> + use<> {
195        ProofIter { next_index: self.sibling() }
196    }
197}
198
199impl Display for NodeIndex {
200    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
201        write!(f, "depth={}, position={}", self.depth, self.position)
202    }
203}
204
205impl Serializable for NodeIndex {
206    fn write_into<W: ByteWriter>(&self, target: &mut W) {
207        target.write_u8(self.depth);
208        target.write_u64(self.position);
209    }
210}
211
212impl Deserializable for NodeIndex {
213    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
214        let depth = source.read_u8()?;
215        let position = source.read_u64()?;
216        NodeIndex::new(depth, position)
217            .map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
218    }
219
220    fn min_serialized_size() -> usize {
221        // u8 (depth) + u64 (value)
222        9
223    }
224}
225
226/// Implementation for [`NodeIndex::proof_indices()`].
227#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
228struct ProofIter {
229    next_index: NodeIndex,
230}
231
232impl Iterator for ProofIter {
233    type Item = NodeIndex;
234
235    fn next(&mut self) -> Option<NodeIndex> {
236        if self.next_index.is_root() {
237            return None;
238        }
239
240        let index = self.next_index;
241        self.next_index = index.parent().sibling();
242
243        Some(index)
244    }
245
246    fn size_hint(&self) -> (usize, Option<usize>) {
247        let remaining = ExactSizeIterator::len(self);
248
249        (remaining, Some(remaining))
250    }
251}
252
253impl ExactSizeIterator for ProofIter {
254    fn len(&self) -> usize {
255        self.next_index.depth() as usize
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use assert_matches::assert_matches;
262    use proptest::prelude::*;
263
264    use super::*;
265
266    #[test]
267    fn test_node_index_position_too_high() {
268        assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, position: 0 });
269        let err = NodeIndex::new(0, 1).unwrap_err();
270        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, position: 1 });
271
272        assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, position: 1 });
273        let err = NodeIndex::new(1, 2).unwrap_err();
274        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, position: 2 });
275
276        assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, position: 3 });
277        let err = NodeIndex::new(2, 4).unwrap_err();
278        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, position: 4 });
279
280        assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, position: 7 });
281        let err = NodeIndex::new(3, 8).unwrap_err();
282        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, position: 8 });
283    }
284
285    #[test]
286    fn test_node_index_can_represent_depth_64() {
287        assert!(NodeIndex::new(64, u64::MAX).is_ok());
288    }
289
290    prop_compose! {
291        fn node_index()(position in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
292            // unwrap never panics because the range of depth is 0..u64::BITS
293            let mut depth = position.ilog2() as u8;
294            if position > (1 << depth) { // round up
295                depth += 1;
296            }
297            NodeIndex::new(depth, position).unwrap()
298        }
299    }
300
301    proptest! {
302        #[test]
303        fn arbitrary_index_wont_panic_on_move_up(
304            mut index in node_index(),
305            count in prop::num::u8::ANY,
306        ) {
307            for _ in 0..count {
308                index.move_up();
309            }
310        }
311
312        #[test]
313        fn to_scalar_index_succeeds_for_depth_lt_64(depth in 0u8..64, position_bits in 0u64..u64::MAX) {
314            let position = if depth == 0 { 0 } else { position_bits % (1u64 << depth) };
315            let index = NodeIndex::new(depth, position).unwrap();
316            assert!(index.to_scalar_index().is_ok());
317        }
318    }
319
320    #[test]
321    fn test_to_scalar_index_depth_64_returns_error() {
322        let index = NodeIndex::new(64, 0).unwrap();
323        assert_matches!(index.to_scalar_index(), Err(MerkleError::DepthTooBig(64)));
324
325        let index = NodeIndex::new(64, u64::MAX).unwrap();
326        assert_matches!(index.to_scalar_index(), Err(MerkleError::DepthTooBig(64)));
327    }
328
329    #[test]
330    fn test_to_scalar_index_known_values() {
331        // Root's children: depth=1, pos=0 → scalar 2; depth=1, pos=1 → scalar 3
332        assert_eq!(NodeIndex::make(1, 0).to_scalar_index().unwrap(), 2);
333        assert_eq!(NodeIndex::make(1, 1).to_scalar_index().unwrap(), 3);
334
335        // depth=2: scalars 4,5,6,7
336        assert_eq!(NodeIndex::make(2, 0).to_scalar_index().unwrap(), 4);
337        assert_eq!(NodeIndex::make(2, 3).to_scalar_index().unwrap(), 7);
338
339        // depth=3: scalars 8..15
340        assert_eq!(NodeIndex::make(3, 0).to_scalar_index().unwrap(), 8);
341        assert_eq!(NodeIndex::make(3, 7).to_scalar_index().unwrap(), 15);
342    }
343
344    #[test]
345    fn test_to_scalar_index_depth_63_max_position() {
346        // 2^63 + (2^63 - 1) = 2^64 - 1 = u64::MAX
347        let index = NodeIndex::new(63, (1u64 << 63) - 1).unwrap();
348        assert_eq!(index.to_scalar_index().unwrap(), u64::MAX);
349    }
350
351    #[test]
352    fn test_to_scalar_index_boundary_depths() {
353        // depth 0 (root): scalar = 1 + 0 = 1
354        assert_eq!(NodeIndex::make(0, 0).to_scalar_index().unwrap(), 1);
355
356        // depth 62, position 0: scalar = 2^62
357        assert_eq!(NodeIndex::make(62, 0).to_scalar_index().unwrap(), 1u64 << 62);
358
359        // depth 63, position 0: scalar = 2^63
360        assert_eq!(NodeIndex::make(63, 0).to_scalar_index().unwrap(), 1u64 << 63);
361    }
362}