miden_crypto/merkle/
index.rs

1use core::fmt::Display;
2
3use super::{Felt, MerkleError, Word};
4use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
5
6// NODE INDEX
7// ================================================================================================
8
9/// Address to an arbitrary node in a binary tree using level order form.
10///
11/// The position is represented by the pair `(depth, pos)`, where for a given depth `d` elements
12/// are numbered from $0..(2^d)-1$. Example:
13///
14/// ```text
15/// depth
16/// 0             0
17/// 1         0        1
18/// 2      0    1    2    3
19/// 3     0 1  2 3  4 5  6 7
20/// ```
21///
22/// The root is represented by the pair $(0, 0)$, its left child is $(1, 0)$ and its right child
23/// $(1, 1)$.
24#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
25#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
26pub struct NodeIndex {
27    depth: u8,
28    value: u64,
29}
30
31impl NodeIndex {
32    // CONSTRUCTORS
33    // --------------------------------------------------------------------------------------------
34
35    /// Creates a new node index.
36    ///
37    /// # Errors
38    /// Returns an error if:
39    /// - `depth` is greater than 64.
40    /// - `value` is greater than or equal to 2^{depth}.
41    pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> {
42        if depth > 64 {
43            Err(MerkleError::DepthTooBig(depth as u64))
44        } else if (64 - value.leading_zeros()) > depth as u32 {
45            Err(MerkleError::InvalidNodeIndex { depth, value })
46        } else {
47            Ok(Self { depth, value })
48        }
49    }
50
51    /// Creates a new node index without checking its validity.
52    pub const fn new_unchecked(depth: u8, value: u64) -> Self {
53        debug_assert!(depth <= 64);
54        debug_assert!((64 - value.leading_zeros()) <= depth as u32);
55        Self { depth, value }
56    }
57
58    /// Creates a new node index for testing purposes.
59    ///
60    /// # Panics
61    /// Panics if the `value` is greater than or equal to 2^{depth}.
62    #[cfg(test)]
63    pub fn make(depth: u8, value: u64) -> Self {
64        Self::new(depth, value).unwrap()
65    }
66
67    /// Creates a node index from a pair of field elements representing the depth and value.
68    ///
69    /// # Errors
70    /// Returns an error if:
71    /// - `depth` is greater than 64.
72    /// - `value` is greater than or equal to 2^{depth}.
73    pub fn from_elements(depth: &Felt, value: &Felt) -> Result<Self, MerkleError> {
74        let depth = depth.as_int();
75        let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?;
76        let value = value.as_int();
77        Self::new(depth, value)
78    }
79
80    /// Creates a new node index pointing to the root of the tree.
81    pub const fn root() -> Self {
82        Self { depth: 0, value: 0 }
83    }
84
85    /// Computes sibling index of the current node.
86    pub const fn sibling(mut self) -> Self {
87        self.value ^= 1;
88        self
89    }
90
91    /// Returns left child index of the current node.
92    pub const fn left_child(mut self) -> Self {
93        self.depth += 1;
94        self.value <<= 1;
95        self
96    }
97
98    /// Returns right child index of the current node.
99    pub const fn right_child(mut self) -> Self {
100        self.depth += 1;
101        self.value = (self.value << 1) + 1;
102        self
103    }
104
105    /// Returns the parent of the current node. This is the same as [`Self::move_up()`], but returns
106    /// a new value instead of mutating `self`.
107    pub const fn parent(mut self) -> Self {
108        self.depth = self.depth.saturating_sub(1);
109        self.value >>= 1;
110        self
111    }
112
113    // PROVIDERS
114    // --------------------------------------------------------------------------------------------
115
116    /// Builds a node to be used as input of a hash function when computing a Merkle path.
117    ///
118    /// Will evaluate the parity of the current instance to define the result.
119    pub const fn build_node(&self, slf: Word, sibling: Word) -> [Word; 2] {
120        if self.is_value_odd() {
121            [sibling, slf]
122        } else {
123            [slf, sibling]
124        }
125    }
126
127    /// Returns the scalar representation of the depth/value pair.
128    ///
129    /// It is computed as `2^depth + value`.
130    pub const fn to_scalar_index(&self) -> u64 {
131        (1 << self.depth as u64) + self.value
132    }
133
134    /// Returns the depth of the current instance.
135    pub const fn depth(&self) -> u8 {
136        self.depth
137    }
138
139    /// Returns the value of this index.
140    pub const fn value(&self) -> u64 {
141        self.value
142    }
143
144    /// Returns `true` if the current instance points to a right sibling node.
145    pub const fn is_value_odd(&self) -> bool {
146        (self.value & 1) == 1
147    }
148
149    /// Returns `true` if the n-th node on the path points to a right child.
150    pub const fn is_nth_bit_odd(&self, n: u8) -> bool {
151        (self.value >> n) & 1 == 1
152    }
153
154    /// Returns `true` if the depth is `0`.
155    pub const fn is_root(&self) -> bool {
156        self.depth == 0
157    }
158
159    // STATE MUTATORS
160    // --------------------------------------------------------------------------------------------
161
162    /// Traverses one level towards the root, decrementing the depth by `1`.
163    pub fn move_up(&mut self) {
164        self.depth = self.depth.saturating_sub(1);
165        self.value >>= 1;
166    }
167
168    /// Traverses towards the root until the specified depth is reached.
169    ///
170    /// Assumes that the specified depth is smaller than the current depth.
171    pub fn move_up_to(&mut self, depth: u8) {
172        debug_assert!(depth < self.depth);
173        let delta = self.depth.saturating_sub(depth);
174        self.depth = self.depth.saturating_sub(delta);
175        self.value >>= delta as u32;
176    }
177
178    // ITERATORS
179    // --------------------------------------------------------------------------------------------
180
181    /// Return an iterator of the indices required for a Merkle proof of inclusion of a node at
182    /// `self`.
183    ///
184    /// This is *exclusive* on both ends: neither `self` nor the root index are included in the
185    /// returned iterator.
186    pub fn proof_indices(&self) -> impl ExactSizeIterator<Item = NodeIndex> + use<> {
187        ProofIter { next_index: self.sibling() }
188    }
189}
190
191impl Display for NodeIndex {
192    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
193        write!(f, "depth={}, value={}", self.depth, self.value)
194    }
195}
196
197impl Serializable for NodeIndex {
198    fn write_into<W: ByteWriter>(&self, target: &mut W) {
199        target.write_u8(self.depth);
200        target.write_u64(self.value);
201    }
202}
203
204impl Deserializable for NodeIndex {
205    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
206        let depth = source.read_u8()?;
207        let value = source.read_u64()?;
208        NodeIndex::new(depth, value)
209            .map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
210    }
211}
212
213/// Implementation for [`NodeIndex::proof_indices()`].
214#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
215struct ProofIter {
216    next_index: NodeIndex,
217}
218
219impl Iterator for ProofIter {
220    type Item = NodeIndex;
221
222    fn next(&mut self) -> Option<NodeIndex> {
223        if self.next_index.is_root() {
224            return None;
225        }
226
227        let index = self.next_index;
228        self.next_index = index.parent().sibling();
229
230        Some(index)
231    }
232
233    fn size_hint(&self) -> (usize, Option<usize>) {
234        let remaining = ExactSizeIterator::len(self);
235
236        (remaining, Some(remaining))
237    }
238}
239
240impl ExactSizeIterator for ProofIter {
241    fn len(&self) -> usize {
242        self.next_index.depth() as usize
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use assert_matches::assert_matches;
249    use proptest::prelude::*;
250
251    use super::*;
252
253    #[test]
254    fn test_node_index_value_too_high() {
255        assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
256        let err = NodeIndex::new(0, 1).unwrap_err();
257        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 });
258
259        assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
260        let err = NodeIndex::new(1, 2).unwrap_err();
261        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 });
262
263        assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
264        let err = NodeIndex::new(2, 4).unwrap_err();
265        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 });
266
267        assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
268        let err = NodeIndex::new(3, 8).unwrap_err();
269        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 });
270    }
271
272    #[test]
273    fn test_node_index_can_represent_depth_64() {
274        assert!(NodeIndex::new(64, u64::MAX).is_ok());
275    }
276
277    prop_compose! {
278        fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
279            // unwrap never panics because the range of depth is 0..u64::BITS
280            let mut depth = value.ilog2() as u8;
281            if value > (1 << depth) { // round up
282                depth += 1;
283            }
284            NodeIndex::new(depth, value).unwrap()
285        }
286    }
287
288    proptest! {
289        #[test]
290        fn arbitrary_index_wont_panic_on_move_up(
291            mut index in node_index(),
292            count in prop::num::u8::ANY,
293        ) {
294            for _ in 0..count {
295                index.move_up();
296            }
297        }
298    }
299}