miden_crypto/merkle/
index.rs

1use core::fmt::Display;
2
3use super::{Felt, MerkleError, RpoDigest};
4use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
5
6// NODE INDEX
7// ================================================================================================
8
9/// Address to an arbitrary node in a binary tree using level order form.
10///
11/// The position is represented by the pair `(depth, pos)`, where for a given depth `d` elements
12/// are numbered from $0..(2^d)-1$. Example:
13///
14/// ```ignore
15/// depth
16/// 0             0
17/// 1         0        1
18/// 2      0    1    2    3
19/// 3     0 1  2 3  4 5  6 7
20/// ```
21///
22/// The root is represented by the pair $(0, 0)$, its left child is $(1, 0)$ and its right child
23/// $(1, 1)$.
24#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
25#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
26pub struct NodeIndex {
27    depth: u8,
28    value: u64,
29}
30
31impl NodeIndex {
32    // CONSTRUCTORS
33    // --------------------------------------------------------------------------------------------
34
35    /// Creates a new node index.
36    ///
37    /// # Errors
38    /// Returns an error if the `value` is greater than or equal to 2^{depth}.
39    pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> {
40        if (64 - value.leading_zeros()) > depth as u32 {
41            Err(MerkleError::InvalidNodeIndex { depth, value })
42        } else {
43            Ok(Self { depth, value })
44        }
45    }
46
47    /// Creates a new node index without checking its validity.
48    pub const fn new_unchecked(depth: u8, value: u64) -> Self {
49        debug_assert!((64 - value.leading_zeros()) <= depth as u32);
50        Self { depth, value }
51    }
52
53    /// Creates a new node index for testing purposes.
54    ///
55    /// # Panics
56    /// Panics if the `value` is greater than or equal to 2^{depth}.
57    #[cfg(test)]
58    pub fn make(depth: u8, value: u64) -> Self {
59        Self::new(depth, value).unwrap()
60    }
61
62    /// Creates a node index from a pair of field elements representing the depth and value.
63    ///
64    /// # Errors
65    /// Returns an error if:
66    /// - `depth` doesn't fit in a `u8`.
67    /// - `value` is greater than or equal to 2^{depth}.
68    pub fn from_elements(depth: &Felt, value: &Felt) -> Result<Self, MerkleError> {
69        let depth = depth.as_int();
70        let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?;
71        let value = value.as_int();
72        Self::new(depth, value)
73    }
74
75    /// Creates a new node index pointing to the root of the tree.
76    pub const fn root() -> Self {
77        Self { depth: 0, value: 0 }
78    }
79
80    /// Computes sibling index of the current node.
81    pub const fn sibling(mut self) -> Self {
82        self.value ^= 1;
83        self
84    }
85
86    /// Returns left child index of the current node.
87    pub const fn left_child(mut self) -> Self {
88        self.depth += 1;
89        self.value <<= 1;
90        self
91    }
92
93    /// Returns right child index of the current node.
94    pub const fn right_child(mut self) -> Self {
95        self.depth += 1;
96        self.value = (self.value << 1) + 1;
97        self
98    }
99
100    /// Returns the parent of the current node. This is the same as [`Self::move_up()`], but returns
101    /// a new value instead of mutating `self`.
102    pub const fn parent(mut self) -> Self {
103        self.depth = self.depth.saturating_sub(1);
104        self.value >>= 1;
105        self
106    }
107
108    // PROVIDERS
109    // --------------------------------------------------------------------------------------------
110
111    /// Builds a node to be used as input of a hash function when computing a Merkle path.
112    ///
113    /// Will evaluate the parity of the current instance to define the result.
114    pub const fn build_node(&self, slf: RpoDigest, sibling: RpoDigest) -> [RpoDigest; 2] {
115        if self.is_value_odd() {
116            [sibling, slf]
117        } else {
118            [slf, sibling]
119        }
120    }
121
122    /// Returns the scalar representation of the depth/value pair.
123    ///
124    /// It is computed as `2^depth + value`.
125    pub const fn to_scalar_index(&self) -> u64 {
126        (1 << self.depth as u64) + self.value
127    }
128
129    /// Returns the depth of the current instance.
130    pub const fn depth(&self) -> u8 {
131        self.depth
132    }
133
134    /// Returns the value of this index.
135    pub const fn value(&self) -> u64 {
136        self.value
137    }
138
139    /// Returns `true` if the current instance points to a right sibling node.
140    pub const fn is_value_odd(&self) -> bool {
141        (self.value & 1) == 1
142    }
143
144    /// Returns `true` if the depth is `0`.
145    pub const fn is_root(&self) -> bool {
146        self.depth == 0
147    }
148
149    // STATE MUTATORS
150    // --------------------------------------------------------------------------------------------
151
152    /// Traverses one level towards the root, decrementing the depth by `1`.
153    pub fn move_up(&mut self) {
154        self.depth = self.depth.saturating_sub(1);
155        self.value >>= 1;
156    }
157
158    /// Traverses towards the root until the specified depth is reached.
159    ///
160    /// Assumes that the specified depth is smaller than the current depth.
161    pub fn move_up_to(&mut self, depth: u8) {
162        debug_assert!(depth < self.depth);
163        let delta = self.depth.saturating_sub(depth);
164        self.depth = self.depth.saturating_sub(delta);
165        self.value >>= delta as u32;
166    }
167}
168
169impl Display for NodeIndex {
170    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
171        write!(f, "depth={}, value={}", self.depth, self.value)
172    }
173}
174
175impl Serializable for NodeIndex {
176    fn write_into<W: ByteWriter>(&self, target: &mut W) {
177        target.write_u8(self.depth);
178        target.write_u64(self.value);
179    }
180}
181
182impl Deserializable for NodeIndex {
183    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
184        let depth = source.read_u8()?;
185        let value = source.read_u64()?;
186        NodeIndex::new(depth, value)
187            .map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use assert_matches::assert_matches;
194    use proptest::prelude::*;
195
196    use super::*;
197
198    #[test]
199    fn test_node_index_value_too_high() {
200        assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
201        let err = NodeIndex::new(0, 1).unwrap_err();
202        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 });
203
204        assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
205        let err = NodeIndex::new(1, 2).unwrap_err();
206        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 });
207
208        assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
209        let err = NodeIndex::new(2, 4).unwrap_err();
210        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 });
211
212        assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
213        let err = NodeIndex::new(3, 8).unwrap_err();
214        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 });
215    }
216
217    #[test]
218    fn test_node_index_can_represent_depth_64() {
219        assert!(NodeIndex::new(64, u64::MAX).is_ok());
220    }
221
222    prop_compose! {
223        fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
224            // unwrap never panics because the range of depth is 0..u64::BITS
225            let mut depth = value.ilog2() as u8;
226            if value > (1 << depth) { // round up
227                depth += 1;
228            }
229            NodeIndex::new(depth, value).unwrap()
230        }
231    }
232
233    proptest! {
234        #[test]
235        fn arbitrary_index_wont_panic_on_move_up(
236            mut index in node_index(),
237            count in prop::num::u8::ANY,
238        ) {
239            for _ in 0..count {
240                index.move_up();
241            }
242        }
243    }
244}