miden_crypto/merkle/
index.rs

1use core::fmt::Display;
2
3use p3_field::PrimeField64;
4
5use super::{Felt, MerkleError, Word};
6use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
7
8// NODE INDEX
9// ================================================================================================
10
11/// Address to an arbitrary node in a binary tree using level order form.
12///
13/// The position is represented by the pair `(depth, pos)`, where for a given depth `d` elements
14/// are numbered from $0..(2^d)-1$. Example:
15///
16/// ```text
17/// depth
18/// 0             0
19/// 1         0        1
20/// 2      0    1    2    3
21/// 3     0 1  2 3  4 5  6 7
22/// ```
23///
24/// The root is represented by the pair $(0, 0)$, its left child is $(1, 0)$ and its right child
25/// $(1, 1)$.
26#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
27#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
28pub struct NodeIndex {
29    depth: u8,
30    value: u64,
31}
32
33impl NodeIndex {
34    // CONSTRUCTORS
35    // --------------------------------------------------------------------------------------------
36
37    /// Creates a new node index.
38    ///
39    /// # Errors
40    /// Returns an error if:
41    /// - `depth` is greater than 64.
42    /// - `value` is greater than or equal to 2^{depth}.
43    pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> {
44        if depth > 64 {
45            Err(MerkleError::DepthTooBig(depth as u64))
46        } else if (64 - value.leading_zeros()) > depth as u32 {
47            Err(MerkleError::InvalidNodeIndex { depth, value })
48        } else {
49            Ok(Self { depth, value })
50        }
51    }
52
53    /// Creates a new node index without checking its validity.
54    pub const fn new_unchecked(depth: u8, value: u64) -> Self {
55        debug_assert!(depth <= 64);
56        debug_assert!((64 - value.leading_zeros()) <= depth as u32);
57        Self { depth, value }
58    }
59
60    /// Creates a new node index for testing purposes.
61    ///
62    /// # Panics
63    /// Panics if the `value` is greater than or equal to 2^{depth}.
64    #[cfg(test)]
65    pub fn make(depth: u8, value: u64) -> Self {
66        Self::new(depth, value).unwrap()
67    }
68
69    /// Creates a node index from a pair of field elements representing the depth and value.
70    ///
71    /// # Errors
72    /// Returns an error if:
73    /// - `depth` is greater than 64.
74    /// - `value` is greater than or equal to 2^{depth}.
75    pub fn from_elements(depth: &Felt, value: &Felt) -> Result<Self, MerkleError> {
76        let depth = depth.as_canonical_u64();
77        let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?;
78        let value = value.as_canonical_u64();
79        Self::new(depth, value)
80    }
81
82    /// Creates a new node index pointing to the root of the tree.
83    pub const fn root() -> Self {
84        Self { depth: 0, value: 0 }
85    }
86
87    /// Computes sibling index of the current node.
88    pub const fn sibling(mut self) -> Self {
89        self.value ^= 1;
90        self
91    }
92
93    /// Returns left child index of the current node.
94    pub const fn left_child(mut self) -> Self {
95        self.depth += 1;
96        self.value <<= 1;
97        self
98    }
99
100    /// Returns right child index of the current node.
101    pub const fn right_child(mut self) -> Self {
102        self.depth += 1;
103        self.value = (self.value << 1) + 1;
104        self
105    }
106
107    /// Returns the parent of the current node. This is the same as [`Self::move_up()`], but returns
108    /// a new value instead of mutating `self`.
109    pub const fn parent(mut self) -> Self {
110        self.depth = self.depth.saturating_sub(1);
111        self.value >>= 1;
112        self
113    }
114
115    // PROVIDERS
116    // --------------------------------------------------------------------------------------------
117
118    /// Builds a node to be used as input of a hash function when computing a Merkle path.
119    ///
120    /// Will evaluate the parity of the current instance to define the result.
121    pub const fn build_node(&self, slf: Word, sibling: Word) -> [Word; 2] {
122        if self.is_value_odd() {
123            [sibling, slf]
124        } else {
125            [slf, sibling]
126        }
127    }
128
129    /// Returns the scalar representation of the depth/value pair.
130    ///
131    /// It is computed as `2^depth + value`.
132    pub const fn to_scalar_index(&self) -> u64 {
133        (1 << self.depth as u64) + self.value
134    }
135
136    /// Returns the depth of the current instance.
137    pub const fn depth(&self) -> u8 {
138        self.depth
139    }
140
141    /// Returns the value of this index.
142    pub const fn value(&self) -> u64 {
143        self.value
144    }
145
146    /// Returns `true` if the current instance points to a right sibling node.
147    pub const fn is_value_odd(&self) -> bool {
148        (self.value & 1) == 1
149    }
150
151    /// Returns `true` if the n-th node on the path points to a right child.
152    pub const fn is_nth_bit_odd(&self, n: u8) -> bool {
153        (self.value >> n) & 1 == 1
154    }
155
156    /// Returns `true` if the depth is `0`.
157    pub const fn is_root(&self) -> bool {
158        self.depth == 0
159    }
160
161    // STATE MUTATORS
162    // --------------------------------------------------------------------------------------------
163
164    /// Traverses one level towards the root, decrementing the depth by `1`.
165    pub fn move_up(&mut self) {
166        self.depth = self.depth.saturating_sub(1);
167        self.value >>= 1;
168    }
169
170    /// Traverses towards the root until the specified depth is reached.
171    ///
172    /// Assumes that the specified depth is smaller than the current depth.
173    pub fn move_up_to(&mut self, depth: u8) {
174        debug_assert!(depth < self.depth);
175        let delta = self.depth.saturating_sub(depth);
176        self.depth = self.depth.saturating_sub(delta);
177        self.value >>= delta as u32;
178    }
179
180    // ITERATORS
181    // --------------------------------------------------------------------------------------------
182
183    /// Return an iterator of the indices required for a Merkle proof of inclusion of a node at
184    /// `self`.
185    ///
186    /// This is *exclusive* on both ends: neither `self` nor the root index are included in the
187    /// returned iterator.
188    pub fn proof_indices(&self) -> impl ExactSizeIterator<Item = NodeIndex> + use<> {
189        ProofIter { next_index: self.sibling() }
190    }
191}
192
193impl Display for NodeIndex {
194    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
195        write!(f, "depth={}, value={}", self.depth, self.value)
196    }
197}
198
199impl Serializable for NodeIndex {
200    fn write_into<W: ByteWriter>(&self, target: &mut W) {
201        target.write_u8(self.depth);
202        target.write_u64(self.value);
203    }
204}
205
206impl Deserializable for NodeIndex {
207    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
208        let depth = source.read_u8()?;
209        let value = source.read_u64()?;
210        NodeIndex::new(depth, value)
211            .map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
212    }
213}
214
215/// Implementation for [`NodeIndex::proof_indices()`].
216#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
217struct ProofIter {
218    next_index: NodeIndex,
219}
220
221impl Iterator for ProofIter {
222    type Item = NodeIndex;
223
224    fn next(&mut self) -> Option<NodeIndex> {
225        if self.next_index.is_root() {
226            return None;
227        }
228
229        let index = self.next_index;
230        self.next_index = index.parent().sibling();
231
232        Some(index)
233    }
234
235    fn size_hint(&self) -> (usize, Option<usize>) {
236        let remaining = ExactSizeIterator::len(self);
237
238        (remaining, Some(remaining))
239    }
240}
241
242impl ExactSizeIterator for ProofIter {
243    fn len(&self) -> usize {
244        self.next_index.depth() as usize
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use assert_matches::assert_matches;
251    use proptest::prelude::*;
252
253    use super::*;
254
255    #[test]
256    fn test_node_index_value_too_high() {
257        assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
258        let err = NodeIndex::new(0, 1).unwrap_err();
259        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 });
260
261        assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
262        let err = NodeIndex::new(1, 2).unwrap_err();
263        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 });
264
265        assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
266        let err = NodeIndex::new(2, 4).unwrap_err();
267        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 });
268
269        assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
270        let err = NodeIndex::new(3, 8).unwrap_err();
271        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 });
272    }
273
274    #[test]
275    fn test_node_index_can_represent_depth_64() {
276        assert!(NodeIndex::new(64, u64::MAX).is_ok());
277    }
278
279    prop_compose! {
280        fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
281            // unwrap never panics because the range of depth is 0..u64::BITS
282            let mut depth = value.ilog2() as u8;
283            if value > (1 << depth) { // round up
284                depth += 1;
285            }
286            NodeIndex::new(depth, value).unwrap()
287        }
288    }
289
290    proptest! {
291        #[test]
292        fn arbitrary_index_wont_panic_on_move_up(
293            mut index in node_index(),
294            count in prop::num::u8::ANY,
295        ) {
296            for _ in 0..count {
297                index.move_up();
298            }
299        }
300    }
301}