miden_crypto/merkle/
index.rs

1use core::fmt::Display;
2
3use super::{Felt, MerkleError, RpoDigest};
4use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
5
6// NODE INDEX
7// ================================================================================================
8
9/// Address to an arbitrary node in a binary tree using level order form.
10///
11/// The position is represented by the pair `(depth, pos)`, where for a given depth `d` elements
12/// are numbered from $0..(2^d)-1$. Example:
13///
14/// ```ignore
15/// depth
16/// 0             0
17/// 1         0        1
18/// 2      0    1    2    3
19/// 3     0 1  2 3  4 5  6 7
20/// ```
21///
22/// The root is represented by the pair $(0, 0)$, its left child is $(1, 0)$ and its right child
23/// $(1, 1)$.
24#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
25#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
26pub struct NodeIndex {
27    depth: u8,
28    value: u64,
29}
30
31impl NodeIndex {
32    // CONSTRUCTORS
33    // --------------------------------------------------------------------------------------------
34
35    /// Creates a new node index.
36    ///
37    /// # Errors
38    /// Returns an error if the `value` is greater than or equal to 2^{depth}.
39    pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> {
40        if (64 - value.leading_zeros()) > depth as u32 {
41            Err(MerkleError::InvalidNodeIndex { depth, value })
42        } else {
43            Ok(Self { depth, value })
44        }
45    }
46
47    /// Creates a new node index without checking its validity.
48    pub const fn new_unchecked(depth: u8, value: u64) -> Self {
49        debug_assert!((64 - value.leading_zeros()) <= depth as u32);
50        Self { depth, value }
51    }
52
53    /// Creates a new node index for testing purposes.
54    ///
55    /// # Panics
56    /// Panics if the `value` is greater than or equal to 2^{depth}.
57    #[cfg(test)]
58    pub fn make(depth: u8, value: u64) -> Self {
59        Self::new(depth, value).unwrap()
60    }
61
62    /// Creates a node index from a pair of field elements representing the depth and value.
63    ///
64    /// # Errors
65    /// Returns an error if:
66    /// - `depth` doesn't fit in a `u8`.
67    /// - `value` is greater than or equal to 2^{depth}.
68    pub fn from_elements(depth: &Felt, value: &Felt) -> Result<Self, MerkleError> {
69        let depth = depth.as_int();
70        let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?;
71        let value = value.as_int();
72        Self::new(depth, value)
73    }
74
75    /// Creates a new node index pointing to the root of the tree.
76    pub const fn root() -> Self {
77        Self { depth: 0, value: 0 }
78    }
79
80    /// Computes sibling index of the current node.
81    pub const fn sibling(mut self) -> Self {
82        self.value ^= 1;
83        self
84    }
85
86    /// Returns left child index of the current node.
87    pub const fn left_child(mut self) -> Self {
88        self.depth += 1;
89        self.value <<= 1;
90        self
91    }
92
93    /// Returns right child index of the current node.
94    pub const fn right_child(mut self) -> Self {
95        self.depth += 1;
96        self.value = (self.value << 1) + 1;
97        self
98    }
99
100    // PROVIDERS
101    // --------------------------------------------------------------------------------------------
102
103    /// Builds a node to be used as input of a hash function when computing a Merkle path.
104    ///
105    /// Will evaluate the parity of the current instance to define the result.
106    pub const fn build_node(&self, slf: RpoDigest, sibling: RpoDigest) -> [RpoDigest; 2] {
107        if self.is_value_odd() {
108            [sibling, slf]
109        } else {
110            [slf, sibling]
111        }
112    }
113
114    /// Returns the scalar representation of the depth/value pair.
115    ///
116    /// It is computed as `2^depth + value`.
117    pub const fn to_scalar_index(&self) -> u64 {
118        (1 << self.depth as u64) + self.value
119    }
120
121    /// Returns the depth of the current instance.
122    pub const fn depth(&self) -> u8 {
123        self.depth
124    }
125
126    /// Returns the value of this index.
127    pub const fn value(&self) -> u64 {
128        self.value
129    }
130
131    /// Returns `true` if the current instance points to a right sibling node.
132    pub const fn is_value_odd(&self) -> bool {
133        (self.value & 1) == 1
134    }
135
136    /// Returns `true` if the depth is `0`.
137    pub const fn is_root(&self) -> bool {
138        self.depth == 0
139    }
140
141    // STATE MUTATORS
142    // --------------------------------------------------------------------------------------------
143
144    /// Traverses one level towards the root, decrementing the depth by `1`.
145    pub fn move_up(&mut self) {
146        self.depth = self.depth.saturating_sub(1);
147        self.value >>= 1;
148    }
149
150    /// Traverses towards the root until the specified depth is reached.
151    ///
152    /// Assumes that the specified depth is smaller than the current depth.
153    pub fn move_up_to(&mut self, depth: u8) {
154        debug_assert!(depth < self.depth);
155        let delta = self.depth.saturating_sub(depth);
156        self.depth = self.depth.saturating_sub(delta);
157        self.value >>= delta as u32;
158    }
159}
160
161impl Display for NodeIndex {
162    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
163        write!(f, "depth={}, value={}", self.depth, self.value)
164    }
165}
166
167impl Serializable for NodeIndex {
168    fn write_into<W: ByteWriter>(&self, target: &mut W) {
169        target.write_u8(self.depth);
170        target.write_u64(self.value);
171    }
172}
173
174impl Deserializable for NodeIndex {
175    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
176        let depth = source.read_u8()?;
177        let value = source.read_u64()?;
178        NodeIndex::new(depth, value)
179            .map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use assert_matches::assert_matches;
186    use proptest::prelude::*;
187
188    use super::*;
189
190    #[test]
191    fn test_node_index_value_too_high() {
192        assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
193        let err = NodeIndex::new(0, 1).unwrap_err();
194        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 });
195
196        assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
197        let err = NodeIndex::new(1, 2).unwrap_err();
198        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 });
199
200        assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
201        let err = NodeIndex::new(2, 4).unwrap_err();
202        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 });
203
204        assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
205        let err = NodeIndex::new(3, 8).unwrap_err();
206        assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 });
207    }
208
209    #[test]
210    fn test_node_index_can_represent_depth_64() {
211        assert!(NodeIndex::new(64, u64::MAX).is_ok());
212    }
213
214    prop_compose! {
215        fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
216            // unwrap never panics because the range of depth is 0..u64::BITS
217            let mut depth = value.ilog2() as u8;
218            if value > (1 << depth) { // round up
219                depth += 1;
220            }
221            NodeIndex::new(depth, value).unwrap()
222        }
223    }
224
225    proptest! {
226        #[test]
227        fn arbitrary_index_wont_panic_on_move_up(
228            mut index in node_index(),
229            count in prop::num::u8::ANY,
230        ) {
231            for _ in 0..count {
232                index.move_up();
233            }
234        }
235    }
236}