miden_crypto/merkle/
index.rs

1use core::fmt::Display;
2
3use super::{Felt, MerkleError, Word};
4use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
5
6// NODE INDEX
7// ================================================================================================
8
9/// Address to an arbitrary node in a binary tree using level order form.
10///
11/// The position is represented by the pair `(depth, pos)`, where for a given depth `d` elements
12/// are numbered from $0..(2^d)-1$. Example:
13///
14/// ```text
15/// depth
16/// 0             0
17/// 1         0        1
18/// 2      0    1    2    3
19/// 3     0 1  2 3  4 5  6 7
20/// ```
21///
22/// The root is represented by the pair $(0, 0)$, its left child is $(1, 0)$ and its right child
23/// $(1, 1)$.
24#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
25#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
26pub struct NodeIndex {
27    depth: u8,
28    value: u64,
29}
30
31impl NodeIndex {
32    // CONSTRUCTORS
33    // --------------------------------------------------------------------------------------------
34
35    /// Creates a new node index.
36    ///
37    /// # Errors
38    /// Returns an error if the `value` is greater than or equal to 2^{depth}.
39    pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> {
40        if (64 - value.leading_zeros()) > depth as u32 {
41            Err(MerkleError::InvalidNodeIndex { depth, value })
42        } else {
43            Ok(Self { depth, value })
44        }
45    }
46
47    /// Creates a new node index without checking its validity.
48    pub const fn new_unchecked(depth: u8, value: u64) -> Self {
49        debug_assert!((64 - value.leading_zeros()) <= depth as u32);
50        Self { depth, value }
51    }
52
53    /// Creates a new node index for testing purposes.
54    ///
55    /// # Panics
56    /// Panics if the `value` is greater than or equal to 2^{depth}.
57    #[cfg(test)]
58    pub fn make(depth: u8, value: u64) -> Self {
59        Self::new(depth, value).unwrap()
60    }
61
62    /// Creates a node index from a pair of field elements representing the depth and value.
63    ///
64    /// # Errors
65    /// Returns an error if:
66    /// - `depth` doesn't fit in a `u8`.
67    /// - `value` is greater than or equal to 2^{depth}.
68    pub fn from_elements(depth: &Felt, value: &Felt) -> Result<Self, MerkleError> {
69        let depth = depth.as_int();
70        let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?;
71        let value = value.as_int();
72        Self::new(depth, value)
73    }
74
75    /// Creates a new node index pointing to the root of the tree.
76    pub const fn root() -> Self {
77        Self { depth: 0, value: 0 }
78    }
79
80    /// Computes sibling index of the current node.
81    pub const fn sibling(mut self) -> Self {
82        self.value ^= 1;
83        self
84    }
85
86    /// Returns left child index of the current node.
87    pub const fn left_child(mut self) -> Self {
88        self.depth += 1;
89        self.value <<= 1;
90        self
91    }
92
93    /// Returns right child index of the current node.
94    pub const fn right_child(mut self) -> Self {
95        self.depth += 1;
96        self.value = (self.value << 1) + 1;
97        self
98    }
99
100    /// Returns the parent of the current node. This is the same as [`Self::move_up()`], but returns
101    /// a new value instead of mutating `self`.
102    pub const fn parent(mut self) -> Self {
103        self.depth = self.depth.saturating_sub(1);
104        self.value >>= 1;
105        self
106    }
107
108    // PROVIDERS
109    // --------------------------------------------------------------------------------------------
110
111    /// Builds a node to be used as input of a hash function when computing a Merkle path.
112    ///
113    /// Will evaluate the parity of the current instance to define the result.
114    pub const fn build_node(&self, slf: Word, sibling: Word) -> [Word; 2] {
115        if self.is_value_odd() {
116            [sibling, slf]
117        } else {
118            [slf, sibling]
119        }
120    }
121
122    /// Returns the scalar representation of the depth/value pair.
123    ///
124    /// It is computed as `2^depth + value`.
125    pub const fn to_scalar_index(&self) -> u64 {
126        (1 << self.depth as u64) + self.value
127    }
128
129    /// Returns the depth of the current instance.
130    pub const fn depth(&self) -> u8 {
131        self.depth
132    }
133
134    /// Returns the value of this index.
135    pub const fn value(&self) -> u64 {
136        self.value
137    }
138
139    /// Returns `true` if the current instance points to a right sibling node.
140    pub const fn is_value_odd(&self) -> bool {
141        (self.value & 1) == 1
142    }
143
144    /// Returns `true` if the depth is `0`.
145    pub const fn is_root(&self) -> bool {
146        self.depth == 0
147    }
148
149    // STATE MUTATORS
150    // --------------------------------------------------------------------------------------------
151
152    /// Traverses one level towards the root, decrementing the depth by `1`.
153    pub fn move_up(&mut self) {
154        self.depth = self.depth.saturating_sub(1);
155        self.value >>= 1;
156    }
157
158    /// Traverses towards the root until the specified depth is reached.
159    ///
160    /// Assumes that the specified depth is smaller than the current depth.
161    pub fn move_up_to(&mut self, depth: u8) {
162        debug_assert!(depth < self.depth);
163        let delta = self.depth.saturating_sub(depth);
164        self.depth = self.depth.saturating_sub(delta);
165        self.value >>= delta as u32;
166    }
167
168    // ITERATORS
169    // --------------------------------------------------------------------------------------------
170
171    /// Return an iterator of the indices required for a Merkle proof of inclusion of a node at
172    /// `self`.
173    ///
174    /// This is *exclusive* on both ends: neither `self` nor the root index are included in the
175    /// returned iterator.
176    pub fn proof_indices(&self) -> impl ExactSizeIterator<Item = NodeIndex> + use<> {
177        ProofIter { next_index: self.sibling() }
178    }
179}
180
181impl Display for NodeIndex {
182    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
183        write!(f, "depth={}, value={}", self.depth, self.value)
184    }
185}
186
187impl Serializable for NodeIndex {
188    fn write_into<W: ByteWriter>(&self, target: &mut W) {
189        target.write_u8(self.depth);
190        target.write_u64(self.value);
191    }
192}
193
194impl Deserializable for NodeIndex {
195    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
196        let depth = source.read_u8()?;
197        let value = source.read_u64()?;
198        NodeIndex::new(depth, value)
199            .map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
200    }
201}
202
203/// Implementation for [`NodeIndex::proof_indices()`].
204#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
205struct ProofIter {
206    next_index: NodeIndex,
207}
208
209impl Iterator for ProofIter {
210    type Item = NodeIndex;
211
212    fn next(&mut self) -> Option<NodeIndex> {
213        if self.next_index.is_root() {
214            return None;
215        }
216
217        let index = self.next_index;
218        self.next_index = index.parent().sibling();
219
220        Some(index)
221    }
222
223    fn size_hint(&self) -> (usize, Option<usize>) {
224        let remaining = ExactSizeIterator::len(self);
225
226        (remaining, Some(remaining))
227    }
228}
229
230impl ExactSizeIterator for ProofIter {
231    fn len(&self) -> usize {
232        self.next_index.depth() as usize
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use assert_matches::assert_matches;
239    use proptest::prelude::*;
240
241    use super::*;
242
243    #[test]
244    fn test_node_index_value_too_high() {
245        assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
246        let err = NodeIndex::new(0, 1).unwrap_err();
247        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 });
248
249        assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
250        let err = NodeIndex::new(1, 2).unwrap_err();
251        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 });
252
253        assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
254        let err = NodeIndex::new(2, 4).unwrap_err();
255        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 });
256
257        assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
258        let err = NodeIndex::new(3, 8).unwrap_err();
259        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 });
260    }
261
262    #[test]
263    fn test_node_index_can_represent_depth_64() {
264        assert!(NodeIndex::new(64, u64::MAX).is_ok());
265    }
266
267    prop_compose! {
268        fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
269            // unwrap never panics because the range of depth is 0..u64::BITS
270            let mut depth = value.ilog2() as u8;
271            if value > (1 << depth) { // round up
272                depth += 1;
273            }
274            NodeIndex::new(depth, value).unwrap()
275        }
276    }
277
278    proptest! {
279        #[test]
280        fn arbitrary_index_wont_panic_on_move_up(
281            mut index in node_index(),
282            count in prop::num::u8::ANY,
283        ) {
284            for _ in 0..count {
285                index.move_up();
286            }
287        }
288    }
289}